digest.go raw

   1  // Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com)
   2  // 2023 Segev Dagan (https://github.com/segevda)
   3  // 2024 Philipp Wolfer (https://github.com/phw)
   4  // All rights reserved.
   5  // resty source code and usage is governed by a MIT style
   6  // license that can be found in the LICENSE file.
   7  
   8  package resty
   9  
  10  import (
  11  	"crypto/md5"
  12  	"crypto/rand"
  13  	"crypto/sha256"
  14  	"crypto/sha512"
  15  	"errors"
  16  	"fmt"
  17  	"hash"
  18  	"io"
  19  	"net/http"
  20  	"strings"
  21  )
  22  
  23  var (
  24  	ErrDigestBadChallenge    = errors.New("digest: challenge is bad")
  25  	ErrDigestCharset         = errors.New("digest: unsupported charset")
  26  	ErrDigestAlgNotSupported = errors.New("digest: algorithm is not supported")
  27  	ErrDigestQopNotSupported = errors.New("digest: no supported qop in list")
  28  	ErrDigestNoQop           = errors.New("digest: qop must be specified")
  29  )
  30  
  31  var hashFuncs = map[string]func() hash.Hash{
  32  	"":                 md5.New,
  33  	"MD5":              md5.New,
  34  	"MD5-sess":         md5.New,
  35  	"SHA-256":          sha256.New,
  36  	"SHA-256-sess":     sha256.New,
  37  	"SHA-512-256":      sha512.New,
  38  	"SHA-512-256-sess": sha512.New,
  39  }
  40  
  41  type digestCredentials struct {
  42  	username, password string
  43  }
  44  
  45  type digestTransport struct {
  46  	digestCredentials
  47  	transport http.RoundTripper
  48  }
  49  
  50  func (dt *digestTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  51  	// Copy the request, so we don't modify the input.
  52  	req2 := new(http.Request)
  53  	*req2 = *req
  54  	req2.Header = make(http.Header)
  55  	for k, s := range req.Header {
  56  		req2.Header[k] = s
  57  	}
  58  
  59  	// Fix http: ContentLength=xxx with Body length 0
  60  	if req2.Body == nil {
  61  		req2.ContentLength = 0
  62  	} else if req2.GetBody != nil {
  63  		var err error
  64  		req2.Body, err = req2.GetBody()
  65  		if err != nil {
  66  			return nil, err
  67  		}
  68  	}
  69  
  70  	// Make a request to get the 401 that contains the challenge.
  71  	resp, err := dt.transport.RoundTrip(req)
  72  	if err != nil || resp.StatusCode != http.StatusUnauthorized {
  73  		return resp, err
  74  	}
  75  	chal := resp.Header.Get(hdrWwwAuthenticateKey)
  76  	if chal == "" {
  77  		return resp, ErrDigestBadChallenge
  78  	}
  79  
  80  	c, err := parseChallenge(chal)
  81  	if err != nil {
  82  		return resp, err
  83  	}
  84  
  85  	// Form credentials based on the challenge
  86  	cr := dt.newCredentials(req2, c)
  87  	auth, err := cr.authorize()
  88  	if err != nil {
  89  		return resp, err
  90  	}
  91  	err = resp.Body.Close()
  92  	if err != nil {
  93  		return nil, err
  94  	}
  95  
  96  	// Make authenticated request
  97  	req2.Header.Set(hdrAuthorizationKey, auth)
  98  	return dt.transport.RoundTrip(req2)
  99  }
 100  
 101  func (dt *digestTransport) newCredentials(req *http.Request, c *challenge) *credentials {
 102  	return &credentials{
 103  		username:   dt.username,
 104  		userhash:   c.userhash,
 105  		realm:      c.realm,
 106  		nonce:      c.nonce,
 107  		digestURI:  req.URL.RequestURI(),
 108  		algorithm:  c.algorithm,
 109  		sessionAlg: strings.HasSuffix(c.algorithm, "-sess"),
 110  		opaque:     c.opaque,
 111  		messageQop: c.qop,
 112  		nc:         0,
 113  		method:     req.Method,
 114  		password:   dt.password,
 115  	}
 116  }
 117  
 118  type challenge struct {
 119  	realm     string
 120  	domain    string
 121  	nonce     string
 122  	opaque    string
 123  	stale     string
 124  	algorithm string
 125  	qop       string
 126  	userhash  string
 127  }
 128  
 129  func (c *challenge) setValue(k, v string) error {
 130  	switch k {
 131  	case "realm":
 132  		c.realm = v
 133  	case "domain":
 134  		c.domain = v
 135  	case "nonce":
 136  		c.nonce = v
 137  	case "opaque":
 138  		c.opaque = v
 139  	case "stale":
 140  		c.stale = v
 141  	case "algorithm":
 142  		c.algorithm = v
 143  	case "qop":
 144  		c.qop = v
 145  	case "charset":
 146  		if strings.ToUpper(v) != "UTF-8" {
 147  			return ErrDigestCharset
 148  		}
 149  	case "userhash":
 150  		c.userhash = v
 151  	default:
 152  		return ErrDigestBadChallenge
 153  	}
 154  	return nil
 155  }
 156  
 157  func parseChallenge(input string) (*challenge, error) {
 158  	const ws = " \n\r\t"
 159  	s := strings.Trim(input, ws)
 160  	if !strings.HasPrefix(s, "Digest ") {
 161  		return nil, ErrDigestBadChallenge
 162  	}
 163  	s = strings.Trim(s[7:], ws)
 164  	c := &challenge{}
 165  	b := strings.Builder{}
 166  	key := ""
 167  	quoted := false
 168  	for _, r := range s {
 169  		switch r {
 170  		case '"':
 171  			quoted = !quoted
 172  		case ',':
 173  			if quoted {
 174  				b.WriteRune(r)
 175  			} else {
 176  				val := strings.Trim(b.String(), ws)
 177  				b.Reset()
 178  				if err := c.setValue(key, val); err != nil {
 179  					return nil, err
 180  				}
 181  				key = ""
 182  			}
 183  		case '=':
 184  			if quoted {
 185  				b.WriteRune(r)
 186  			} else {
 187  				key = strings.Trim(b.String(), ws)
 188  				b.Reset()
 189  			}
 190  		default:
 191  			b.WriteRune(r)
 192  		}
 193  	}
 194  	if quoted || (key == "" && b.Len() > 0) {
 195  		return nil, ErrDigestBadChallenge
 196  	}
 197  	if key != "" {
 198  		val := strings.Trim(b.String(), ws)
 199  		if err := c.setValue(key, val); err != nil {
 200  			return nil, err
 201  		}
 202  	}
 203  	return c, nil
 204  }
 205  
 206  type credentials struct {
 207  	username   string
 208  	userhash   string
 209  	realm      string
 210  	nonce      string
 211  	digestURI  string
 212  	algorithm  string
 213  	sessionAlg bool
 214  	cNonce     string
 215  	opaque     string
 216  	messageQop string
 217  	nc         int
 218  	method     string
 219  	password   string
 220  }
 221  
 222  func (c *credentials) authorize() (string, error) {
 223  	if _, ok := hashFuncs[c.algorithm]; !ok {
 224  		return "", ErrDigestAlgNotSupported
 225  	}
 226  
 227  	if err := c.validateQop(); err != nil {
 228  		return "", err
 229  	}
 230  
 231  	resp, err := c.resp()
 232  	if err != nil {
 233  		return "", err
 234  	}
 235  
 236  	sl := make([]string, 0, 10)
 237  	if c.userhash == "true" {
 238  		// RFC 7616 3.4.4
 239  		c.username = c.h(fmt.Sprintf("%s:%s", c.username, c.realm))
 240  		sl = append(sl, fmt.Sprintf(`userhash=%s`, c.userhash))
 241  	}
 242  	sl = append(sl, fmt.Sprintf(`username="%s"`, c.username))
 243  	sl = append(sl, fmt.Sprintf(`realm="%s"`, c.realm))
 244  	sl = append(sl, fmt.Sprintf(`nonce="%s"`, c.nonce))
 245  	sl = append(sl, fmt.Sprintf(`uri="%s"`, c.digestURI))
 246  	sl = append(sl, fmt.Sprintf(`response="%s"`, resp))
 247  	sl = append(sl, fmt.Sprintf(`algorithm=%s`, c.algorithm))
 248  	if c.opaque != "" {
 249  		sl = append(sl, fmt.Sprintf(`opaque="%s"`, c.opaque))
 250  	}
 251  	if c.messageQop != "" {
 252  		sl = append(sl, fmt.Sprintf("qop=%s", c.messageQop))
 253  		sl = append(sl, fmt.Sprintf("nc=%08x", c.nc))
 254  		sl = append(sl, fmt.Sprintf(`cnonce="%s"`, c.cNonce))
 255  	}
 256  
 257  	return fmt.Sprintf("Digest %s", strings.Join(sl, ", ")), nil
 258  }
 259  
 260  func (c *credentials) validateQop() error {
 261  	// Currently only supporting auth quality of protection. TODO: add auth-int support
 262  	// NOTE: cURL support auth-int qop for requests other than POST and PUT (i.e. w/o body) by hashing an empty string
 263  	// is this applicable for resty? see: https://github.com/curl/curl/blob/307b7543ea1e73ab04e062bdbe4b5bb409eaba3a/lib/vauth/digest.c#L774
 264  	if c.messageQop == "" {
 265  		return ErrDigestNoQop
 266  	}
 267  	possibleQops := strings.Split(c.messageQop, ",")
 268  	var authSupport bool
 269  	for _, qop := range possibleQops {
 270  		qop = strings.TrimSpace(qop)
 271  		if qop == "auth" {
 272  			authSupport = true
 273  			break
 274  		}
 275  	}
 276  	if !authSupport {
 277  		return ErrDigestQopNotSupported
 278  	}
 279  
 280  	c.messageQop = "auth"
 281  
 282  	return nil
 283  }
 284  
 285  func (c *credentials) h(data string) string {
 286  	hfCtor := hashFuncs[c.algorithm]
 287  	hf := hfCtor()
 288  	_, _ = hf.Write([]byte(data)) // Hash.Write never returns an error
 289  	return fmt.Sprintf("%x", hf.Sum(nil))
 290  }
 291  
 292  func (c *credentials) resp() (string, error) {
 293  	c.nc++
 294  
 295  	b := make([]byte, 16)
 296  	_, err := io.ReadFull(rand.Reader, b)
 297  	if err != nil {
 298  		return "", err
 299  	}
 300  	c.cNonce = fmt.Sprintf("%x", b)[:32]
 301  
 302  	ha1 := c.ha1()
 303  	ha2 := c.ha2()
 304  
 305  	return c.kd(ha1, fmt.Sprintf("%s:%08x:%s:%s:%s",
 306  		c.nonce, c.nc, c.cNonce, c.messageQop, ha2)), nil
 307  }
 308  
 309  func (c *credentials) kd(secret, data string) string {
 310  	return c.h(fmt.Sprintf("%s:%s", secret, data))
 311  }
 312  
 313  // RFC 7616 3.4.2
 314  func (c *credentials) ha1() string {
 315  	ret := c.h(fmt.Sprintf("%s:%s:%s", c.username, c.realm, c.password))
 316  	if c.sessionAlg {
 317  		return c.h(fmt.Sprintf("%s:%s:%s", ret, c.nonce, c.cNonce))
 318  	}
 319  
 320  	return ret
 321  }
 322  
 323  // RFC 7616 3.4.3
 324  func (c *credentials) ha2() string {
 325  	// currently no auth-int support
 326  	return c.h(fmt.Sprintf("%s:%s", c.method, c.digestURI))
 327  }
 328