tls.go raw

   1  // Package tls provides a TLS/ACME transport for the relay.
   2  package tls
   3  
   4  import (
   5  	"context"
   6  	"crypto/tls"
   7  	"crypto/x509"
   8  	"fmt"
   9  	"net/http"
  10  	"os"
  11  	"path/filepath"
  12  	"strings"
  13  	"sync"
  14  
  15  	"golang.org/x/crypto/acme/autocert"
  16  	"next.orly.dev/pkg/lol/chk"
  17  	"next.orly.dev/pkg/lol/log"
  18  )
  19  
  20  // Config holds TLS transport configuration.
  21  type Config struct {
  22  	// Domains is the list of domains for ACME auto-cert.
  23  	Domains []string
  24  	// Certs is a list of manual certificate paths (without extension).
  25  	// For each path, .pem and .key files are loaded.
  26  	Certs []string
  27  	// DataDir is the base data directory for the autocert cache.
  28  	DataDir string
  29  	// Handler is the HTTP handler to serve.
  30  	Handler http.Handler
  31  }
  32  
  33  // Transport serves HTTPS with automatic or manual TLS certificates.
  34  // It runs two servers: HTTPS on :443 and HTTP on :80 for ACME challenges.
  35  type Transport struct {
  36  	cfg        *Config
  37  	tlsServer  *http.Server
  38  	httpServer *http.Server
  39  	mu         sync.Mutex
  40  }
  41  
  42  // New creates a new TLS transport.
  43  func New(cfg *Config) *Transport {
  44  	return &Transport{cfg: cfg}
  45  }
  46  
  47  func (t *Transport) Name() string { return "tls" }
  48  
  49  func (t *Transport) Start(ctx context.Context) error {
  50  	t.mu.Lock()
  51  	defer t.mu.Unlock()
  52  
  53  	if err := ValidateConfig(t.cfg.Domains, t.cfg.Certs); err != nil {
  54  		return fmt.Errorf("invalid TLS configuration: %w", err)
  55  	}
  56  
  57  	// Create cache directory for autocert
  58  	cacheDir := filepath.Join(t.cfg.DataDir, "autocert")
  59  	if err := os.MkdirAll(cacheDir, 0700); err != nil {
  60  		return fmt.Errorf("failed to create autocert cache directory: %w", err)
  61  	}
  62  
  63  	// Set up autocert manager
  64  	m := &autocert.Manager{
  65  		Prompt:     autocert.AcceptTOS,
  66  		Cache:      autocert.DirCache(cacheDir),
  67  		HostPolicy: autocert.HostWhitelist(t.cfg.Domains...),
  68  	}
  69  
  70  	// Create TLS server on port 443
  71  	t.tlsServer = &http.Server{
  72  		Addr:      ":443",
  73  		Handler:   t.cfg.Handler,
  74  		TLSConfig: tlsConfig(m, t.cfg.Certs...),
  75  	}
  76  
  77  	// Create HTTP server for ACME challenges and redirects on port 80
  78  	t.httpServer = &http.Server{
  79  		Addr:    ":80",
  80  		Handler: m.HTTPHandler(nil),
  81  	}
  82  
  83  	log.I.F("TLS enabled for domains: %v", t.cfg.Domains)
  84  
  85  	// Start TLS server
  86  	go func() {
  87  		log.I.F("starting TLS listener on https://:443")
  88  		if err := t.tlsServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
  89  			log.E.F("TLS server error: %v", err)
  90  		}
  91  	}()
  92  
  93  	// Start HTTP server for ACME challenges
  94  	go func() {
  95  		log.I.F("starting HTTP listener on http://:80 for ACME challenges")
  96  		if err := t.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
  97  			log.E.F("HTTP server error: %v", err)
  98  		}
  99  	}()
 100  
 101  	return nil
 102  }
 103  
 104  func (t *Transport) Stop(ctx context.Context) error {
 105  	t.mu.Lock()
 106  	defer t.mu.Unlock()
 107  
 108  	var firstErr error
 109  
 110  	if t.tlsServer != nil {
 111  		if err := t.tlsServer.Shutdown(ctx); err != nil {
 112  			log.E.F("TLS server shutdown error: %v", err)
 113  			firstErr = err
 114  		} else {
 115  			log.I.F("TLS server shutdown completed")
 116  		}
 117  	}
 118  
 119  	if t.httpServer != nil {
 120  		if err := t.httpServer.Shutdown(ctx); err != nil {
 121  			log.E.F("HTTP server shutdown error: %v", err)
 122  			if firstErr == nil {
 123  				firstErr = err
 124  			}
 125  		} else {
 126  			log.I.F("HTTP server shutdown completed")
 127  		}
 128  	}
 129  
 130  	return firstErr
 131  }
 132  
 133  func (t *Transport) Addresses() []string {
 134  	var addrs []string
 135  	for _, domain := range t.cfg.Domains {
 136  		addrs = append(addrs, "wss://"+domain+"/")
 137  	}
 138  	return addrs
 139  }
 140  
 141  // ValidateConfig checks if the TLS configuration is valid.
 142  func ValidateConfig(domains []string, certs []string) error {
 143  	if len(domains) == 0 {
 144  		return fmt.Errorf("no TLS domains specified")
 145  	}
 146  
 147  	for _, domain := range domains {
 148  		if domain == "" {
 149  			continue
 150  		}
 151  		if strings.Contains(domain, " ") || strings.Contains(domain, "\t") {
 152  			return fmt.Errorf("invalid domain name: %s", domain)
 153  		}
 154  	}
 155  
 156  	return nil
 157  }
 158  
 159  // tlsConfig returns a TLS configuration that works with LetsEncrypt automatic
 160  // SSL cert issuer as well as any provided certificate files.
 161  //
 162  // Certs are provided as paths where .pem and .key files exist.
 163  func tlsConfig(m *autocert.Manager, certs ...string) *tls.Config {
 164  	certMap := make(map[string]*tls.Certificate)
 165  	var mx sync.Mutex
 166  
 167  	for _, certPath := range certs {
 168  		if certPath == "" {
 169  			continue
 170  		}
 171  
 172  		var err error
 173  		var c tls.Certificate
 174  
 175  		if c, err = tls.LoadX509KeyPair(
 176  			certPath+".pem", certPath+".key",
 177  		); chk.E(err) {
 178  			log.E.F("failed to load certificate from %s: %v", certPath, err)
 179  			continue
 180  		}
 181  
 182  		if len(c.Certificate) > 0 {
 183  			if x509Cert, err := x509.ParseCertificate(c.Certificate[0]); err == nil {
 184  				if x509Cert.Subject.CommonName != "" {
 185  					certMap[x509Cert.Subject.CommonName] = &c
 186  					log.I.F("loaded certificate for domain: %s", x509Cert.Subject.CommonName)
 187  				}
 188  				for _, san := range x509Cert.DNSNames {
 189  					if san != "" {
 190  						certMap[san] = &c
 191  						log.I.F("loaded certificate for SAN domain: %s", san)
 192  					}
 193  				}
 194  			}
 195  		}
 196  	}
 197  
 198  	if m == nil {
 199  		return &tls.Config{
 200  			GetCertificate: func(helo *tls.ClientHelloInfo) (*tls.Certificate, error) {
 201  				mx.Lock()
 202  				defer mx.Unlock()
 203  
 204  				if cert, exists := certMap[helo.ServerName]; exists {
 205  					return cert, nil
 206  				}
 207  
 208  				for domain, cert := range certMap {
 209  					if strings.HasPrefix(domain, "*.") {
 210  						baseDomain := domain[2:]
 211  						if strings.HasSuffix(helo.ServerName, baseDomain) {
 212  							return cert, nil
 213  						}
 214  					}
 215  				}
 216  
 217  				return nil, fmt.Errorf("no certificate found for %s", helo.ServerName)
 218  			},
 219  		}
 220  	}
 221  
 222  	tc := m.TLSConfig()
 223  	tc.GetCertificate = func(helo *tls.ClientHelloInfo) (*tls.Certificate, error) {
 224  		mx.Lock()
 225  
 226  		if cert, exists := certMap[helo.ServerName]; exists {
 227  			mx.Unlock()
 228  			return cert, nil
 229  		}
 230  
 231  		for domain, cert := range certMap {
 232  			if strings.HasPrefix(domain, "*.") {
 233  				baseDomain := domain[2:]
 234  				if strings.HasSuffix(helo.ServerName, baseDomain) {
 235  					mx.Unlock()
 236  					return cert, nil
 237  				}
 238  			}
 239  		}
 240  
 241  		mx.Unlock()
 242  
 243  		return m.GetCertificate(helo)
 244  	}
 245  
 246  	return tc
 247  }
 248