rsa_pss.go raw

   1  package jwt
   2  
   3  import (
   4  	"crypto"
   5  	"crypto/rand"
   6  	"crypto/rsa"
   7  )
   8  
   9  // SigningMethodRSAPSS implements the RSAPSS family of signing methods signing methods
  10  type SigningMethodRSAPSS struct {
  11  	*SigningMethodRSA
  12  	Options *rsa.PSSOptions
  13  	// VerifyOptions is optional. If set overrides Options for rsa.VerifyPPS.
  14  	// Used to accept tokens signed with rsa.PSSSaltLengthAuto, what doesn't follow
  15  	// https://tools.ietf.org/html/rfc7518#section-3.5 but was used previously.
  16  	// See https://github.com/dgrijalva/jwt-go/issues/285#issuecomment-437451244 for details.
  17  	VerifyOptions *rsa.PSSOptions
  18  }
  19  
  20  // Specific instances for RS/PS and company.
  21  var (
  22  	SigningMethodPS256 *SigningMethodRSAPSS
  23  	SigningMethodPS384 *SigningMethodRSAPSS
  24  	SigningMethodPS512 *SigningMethodRSAPSS
  25  )
  26  
  27  func init() {
  28  	// PS256
  29  	SigningMethodPS256 = &SigningMethodRSAPSS{
  30  		SigningMethodRSA: &SigningMethodRSA{
  31  			Name: "PS256",
  32  			Hash: crypto.SHA256,
  33  		},
  34  		Options: &rsa.PSSOptions{
  35  			SaltLength: rsa.PSSSaltLengthEqualsHash,
  36  		},
  37  		VerifyOptions: &rsa.PSSOptions{
  38  			SaltLength: rsa.PSSSaltLengthAuto,
  39  		},
  40  	}
  41  	RegisterSigningMethod(SigningMethodPS256.Alg(), func() SigningMethod {
  42  		return SigningMethodPS256
  43  	})
  44  
  45  	// PS384
  46  	SigningMethodPS384 = &SigningMethodRSAPSS{
  47  		SigningMethodRSA: &SigningMethodRSA{
  48  			Name: "PS384",
  49  			Hash: crypto.SHA384,
  50  		},
  51  		Options: &rsa.PSSOptions{
  52  			SaltLength: rsa.PSSSaltLengthEqualsHash,
  53  		},
  54  		VerifyOptions: &rsa.PSSOptions{
  55  			SaltLength: rsa.PSSSaltLengthAuto,
  56  		},
  57  	}
  58  	RegisterSigningMethod(SigningMethodPS384.Alg(), func() SigningMethod {
  59  		return SigningMethodPS384
  60  	})
  61  
  62  	// PS512
  63  	SigningMethodPS512 = &SigningMethodRSAPSS{
  64  		SigningMethodRSA: &SigningMethodRSA{
  65  			Name: "PS512",
  66  			Hash: crypto.SHA512,
  67  		},
  68  		Options: &rsa.PSSOptions{
  69  			SaltLength: rsa.PSSSaltLengthEqualsHash,
  70  		},
  71  		VerifyOptions: &rsa.PSSOptions{
  72  			SaltLength: rsa.PSSSaltLengthAuto,
  73  		},
  74  	}
  75  	RegisterSigningMethod(SigningMethodPS512.Alg(), func() SigningMethod {
  76  		return SigningMethodPS512
  77  	})
  78  }
  79  
  80  // Verify implements token verification for the SigningMethod.
  81  // For this verify method, key must be an rsa.PublicKey struct
  82  func (m *SigningMethodRSAPSS) Verify(signingString string, sig []byte, key any) error {
  83  	var rsaKey *rsa.PublicKey
  84  	switch k := key.(type) {
  85  	case *rsa.PublicKey:
  86  		rsaKey = k
  87  	default:
  88  		return newError("RSA-PSS verify expects *rsa.PublicKey", ErrInvalidKeyType)
  89  	}
  90  
  91  	// Create hasher
  92  	if !m.Hash.Available() {
  93  		return ErrHashUnavailable
  94  	}
  95  	hasher := m.Hash.New()
  96  	hasher.Write([]byte(signingString))
  97  
  98  	opts := m.Options
  99  	if m.VerifyOptions != nil {
 100  		opts = m.VerifyOptions
 101  	}
 102  
 103  	return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts)
 104  }
 105  
 106  // Sign implements token signing for the SigningMethod.
 107  // For this signing method, key must be an rsa.PrivateKey struct
 108  func (m *SigningMethodRSAPSS) Sign(signingString string, key any) ([]byte, error) {
 109  	var rsaKey *rsa.PrivateKey
 110  
 111  	switch k := key.(type) {
 112  	case *rsa.PrivateKey:
 113  		rsaKey = k
 114  	default:
 115  		return nil, newError("RSA-PSS sign expects *rsa.PrivateKey", ErrInvalidKeyType)
 116  	}
 117  
 118  	// Create the hasher
 119  	if !m.Hash.Available() {
 120  		return nil, ErrHashUnavailable
 121  	}
 122  
 123  	hasher := m.Hash.New()
 124  	hasher.Write([]byte(signingString))
 125  
 126  	// Sign the string and return the encoded bytes
 127  	if sigBytes, err := rsa.SignPSS(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil), m.Options); err == nil {
 128  		return sigBytes, nil
 129  	} else {
 130  		return nil, err
 131  	}
 132  }
 133