jwt.go raw

   1  // Copyright 2023 Google LLC
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License");
   4  // you may not use this file except in compliance with the License.
   5  // You may obtain a copy of the License at
   6  //
   7  //      http://www.apache.org/licenses/LICENSE-2.0
   8  //
   9  // Unless required by applicable law or agreed to in writing, software
  10  // distributed under the License is distributed on an "AS IS" BASIS,
  11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12  // See the License for the specific language governing permissions and
  13  // limitations under the License.
  14  
  15  package jwt
  16  
  17  import (
  18  	"bytes"
  19  	"crypto"
  20  	"crypto/rand"
  21  	"crypto/rsa"
  22  	"crypto/sha256"
  23  	"encoding/base64"
  24  	"encoding/json"
  25  	"errors"
  26  	"fmt"
  27  	"strings"
  28  	"time"
  29  )
  30  
  31  const (
  32  	// HeaderAlgRSA256 is the RS256 [Header.Algorithm].
  33  	HeaderAlgRSA256 = "RS256"
  34  	// HeaderAlgES256 is the ES256 [Header.Algorithm].
  35  	HeaderAlgES256 = "ES256"
  36  	// HeaderType is the standard [Header.Type].
  37  	HeaderType = "JWT"
  38  )
  39  
  40  // Header represents a JWT header.
  41  type Header struct {
  42  	Algorithm string `json:"alg"`
  43  	Type      string `json:"typ"`
  44  	KeyID     string `json:"kid"`
  45  }
  46  
  47  func (h *Header) encode() (string, error) {
  48  	b, err := json.Marshal(h)
  49  	if err != nil {
  50  		return "", err
  51  	}
  52  	return base64.RawURLEncoding.EncodeToString(b), nil
  53  }
  54  
  55  // Claims represents the claims set of a JWT.
  56  type Claims struct {
  57  	// Iss is the issuer JWT claim.
  58  	Iss string `json:"iss"`
  59  	// Scope is the scope JWT claim.
  60  	Scope string `json:"scope,omitempty"`
  61  	// Exp is the expiry JWT claim. If unset, default is in one hour from now.
  62  	Exp int64 `json:"exp"`
  63  	// Iat is the subject issued at claim. If unset, default is now.
  64  	Iat int64 `json:"iat"`
  65  	// Aud is the audience JWT claim. Optional.
  66  	Aud string `json:"aud"`
  67  	// Sub is the subject JWT claim. Optional.
  68  	Sub string `json:"sub,omitempty"`
  69  	// AdditionalClaims contains any additional non-standard JWT claims. Optional.
  70  	AdditionalClaims map[string]interface{} `json:"-"`
  71  }
  72  
  73  func (c *Claims) encode() (string, error) {
  74  	// Compensate for skew
  75  	now := time.Now().Add(-10 * time.Second)
  76  	if c.Iat == 0 {
  77  		c.Iat = now.Unix()
  78  	}
  79  	if c.Exp == 0 {
  80  		c.Exp = now.Add(time.Hour).Unix()
  81  	}
  82  	if c.Exp < c.Iat {
  83  		return "", fmt.Errorf("jwt: invalid Exp = %d; must be later than Iat = %d", c.Exp, c.Iat)
  84  	}
  85  
  86  	b, err := json.Marshal(c)
  87  	if err != nil {
  88  		return "", err
  89  	}
  90  
  91  	if len(c.AdditionalClaims) == 0 {
  92  		return base64.RawURLEncoding.EncodeToString(b), nil
  93  	}
  94  
  95  	// Marshal private claim set and then append it to b.
  96  	prv, err := json.Marshal(c.AdditionalClaims)
  97  	if err != nil {
  98  		return "", fmt.Errorf("invalid map of additional claims %v: %w", c.AdditionalClaims, err)
  99  	}
 100  
 101  	// Concatenate public and private claim JSON objects.
 102  	if !bytes.HasSuffix(b, []byte{'}'}) {
 103  		return "", fmt.Errorf("invalid JSON %s", b)
 104  	}
 105  	if !bytes.HasPrefix(prv, []byte{'{'}) {
 106  		return "", fmt.Errorf("invalid JSON %s", prv)
 107  	}
 108  	b[len(b)-1] = ','         // Replace closing curly brace with a comma.
 109  	b = append(b, prv[1:]...) // Append private claims.
 110  	return base64.RawURLEncoding.EncodeToString(b), nil
 111  }
 112  
 113  // EncodeJWS encodes the data using the provided key as a JSON web signature.
 114  func EncodeJWS(header *Header, c *Claims, signer crypto.Signer) (string, error) {
 115  	head, err := header.encode()
 116  	if err != nil {
 117  		return "", err
 118  	}
 119  	claims, err := c.encode()
 120  	if err != nil {
 121  		return "", err
 122  	}
 123  	ss := fmt.Sprintf("%s.%s", head, claims)
 124  	h := sha256.New()
 125  	h.Write([]byte(ss))
 126  	sig, err := signer.Sign(rand.Reader, h.Sum(nil), crypto.SHA256)
 127  	if err != nil {
 128  		return "", err
 129  	}
 130  	return fmt.Sprintf("%s.%s", ss, base64.RawURLEncoding.EncodeToString(sig)), nil
 131  }
 132  
 133  // DecodeJWS decodes a claim set from a JWS payload.
 134  func DecodeJWS(payload string) (*Claims, error) {
 135  	// decode returned id token to get expiry
 136  	s := strings.Split(payload, ".")
 137  	if len(s) < 2 {
 138  		return nil, errors.New("invalid token received")
 139  	}
 140  	decoded, err := base64.RawURLEncoding.DecodeString(s[1])
 141  	if err != nil {
 142  		return nil, err
 143  	}
 144  	c := &Claims{}
 145  	if err := json.NewDecoder(bytes.NewBuffer(decoded)).Decode(c); err != nil {
 146  		return nil, err
 147  	}
 148  	if err := json.NewDecoder(bytes.NewBuffer(decoded)).Decode(&c.AdditionalClaims); err != nil {
 149  		return nil, err
 150  	}
 151  	return c, err
 152  }
 153  
 154  // VerifyJWS tests whether the provided JWT token's signature was produced by
 155  // the private key associated with the provided public key.
 156  func VerifyJWS(token string, key *rsa.PublicKey) error {
 157  	parts := strings.Split(token, ".")
 158  	if len(parts) != 3 {
 159  		return errors.New("jwt: invalid token received, token must have 3 parts")
 160  	}
 161  
 162  	signedContent := parts[0] + "." + parts[1]
 163  	signatureString, err := base64.RawURLEncoding.DecodeString(parts[2])
 164  	if err != nil {
 165  		return err
 166  	}
 167  
 168  	h := sha256.New()
 169  	h.Write([]byte(signedContent))
 170  	return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString)
 171  }
 172