oauthbearer.go raw
1 package sasl
2
3 import (
4 "bytes"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "strconv"
9 "strings"
10 )
11
12 // The OAUTHBEARER mechanism name.
13 const OAuthBearer = "OAUTHBEARER"
14
15 type OAuthBearerError struct {
16 Status string `json:"status"`
17 Schemes string `json:"schemes"`
18 Scope string `json:"scope"`
19 }
20
21 type OAuthBearerOptions struct {
22 Username string
23 Token string
24 Host string
25 Port int
26 }
27
28 // Implements error
29 func (err *OAuthBearerError) Error() string {
30 return fmt.Sprintf("OAUTHBEARER authentication error (%v)", err.Status)
31 }
32
33 type oauthBearerClient struct {
34 OAuthBearerOptions
35 }
36
37 func (a *oauthBearerClient) Start() (mech string, ir []byte, err error) {
38 var authzid string
39 if a.Username != "" {
40 authzid = "a=" + a.Username
41 }
42 str := "n," + authzid + ","
43
44 if a.Host != "" {
45 str += "\x01host=" + a.Host
46 }
47
48 if a.Port != 0 {
49 str += "\x01port=" + strconv.Itoa(a.Port)
50 }
51 str += "\x01auth=Bearer " + a.Token + "\x01\x01"
52 ir = []byte(str)
53 return OAuthBearer, ir, nil
54 }
55
56 func (a *oauthBearerClient) Next(challenge []byte) ([]byte, error) {
57 authBearerErr := &OAuthBearerError{}
58 if err := json.Unmarshal(challenge, authBearerErr); err != nil {
59 return nil, err
60 } else {
61 return nil, authBearerErr
62 }
63 }
64
65 // An implementation of the OAUTHBEARER authentication mechanism, as
66 // described in RFC 7628.
67 func NewOAuthBearerClient(opt *OAuthBearerOptions) Client {
68 return &oauthBearerClient{*opt}
69 }
70
71 type OAuthBearerAuthenticator func(opts OAuthBearerOptions) *OAuthBearerError
72
73 type oauthBearerServer struct {
74 done bool
75 failErr error
76 authenticate OAuthBearerAuthenticator
77 }
78
79 func (a *oauthBearerServer) fail(descr string) ([]byte, bool, error) {
80 blob, err := json.Marshal(OAuthBearerError{
81 Status: "invalid_request",
82 Schemes: "bearer",
83 })
84 if err != nil {
85 panic(err) // wtf
86 }
87 a.failErr = errors.New("sasl: client error: " + descr)
88 return blob, false, nil
89 }
90
91 func (a *oauthBearerServer) Next(response []byte) (challenge []byte, done bool, err error) {
92 // Per RFC, we cannot just send an error, we need to return JSON-structured
93 // value as a challenge and then after getting dummy response from the
94 // client stop the exchange.
95 if a.failErr != nil {
96 // Server libraries (go-smtp, go-imap) will not call Next on
97 // protocol-specific SASL cancel response ('*'). However, GS2 (and
98 // indirectly OAUTHBEARER) defines a protocol-independent way to do so
99 // using 0x01.
100 if len(response) != 1 && response[0] != 0x01 {
101 return nil, true, errors.New("sasl: invalid response")
102 }
103 return nil, true, a.failErr
104 }
105
106 if a.done {
107 err = ErrUnexpectedClientResponse
108 return
109 }
110
111 // Generate empty challenge.
112 if response == nil {
113 return []byte{}, false, nil
114 }
115
116 a.done = true
117
118 // Cut n,a=username,\x01host=...\x01auth=...
119 // into
120 // n
121 // a=username
122 // \x01host=...\x01auth=...\x01\x01
123 parts := bytes.SplitN(response, []byte{','}, 3)
124 if len(parts) != 3 {
125 return a.fail("Invalid response")
126 }
127 flag := parts[0]
128 authzid := parts[1]
129 if !bytes.Equal(flag, []byte{'n'}) {
130 return a.fail("Invalid response, missing 'n' in gs2-cb-flag")
131 }
132 opts := OAuthBearerOptions{}
133 if len(authzid) > 0 {
134 if !bytes.HasPrefix(authzid, []byte("a=")) {
135 return a.fail("Invalid response, missing 'a=' in gs2-authzid")
136 }
137 opts.Username = string(bytes.TrimPrefix(authzid, []byte("a=")))
138 }
139
140 // Cut \x01host=...\x01auth=...\x01\x01
141 // into
142 // *empty*
143 // host=...
144 // auth=...
145 // *empty*
146 //
147 // Note that this code does not do a lot of checks to make sure the input
148 // follows the exact format specified by RFC.
149 params := bytes.Split(parts[2], []byte{0x01})
150 for _, p := range params {
151 // Skip empty fields (one at start and end).
152 if len(p) == 0 {
153 continue
154 }
155
156 pParts := bytes.SplitN(p, []byte{'='}, 2)
157 if len(pParts) != 2 {
158 return a.fail("Invalid response, missing '='")
159 }
160
161 switch string(pParts[0]) {
162 case "host":
163 opts.Host = string(pParts[1])
164 case "port":
165 port, err := strconv.ParseUint(string(pParts[1]), 10, 16)
166 if err != nil {
167 return a.fail("Invalid response, malformed 'port' value")
168 }
169 opts.Port = int(port)
170 case "auth":
171 const prefix = "bearer "
172 strValue := string(pParts[1])
173 // Token type is case-insensitive.
174 if !strings.HasPrefix(strings.ToLower(strValue), prefix) {
175 return a.fail("Unsupported token type")
176 }
177 opts.Token = strValue[len(prefix):]
178 default:
179 return a.fail("Invalid response, unknown parameter: " + string(pParts[0]))
180 }
181 }
182
183 authzErr := a.authenticate(opts)
184 if authzErr != nil {
185 blob, err := json.Marshal(authzErr)
186 if err != nil {
187 panic(err) // wtf
188 }
189 a.failErr = authzErr
190 return blob, false, nil
191 }
192
193 return nil, true, nil
194 }
195
196 func NewOAuthBearerServer(auth OAuthBearerAuthenticator) Server {
197 return &oauthBearerServer{authenticate: auth}
198 }
199