manager.go raw

   1  package main
   2  
   3  import (
   4  	"crypto"
   5  	"crypto/ecdsa"
   6  	"crypto/elliptic"
   7  	"crypto/rand"
   8  	"crypto/x509"
   9  	"encoding/json"
  10  	"encoding/pem"
  11  	"fmt"
  12  	"os"
  13  	"path/filepath"
  14  	"time"
  15  
  16  	"github.com/go-acme/lego/v4/certcrypto"
  17  	"github.com/go-acme/lego/v4/certificate"
  18  	"github.com/go-acme/lego/v4/lego"
  19  	"github.com/go-acme/lego/v4/registration"
  20  	"next.orly.dev/pkg/lol/chk"
  21  	"next.orly.dev/pkg/lol/log"
  22  )
  23  
  24  // CertManager handles certificate acquisition and renewal.
  25  type CertManager struct {
  26  	cfg      *Config
  27  	client   *lego.Client
  28  	user     *User
  29  	certPath string
  30  	keyPath  string
  31  	metaPath string
  32  }
  33  
  34  // User implements the lego registration.User interface.
  35  type User struct {
  36  	Email        string
  37  	Registration *registration.Resource
  38  	key          crypto.PrivateKey
  39  }
  40  
  41  func (u *User) GetEmail() string {
  42  	return u.Email
  43  }
  44  
  45  func (u *User) GetRegistration() *registration.Resource {
  46  	return u.Registration
  47  }
  48  
  49  func (u *User) GetPrivateKey() crypto.PrivateKey {
  50  	return u.key
  51  }
  52  
  53  // CertMetadata stores certificate metadata.
  54  type CertMetadata struct {
  55  	Domain    string    `json:"domain"`
  56  	Domains   []string  `json:"domains"`
  57  	NotBefore time.Time `json:"not_before"`
  58  	NotAfter  time.Time `json:"not_after"`
  59  	Issuer    string    `json:"issuer"`
  60  	RenewedAt time.Time `json:"renewed_at"`
  61  }
  62  
  63  // NewCertManager creates a new certificate manager.
  64  func NewCertManager(cfg *Config) (*CertManager, error) {
  65  	// Create output directory
  66  	domainDir := filepath.Join(cfg.OutputDir, cfg.BaseDomain())
  67  	if err := os.MkdirAll(domainDir, 0755); chk.E(err) {
  68  		return nil, fmt.Errorf("failed to create output directory: %w", err)
  69  	}
  70  
  71  	// Generate or load account private key
  72  	privateKey, err := loadOrCreateAccountKey(cfg)
  73  	if chk.E(err) {
  74  		return nil, fmt.Errorf("failed to load/create account key: %w", err)
  75  	}
  76  
  77  	user := &User{
  78  		Email: cfg.Email,
  79  		key:   privateKey,
  80  	}
  81  
  82  	// Create lego config
  83  	legoCfg := lego.NewConfig(user)
  84  	legoCfg.CADirURL = cfg.ACMEServerURL()
  85  	legoCfg.Certificate.KeyType = certcrypto.EC256
  86  
  87  	// Create lego client
  88  	client, err := lego.NewClient(legoCfg)
  89  	if chk.E(err) {
  90  		return nil, fmt.Errorf("failed to create ACME client: %w", err)
  91  	}
  92  
  93  	// Set up DNS provider
  94  	dnsProvider, err := NewDNSProvider(cfg.DNSProvider)
  95  	if chk.E(err) {
  96  		return nil, err
  97  	}
  98  
  99  	if err := client.Challenge.SetDNS01Provider(dnsProvider); chk.E(err) {
 100  		return nil, fmt.Errorf("failed to set DNS provider: %w", err)
 101  	}
 102  
 103  	// Register account if needed
 104  	reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
 105  	if err != nil {
 106  		// Try to recover existing registration
 107  		reg, err = client.Registration.ResolveAccountByKey()
 108  		if chk.E(err) {
 109  			return nil, fmt.Errorf("failed to register account: %w", err)
 110  		}
 111  	}
 112  	user.Registration = reg
 113  
 114  	return &CertManager{
 115  		cfg:      cfg,
 116  		client:   client,
 117  		user:     user,
 118  		certPath: filepath.Join(domainDir, "cert.pem"),
 119  		keyPath:  filepath.Join(domainDir, "key.pem"),
 120  		metaPath: filepath.Join(domainDir, "metadata.json"),
 121  	}, nil
 122  }
 123  
 124  // EnsureCertificate obtains a certificate if none exists or if it needs renewal.
 125  func (m *CertManager) EnsureCertificate() error {
 126  	// Check if certificate exists and is valid
 127  	if m.certificateExists() {
 128  		needsRenewal, err := m.needsRenewal()
 129  		if chk.E(err) {
 130  			log.W.F("failed to check renewal status, will obtain new cert: %v", err)
 131  		} else if !needsRenewal {
 132  			log.I.F("certificate is valid, no renewal needed")
 133  			return nil
 134  		}
 135  		log.I.F("certificate needs renewal")
 136  	}
 137  
 138  	return m.obtainCertificate()
 139  }
 140  
 141  // CheckRenewal checks if the certificate needs renewal and renews if needed.
 142  func (m *CertManager) CheckRenewal() error {
 143  	if !m.certificateExists() {
 144  		return m.obtainCertificate()
 145  	}
 146  
 147  	needsRenewal, err := m.needsRenewal()
 148  	if chk.E(err) {
 149  		return err
 150  	}
 151  
 152  	if needsRenewal {
 153  		log.I.F("certificate expiring soon, renewing...")
 154  		return m.obtainCertificate()
 155  	}
 156  
 157  	log.D.F("certificate still valid, no renewal needed")
 158  	return nil
 159  }
 160  
 161  func (m *CertManager) certificateExists() bool {
 162  	_, err := os.Stat(m.certPath)
 163  	return err == nil
 164  }
 165  
 166  func (m *CertManager) needsRenewal() (bool, error) {
 167  	certPEM, err := os.ReadFile(m.certPath)
 168  	if chk.E(err) {
 169  		return true, err
 170  	}
 171  
 172  	block, _ := pem.Decode(certPEM)
 173  	if block == nil {
 174  		return true, fmt.Errorf("failed to decode certificate PEM")
 175  	}
 176  
 177  	cert, err := x509.ParseCertificate(block.Bytes)
 178  	if chk.E(err) {
 179  		return true, err
 180  	}
 181  
 182  	// Check if certificate expires within RenewDays
 183  	renewTime := time.Now().Add(time.Duration(m.cfg.RenewDays) * 24 * time.Hour)
 184  	return cert.NotAfter.Before(renewTime), nil
 185  }
 186  
 187  func (m *CertManager) obtainCertificate() error {
 188  	log.I.F("obtaining certificate for %s", m.cfg.Domain)
 189  
 190  	request := certificate.ObtainRequest{
 191  		Domains: []string{m.cfg.Domain, m.cfg.BaseDomain()},
 192  		Bundle:  true,
 193  	}
 194  
 195  	certificates, err := m.client.Certificate.Obtain(request)
 196  	if chk.E(err) {
 197  		return fmt.Errorf("failed to obtain certificate: %w", err)
 198  	}
 199  
 200  	// Write certificate chain
 201  	if err := os.WriteFile(m.certPath, certificates.Certificate, 0644); chk.E(err) {
 202  		return fmt.Errorf("failed to write certificate: %w", err)
 203  	}
 204  
 205  	// Write private key with restricted permissions
 206  	if err := os.WriteFile(m.keyPath, certificates.PrivateKey, 0600); chk.E(err) {
 207  		return fmt.Errorf("failed to write private key: %w", err)
 208  	}
 209  
 210  	// Write issuer certificate if available
 211  	if len(certificates.IssuerCertificate) > 0 {
 212  		issuerPath := filepath.Join(filepath.Dir(m.certPath), "issuer.pem")
 213  		if err := os.WriteFile(issuerPath, certificates.IssuerCertificate, 0644); chk.E(err) {
 214  			log.W.F("failed to write issuer certificate: %v", err)
 215  		}
 216  	}
 217  
 218  	// Write metadata
 219  	if err := m.writeMetadata(certificates.Certificate); chk.E(err) {
 220  		log.W.F("failed to write metadata: %v", err)
 221  	}
 222  
 223  	log.I.F("certificate obtained successfully for %s", m.cfg.Domain)
 224  	log.I.F("  cert: %s", m.certPath)
 225  	log.I.F("  key:  %s", m.keyPath)
 226  
 227  	return nil
 228  }
 229  
 230  func (m *CertManager) writeMetadata(certPEM []byte) error {
 231  	block, _ := pem.Decode(certPEM)
 232  	if block == nil {
 233  		return fmt.Errorf("failed to decode certificate for metadata")
 234  	}
 235  
 236  	cert, err := x509.ParseCertificate(block.Bytes)
 237  	if chk.E(err) {
 238  		return err
 239  	}
 240  
 241  	meta := CertMetadata{
 242  		Domain:    m.cfg.Domain,
 243  		Domains:   cert.DNSNames,
 244  		NotBefore: cert.NotBefore,
 245  		NotAfter:  cert.NotAfter,
 246  		Issuer:    cert.Issuer.CommonName,
 247  		RenewedAt: time.Now(),
 248  	}
 249  
 250  	data, err := json.MarshalIndent(meta, "", "  ")
 251  	if chk.E(err) {
 252  		return err
 253  	}
 254  
 255  	return os.WriteFile(m.metaPath, data, 0644)
 256  }
 257  
 258  func loadOrCreateAccountKey(cfg *Config) (crypto.PrivateKey, error) {
 259  	keyPath := cfg.AccountKeyPath
 260  	if keyPath == "" {
 261  		keyPath = filepath.Join(cfg.OutputDir, "account.key")
 262  	}
 263  
 264  	// Try to load existing key
 265  	if data, err := os.ReadFile(keyPath); err == nil {
 266  		block, _ := pem.Decode(data)
 267  		if block != nil {
 268  			key, err := x509.ParseECPrivateKey(block.Bytes)
 269  			if err == nil {
 270  				log.D.F("loaded existing account key from %s", keyPath)
 271  				return key, nil
 272  			}
 273  		}
 274  	}
 275  
 276  	// Generate new key
 277  	log.I.F("generating new account key")
 278  	key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
 279  	if chk.E(err) {
 280  		return nil, err
 281  	}
 282  
 283  	// Save key
 284  	keyBytes, err := x509.MarshalECPrivateKey(key)
 285  	if chk.E(err) {
 286  		return nil, err
 287  	}
 288  
 289  	keyPEM := pem.EncodeToMemory(&pem.Block{
 290  		Type:  "EC PRIVATE KEY",
 291  		Bytes: keyBytes,
 292  	})
 293  
 294  	if err := os.MkdirAll(filepath.Dir(keyPath), 0755); chk.E(err) {
 295  		return nil, err
 296  	}
 297  
 298  	if err := os.WriteFile(keyPath, keyPEM, 0600); chk.E(err) {
 299  		return nil, err
 300  	}
 301  
 302  	log.I.F("saved new account key to %s", keyPath)
 303  	return key, nil
 304  }
 305