deviceauth.go raw

   1  package oauth2
   2  
   3  import (
   4  	"context"
   5  	"encoding/json"
   6  	"errors"
   7  	"fmt"
   8  	"io"
   9  	"mime"
  10  	"net/http"
  11  	"net/url"
  12  	"strings"
  13  	"time"
  14  
  15  	"golang.org/x/oauth2/internal"
  16  )
  17  
  18  // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
  19  const (
  20  	errAuthorizationPending = "authorization_pending"
  21  	errSlowDown             = "slow_down"
  22  	errAccessDenied         = "access_denied"
  23  	errExpiredToken         = "expired_token"
  24  )
  25  
  26  // DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
  27  // https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
  28  type DeviceAuthResponse struct {
  29  	// DeviceCode
  30  	DeviceCode string `json:"device_code"`
  31  	// UserCode is the code the user should enter at the verification uri
  32  	UserCode string `json:"user_code"`
  33  	// VerificationURI is where user should enter the user code
  34  	VerificationURI string `json:"verification_uri"`
  35  	// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
  36  	VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
  37  	// Expiry is when the device code and user code expire
  38  	Expiry time.Time `json:"expires_in,omitempty"`
  39  	// Interval is the duration in seconds that Poll should wait between requests
  40  	Interval int64 `json:"interval,omitempty"`
  41  }
  42  
  43  func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
  44  	type Alias DeviceAuthResponse
  45  	var expiresIn int64
  46  	if !d.Expiry.IsZero() {
  47  		expiresIn = int64(time.Until(d.Expiry).Seconds())
  48  	}
  49  	return json.Marshal(&struct {
  50  		ExpiresIn int64 `json:"expires_in,omitempty"`
  51  		*Alias
  52  	}{
  53  		ExpiresIn: expiresIn,
  54  		Alias:     (*Alias)(&d),
  55  	})
  56  
  57  }
  58  
  59  func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
  60  	type Alias DeviceAuthResponse
  61  	aux := &struct {
  62  		ExpiresIn int64 `json:"expires_in"`
  63  		// workaround misspelling of verification_uri
  64  		VerificationURL string `json:"verification_url"`
  65  		*Alias
  66  	}{
  67  		Alias: (*Alias)(c),
  68  	}
  69  	if err := json.Unmarshal(data, &aux); err != nil {
  70  		return err
  71  	}
  72  	if aux.ExpiresIn != 0 {
  73  		c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
  74  	}
  75  	if c.VerificationURI == "" {
  76  		c.VerificationURI = aux.VerificationURL
  77  	}
  78  	return nil
  79  }
  80  
  81  // DeviceAuth returns a device auth struct which contains a device code
  82  // and authorization information provided for users to enter on another device.
  83  func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
  84  	// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
  85  	v := url.Values{
  86  		"client_id": {c.ClientID},
  87  	}
  88  	if len(c.Scopes) > 0 {
  89  		v.Set("scope", strings.Join(c.Scopes, " "))
  90  	}
  91  	for _, opt := range opts {
  92  		opt.setValue(v)
  93  	}
  94  	return retrieveDeviceAuth(ctx, c, v)
  95  }
  96  
  97  func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
  98  	if c.Endpoint.DeviceAuthURL == "" {
  99  		return nil, errors.New("endpoint missing DeviceAuthURL")
 100  	}
 101  
 102  	req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
 103  	if err != nil {
 104  		return nil, err
 105  	}
 106  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 107  	req.Header.Set("Accept", "application/json")
 108  
 109  	t := time.Now()
 110  	r, err := internal.ContextClient(ctx).Do(req)
 111  	if err != nil {
 112  		return nil, err
 113  	}
 114  
 115  	body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
 116  	if err != nil {
 117  		return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
 118  	}
 119  	if code := r.StatusCode; code < 200 || code > 299 {
 120  		retrieveError := &RetrieveError{
 121  			Response: r,
 122  			Body:     body,
 123  		}
 124  
 125  		content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
 126  		switch content {
 127  		case "application/x-www-form-urlencoded", "text/plain":
 128  			// some endpoints return a query string
 129  			vals, err := url.ParseQuery(string(body))
 130  			if err != nil {
 131  				return nil, retrieveError
 132  			}
 133  			retrieveError.ErrorCode = vals.Get("error")
 134  			retrieveError.ErrorDescription = vals.Get("error_description")
 135  			retrieveError.ErrorURI = vals.Get("error_uri")
 136  		default:
 137  			var tj struct {
 138  				// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
 139  				ErrorCode        string `json:"error"`
 140  				ErrorDescription string `json:"error_description"`
 141  				ErrorURI         string `json:"error_uri"`
 142  			}
 143  			if json.Unmarshal(body, &tj) != nil {
 144  				return nil, retrieveError
 145  			}
 146  			retrieveError.ErrorCode = tj.ErrorCode
 147  			retrieveError.ErrorDescription = tj.ErrorDescription
 148  			retrieveError.ErrorURI = tj.ErrorURI
 149  		}
 150  
 151  		return nil, retrieveError
 152  	}
 153  
 154  	da := &DeviceAuthResponse{}
 155  	err = json.Unmarshal(body, &da)
 156  	if err != nil {
 157  		return nil, fmt.Errorf("unmarshal %s", err)
 158  	}
 159  
 160  	if !da.Expiry.IsZero() {
 161  		// Make a small adjustment to account for time taken by the request
 162  		da.Expiry = da.Expiry.Add(-time.Since(t))
 163  	}
 164  
 165  	return da, nil
 166  }
 167  
 168  // DeviceAccessToken polls the server to exchange a device code for a token.
 169  func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
 170  	if !da.Expiry.IsZero() {
 171  		var cancel context.CancelFunc
 172  		ctx, cancel = context.WithDeadline(ctx, da.Expiry)
 173  		defer cancel()
 174  	}
 175  
 176  	// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
 177  	v := url.Values{
 178  		"client_id":   {c.ClientID},
 179  		"grant_type":  {"urn:ietf:params:oauth:grant-type:device_code"},
 180  		"device_code": {da.DeviceCode},
 181  	}
 182  	if len(c.Scopes) > 0 {
 183  		v.Set("scope", strings.Join(c.Scopes, " "))
 184  	}
 185  	for _, opt := range opts {
 186  		opt.setValue(v)
 187  	}
 188  
 189  	// "If no value is provided, clients MUST use 5 as the default."
 190  	// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
 191  	interval := da.Interval
 192  	if interval == 0 {
 193  		interval = 5
 194  	}
 195  
 196  	ticker := time.NewTicker(time.Duration(interval) * time.Second)
 197  	defer ticker.Stop()
 198  	for {
 199  		select {
 200  		case <-ctx.Done():
 201  			return nil, ctx.Err()
 202  		case <-ticker.C:
 203  			tok, err := retrieveToken(ctx, c, v)
 204  			if err == nil {
 205  				return tok, nil
 206  			}
 207  
 208  			e, ok := err.(*RetrieveError)
 209  			if !ok {
 210  				return nil, err
 211  			}
 212  			switch e.ErrorCode {
 213  			case errSlowDown:
 214  				// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
 215  				// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
 216  				interval += 5
 217  				ticker.Reset(time.Duration(interval) * time.Second)
 218  			case errAuthorizationPending:
 219  				// Do nothing.
 220  			case errAccessDenied, errExpiredToken:
 221  				fallthrough
 222  			default:
 223  				return tok, err
 224  			}
 225  		}
 226  	}
 227  }
 228