ecdsa.go raw

   1  package jwt
   2  
   3  import (
   4  	"crypto"
   5  	"crypto/ecdsa"
   6  	"crypto/rand"
   7  	"errors"
   8  	"math/big"
   9  )
  10  
  11  var (
  12  	// Sadly this is missing from crypto/ecdsa compared to crypto/rsa
  13  	ErrECDSAVerification = errors.New("crypto/ecdsa: verification error")
  14  )
  15  
  16  // SigningMethodECDSA implements the ECDSA family of signing methods.
  17  // Expects *ecdsa.PrivateKey for signing and *ecdsa.PublicKey for verification
  18  type SigningMethodECDSA struct {
  19  	Name      string
  20  	Hash      crypto.Hash
  21  	KeySize   int
  22  	CurveBits int
  23  }
  24  
  25  // Specific instances for EC256 and company
  26  var (
  27  	SigningMethodES256 *SigningMethodECDSA
  28  	SigningMethodES384 *SigningMethodECDSA
  29  	SigningMethodES512 *SigningMethodECDSA
  30  )
  31  
  32  func init() {
  33  	// ES256
  34  	SigningMethodES256 = &SigningMethodECDSA{"ES256", crypto.SHA256, 32, 256}
  35  	RegisterSigningMethod(SigningMethodES256.Alg(), func() SigningMethod {
  36  		return SigningMethodES256
  37  	})
  38  
  39  	// ES384
  40  	SigningMethodES384 = &SigningMethodECDSA{"ES384", crypto.SHA384, 48, 384}
  41  	RegisterSigningMethod(SigningMethodES384.Alg(), func() SigningMethod {
  42  		return SigningMethodES384
  43  	})
  44  
  45  	// ES512
  46  	SigningMethodES512 = &SigningMethodECDSA{"ES512", crypto.SHA512, 66, 521}
  47  	RegisterSigningMethod(SigningMethodES512.Alg(), func() SigningMethod {
  48  		return SigningMethodES512
  49  	})
  50  }
  51  
  52  func (m *SigningMethodECDSA) Alg() string {
  53  	return m.Name
  54  }
  55  
  56  // Verify implements token verification for the SigningMethod.
  57  // For this verify method, key must be an ecdsa.PublicKey struct
  58  func (m *SigningMethodECDSA) Verify(signingString string, sig []byte, key any) error {
  59  	// Get the key
  60  	var ecdsaKey *ecdsa.PublicKey
  61  	switch k := key.(type) {
  62  	case *ecdsa.PublicKey:
  63  		ecdsaKey = k
  64  	default:
  65  		return newError("ECDSA verify expects *ecdsa.PublicKey", ErrInvalidKeyType)
  66  	}
  67  
  68  	if len(sig) != 2*m.KeySize {
  69  		return ErrECDSAVerification
  70  	}
  71  
  72  	r := big.NewInt(0).SetBytes(sig[:m.KeySize])
  73  	s := big.NewInt(0).SetBytes(sig[m.KeySize:])
  74  
  75  	// Create hasher
  76  	if !m.Hash.Available() {
  77  		return ErrHashUnavailable
  78  	}
  79  	hasher := m.Hash.New()
  80  	hasher.Write([]byte(signingString))
  81  
  82  	// Verify the signature
  83  	if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus {
  84  		return nil
  85  	}
  86  
  87  	return ErrECDSAVerification
  88  }
  89  
  90  // Sign implements token signing for the SigningMethod.
  91  // For this signing method, key must be an ecdsa.PrivateKey struct
  92  func (m *SigningMethodECDSA) Sign(signingString string, key any) ([]byte, error) {
  93  	// Get the key
  94  	var ecdsaKey *ecdsa.PrivateKey
  95  	switch k := key.(type) {
  96  	case *ecdsa.PrivateKey:
  97  		ecdsaKey = k
  98  	default:
  99  		return nil, newError("ECDSA sign expects *ecdsa.PrivateKey", ErrInvalidKeyType)
 100  	}
 101  
 102  	// Create the hasher
 103  	if !m.Hash.Available() {
 104  		return nil, ErrHashUnavailable
 105  	}
 106  
 107  	hasher := m.Hash.New()
 108  	hasher.Write([]byte(signingString))
 109  
 110  	// Sign the string and return r, s
 111  	if r, s, err := ecdsa.Sign(rand.Reader, ecdsaKey, hasher.Sum(nil)); err == nil {
 112  		curveBits := ecdsaKey.Curve.Params().BitSize
 113  
 114  		if m.CurveBits != curveBits {
 115  			return nil, ErrInvalidKey
 116  		}
 117  
 118  		keyBytes := curveBits / 8
 119  		if curveBits%8 > 0 {
 120  			keyBytes += 1
 121  		}
 122  
 123  		// We serialize the outputs (r and s) into big-endian byte arrays
 124  		// padded with zeros on the left to make sure the sizes work out.
 125  		// Output must be 2*keyBytes long.
 126  		out := make([]byte, 2*keyBytes)
 127  		r.FillBytes(out[0:keyBytes]) // r is assigned to the first half of output.
 128  		s.FillBytes(out[keyBytes:])  // s is assigned to the second half of output.
 129  
 130  		return out, nil
 131  	} else {
 132  		return nil, err
 133  	}
 134  }
 135