devicetoken.go raw

   1  package adal
   2  
   3  // Copyright 2017 Microsoft Corporation
   4  //
   5  //  Licensed under the Apache License, Version 2.0 (the "License");
   6  //  you may not use this file except in compliance with the License.
   7  //  You may obtain a copy of the License at
   8  //
   9  //      http://www.apache.org/licenses/LICENSE-2.0
  10  //
  11  //  Unless required by applicable law or agreed to in writing, software
  12  //  distributed under the License is distributed on an "AS IS" BASIS,
  13  //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14  //  See the License for the specific language governing permissions and
  15  //  limitations under the License.
  16  
  17  /*
  18    This file is largely based on rjw57/oauth2device's code, with the follow differences:
  19     * scope -> resource, and only allow a single one
  20     * receive "Message" in the DeviceCode struct and show it to users as the prompt
  21     * azure-xplat-cli has the following behavior that this emulates:
  22       - does not send client_secret during the token exchange
  23       - sends resource again in the token exchange request
  24  */
  25  
  26  import (
  27  	"context"
  28  	"encoding/json"
  29  	"fmt"
  30  	"io/ioutil"
  31  	"net/http"
  32  	"net/url"
  33  	"strings"
  34  	"time"
  35  )
  36  
  37  const (
  38  	logPrefix = "autorest/adal/devicetoken:"
  39  )
  40  
  41  var (
  42  	// ErrDeviceGeneric represents an unknown error from the token endpoint when using device flow
  43  	ErrDeviceGeneric = fmt.Errorf("%s Error while retrieving OAuth token: Unknown Error", logPrefix)
  44  
  45  	// ErrDeviceAccessDenied represents an access denied error from the token endpoint when using device flow
  46  	ErrDeviceAccessDenied = fmt.Errorf("%s Error while retrieving OAuth token: Access Denied", logPrefix)
  47  
  48  	// ErrDeviceAuthorizationPending represents the server waiting on the user to complete the device flow
  49  	ErrDeviceAuthorizationPending = fmt.Errorf("%s Error while retrieving OAuth token: Authorization Pending", logPrefix)
  50  
  51  	// ErrDeviceCodeExpired represents the server timing out and expiring the code during device flow
  52  	ErrDeviceCodeExpired = fmt.Errorf("%s Error while retrieving OAuth token: Code Expired", logPrefix)
  53  
  54  	// ErrDeviceSlowDown represents the service telling us we're polling too often during device flow
  55  	ErrDeviceSlowDown = fmt.Errorf("%s Error while retrieving OAuth token: Slow Down", logPrefix)
  56  
  57  	// ErrDeviceCodeEmpty represents an empty device code from the device endpoint while using device flow
  58  	ErrDeviceCodeEmpty = fmt.Errorf("%s Error while retrieving device code: Device Code Empty", logPrefix)
  59  
  60  	// ErrOAuthTokenEmpty represents an empty OAuth token from the token endpoint when using device flow
  61  	ErrOAuthTokenEmpty = fmt.Errorf("%s Error while retrieving OAuth token: Token Empty", logPrefix)
  62  
  63  	errCodeSendingFails   = "Error occurred while sending request for Device Authorization Code"
  64  	errCodeHandlingFails  = "Error occurred while handling response from the Device Endpoint"
  65  	errTokenSendingFails  = "Error occurred while sending request with device code for a token"
  66  	errTokenHandlingFails = "Error occurred while handling response from the Token Endpoint (during device flow)"
  67  	errStatusNotOK        = "Error HTTP status != 200"
  68  )
  69  
  70  // DeviceCode is the object returned by the device auth endpoint
  71  // It contains information to instruct the user to complete the auth flow
  72  type DeviceCode struct {
  73  	DeviceCode      *string `json:"device_code,omitempty"`
  74  	UserCode        *string `json:"user_code,omitempty"`
  75  	VerificationURL *string `json:"verification_url,omitempty"`
  76  	ExpiresIn       *int64  `json:"expires_in,string,omitempty"`
  77  	Interval        *int64  `json:"interval,string,omitempty"`
  78  
  79  	Message     *string `json:"message"` // Azure specific
  80  	Resource    string  // store the following, stored when initiating, used when exchanging
  81  	OAuthConfig OAuthConfig
  82  	ClientID    string
  83  }
  84  
  85  // TokenError is the object returned by the token exchange endpoint
  86  // when something is amiss
  87  type TokenError struct {
  88  	Error            *string `json:"error,omitempty"`
  89  	ErrorCodes       []int   `json:"error_codes,omitempty"`
  90  	ErrorDescription *string `json:"error_description,omitempty"`
  91  	Timestamp        *string `json:"timestamp,omitempty"`
  92  	TraceID          *string `json:"trace_id,omitempty"`
  93  }
  94  
  95  // DeviceToken is the object return by the token exchange endpoint
  96  // It can either look like a Token or an ErrorToken, so put both here
  97  // and check for presence of "Error" to know if we are in error state
  98  type deviceToken struct {
  99  	Token
 100  	TokenError
 101  }
 102  
 103  // InitiateDeviceAuth initiates a device auth flow. It returns a DeviceCode
 104  // that can be used with CheckForUserCompletion or WaitForUserCompletion.
 105  // Deprecated: use InitiateDeviceAuthWithContext() instead.
 106  func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
 107  	return InitiateDeviceAuthWithContext(context.Background(), sender, oauthConfig, clientID, resource)
 108  }
 109  
 110  // InitiateDeviceAuthWithContext initiates a device auth flow. It returns a DeviceCode
 111  // that can be used with CheckForUserCompletion or WaitForUserCompletion.
 112  func InitiateDeviceAuthWithContext(ctx context.Context, sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
 113  	v := url.Values{
 114  		"client_id": []string{clientID},
 115  		"resource":  []string{resource},
 116  	}
 117  
 118  	s := v.Encode()
 119  	body := ioutil.NopCloser(strings.NewReader(s))
 120  
 121  	req, err := http.NewRequest(http.MethodPost, oauthConfig.DeviceCodeEndpoint.String(), body)
 122  	if err != nil {
 123  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
 124  	}
 125  
 126  	req.ContentLength = int64(len(s))
 127  	req.Header.Set(contentType, mimeTypeFormPost)
 128  	resp, err := sender.Do(req.WithContext(ctx))
 129  	if err != nil {
 130  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
 131  	}
 132  	defer resp.Body.Close()
 133  
 134  	rb, err := ioutil.ReadAll(resp.Body)
 135  	if err != nil {
 136  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error())
 137  	}
 138  
 139  	if resp.StatusCode != http.StatusOK {
 140  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, errStatusNotOK)
 141  	}
 142  
 143  	if len(strings.Trim(string(rb), " ")) == 0 {
 144  		return nil, ErrDeviceCodeEmpty
 145  	}
 146  
 147  	var code DeviceCode
 148  	err = json.Unmarshal(rb, &code)
 149  	if err != nil {
 150  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error())
 151  	}
 152  
 153  	code.ClientID = clientID
 154  	code.Resource = resource
 155  	code.OAuthConfig = oauthConfig
 156  
 157  	return &code, nil
 158  }
 159  
 160  // CheckForUserCompletion takes a DeviceCode and checks with the Azure AD OAuth endpoint
 161  // to see if the device flow has: been completed, timed out, or otherwise failed
 162  // Deprecated: use CheckForUserCompletionWithContext() instead.
 163  func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
 164  	return CheckForUserCompletionWithContext(context.Background(), sender, code)
 165  }
 166  
 167  // CheckForUserCompletionWithContext takes a DeviceCode and checks with the Azure AD OAuth endpoint
 168  // to see if the device flow has: been completed, timed out, or otherwise failed
 169  func CheckForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) {
 170  	v := url.Values{
 171  		"client_id":  []string{code.ClientID},
 172  		"code":       []string{*code.DeviceCode},
 173  		"grant_type": []string{OAuthGrantTypeDeviceCode},
 174  		"resource":   []string{code.Resource},
 175  	}
 176  
 177  	s := v.Encode()
 178  	body := ioutil.NopCloser(strings.NewReader(s))
 179  
 180  	req, err := http.NewRequest(http.MethodPost, code.OAuthConfig.TokenEndpoint.String(), body)
 181  	if err != nil {
 182  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
 183  	}
 184  
 185  	req.ContentLength = int64(len(s))
 186  	req.Header.Set(contentType, mimeTypeFormPost)
 187  	resp, err := sender.Do(req.WithContext(ctx))
 188  	if err != nil {
 189  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
 190  	}
 191  	defer resp.Body.Close()
 192  
 193  	rb, err := ioutil.ReadAll(resp.Body)
 194  	if err != nil {
 195  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error())
 196  	}
 197  
 198  	if resp.StatusCode != http.StatusOK && len(strings.Trim(string(rb), " ")) == 0 {
 199  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, errStatusNotOK)
 200  	}
 201  	if len(strings.Trim(string(rb), " ")) == 0 {
 202  		return nil, ErrOAuthTokenEmpty
 203  	}
 204  
 205  	var token deviceToken
 206  	err = json.Unmarshal(rb, &token)
 207  	if err != nil {
 208  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error())
 209  	}
 210  
 211  	if token.Error == nil {
 212  		return &token.Token, nil
 213  	}
 214  
 215  	switch *token.Error {
 216  	case "authorization_pending":
 217  		return nil, ErrDeviceAuthorizationPending
 218  	case "slow_down":
 219  		return nil, ErrDeviceSlowDown
 220  	case "access_denied":
 221  		return nil, ErrDeviceAccessDenied
 222  	case "code_expired":
 223  		return nil, ErrDeviceCodeExpired
 224  	default:
 225  		// return a more meaningful error message if available
 226  		if token.ErrorDescription != nil {
 227  			return nil, fmt.Errorf("%s %s: %s", logPrefix, *token.Error, *token.ErrorDescription)
 228  		}
 229  		return nil, ErrDeviceGeneric
 230  	}
 231  }
 232  
 233  // WaitForUserCompletion calls CheckForUserCompletion repeatedly until a token is granted or an error state occurs.
 234  // This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
 235  // Deprecated: use WaitForUserCompletionWithContext() instead.
 236  func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
 237  	return WaitForUserCompletionWithContext(context.Background(), sender, code)
 238  }
 239  
 240  // WaitForUserCompletionWithContext calls CheckForUserCompletion repeatedly until a token is granted or an error
 241  // state occurs.  This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
 242  func WaitForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) {
 243  	intervalDuration := time.Duration(*code.Interval) * time.Second
 244  	waitDuration := intervalDuration
 245  
 246  	for {
 247  		token, err := CheckForUserCompletionWithContext(ctx, sender, code)
 248  
 249  		if err == nil {
 250  			return token, nil
 251  		}
 252  
 253  		switch err {
 254  		case ErrDeviceSlowDown:
 255  			waitDuration += waitDuration
 256  		case ErrDeviceAuthorizationPending:
 257  			// noop
 258  		default: // everything else is "fatal" to us
 259  			return nil, err
 260  		}
 261  
 262  		if waitDuration > (intervalDuration * 3) {
 263  			return nil, fmt.Errorf("%s Error waiting for user to complete device flow. Server told us to slow_down too much", logPrefix)
 264  		}
 265  
 266  		select {
 267  		case <-time.After(waitDuration):
 268  			// noop
 269  		case <-ctx.Done():
 270  			return nil, ctx.Err()
 271  		}
 272  	}
 273  }
 274