httpguard.go raw

   1  // Package httpguard provides application-level HTTP protection: bot User-Agent
   2  // blocking and per-IP rate limiting. Designed for environments where the reverse
   3  // proxy (e.g., Cloudron's nginx) cannot be customized.
   4  package httpguard
   5  
   6  import (
   7  	"net"
   8  	"net/http"
   9  	"strings"
  10  	"sync"
  11  	"sync/atomic"
  12  	"time"
  13  
  14  	"next.orly.dev/pkg/lol/log"
  15  )
  16  
  17  // blockedBots is the list of User-Agent substrings to block. Matched
  18  // case-insensitively. Sourced from relay.orly.dev Caddy config.
  19  var blockedBots = []string{
  20  	"semrushbot",
  21  	"ahrefsbot",
  22  	"mj12bot",
  23  	"dotbot",
  24  	"petalbot",
  25  	"blexbot",
  26  	"dataforseobot",
  27  	"amazonbot",
  28  	"meta-externalagent",
  29  	"bytespider",
  30  	"gptbot",
  31  	"claudebot",
  32  	"ccbot",
  33  	"facebookbot",
  34  }
  35  
  36  // Config holds Guard configuration.
  37  type Config struct {
  38  	Enabled      bool
  39  	BotBlock     bool
  40  	RPM          int // HTTP requests per minute per IP
  41  	WSPerMin     int // WebSocket upgrades per minute per IP
  42  	IPBlacklist  []string
  43  }
  44  
  45  // Guard is the HTTP guard middleware.
  46  type Guard struct {
  47  	cfg     Config
  48  	clients sync.Map // map[string]*clientState
  49  
  50  	// cleanup tracking
  51  	stopCleanup chan struct{}
  52  }
  53  
  54  type clientState struct {
  55  	httpTokens  atomic.Int64
  56  	wsTokens    atomic.Int64
  57  	lastSeen    atomic.Int64 // unix seconds
  58  }
  59  
  60  const (
  61  	cleanupInterval = 5 * time.Minute
  62  	idleEvictTime   = 10 * time.Minute
  63  )
  64  
  65  // New creates a new Guard. Starts a background cleanup goroutine.
  66  func New(cfg Config) *Guard {
  67  	if cfg.RPM <= 0 {
  68  		cfg.RPM = 120
  69  	}
  70  	if cfg.WSPerMin <= 0 {
  71  		cfg.WSPerMin = 10
  72  	}
  73  
  74  	g := &Guard{
  75  		cfg:         cfg,
  76  		stopCleanup: make(chan struct{}),
  77  	}
  78  
  79  	go g.cleanupLoop()
  80  	go g.refillLoop()
  81  
  82  	return g
  83  }
  84  
  85  // Stop shuts down the cleanup goroutine.
  86  func (g *Guard) Stop() {
  87  	close(g.stopCleanup)
  88  }
  89  
  90  // Allow checks whether the request should be allowed. If blocked, it writes
  91  // the appropriate HTTP response (403 or 429) and returns false. If allowed,
  92  // returns true without touching the ResponseWriter.
  93  func (g *Guard) Allow(w http.ResponseWriter, r *http.Request) bool {
  94  	if !g.cfg.Enabled {
  95  		return true
  96  	}
  97  
  98  	ip := extractIP(r)
  99  
 100  	// IP blacklist
 101  	for _, blocked := range g.cfg.IPBlacklist {
 102  		if strings.HasPrefix(ip, blocked) {
 103  			http.Error(w, "Forbidden", http.StatusForbidden)
 104  			return false
 105  		}
 106  	}
 107  
 108  	// Bot User-Agent blocking
 109  	if g.cfg.BotBlock {
 110  		ua := strings.ToLower(r.Header.Get("User-Agent"))
 111  		for _, bot := range blockedBots {
 112  			if strings.Contains(ua, bot) {
 113  				http.Error(w, "Forbidden", http.StatusForbidden)
 114  				return false
 115  			}
 116  		}
 117  	}
 118  
 119  	// Rate limiting
 120  	cs := g.getOrCreate(ip)
 121  	now := time.Now().Unix()
 122  	cs.lastSeen.Store(now)
 123  
 124  	isWS := isWebSocketUpgrade(r)
 125  	if isWS {
 126  		if cs.wsTokens.Add(-1) < 0 {
 127  			cs.wsTokens.Add(1) // restore
 128  			w.Header().Set("Retry-After", "60")
 129  			http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
 130  			return false
 131  		}
 132  	}
 133  
 134  	if cs.httpTokens.Add(-1) < 0 {
 135  		cs.httpTokens.Add(1) // restore
 136  		w.Header().Set("Retry-After", "60")
 137  		http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
 138  		return false
 139  	}
 140  
 141  	return true
 142  }
 143  
 144  func (g *Guard) getOrCreate(ip string) *clientState {
 145  	if val, ok := g.clients.Load(ip); ok {
 146  		return val.(*clientState)
 147  	}
 148  	cs := &clientState{}
 149  	cs.httpTokens.Store(int64(g.cfg.RPM))
 150  	cs.wsTokens.Store(int64(g.cfg.WSPerMin))
 151  	actual, _ := g.clients.LoadOrStore(ip, cs)
 152  	return actual.(*clientState)
 153  }
 154  
 155  // refillLoop refills token buckets every minute.
 156  func (g *Guard) refillLoop() {
 157  	ticker := time.NewTicker(1 * time.Minute)
 158  	defer ticker.Stop()
 159  
 160  	for {
 161  		select {
 162  		case <-g.stopCleanup:
 163  			return
 164  		case <-ticker.C:
 165  			g.clients.Range(func(key, value any) bool {
 166  				cs := value.(*clientState)
 167  				// Refill to max, don't exceed
 168  				httpMax := int64(g.cfg.RPM)
 169  				if cs.httpTokens.Load() < httpMax {
 170  					cs.httpTokens.Store(httpMax)
 171  				}
 172  				wsMax := int64(g.cfg.WSPerMin)
 173  				if cs.wsTokens.Load() < wsMax {
 174  					cs.wsTokens.Store(wsMax)
 175  				}
 176  				return true
 177  			})
 178  		}
 179  	}
 180  }
 181  
 182  // cleanupLoop evicts idle clients.
 183  func (g *Guard) cleanupLoop() {
 184  	ticker := time.NewTicker(cleanupInterval)
 185  	defer ticker.Stop()
 186  
 187  	for {
 188  		select {
 189  		case <-g.stopCleanup:
 190  			return
 191  		case <-ticker.C:
 192  			cutoff := time.Now().Add(-idleEvictTime).Unix()
 193  			evicted := 0
 194  			g.clients.Range(func(key, value any) bool {
 195  				cs := value.(*clientState)
 196  				if cs.lastSeen.Load() < cutoff {
 197  					g.clients.Delete(key)
 198  					evicted++
 199  				}
 200  				return true
 201  			})
 202  			if evicted > 0 {
 203  				log.D.F("httpguard: evicted %d idle client entries", evicted)
 204  			}
 205  		}
 206  	}
 207  }
 208  
 209  func extractIP(r *http.Request) string {
 210  	// Check X-Forwarded-For first (reverse proxy)
 211  	if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
 212  		// First IP in the chain is the client
 213  		if idx := strings.IndexByte(xff, ','); idx > 0 {
 214  			return strings.TrimSpace(xff[:idx])
 215  		}
 216  		return strings.TrimSpace(xff)
 217  	}
 218  	// Check X-Real-IP
 219  	if xri := r.Header.Get("X-Real-Ip"); xri != "" {
 220  		return strings.TrimSpace(xri)
 221  	}
 222  	// Fall back to remote address
 223  	host, _, err := net.SplitHostPort(r.RemoteAddr)
 224  	if err != nil {
 225  		return r.RemoteAddr
 226  	}
 227  	return host
 228  }
 229  
 230  func isWebSocketUpgrade(r *http.Request) bool {
 231  	return strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
 232  }
 233