1 package adal
2 3 // Copyright 2017 Microsoft Corporation
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 17 /*
18 This file is largely based on rjw57/oauth2device's code, with the follow differences:
19 * scope -> resource, and only allow a single one
20 * receive "Message" in the DeviceCode struct and show it to users as the prompt
21 * azure-xplat-cli has the following behavior that this emulates:
22 - does not send client_secret during the token exchange
23 - sends resource again in the token exchange request
24 */
25 26 import (
27 "context"
28 "encoding/json"
29 "fmt"
30 "io/ioutil"
31 "net/http"
32 "net/url"
33 "strings"
34 "time"
35 )
36 37 const (
38 logPrefix = "autorest/adal/devicetoken:"
39 )
40 41 var (
42 // ErrDeviceGeneric represents an unknown error from the token endpoint when using device flow
43 ErrDeviceGeneric = fmt.Errorf("%s Error while retrieving OAuth token: Unknown Error", logPrefix)
44 45 // ErrDeviceAccessDenied represents an access denied error from the token endpoint when using device flow
46 ErrDeviceAccessDenied = fmt.Errorf("%s Error while retrieving OAuth token: Access Denied", logPrefix)
47 48 // ErrDeviceAuthorizationPending represents the server waiting on the user to complete the device flow
49 ErrDeviceAuthorizationPending = fmt.Errorf("%s Error while retrieving OAuth token: Authorization Pending", logPrefix)
50 51 // ErrDeviceCodeExpired represents the server timing out and expiring the code during device flow
52 ErrDeviceCodeExpired = fmt.Errorf("%s Error while retrieving OAuth token: Code Expired", logPrefix)
53 54 // ErrDeviceSlowDown represents the service telling us we're polling too often during device flow
55 ErrDeviceSlowDown = fmt.Errorf("%s Error while retrieving OAuth token: Slow Down", logPrefix)
56 57 // ErrDeviceCodeEmpty represents an empty device code from the device endpoint while using device flow
58 ErrDeviceCodeEmpty = fmt.Errorf("%s Error while retrieving device code: Device Code Empty", logPrefix)
59 60 // ErrOAuthTokenEmpty represents an empty OAuth token from the token endpoint when using device flow
61 ErrOAuthTokenEmpty = fmt.Errorf("%s Error while retrieving OAuth token: Token Empty", logPrefix)
62 63 errCodeSendingFails = "Error occurred while sending request for Device Authorization Code"
64 errCodeHandlingFails = "Error occurred while handling response from the Device Endpoint"
65 errTokenSendingFails = "Error occurred while sending request with device code for a token"
66 errTokenHandlingFails = "Error occurred while handling response from the Token Endpoint (during device flow)"
67 errStatusNotOK = "Error HTTP status != 200"
68 )
69 70 // DeviceCode is the object returned by the device auth endpoint
71 // It contains information to instruct the user to complete the auth flow
72 type DeviceCode struct {
73 DeviceCode *string `json:"device_code,omitempty"`
74 UserCode *string `json:"user_code,omitempty"`
75 VerificationURL *string `json:"verification_url,omitempty"`
76 ExpiresIn *int64 `json:"expires_in,string,omitempty"`
77 Interval *int64 `json:"interval,string,omitempty"`
78 79 Message *string `json:"message"` // Azure specific
80 Resource string // store the following, stored when initiating, used when exchanging
81 OAuthConfig OAuthConfig
82 ClientID string
83 }
84 85 // TokenError is the object returned by the token exchange endpoint
86 // when something is amiss
87 type TokenError struct {
88 Error *string `json:"error,omitempty"`
89 ErrorCodes []int `json:"error_codes,omitempty"`
90 ErrorDescription *string `json:"error_description,omitempty"`
91 Timestamp *string `json:"timestamp,omitempty"`
92 TraceID *string `json:"trace_id,omitempty"`
93 }
94 95 // DeviceToken is the object return by the token exchange endpoint
96 // It can either look like a Token or an ErrorToken, so put both here
97 // and check for presence of "Error" to know if we are in error state
98 type deviceToken struct {
99 Token
100 TokenError
101 }
102 103 // InitiateDeviceAuth initiates a device auth flow. It returns a DeviceCode
104 // that can be used with CheckForUserCompletion or WaitForUserCompletion.
105 // Deprecated: use InitiateDeviceAuthWithContext() instead.
106 func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
107 return InitiateDeviceAuthWithContext(context.Background(), sender, oauthConfig, clientID, resource)
108 }
109 110 // InitiateDeviceAuthWithContext initiates a device auth flow. It returns a DeviceCode
111 // that can be used with CheckForUserCompletion or WaitForUserCompletion.
112 func InitiateDeviceAuthWithContext(ctx context.Context, sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
113 v := url.Values{
114 "client_id": []string{clientID},
115 "resource": []string{resource},
116 }
117 118 s := v.Encode()
119 body := ioutil.NopCloser(strings.NewReader(s))
120 121 req, err := http.NewRequest(http.MethodPost, oauthConfig.DeviceCodeEndpoint.String(), body)
122 if err != nil {
123 return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
124 }
125 126 req.ContentLength = int64(len(s))
127 req.Header.Set(contentType, mimeTypeFormPost)
128 resp, err := sender.Do(req.WithContext(ctx))
129 if err != nil {
130 return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
131 }
132 defer resp.Body.Close()
133 134 rb, err := ioutil.ReadAll(resp.Body)
135 if err != nil {
136 return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error())
137 }
138 139 if resp.StatusCode != http.StatusOK {
140 return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, errStatusNotOK)
141 }
142 143 if len(strings.Trim(string(rb), " ")) == 0 {
144 return nil, ErrDeviceCodeEmpty
145 }
146 147 var code DeviceCode
148 err = json.Unmarshal(rb, &code)
149 if err != nil {
150 return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error())
151 }
152 153 code.ClientID = clientID
154 code.Resource = resource
155 code.OAuthConfig = oauthConfig
156 157 return &code, nil
158 }
159 160 // CheckForUserCompletion takes a DeviceCode and checks with the Azure AD OAuth endpoint
161 // to see if the device flow has: been completed, timed out, or otherwise failed
162 // Deprecated: use CheckForUserCompletionWithContext() instead.
163 func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
164 return CheckForUserCompletionWithContext(context.Background(), sender, code)
165 }
166 167 // CheckForUserCompletionWithContext takes a DeviceCode and checks with the Azure AD OAuth endpoint
168 // to see if the device flow has: been completed, timed out, or otherwise failed
169 func CheckForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) {
170 v := url.Values{
171 "client_id": []string{code.ClientID},
172 "code": []string{*code.DeviceCode},
173 "grant_type": []string{OAuthGrantTypeDeviceCode},
174 "resource": []string{code.Resource},
175 }
176 177 s := v.Encode()
178 body := ioutil.NopCloser(strings.NewReader(s))
179 180 req, err := http.NewRequest(http.MethodPost, code.OAuthConfig.TokenEndpoint.String(), body)
181 if err != nil {
182 return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
183 }
184 185 req.ContentLength = int64(len(s))
186 req.Header.Set(contentType, mimeTypeFormPost)
187 resp, err := sender.Do(req.WithContext(ctx))
188 if err != nil {
189 return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
190 }
191 defer resp.Body.Close()
192 193 rb, err := ioutil.ReadAll(resp.Body)
194 if err != nil {
195 return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error())
196 }
197 198 if resp.StatusCode != http.StatusOK && len(strings.Trim(string(rb), " ")) == 0 {
199 return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, errStatusNotOK)
200 }
201 if len(strings.Trim(string(rb), " ")) == 0 {
202 return nil, ErrOAuthTokenEmpty
203 }
204 205 var token deviceToken
206 err = json.Unmarshal(rb, &token)
207 if err != nil {
208 return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error())
209 }
210 211 if token.Error == nil {
212 return &token.Token, nil
213 }
214 215 switch *token.Error {
216 case "authorization_pending":
217 return nil, ErrDeviceAuthorizationPending
218 case "slow_down":
219 return nil, ErrDeviceSlowDown
220 case "access_denied":
221 return nil, ErrDeviceAccessDenied
222 case "code_expired":
223 return nil, ErrDeviceCodeExpired
224 default:
225 // return a more meaningful error message if available
226 if token.ErrorDescription != nil {
227 return nil, fmt.Errorf("%s %s: %s", logPrefix, *token.Error, *token.ErrorDescription)
228 }
229 return nil, ErrDeviceGeneric
230 }
231 }
232 233 // WaitForUserCompletion calls CheckForUserCompletion repeatedly until a token is granted or an error state occurs.
234 // This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
235 // Deprecated: use WaitForUserCompletionWithContext() instead.
236 func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
237 return WaitForUserCompletionWithContext(context.Background(), sender, code)
238 }
239 240 // WaitForUserCompletionWithContext calls CheckForUserCompletion repeatedly until a token is granted or an error
241 // state occurs. This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
242 func WaitForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) {
243 intervalDuration := time.Duration(*code.Interval) * time.Second
244 waitDuration := intervalDuration
245 246 for {
247 token, err := CheckForUserCompletionWithContext(ctx, sender, code)
248 249 if err == nil {
250 return token, nil
251 }
252 253 switch err {
254 case ErrDeviceSlowDown:
255 waitDuration += waitDuration
256 case ErrDeviceAuthorizationPending:
257 // noop
258 default: // everything else is "fatal" to us
259 return nil, err
260 }
261 262 if waitDuration > (intervalDuration * 3) {
263 return nil, fmt.Errorf("%s Error waiting for user to complete device flow. Server told us to slow_down too much", logPrefix)
264 }
265 266 select {
267 case <-time.After(waitDuration):
268 // noop
269 case <-ctx.Done():
270 return nil, ctx.Err()
271 }
272 }
273 }
274