server.go raw

   1  package wireguard
   2  
   3  import (
   4  	"context"
   5  	"encoding/base64"
   6  	"encoding/hex"
   7  	"fmt"
   8  	"net"
   9  	"net/netip"
  10  	"sync"
  11  
  12  	"golang.zx2c4.com/wireguard/conn"
  13  	"golang.zx2c4.com/wireguard/device"
  14  	"golang.zx2c4.com/wireguard/tun"
  15  	"golang.zx2c4.com/wireguard/tun/netstack"
  16  	"next.orly.dev/pkg/lol/log"
  17  )
  18  
  19  // Config holds the WireGuard server configuration.
  20  type Config struct {
  21  	Port       int    // UDP port for WireGuard (default 51820)
  22  	Endpoint   string // Public IP/domain for clients to connect to
  23  	PrivateKey []byte // Server's 32-byte Curve25519 private key
  24  	Network    string // CIDR for internal network (e.g., "10.73.0.0/16")
  25  	ServerIP   string // Server's internal IP (e.g., "10.73.0.1")
  26  }
  27  
  28  // Peer represents a WireGuard peer (client).
  29  type Peer struct {
  30  	NostrPubkey []byte // User's Nostr pubkey (32 bytes)
  31  	WGPublicKey []byte // WireGuard public key (32 bytes)
  32  	AssignedIP  string // Assigned internal IP
  33  }
  34  
  35  // Server manages the embedded WireGuard VPN server.
  36  type Server struct {
  37  	cfg       *Config
  38  	device    *device.Device
  39  	tun       *netstack.Net
  40  	tunDev    tun.Device
  41  	publicKey []byte
  42  
  43  	peers   map[string]*Peer // WG pubkey (base64) -> Peer
  44  	peersMu sync.RWMutex
  45  
  46  	ctx     context.Context
  47  	cancel  context.CancelFunc
  48  	running bool
  49  	mu      sync.RWMutex
  50  }
  51  
  52  // New creates a new WireGuard server with the given configuration.
  53  func New(cfg *Config) (*Server, error) {
  54  	if cfg.Endpoint == "" {
  55  		return nil, ErrEndpointRequired
  56  	}
  57  
  58  	// Parse network CIDR to validate it
  59  	_, _, err := net.ParseCIDR(cfg.Network)
  60  	if err != nil {
  61  		return nil, fmt.Errorf("%w: %v", ErrInvalidNetwork, err)
  62  	}
  63  
  64  	// Default server IP if not set
  65  	if cfg.ServerIP == "" {
  66  		cfg.ServerIP = "10.73.0.1"
  67  	}
  68  
  69  	// Derive public key from private key
  70  	publicKey, err := DerivePublicKey(cfg.PrivateKey)
  71  	if err != nil {
  72  		return nil, fmt.Errorf("failed to derive public key: %w", err)
  73  	}
  74  
  75  	return &Server{
  76  		cfg:       cfg,
  77  		publicKey: publicKey,
  78  		peers:     make(map[string]*Peer),
  79  	}, nil
  80  }
  81  
  82  // Start initializes and starts the WireGuard server.
  83  func (s *Server) Start() error {
  84  	s.mu.Lock()
  85  	defer s.mu.Unlock()
  86  
  87  	if s.running {
  88  		return nil
  89  	}
  90  
  91  	s.ctx, s.cancel = context.WithCancel(context.Background())
  92  
  93  	// Parse server IP
  94  	serverAddr, err := netip.ParseAddr(s.cfg.ServerIP)
  95  	if err != nil {
  96  		return fmt.Errorf("invalid server IP: %w", err)
  97  	}
  98  
  99  	// Create netstack TUN device (userspace, no root required)
 100  	s.tunDev, s.tun, err = netstack.CreateNetTUN(
 101  		[]netip.Addr{serverAddr},
 102  		[]netip.Addr{}, // No DNS servers
 103  		1420,           // MTU
 104  	)
 105  	if err != nil {
 106  		return fmt.Errorf("failed to create netstack TUN: %w", err)
 107  	}
 108  
 109  	// Create WireGuard device
 110  	s.device = device.NewDevice(
 111  		s.tunDev,
 112  		conn.NewDefaultBind(),
 113  		device.NewLogger(device.LogLevelSilent, "wg"),
 114  	)
 115  
 116  	// Configure device with server private key and listen port
 117  	privateKeyHex := hex.EncodeToString(s.cfg.PrivateKey)
 118  	ipcConfig := fmt.Sprintf("private_key=%s\nlisten_port=%d\n",
 119  		privateKeyHex,
 120  		s.cfg.Port,
 121  	)
 122  
 123  	if err = s.device.IpcSet(ipcConfig); err != nil {
 124  		s.device.Close()
 125  		return fmt.Errorf("failed to configure WireGuard device: %w", err)
 126  	}
 127  
 128  	// Bring up the device
 129  	if err = s.device.Up(); err != nil {
 130  		s.device.Close()
 131  		return fmt.Errorf("failed to bring up WireGuard device: %w", err)
 132  	}
 133  
 134  	s.running = true
 135  	log.I.F("WireGuard server started on UDP port %d", s.cfg.Port)
 136  	log.I.F("WireGuard server public key: %s", base64.StdEncoding.EncodeToString(s.publicKey))
 137  	log.I.F("WireGuard internal network: %s (server: %s)", s.cfg.Network, s.cfg.ServerIP)
 138  
 139  	return nil
 140  }
 141  
 142  // Stop shuts down the WireGuard server.
 143  func (s *Server) Stop() error {
 144  	s.mu.Lock()
 145  	defer s.mu.Unlock()
 146  
 147  	if !s.running {
 148  		return nil
 149  	}
 150  
 151  	if s.cancel != nil {
 152  		s.cancel()
 153  	}
 154  
 155  	if s.device != nil {
 156  		s.device.Close()
 157  	}
 158  
 159  	s.running = false
 160  	log.I.F("WireGuard server stopped")
 161  
 162  	return nil
 163  }
 164  
 165  // IsRunning returns whether the server is currently running.
 166  func (s *Server) IsRunning() bool {
 167  	s.mu.RLock()
 168  	defer s.mu.RUnlock()
 169  	return s.running
 170  }
 171  
 172  // ServerPublicKey returns the server's WireGuard public key.
 173  func (s *Server) ServerPublicKey() []byte {
 174  	return s.publicKey
 175  }
 176  
 177  // Endpoint returns the configured endpoint address.
 178  func (s *Server) Endpoint() string {
 179  	return fmt.Sprintf("%s:%d", s.cfg.Endpoint, s.cfg.Port)
 180  }
 181  
 182  // GetNetstack returns the netstack networking interface.
 183  // This is used by the bunker to listen on the WireGuard network.
 184  func (s *Server) GetNetstack() *netstack.Net {
 185  	s.mu.RLock()
 186  	defer s.mu.RUnlock()
 187  	return s.tun
 188  }
 189  
 190  // ServerIP returns the server's internal IP address.
 191  func (s *Server) ServerIP() string {
 192  	return s.cfg.ServerIP
 193  }
 194  
 195  // AddPeer adds a new peer to the WireGuard server.
 196  func (s *Server) AddPeer(nostrPubkey, wgPublicKey []byte, assignedIP string) error {
 197  	s.mu.RLock()
 198  	if !s.running {
 199  		s.mu.RUnlock()
 200  		return ErrServerNotRunning
 201  	}
 202  	s.mu.RUnlock()
 203  
 204  	// Encode WG public key as hex for IPC
 205  	wgPubkeyHex := hex.EncodeToString(wgPublicKey)
 206  	wgPubkeyBase64 := base64.StdEncoding.EncodeToString(wgPublicKey)
 207  
 208  	// Configure peer in WireGuard device
 209  	ipcConfig := fmt.Sprintf(
 210  		"public_key=%s\nallowed_ip=%s/32\n",
 211  		wgPubkeyHex,
 212  		assignedIP,
 213  	)
 214  
 215  	if err := s.device.IpcSet(ipcConfig); err != nil {
 216  		return fmt.Errorf("failed to add peer: %w", err)
 217  	}
 218  
 219  	// Track peer
 220  	s.peersMu.Lock()
 221  	s.peers[wgPubkeyBase64] = &Peer{
 222  		NostrPubkey: nostrPubkey,
 223  		WGPublicKey: wgPublicKey,
 224  		AssignedIP:  assignedIP,
 225  	}
 226  	s.peersMu.Unlock()
 227  
 228  	log.D.F("WireGuard peer added: %s -> %s", wgPubkeyBase64[:16]+"...", assignedIP)
 229  
 230  	return nil
 231  }
 232  
 233  // RemovePeer removes a peer from the WireGuard server.
 234  func (s *Server) RemovePeer(wgPublicKey []byte) error {
 235  	s.mu.RLock()
 236  	if !s.running {
 237  		s.mu.RUnlock()
 238  		return ErrServerNotRunning
 239  	}
 240  	s.mu.RUnlock()
 241  
 242  	wgPubkeyHex := hex.EncodeToString(wgPublicKey)
 243  	wgPubkeyBase64 := base64.StdEncoding.EncodeToString(wgPublicKey)
 244  
 245  	// Remove peer from WireGuard device
 246  	ipcConfig := fmt.Sprintf(
 247  		"public_key=%s\nremove=true\n",
 248  		wgPubkeyHex,
 249  	)
 250  
 251  	if err := s.device.IpcSet(ipcConfig); err != nil {
 252  		return fmt.Errorf("failed to remove peer: %w", err)
 253  	}
 254  
 255  	// Remove from tracking
 256  	s.peersMu.Lock()
 257  	delete(s.peers, wgPubkeyBase64)
 258  	s.peersMu.Unlock()
 259  
 260  	log.D.F("WireGuard peer removed: %s", wgPubkeyBase64[:16]+"...")
 261  
 262  	return nil
 263  }
 264  
 265  // GetPeer returns a peer by their WireGuard public key.
 266  func (s *Server) GetPeer(wgPublicKey []byte) (*Peer, bool) {
 267  	wgPubkeyBase64 := base64.StdEncoding.EncodeToString(wgPublicKey)
 268  
 269  	s.peersMu.RLock()
 270  	defer s.peersMu.RUnlock()
 271  
 272  	peer, ok := s.peers[wgPubkeyBase64]
 273  	return peer, ok
 274  }
 275  
 276  // PeerCount returns the number of active peers.
 277  func (s *Server) PeerCount() int {
 278  	s.peersMu.RLock()
 279  	defer s.peersMu.RUnlock()
 280  	return len(s.peers)
 281  }
 282