server.go raw

   1  // Copyright (c) Microsoft Corporation.
   2  // Licensed under the MIT license.
   3  
   4  // Package local contains a local HTTP server used with interactive authentication.
   5  package local
   6  
   7  import (
   8  	"context"
   9  	"fmt"
  10  	"html"
  11  	"net"
  12  	"net/http"
  13  	"strconv"
  14  	"strings"
  15  	"time"
  16  )
  17  
  18  var okPage = []byte(`
  19  <!DOCTYPE html>
  20  <html>
  21  <head>
  22      <meta charset="utf-8" />
  23      <title>Authentication Complete</title>
  24  </head>
  25  <body>
  26      <p>Authentication complete. You can return to the application. Feel free to close this browser tab.</p>
  27  </body>
  28  </html>
  29  `)
  30  
  31  const failPage = `
  32  <!DOCTYPE html>
  33  <html>
  34  <head>
  35      <meta charset="utf-8" />
  36      <title>Authentication Failed</title>
  37  </head>
  38  <body>
  39  	<p>Authentication failed. You can return to the application. Feel free to close this browser tab.</p>
  40  	<p>Error details: error %s error_description: %s</p>
  41  </body>
  42  </html>
  43  `
  44  
  45  // Result is the result from the redirect.
  46  type Result struct {
  47  	// Code is the code sent by the authority server.
  48  	Code string
  49  	// Err is set if there was an error.
  50  	Err error
  51  }
  52  
  53  // Server is an HTTP server.
  54  type Server struct {
  55  	// Addr is the address the server is listening on.
  56  	Addr     string
  57  	resultCh chan Result
  58  	s        *http.Server
  59  	reqState string
  60  }
  61  
  62  // New creates a local HTTP server and starts it.
  63  func New(reqState string, port int) (*Server, error) {
  64  	var l net.Listener
  65  	var err error
  66  	var portStr string
  67  	if port > 0 {
  68  		// use port provided by caller
  69  		l, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
  70  		portStr = strconv.FormatInt(int64(port), 10)
  71  	} else {
  72  		// find a free port
  73  		for i := 0; i < 10; i++ {
  74  			l, err = net.Listen("tcp", "localhost:0")
  75  			if err != nil {
  76  				continue
  77  			}
  78  			addr := l.Addr().String()
  79  			portStr = addr[strings.LastIndex(addr, ":")+1:]
  80  			break
  81  		}
  82  	}
  83  	if err != nil {
  84  		return nil, err
  85  	}
  86  
  87  	serv := &Server{
  88  		Addr:     fmt.Sprintf("http://localhost:%s", portStr),
  89  		s:        &http.Server{Addr: "localhost:0", ReadHeaderTimeout: time.Second},
  90  		reqState: reqState,
  91  		resultCh: make(chan Result, 1),
  92  	}
  93  	serv.s.Handler = http.HandlerFunc(serv.handler)
  94  
  95  	if err := serv.start(l); err != nil {
  96  		return nil, err
  97  	}
  98  
  99  	return serv, nil
 100  }
 101  
 102  func (s *Server) start(l net.Listener) error {
 103  	go func() {
 104  		err := s.s.Serve(l)
 105  		if err != nil {
 106  			select {
 107  			case s.resultCh <- Result{Err: err}:
 108  			default:
 109  			}
 110  		}
 111  	}()
 112  
 113  	return nil
 114  }
 115  
 116  // Result gets the result of the redirect operation. Once a single result is returned, the server
 117  // is shutdown. ctx deadline will be honored.
 118  func (s *Server) Result(ctx context.Context) Result {
 119  	select {
 120  	case <-ctx.Done():
 121  		return Result{Err: ctx.Err()}
 122  	case r := <-s.resultCh:
 123  		return r
 124  	}
 125  }
 126  
 127  // Shutdown shuts down the server.
 128  func (s *Server) Shutdown() {
 129  	// Note: You might get clever and think you can do this in handler() as a defer, you can't.
 130  	_ = s.s.Shutdown(context.Background())
 131  }
 132  
 133  func (s *Server) putResult(r Result) {
 134  	select {
 135  	case s.resultCh <- r:
 136  	default:
 137  	}
 138  }
 139  
 140  func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
 141  	q := r.URL.Query()
 142  
 143  	headerErr := q.Get("error")
 144  	if headerErr != "" {
 145  		desc := html.EscapeString(q.Get("error_description"))
 146  		escapedHeaderErr := html.EscapeString(headerErr)
 147  		// Note: It is a little weird we handle some errors by not going to the failPage. If they all should,
 148  		// change this to s.error() and make s.error() write the failPage instead of an error code.
 149  		_, _ = w.Write([]byte(fmt.Sprintf(failPage, escapedHeaderErr, desc)))
 150  		s.putResult(Result{Err: fmt.Errorf("%s", desc)})
 151  
 152  		return
 153  	}
 154  
 155  	respState := q.Get("state")
 156  	switch respState {
 157  	case s.reqState:
 158  	case "":
 159  		s.error(w, http.StatusInternalServerError, "server didn't send OAuth state")
 160  		return
 161  	default:
 162  		s.error(w, http.StatusInternalServerError, "mismatched OAuth state, req(%s), resp(%s)", s.reqState, respState)
 163  		return
 164  	}
 165  
 166  	code := q.Get("code")
 167  	if code == "" {
 168  		s.error(w, http.StatusInternalServerError, "authorization code missing in query string")
 169  		return
 170  	}
 171  
 172  	_, _ = w.Write(okPage)
 173  	s.putResult(Result{Code: code})
 174  }
 175  
 176  func (s *Server) error(w http.ResponseWriter, code int, str string, i ...interface{}) {
 177  	err := fmt.Errorf(str, i...)
 178  	http.Error(w, err.Error(), code)
 179  	s.putResult(Result{Err: err})
 180  }
 181