tls.go raw
1 // Package tls provides a TLS/ACME transport for the relay.
2 package tls
3
4 import (
5 "context"
6 "crypto/tls"
7 "crypto/x509"
8 "fmt"
9 "net/http"
10 "os"
11 "path/filepath"
12 "strings"
13 "sync"
14
15 "golang.org/x/crypto/acme/autocert"
16 "next.orly.dev/pkg/lol/chk"
17 "next.orly.dev/pkg/lol/log"
18 )
19
20 // Config holds TLS transport configuration.
21 type Config struct {
22 // Domains is the list of domains for ACME auto-cert.
23 Domains []string
24 // Certs is a list of manual certificate paths (without extension).
25 // For each path, .pem and .key files are loaded.
26 Certs []string
27 // DataDir is the base data directory for the autocert cache.
28 DataDir string
29 // Handler is the HTTP handler to serve.
30 Handler http.Handler
31 }
32
33 // Transport serves HTTPS with automatic or manual TLS certificates.
34 // It runs two servers: HTTPS on :443 and HTTP on :80 for ACME challenges.
35 type Transport struct {
36 cfg *Config
37 tlsServer *http.Server
38 httpServer *http.Server
39 mu sync.Mutex
40 }
41
42 // New creates a new TLS transport.
43 func New(cfg *Config) *Transport {
44 return &Transport{cfg: cfg}
45 }
46
47 func (t *Transport) Name() string { return "tls" }
48
49 func (t *Transport) Start(ctx context.Context) error {
50 t.mu.Lock()
51 defer t.mu.Unlock()
52
53 if err := ValidateConfig(t.cfg.Domains, t.cfg.Certs); err != nil {
54 return fmt.Errorf("invalid TLS configuration: %w", err)
55 }
56
57 // Create cache directory for autocert
58 cacheDir := filepath.Join(t.cfg.DataDir, "autocert")
59 if err := os.MkdirAll(cacheDir, 0700); err != nil {
60 return fmt.Errorf("failed to create autocert cache directory: %w", err)
61 }
62
63 // Set up autocert manager
64 m := &autocert.Manager{
65 Prompt: autocert.AcceptTOS,
66 Cache: autocert.DirCache(cacheDir),
67 HostPolicy: autocert.HostWhitelist(t.cfg.Domains...),
68 }
69
70 // Create TLS server on port 443
71 t.tlsServer = &http.Server{
72 Addr: ":443",
73 Handler: t.cfg.Handler,
74 TLSConfig: tlsConfig(m, t.cfg.Certs...),
75 }
76
77 // Create HTTP server for ACME challenges and redirects on port 80
78 t.httpServer = &http.Server{
79 Addr: ":80",
80 Handler: m.HTTPHandler(nil),
81 }
82
83 log.I.F("TLS enabled for domains: %v", t.cfg.Domains)
84
85 // Start TLS server
86 go func() {
87 log.I.F("starting TLS listener on https://:443")
88 if err := t.tlsServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
89 log.E.F("TLS server error: %v", err)
90 }
91 }()
92
93 // Start HTTP server for ACME challenges
94 go func() {
95 log.I.F("starting HTTP listener on http://:80 for ACME challenges")
96 if err := t.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
97 log.E.F("HTTP server error: %v", err)
98 }
99 }()
100
101 return nil
102 }
103
104 func (t *Transport) Stop(ctx context.Context) error {
105 t.mu.Lock()
106 defer t.mu.Unlock()
107
108 var firstErr error
109
110 if t.tlsServer != nil {
111 if err := t.tlsServer.Shutdown(ctx); err != nil {
112 log.E.F("TLS server shutdown error: %v", err)
113 firstErr = err
114 } else {
115 log.I.F("TLS server shutdown completed")
116 }
117 }
118
119 if t.httpServer != nil {
120 if err := t.httpServer.Shutdown(ctx); err != nil {
121 log.E.F("HTTP server shutdown error: %v", err)
122 if firstErr == nil {
123 firstErr = err
124 }
125 } else {
126 log.I.F("HTTP server shutdown completed")
127 }
128 }
129
130 return firstErr
131 }
132
133 func (t *Transport) Addresses() []string {
134 var addrs []string
135 for _, domain := range t.cfg.Domains {
136 addrs = append(addrs, "wss://"+domain+"/")
137 }
138 return addrs
139 }
140
141 // ValidateConfig checks if the TLS configuration is valid.
142 func ValidateConfig(domains []string, certs []string) error {
143 if len(domains) == 0 {
144 return fmt.Errorf("no TLS domains specified")
145 }
146
147 for _, domain := range domains {
148 if domain == "" {
149 continue
150 }
151 if strings.Contains(domain, " ") || strings.Contains(domain, "\t") {
152 return fmt.Errorf("invalid domain name: %s", domain)
153 }
154 }
155
156 return nil
157 }
158
159 // tlsConfig returns a TLS configuration that works with LetsEncrypt automatic
160 // SSL cert issuer as well as any provided certificate files.
161 //
162 // Certs are provided as paths where .pem and .key files exist.
163 func tlsConfig(m *autocert.Manager, certs ...string) *tls.Config {
164 certMap := make(map[string]*tls.Certificate)
165 var mx sync.Mutex
166
167 for _, certPath := range certs {
168 if certPath == "" {
169 continue
170 }
171
172 var err error
173 var c tls.Certificate
174
175 if c, err = tls.LoadX509KeyPair(
176 certPath+".pem", certPath+".key",
177 ); chk.E(err) {
178 log.E.F("failed to load certificate from %s: %v", certPath, err)
179 continue
180 }
181
182 if len(c.Certificate) > 0 {
183 if x509Cert, err := x509.ParseCertificate(c.Certificate[0]); err == nil {
184 if x509Cert.Subject.CommonName != "" {
185 certMap[x509Cert.Subject.CommonName] = &c
186 log.I.F("loaded certificate for domain: %s", x509Cert.Subject.CommonName)
187 }
188 for _, san := range x509Cert.DNSNames {
189 if san != "" {
190 certMap[san] = &c
191 log.I.F("loaded certificate for SAN domain: %s", san)
192 }
193 }
194 }
195 }
196 }
197
198 if m == nil {
199 return &tls.Config{
200 GetCertificate: func(helo *tls.ClientHelloInfo) (*tls.Certificate, error) {
201 mx.Lock()
202 defer mx.Unlock()
203
204 if cert, exists := certMap[helo.ServerName]; exists {
205 return cert, nil
206 }
207
208 for domain, cert := range certMap {
209 if strings.HasPrefix(domain, "*.") {
210 baseDomain := domain[2:]
211 if strings.HasSuffix(helo.ServerName, baseDomain) {
212 return cert, nil
213 }
214 }
215 }
216
217 return nil, fmt.Errorf("no certificate found for %s", helo.ServerName)
218 },
219 }
220 }
221
222 tc := m.TLSConfig()
223 tc.GetCertificate = func(helo *tls.ClientHelloInfo) (*tls.Certificate, error) {
224 mx.Lock()
225
226 if cert, exists := certMap[helo.ServerName]; exists {
227 mx.Unlock()
228 return cert, nil
229 }
230
231 for domain, cert := range certMap {
232 if strings.HasPrefix(domain, "*.") {
233 baseDomain := domain[2:]
234 if strings.HasSuffix(helo.ServerName, baseDomain) {
235 mx.Unlock()
236 return cert, nil
237 }
238 }
239 }
240
241 mx.Unlock()
242
243 return m.GetCertificate(helo)
244 }
245
246 return tc
247 }
248