oauthbearer.go raw

   1  package sasl
   2  
   3  import (
   4  	"bytes"
   5  	"encoding/json"
   6  	"errors"
   7  	"fmt"
   8  	"strconv"
   9  	"strings"
  10  )
  11  
  12  // The OAUTHBEARER mechanism name.
  13  const OAuthBearer = "OAUTHBEARER"
  14  
  15  type OAuthBearerError struct {
  16  	Status  string `json:"status"`
  17  	Schemes string `json:"schemes"`
  18  	Scope   string `json:"scope"`
  19  }
  20  
  21  type OAuthBearerOptions struct {
  22  	Username string
  23  	Token    string
  24  	Host     string
  25  	Port     int
  26  }
  27  
  28  // Implements error
  29  func (err *OAuthBearerError) Error() string {
  30  	return fmt.Sprintf("OAUTHBEARER authentication error (%v)", err.Status)
  31  }
  32  
  33  type oauthBearerClient struct {
  34  	OAuthBearerOptions
  35  }
  36  
  37  func (a *oauthBearerClient) Start() (mech string, ir []byte, err error) {
  38  	var authzid string
  39  	if a.Username != "" {
  40  		authzid = "a=" + a.Username
  41  	}
  42  	str := "n," + authzid + ","
  43  
  44  	if a.Host != "" {
  45  		str += "\x01host=" + a.Host
  46  	}
  47  
  48  	if a.Port != 0 {
  49  		str += "\x01port=" + strconv.Itoa(a.Port)
  50  	}
  51  	str += "\x01auth=Bearer " + a.Token + "\x01\x01"
  52  	ir = []byte(str)
  53  	return OAuthBearer, ir, nil
  54  }
  55  
  56  func (a *oauthBearerClient) Next(challenge []byte) ([]byte, error) {
  57  	authBearerErr := &OAuthBearerError{}
  58  	if err := json.Unmarshal(challenge, authBearerErr); err != nil {
  59  		return nil, err
  60  	} else {
  61  		return nil, authBearerErr
  62  	}
  63  }
  64  
  65  // An implementation of the OAUTHBEARER authentication mechanism, as
  66  // described in RFC 7628.
  67  func NewOAuthBearerClient(opt *OAuthBearerOptions) Client {
  68  	return &oauthBearerClient{*opt}
  69  }
  70  
  71  type OAuthBearerAuthenticator func(opts OAuthBearerOptions) *OAuthBearerError
  72  
  73  type oauthBearerServer struct {
  74  	done         bool
  75  	failErr      error
  76  	authenticate OAuthBearerAuthenticator
  77  }
  78  
  79  func (a *oauthBearerServer) fail(descr string) ([]byte, bool, error) {
  80  	blob, err := json.Marshal(OAuthBearerError{
  81  		Status:  "invalid_request",
  82  		Schemes: "bearer",
  83  	})
  84  	if err != nil {
  85  		panic(err) // wtf
  86  	}
  87  	a.failErr = errors.New("sasl: client error: " + descr)
  88  	return blob, false, nil
  89  }
  90  
  91  func (a *oauthBearerServer) Next(response []byte) (challenge []byte, done bool, err error) {
  92  	// Per RFC, we cannot just send an error, we need to return JSON-structured
  93  	// value as a challenge and then after getting dummy response from the
  94  	// client stop the exchange.
  95  	if a.failErr != nil {
  96  		// Server libraries (go-smtp, go-imap) will not call Next on
  97  		// protocol-specific SASL cancel response ('*'). However, GS2 (and
  98  		// indirectly OAUTHBEARER) defines a protocol-independent way to do so
  99  		// using 0x01.
 100  		if len(response) != 1 && response[0] != 0x01 {
 101  			return nil, true, errors.New("sasl: invalid response")
 102  		}
 103  		return nil, true, a.failErr
 104  	}
 105  
 106  	if a.done {
 107  		err = ErrUnexpectedClientResponse
 108  		return
 109  	}
 110  
 111  	// Generate empty challenge.
 112  	if response == nil {
 113  		return []byte{}, false, nil
 114  	}
 115  
 116  	a.done = true
 117  
 118  	// Cut n,a=username,\x01host=...\x01auth=...
 119  	// into
 120  	//   n
 121  	//   a=username
 122  	//   \x01host=...\x01auth=...\x01\x01
 123  	parts := bytes.SplitN(response, []byte{','}, 3)
 124  	if len(parts) != 3 {
 125  		return a.fail("Invalid response")
 126  	}
 127  	flag := parts[0]
 128  	authzid := parts[1]
 129  	if !bytes.Equal(flag, []byte{'n'}) {
 130  		return a.fail("Invalid response, missing 'n' in gs2-cb-flag")
 131  	}
 132  	opts := OAuthBearerOptions{}
 133  	if len(authzid) > 0 {
 134  		if !bytes.HasPrefix(authzid, []byte("a=")) {
 135  			return a.fail("Invalid response, missing 'a=' in gs2-authzid")
 136  		}
 137  		opts.Username = string(bytes.TrimPrefix(authzid, []byte("a=")))
 138  	}
 139  
 140  	// Cut \x01host=...\x01auth=...\x01\x01
 141  	// into
 142  	//   *empty*
 143  	//   host=...
 144  	//   auth=...
 145  	//   *empty*
 146  	//
 147  	// Note that this code does not do a lot of checks to make sure the input
 148  	// follows the exact format specified by RFC.
 149  	params := bytes.Split(parts[2], []byte{0x01})
 150  	for _, p := range params {
 151  		// Skip empty fields (one at start and end).
 152  		if len(p) == 0 {
 153  			continue
 154  		}
 155  
 156  		pParts := bytes.SplitN(p, []byte{'='}, 2)
 157  		if len(pParts) != 2 {
 158  			return a.fail("Invalid response, missing '='")
 159  		}
 160  
 161  		switch string(pParts[0]) {
 162  		case "host":
 163  			opts.Host = string(pParts[1])
 164  		case "port":
 165  			port, err := strconv.ParseUint(string(pParts[1]), 10, 16)
 166  			if err != nil {
 167  				return a.fail("Invalid response, malformed 'port' value")
 168  			}
 169  			opts.Port = int(port)
 170  		case "auth":
 171  			const prefix = "bearer "
 172  			strValue := string(pParts[1])
 173  			// Token type is case-insensitive.
 174  			if !strings.HasPrefix(strings.ToLower(strValue), prefix) {
 175  				return a.fail("Unsupported token type")
 176  			}
 177  			opts.Token = strValue[len(prefix):]
 178  		default:
 179  			return a.fail("Invalid response, unknown parameter: " + string(pParts[0]))
 180  		}
 181  	}
 182  
 183  	authzErr := a.authenticate(opts)
 184  	if authzErr != nil {
 185  		blob, err := json.Marshal(authzErr)
 186  		if err != nil {
 187  			panic(err) // wtf
 188  		}
 189  		a.failErr = authzErr
 190  		return blob, false, nil
 191  	}
 192  
 193  	return nil, true, nil
 194  }
 195  
 196  func NewOAuthBearerServer(auth OAuthBearerAuthenticator) Server {
 197  	return &oauthBearerServer{authenticate: auth}
 198  }
 199