server.go raw
1 // Copyright (c) Microsoft Corporation.
2 // Licensed under the MIT license.
3
4 // Package local contains a local HTTP server used with interactive authentication.
5 package local
6
7 import (
8 "context"
9 "fmt"
10 "html"
11 "net"
12 "net/http"
13 "strconv"
14 "strings"
15 "time"
16 )
17
18 var okPage = []byte(`
19 <!DOCTYPE html>
20 <html>
21 <head>
22 <meta charset="utf-8" />
23 <title>Authentication Complete</title>
24 </head>
25 <body>
26 <p>Authentication complete. You can return to the application. Feel free to close this browser tab.</p>
27 </body>
28 </html>
29 `)
30
31 const failPage = `
32 <!DOCTYPE html>
33 <html>
34 <head>
35 <meta charset="utf-8" />
36 <title>Authentication Failed</title>
37 </head>
38 <body>
39 <p>Authentication failed. You can return to the application. Feel free to close this browser tab.</p>
40 <p>Error details: error %s error_description: %s</p>
41 </body>
42 </html>
43 `
44
45 // Result is the result from the redirect.
46 type Result struct {
47 // Code is the code sent by the authority server.
48 Code string
49 // Err is set if there was an error.
50 Err error
51 }
52
53 // Server is an HTTP server.
54 type Server struct {
55 // Addr is the address the server is listening on.
56 Addr string
57 resultCh chan Result
58 s *http.Server
59 reqState string
60 }
61
62 // New creates a local HTTP server and starts it.
63 func New(reqState string, port int) (*Server, error) {
64 var l net.Listener
65 var err error
66 var portStr string
67 if port > 0 {
68 // use port provided by caller
69 l, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
70 portStr = strconv.FormatInt(int64(port), 10)
71 } else {
72 // find a free port
73 for i := 0; i < 10; i++ {
74 l, err = net.Listen("tcp", "localhost:0")
75 if err != nil {
76 continue
77 }
78 addr := l.Addr().String()
79 portStr = addr[strings.LastIndex(addr, ":")+1:]
80 break
81 }
82 }
83 if err != nil {
84 return nil, err
85 }
86
87 serv := &Server{
88 Addr: fmt.Sprintf("http://localhost:%s", portStr),
89 s: &http.Server{Addr: "localhost:0", ReadHeaderTimeout: time.Second},
90 reqState: reqState,
91 resultCh: make(chan Result, 1),
92 }
93 serv.s.Handler = http.HandlerFunc(serv.handler)
94
95 if err := serv.start(l); err != nil {
96 return nil, err
97 }
98
99 return serv, nil
100 }
101
102 func (s *Server) start(l net.Listener) error {
103 go func() {
104 err := s.s.Serve(l)
105 if err != nil {
106 select {
107 case s.resultCh <- Result{Err: err}:
108 default:
109 }
110 }
111 }()
112
113 return nil
114 }
115
116 // Result gets the result of the redirect operation. Once a single result is returned, the server
117 // is shutdown. ctx deadline will be honored.
118 func (s *Server) Result(ctx context.Context) Result {
119 select {
120 case <-ctx.Done():
121 return Result{Err: ctx.Err()}
122 case r := <-s.resultCh:
123 return r
124 }
125 }
126
127 // Shutdown shuts down the server.
128 func (s *Server) Shutdown() {
129 // Note: You might get clever and think you can do this in handler() as a defer, you can't.
130 _ = s.s.Shutdown(context.Background())
131 }
132
133 func (s *Server) putResult(r Result) {
134 select {
135 case s.resultCh <- r:
136 default:
137 }
138 }
139
140 func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
141 q := r.URL.Query()
142
143 headerErr := q.Get("error")
144 if headerErr != "" {
145 desc := html.EscapeString(q.Get("error_description"))
146 escapedHeaderErr := html.EscapeString(headerErr)
147 // Note: It is a little weird we handle some errors by not going to the failPage. If they all should,
148 // change this to s.error() and make s.error() write the failPage instead of an error code.
149 _, _ = w.Write([]byte(fmt.Sprintf(failPage, escapedHeaderErr, desc)))
150 s.putResult(Result{Err: fmt.Errorf("%s", desc)})
151
152 return
153 }
154
155 respState := q.Get("state")
156 switch respState {
157 case s.reqState:
158 case "":
159 s.error(w, http.StatusInternalServerError, "server didn't send OAuth state")
160 return
161 default:
162 s.error(w, http.StatusInternalServerError, "mismatched OAuth state, req(%s), resp(%s)", s.reqState, respState)
163 return
164 }
165
166 code := q.Get("code")
167 if code == "" {
168 s.error(w, http.StatusInternalServerError, "authorization code missing in query string")
169 return
170 }
171
172 _, _ = w.Write(okPage)
173 s.putResult(Result{Code: code})
174 }
175
176 func (s *Server) error(w http.ResponseWriter, code int, str string, i ...interface{}) {
177 err := fmt.Errorf(str, i...)
178 http.Error(w, err.Error(), code)
179 s.putResult(Result{Err: err})
180 }
181