deviceauth.go raw
1 package oauth2
2
3 import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "mime"
10 "net/http"
11 "net/url"
12 "strings"
13 "time"
14
15 "golang.org/x/oauth2/internal"
16 )
17
18 // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
19 const (
20 errAuthorizationPending = "authorization_pending"
21 errSlowDown = "slow_down"
22 errAccessDenied = "access_denied"
23 errExpiredToken = "expired_token"
24 )
25
26 // DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
27 // https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
28 type DeviceAuthResponse struct {
29 // DeviceCode
30 DeviceCode string `json:"device_code"`
31 // UserCode is the code the user should enter at the verification uri
32 UserCode string `json:"user_code"`
33 // VerificationURI is where user should enter the user code
34 VerificationURI string `json:"verification_uri"`
35 // VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
36 VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
37 // Expiry is when the device code and user code expire
38 Expiry time.Time `json:"expires_in,omitempty"`
39 // Interval is the duration in seconds that Poll should wait between requests
40 Interval int64 `json:"interval,omitempty"`
41 }
42
43 func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
44 type Alias DeviceAuthResponse
45 var expiresIn int64
46 if !d.Expiry.IsZero() {
47 expiresIn = int64(time.Until(d.Expiry).Seconds())
48 }
49 return json.Marshal(&struct {
50 ExpiresIn int64 `json:"expires_in,omitempty"`
51 *Alias
52 }{
53 ExpiresIn: expiresIn,
54 Alias: (*Alias)(&d),
55 })
56
57 }
58
59 func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
60 type Alias DeviceAuthResponse
61 aux := &struct {
62 ExpiresIn int64 `json:"expires_in"`
63 // workaround misspelling of verification_uri
64 VerificationURL string `json:"verification_url"`
65 *Alias
66 }{
67 Alias: (*Alias)(c),
68 }
69 if err := json.Unmarshal(data, &aux); err != nil {
70 return err
71 }
72 if aux.ExpiresIn != 0 {
73 c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
74 }
75 if c.VerificationURI == "" {
76 c.VerificationURI = aux.VerificationURL
77 }
78 return nil
79 }
80
81 // DeviceAuth returns a device auth struct which contains a device code
82 // and authorization information provided for users to enter on another device.
83 func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
84 // https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
85 v := url.Values{
86 "client_id": {c.ClientID},
87 }
88 if len(c.Scopes) > 0 {
89 v.Set("scope", strings.Join(c.Scopes, " "))
90 }
91 for _, opt := range opts {
92 opt.setValue(v)
93 }
94 return retrieveDeviceAuth(ctx, c, v)
95 }
96
97 func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
98 if c.Endpoint.DeviceAuthURL == "" {
99 return nil, errors.New("endpoint missing DeviceAuthURL")
100 }
101
102 req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
103 if err != nil {
104 return nil, err
105 }
106 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
107 req.Header.Set("Accept", "application/json")
108
109 t := time.Now()
110 r, err := internal.ContextClient(ctx).Do(req)
111 if err != nil {
112 return nil, err
113 }
114
115 body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
116 if err != nil {
117 return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
118 }
119 if code := r.StatusCode; code < 200 || code > 299 {
120 retrieveError := &RetrieveError{
121 Response: r,
122 Body: body,
123 }
124
125 content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
126 switch content {
127 case "application/x-www-form-urlencoded", "text/plain":
128 // some endpoints return a query string
129 vals, err := url.ParseQuery(string(body))
130 if err != nil {
131 return nil, retrieveError
132 }
133 retrieveError.ErrorCode = vals.Get("error")
134 retrieveError.ErrorDescription = vals.Get("error_description")
135 retrieveError.ErrorURI = vals.Get("error_uri")
136 default:
137 var tj struct {
138 // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
139 ErrorCode string `json:"error"`
140 ErrorDescription string `json:"error_description"`
141 ErrorURI string `json:"error_uri"`
142 }
143 if json.Unmarshal(body, &tj) != nil {
144 return nil, retrieveError
145 }
146 retrieveError.ErrorCode = tj.ErrorCode
147 retrieveError.ErrorDescription = tj.ErrorDescription
148 retrieveError.ErrorURI = tj.ErrorURI
149 }
150
151 return nil, retrieveError
152 }
153
154 da := &DeviceAuthResponse{}
155 err = json.Unmarshal(body, &da)
156 if err != nil {
157 return nil, fmt.Errorf("unmarshal %s", err)
158 }
159
160 if !da.Expiry.IsZero() {
161 // Make a small adjustment to account for time taken by the request
162 da.Expiry = da.Expiry.Add(-time.Since(t))
163 }
164
165 return da, nil
166 }
167
168 // DeviceAccessToken polls the server to exchange a device code for a token.
169 func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
170 if !da.Expiry.IsZero() {
171 var cancel context.CancelFunc
172 ctx, cancel = context.WithDeadline(ctx, da.Expiry)
173 defer cancel()
174 }
175
176 // https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
177 v := url.Values{
178 "client_id": {c.ClientID},
179 "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
180 "device_code": {da.DeviceCode},
181 }
182 if len(c.Scopes) > 0 {
183 v.Set("scope", strings.Join(c.Scopes, " "))
184 }
185 for _, opt := range opts {
186 opt.setValue(v)
187 }
188
189 // "If no value is provided, clients MUST use 5 as the default."
190 // https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
191 interval := da.Interval
192 if interval == 0 {
193 interval = 5
194 }
195
196 ticker := time.NewTicker(time.Duration(interval) * time.Second)
197 defer ticker.Stop()
198 for {
199 select {
200 case <-ctx.Done():
201 return nil, ctx.Err()
202 case <-ticker.C:
203 tok, err := retrieveToken(ctx, c, v)
204 if err == nil {
205 return tok, nil
206 }
207
208 e, ok := err.(*RetrieveError)
209 if !ok {
210 return nil, err
211 }
212 switch e.ErrorCode {
213 case errSlowDown:
214 // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
215 // "the interval MUST be increased by 5 seconds for this and all subsequent requests"
216 interval += 5
217 ticker.Reset(time.Duration(interval) * time.Second)
218 case errAuthorizationPending:
219 // Do nothing.
220 case errAccessDenied, errExpiredToken:
221 fallthrough
222 default:
223 return tok, err
224 }
225 }
226 }
227 }
228