crypto.go raw

   1  package certcrypto
   2  
   3  import (
   4  	"crypto"
   5  	"crypto/ecdsa"
   6  	"crypto/ed25519"
   7  	"crypto/elliptic"
   8  	"crypto/rand"
   9  	"crypto/rsa"
  10  	"crypto/x509"
  11  	"crypto/x509/pkix"
  12  	"encoding/asn1"
  13  	"encoding/pem"
  14  	"errors"
  15  	"fmt"
  16  	"math/big"
  17  	"net"
  18  	"slices"
  19  	"strings"
  20  	"time"
  21  
  22  	"golang.org/x/crypto/ocsp"
  23  )
  24  
  25  // Constants for all key types we support.
  26  const (
  27  	EC256   = KeyType("P256")
  28  	EC384   = KeyType("P384")
  29  	RSA2048 = KeyType("2048")
  30  	RSA3072 = KeyType("3072")
  31  	RSA4096 = KeyType("4096")
  32  	RSA8192 = KeyType("8192")
  33  )
  34  
  35  const (
  36  	// OCSPGood means that the certificate is valid.
  37  	OCSPGood = ocsp.Good
  38  	// OCSPRevoked means that the certificate has been deliberately revoked.
  39  	OCSPRevoked = ocsp.Revoked
  40  	// OCSPUnknown means that the OCSP responder doesn't know about the certificate.
  41  	OCSPUnknown = ocsp.Unknown
  42  	// OCSPServerFailed means that the OCSP responder failed to process the request.
  43  	OCSPServerFailed = ocsp.ServerFailed
  44  )
  45  
  46  // Constants for OCSP must staple.
  47  var (
  48  	tlsFeatureExtensionOID = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 24}
  49  	ocspMustStapleFeature  = []byte{0x30, 0x03, 0x02, 0x01, 0x05}
  50  )
  51  
  52  // KeyType represents the key algo as well as the key size or curve to use.
  53  type KeyType string
  54  
  55  type DERCertificateBytes []byte
  56  
  57  // ParsePEMBundle parses a certificate bundle from top to bottom and returns
  58  // a slice of x509 certificates. This function will error if no certificates are found.
  59  func ParsePEMBundle(bundle []byte) ([]*x509.Certificate, error) {
  60  	var (
  61  		certificates []*x509.Certificate
  62  		certDERBlock *pem.Block
  63  	)
  64  
  65  	for {
  66  		certDERBlock, bundle = pem.Decode(bundle)
  67  		if certDERBlock == nil {
  68  			break
  69  		}
  70  
  71  		if certDERBlock.Type == "CERTIFICATE" {
  72  			cert, err := x509.ParseCertificate(certDERBlock.Bytes)
  73  			if err != nil {
  74  				return nil, err
  75  			}
  76  
  77  			certificates = append(certificates, cert)
  78  		}
  79  	}
  80  
  81  	if len(certificates) == 0 {
  82  		return nil, errors.New("no certificates were found while parsing the bundle")
  83  	}
  84  
  85  	return certificates, nil
  86  }
  87  
  88  // ParsePEMPrivateKey parses a private key from key, which is a PEM block.
  89  // Borrowed from Go standard library, to handle various private key and PEM block types.
  90  // https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L291-L308
  91  // https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L238
  92  func ParsePEMPrivateKey(key []byte) (crypto.PrivateKey, error) {
  93  	keyBlockDER, _ := pem.Decode(key)
  94  	if keyBlockDER == nil {
  95  		return nil, errors.New("invalid PEM block")
  96  	}
  97  
  98  	if keyBlockDER.Type != "PRIVATE KEY" && !strings.HasSuffix(keyBlockDER.Type, " PRIVATE KEY") {
  99  		return nil, fmt.Errorf("unknown PEM header %q", keyBlockDER.Type)
 100  	}
 101  
 102  	if key, err := x509.ParsePKCS1PrivateKey(keyBlockDER.Bytes); err == nil {
 103  		return key, nil
 104  	}
 105  
 106  	if key, err := x509.ParsePKCS8PrivateKey(keyBlockDER.Bytes); err == nil {
 107  		switch key := key.(type) {
 108  		case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
 109  			return key, nil
 110  		default:
 111  			return nil, fmt.Errorf("found unknown private key type in PKCS#8 wrapping: %T", key)
 112  		}
 113  	}
 114  
 115  	if key, err := x509.ParseECPrivateKey(keyBlockDER.Bytes); err == nil {
 116  		return key, nil
 117  	}
 118  
 119  	return nil, errors.New("failed to parse private key")
 120  }
 121  
 122  func GeneratePrivateKey(keyType KeyType) (crypto.PrivateKey, error) {
 123  	switch keyType {
 124  	case EC256:
 125  		return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
 126  	case EC384:
 127  		return ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
 128  	case RSA2048:
 129  		return rsa.GenerateKey(rand.Reader, 2048)
 130  	case RSA3072:
 131  		return rsa.GenerateKey(rand.Reader, 3072)
 132  	case RSA4096:
 133  		return rsa.GenerateKey(rand.Reader, 4096)
 134  	case RSA8192:
 135  		return rsa.GenerateKey(rand.Reader, 8192)
 136  	}
 137  
 138  	return nil, fmt.Errorf("invalid KeyType: %s", keyType)
 139  }
 140  
 141  // Deprecated: uses [CreateCSR] instead.
 142  func GenerateCSR(privateKey crypto.PrivateKey, domain string, san []string, mustStaple bool) ([]byte, error) {
 143  	return CreateCSR(privateKey, CSROptions{
 144  		Domain:     domain,
 145  		SAN:        san,
 146  		MustStaple: mustStaple,
 147  	})
 148  }
 149  
 150  type CSROptions struct {
 151  	Domain         string
 152  	SAN            []string
 153  	MustStaple     bool
 154  	EmailAddresses []string
 155  }
 156  
 157  func CreateCSR(privateKey crypto.PrivateKey, opts CSROptions) ([]byte, error) {
 158  	var (
 159  		dnsNames    []string
 160  		ipAddresses []net.IP
 161  	)
 162  
 163  	for _, altname := range opts.SAN {
 164  		if ip := net.ParseIP(altname); ip != nil {
 165  			ipAddresses = append(ipAddresses, ip)
 166  		} else {
 167  			dnsNames = append(dnsNames, altname)
 168  		}
 169  	}
 170  
 171  	template := x509.CertificateRequest{
 172  		Subject:        pkix.Name{CommonName: opts.Domain},
 173  		DNSNames:       dnsNames,
 174  		EmailAddresses: opts.EmailAddresses,
 175  		IPAddresses:    ipAddresses,
 176  	}
 177  
 178  	if opts.MustStaple {
 179  		template.ExtraExtensions = append(template.ExtraExtensions, pkix.Extension{
 180  			Id:    tlsFeatureExtensionOID,
 181  			Value: ocspMustStapleFeature,
 182  		})
 183  	}
 184  
 185  	return x509.CreateCertificateRequest(rand.Reader, &template, privateKey)
 186  }
 187  
 188  func PEMEncode(data any) []byte {
 189  	return pem.EncodeToMemory(PEMBlock(data))
 190  }
 191  
 192  func PEMBlock(data any) *pem.Block {
 193  	var pemBlock *pem.Block
 194  
 195  	switch key := data.(type) {
 196  	case *ecdsa.PrivateKey:
 197  		keyBytes, _ := x509.MarshalECPrivateKey(key)
 198  		pemBlock = &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyBytes}
 199  	case *rsa.PrivateKey:
 200  		pemBlock = &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}
 201  	case *x509.CertificateRequest:
 202  		pemBlock = &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: key.Raw}
 203  	case DERCertificateBytes:
 204  		pemBlock = &pem.Block{Type: "CERTIFICATE", Bytes: []byte(data.(DERCertificateBytes))}
 205  	}
 206  
 207  	return pemBlock
 208  }
 209  
 210  func pemDecode(data []byte) (*pem.Block, error) {
 211  	pemBlock, _ := pem.Decode(data)
 212  	if pemBlock == nil {
 213  		return nil, errors.New("PEM decode did not yield a valid block. Is the certificate in the right format?")
 214  	}
 215  
 216  	return pemBlock, nil
 217  }
 218  
 219  func PemDecodeTox509CSR(data []byte) (*x509.CertificateRequest, error) {
 220  	pemBlock, err := pemDecode(data)
 221  	if pemBlock == nil {
 222  		return nil, err
 223  	}
 224  
 225  	if pemBlock.Type != "CERTIFICATE REQUEST" && pemBlock.Type != "NEW CERTIFICATE REQUEST" {
 226  		return nil, errors.New("PEM block is not a certificate request")
 227  	}
 228  
 229  	return x509.ParseCertificateRequest(pemBlock.Bytes)
 230  }
 231  
 232  // ParsePEMCertificate returns Certificate from a PEM encoded certificate.
 233  // The certificate has to be PEM encoded. Any other encodings like DER will fail.
 234  func ParsePEMCertificate(cert []byte) (*x509.Certificate, error) {
 235  	pemBlock, err := pemDecode(cert)
 236  	if pemBlock == nil {
 237  		return nil, err
 238  	}
 239  
 240  	// from a DER encoded certificate
 241  	return x509.ParseCertificate(pemBlock.Bytes)
 242  }
 243  
 244  func GetCertificateMainDomain(cert *x509.Certificate) (string, error) {
 245  	return getMainDomain(cert.Subject, cert.DNSNames)
 246  }
 247  
 248  func GetCSRMainDomain(cert *x509.CertificateRequest) (string, error) {
 249  	return getMainDomain(cert.Subject, cert.DNSNames)
 250  }
 251  
 252  func getMainDomain(subject pkix.Name, dnsNames []string) (string, error) {
 253  	if subject.CommonName == "" && len(dnsNames) == 0 {
 254  		return "", errors.New("missing domain")
 255  	}
 256  
 257  	if subject.CommonName != "" {
 258  		return subject.CommonName, nil
 259  	}
 260  
 261  	return dnsNames[0], nil
 262  }
 263  
 264  func ExtractDomains(cert *x509.Certificate) []string {
 265  	var domains []string
 266  	if cert.Subject.CommonName != "" {
 267  		domains = append(domains, cert.Subject.CommonName)
 268  	}
 269  
 270  	// Check for SAN certificate
 271  	for _, sanDomain := range cert.DNSNames {
 272  		if sanDomain == cert.Subject.CommonName {
 273  			continue
 274  		}
 275  
 276  		domains = append(domains, sanDomain)
 277  	}
 278  
 279  	commonNameIP := net.ParseIP(cert.Subject.CommonName)
 280  	for _, sanIP := range cert.IPAddresses {
 281  		if !commonNameIP.Equal(sanIP) {
 282  			domains = append(domains, sanIP.String())
 283  		}
 284  	}
 285  
 286  	return domains
 287  }
 288  
 289  func ExtractDomainsCSR(csr *x509.CertificateRequest) []string {
 290  	var domains []string
 291  	if csr.Subject.CommonName != "" {
 292  		domains = append(domains, csr.Subject.CommonName)
 293  	}
 294  
 295  	// loop over the SubjectAltName DNS names
 296  	for _, sanName := range csr.DNSNames {
 297  		if slices.Contains(domains, sanName) {
 298  			// Duplicate; skip this name
 299  			continue
 300  		}
 301  
 302  		// Name is unique
 303  		domains = append(domains, sanName)
 304  	}
 305  
 306  	cnip := net.ParseIP(csr.Subject.CommonName)
 307  	for _, sanIP := range csr.IPAddresses {
 308  		if !cnip.Equal(sanIP) {
 309  			domains = append(domains, sanIP.String())
 310  		}
 311  	}
 312  
 313  	return domains
 314  }
 315  
 316  func GeneratePemCert(privateKey *rsa.PrivateKey, domain string, extensions []pkix.Extension) ([]byte, error) {
 317  	derBytes, err := generateDerCert(privateKey, time.Time{}, domain, extensions)
 318  	if err != nil {
 319  		return nil, err
 320  	}
 321  
 322  	return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}), nil
 323  }
 324  
 325  func generateDerCert(privateKey *rsa.PrivateKey, expiration time.Time, domain string, extensions []pkix.Extension) ([]byte, error) {
 326  	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
 327  
 328  	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
 329  	if err != nil {
 330  		return nil, err
 331  	}
 332  
 333  	if expiration.IsZero() {
 334  		expiration = time.Now().AddDate(1, 0, 0)
 335  	}
 336  
 337  	template := x509.Certificate{
 338  		SerialNumber: serialNumber,
 339  		Subject: pkix.Name{
 340  			CommonName: "ACME Challenge TEMP",
 341  		},
 342  		NotBefore: time.Now(),
 343  		NotAfter:  expiration,
 344  
 345  		KeyUsage:              x509.KeyUsageKeyEncipherment,
 346  		BasicConstraintsValid: true,
 347  		ExtraExtensions:       extensions,
 348  	}
 349  
 350  	// handling SAN filling as type suspected
 351  	if ip := net.ParseIP(domain); ip != nil {
 352  		template.IPAddresses = []net.IP{ip}
 353  	} else {
 354  		template.DNSNames = []string{domain}
 355  	}
 356  
 357  	return x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
 358  }
 359