neo4j_monitor.go raw

   1  package ratelimit
   2  
   3  import (
   4  	"context"
   5  	"sync"
   6  	"sync/atomic"
   7  	"time"
   8  
   9  	"github.com/neo4j/neo4j-go-driver/v5/neo4j"
  10  	"next.orly.dev/pkg/lol/log"
  11  	"next.orly.dev/pkg/interfaces/loadmonitor"
  12  )
  13  
  14  // Neo4jMonitor implements loadmonitor.Monitor for Neo4j database.
  15  // Since Neo4j driver doesn't expose detailed metrics, we track:
  16  // - Memory pressure via actual RSS (not Go runtime)
  17  // - Query concurrency via the semaphore
  18  // - Latency via recording
  19  //
  20  // This monitor implements aggressive memory-based limiting:
  21  // When memory exceeds the target, it applies 50% more aggressive throttling.
  22  // It rechecks every 10 seconds and doubles the throttling multiplier until
  23  // memory returns under target.
  24  type Neo4jMonitor struct {
  25  	driver   neo4j.DriverWithContext
  26  	querySem chan struct{} // Reference to the query semaphore
  27  
  28  	// Target memory for pressure calculation
  29  	targetMemoryBytes atomic.Uint64
  30  
  31  	// Emergency mode configuration
  32  	emergencyThreshold atomic.Uint64 // stored as threshold * 1000 (e.g., 1500 = 1.5)
  33  	emergencyModeUntil atomic.Int64  // Unix nano when forced emergency mode ends
  34  	inEmergencyMode    atomic.Bool
  35  
  36  	// Aggressive throttling multiplier for Neo4j
  37  	// Starts at 1.5 (50% more aggressive), doubles every 10 seconds while over limit
  38  	throttleMultiplier atomic.Uint64 // stored as multiplier * 100 (e.g., 150 = 1.5x)
  39  	lastThrottleCheck  atomic.Int64  // Unix nano timestamp
  40  
  41  	// Latency tracking with exponential moving average
  42  	queryLatencyNs atomic.Int64
  43  	writeLatencyNs atomic.Int64
  44  	latencyAlpha   float64 // EMA coefficient (default 0.1)
  45  
  46  	// Concurrency tracking
  47  	activeReads    atomic.Int32
  48  	activeWrites   atomic.Int32
  49  	maxConcurrency int
  50  
  51  	// Cached metrics (updated by background goroutine)
  52  	metricsLock   sync.RWMutex
  53  	cachedMetrics loadmonitor.Metrics
  54  
  55  	// Background collection
  56  	stopChan chan struct{}
  57  	stopped  chan struct{}
  58  	interval time.Duration
  59  }
  60  
  61  // Compile-time checks for interface implementation
  62  var _ loadmonitor.Monitor = (*Neo4jMonitor)(nil)
  63  var _ loadmonitor.EmergencyModeMonitor = (*Neo4jMonitor)(nil)
  64  
  65  // ThrottleCheckInterval is how often to recheck memory and adjust throttling
  66  const ThrottleCheckInterval = 10 * time.Second
  67  
  68  // NewNeo4jMonitor creates a new Neo4j load monitor.
  69  // The querySem should be the same semaphore used for limiting concurrent queries.
  70  // maxConcurrency is the maximum concurrent query limit (typically 10).
  71  func NewNeo4jMonitor(
  72  	driver neo4j.DriverWithContext,
  73  	querySem chan struct{},
  74  	maxConcurrency int,
  75  	updateInterval time.Duration,
  76  ) *Neo4jMonitor {
  77  	if updateInterval <= 0 {
  78  		updateInterval = 100 * time.Millisecond
  79  	}
  80  	if maxConcurrency <= 0 {
  81  		maxConcurrency = 10
  82  	}
  83  
  84  	m := &Neo4jMonitor{
  85  		driver:         driver,
  86  		querySem:       querySem,
  87  		maxConcurrency: maxConcurrency,
  88  		latencyAlpha:   0.1, // 10% new, 90% old for smooth EMA
  89  		stopChan:       make(chan struct{}),
  90  		stopped:        make(chan struct{}),
  91  		interval:       updateInterval,
  92  	}
  93  
  94  	// Set a default target (1.5GB)
  95  	m.targetMemoryBytes.Store(1500 * 1024 * 1024)
  96  
  97  	// Default emergency threshold: 100% of target (same as target for Neo4j)
  98  	m.emergencyThreshold.Store(1000)
  99  
 100  	// Start with 1.0x multiplier (no throttling)
 101  	m.throttleMultiplier.Store(100)
 102  
 103  	return m
 104  }
 105  
 106  // SetEmergencyThreshold sets the memory threshold above which emergency mode is triggered.
 107  // threshold is a fraction, e.g., 1.0 = 100% of target memory.
 108  func (m *Neo4jMonitor) SetEmergencyThreshold(threshold float64) {
 109  	m.emergencyThreshold.Store(uint64(threshold * 1000))
 110  }
 111  
 112  // GetEmergencyThreshold returns the current emergency threshold as a fraction.
 113  func (m *Neo4jMonitor) GetEmergencyThreshold() float64 {
 114  	return float64(m.emergencyThreshold.Load()) / 1000.0
 115  }
 116  
 117  // ForceEmergencyMode manually triggers emergency mode for a duration.
 118  func (m *Neo4jMonitor) ForceEmergencyMode(duration time.Duration) {
 119  	m.emergencyModeUntil.Store(time.Now().Add(duration).UnixNano())
 120  	m.inEmergencyMode.Store(true)
 121  	m.throttleMultiplier.Store(150) // Start at 1.5x
 122  	log.W.F("⚠️  Neo4j emergency mode forced for %v", duration)
 123  }
 124  
 125  // GetThrottleMultiplier returns the current throttle multiplier.
 126  // Returns a value >= 1.0, where 1.0 = no extra throttling, 1.5 = 50% more aggressive, etc.
 127  func (m *Neo4jMonitor) GetThrottleMultiplier() float64 {
 128  	return float64(m.throttleMultiplier.Load()) / 100.0
 129  }
 130  
 131  // GetMetrics returns the current load metrics.
 132  func (m *Neo4jMonitor) GetMetrics() loadmonitor.Metrics {
 133  	m.metricsLock.RLock()
 134  	defer m.metricsLock.RUnlock()
 135  	return m.cachedMetrics
 136  }
 137  
 138  // RecordQueryLatency records a query latency sample using exponential moving average.
 139  func (m *Neo4jMonitor) RecordQueryLatency(latency time.Duration) {
 140  	ns := latency.Nanoseconds()
 141  	for {
 142  		old := m.queryLatencyNs.Load()
 143  		if old == 0 {
 144  			if m.queryLatencyNs.CompareAndSwap(0, ns) {
 145  				return
 146  			}
 147  			continue
 148  		}
 149  		// EMA: new = alpha * sample + (1-alpha) * old
 150  		newVal := int64(m.latencyAlpha*float64(ns) + (1-m.latencyAlpha)*float64(old))
 151  		if m.queryLatencyNs.CompareAndSwap(old, newVal) {
 152  			return
 153  		}
 154  	}
 155  }
 156  
 157  // RecordWriteLatency records a write latency sample using exponential moving average.
 158  func (m *Neo4jMonitor) RecordWriteLatency(latency time.Duration) {
 159  	ns := latency.Nanoseconds()
 160  	for {
 161  		old := m.writeLatencyNs.Load()
 162  		if old == 0 {
 163  			if m.writeLatencyNs.CompareAndSwap(0, ns) {
 164  				return
 165  			}
 166  			continue
 167  		}
 168  		// EMA: new = alpha * sample + (1-alpha) * old
 169  		newVal := int64(m.latencyAlpha*float64(ns) + (1-m.latencyAlpha)*float64(old))
 170  		if m.writeLatencyNs.CompareAndSwap(old, newVal) {
 171  			return
 172  		}
 173  	}
 174  }
 175  
 176  // SetMemoryTarget sets the target memory limit in bytes.
 177  func (m *Neo4jMonitor) SetMemoryTarget(bytes uint64) {
 178  	m.targetMemoryBytes.Store(bytes)
 179  }
 180  
 181  // Start begins background metric collection.
 182  func (m *Neo4jMonitor) Start() <-chan struct{} {
 183  	go m.collectLoop()
 184  	return m.stopped
 185  }
 186  
 187  // Stop halts background metric collection.
 188  func (m *Neo4jMonitor) Stop() {
 189  	close(m.stopChan)
 190  	<-m.stopped
 191  }
 192  
 193  // collectLoop periodically collects metrics.
 194  func (m *Neo4jMonitor) collectLoop() {
 195  	defer close(m.stopped)
 196  
 197  	ticker := time.NewTicker(m.interval)
 198  	defer ticker.Stop()
 199  
 200  	for {
 201  		select {
 202  		case <-m.stopChan:
 203  			return
 204  		case <-ticker.C:
 205  			m.updateMetrics()
 206  		}
 207  	}
 208  }
 209  
 210  // updateMetrics collects current metrics and manages aggressive throttling.
 211  func (m *Neo4jMonitor) updateMetrics() {
 212  	metrics := loadmonitor.Metrics{
 213  		Timestamp: time.Now(),
 214  	}
 215  
 216  	// Use RSS-based memory pressure (actual physical memory, not Go runtime)
 217  	procMem := ReadProcessMemoryStats()
 218  	physicalMemBytes := procMem.PhysicalMemoryBytes()
 219  	metrics.PhysicalMemoryMB = physicalMemBytes / (1024 * 1024)
 220  
 221  	targetBytes := m.targetMemoryBytes.Load()
 222  	if targetBytes > 0 {
 223  		// Use actual physical memory (RSS - shared) for pressure calculation
 224  		metrics.MemoryPressure = float64(physicalMemBytes) / float64(targetBytes)
 225  	}
 226  
 227  	// Check and update emergency mode with aggressive throttling
 228  	m.updateEmergencyMode(metrics.MemoryPressure)
 229  	metrics.InEmergencyMode = m.inEmergencyMode.Load()
 230  
 231  	// Calculate load from semaphore usage
 232  	// querySem is a buffered channel - count how many slots are taken
 233  	if m.querySem != nil {
 234  		usedSlots := len(m.querySem)
 235  		concurrencyLoad := float64(usedSlots) / float64(m.maxConcurrency)
 236  		if concurrencyLoad > 1.0 {
 237  			concurrencyLoad = 1.0
 238  		}
 239  		// Both read and write use the same semaphore
 240  		metrics.WriteLoad = concurrencyLoad
 241  		metrics.ReadLoad = concurrencyLoad
 242  	}
 243  
 244  	// Apply throttle multiplier to loads when in emergency mode
 245  	// This makes the PID controller think load is higher, causing more throttling
 246  	if metrics.InEmergencyMode {
 247  		multiplier := m.GetThrottleMultiplier()
 248  		metrics.WriteLoad = metrics.WriteLoad * multiplier
 249  		if metrics.WriteLoad > 1.0 {
 250  			metrics.WriteLoad = 1.0
 251  		}
 252  		metrics.ReadLoad = metrics.ReadLoad * multiplier
 253  		if metrics.ReadLoad > 1.0 {
 254  			metrics.ReadLoad = 1.0
 255  		}
 256  	}
 257  
 258  	// Add latency-based load adjustment
 259  	// High latency indicates the database is struggling
 260  	queryLatencyNs := m.queryLatencyNs.Load()
 261  	writeLatencyNs := m.writeLatencyNs.Load()
 262  
 263  	// Consider > 500ms query latency as concerning
 264  	const latencyThresholdNs = 500 * 1e6 // 500ms
 265  	if queryLatencyNs > 0 {
 266  		latencyLoad := float64(queryLatencyNs) / float64(latencyThresholdNs)
 267  		if latencyLoad > 1.0 {
 268  			latencyLoad = 1.0
 269  		}
 270  		// Blend concurrency and latency for read load
 271  		metrics.ReadLoad = 0.5*metrics.ReadLoad + 0.5*latencyLoad
 272  	}
 273  
 274  	if writeLatencyNs > 0 {
 275  		latencyLoad := float64(writeLatencyNs) / float64(latencyThresholdNs)
 276  		if latencyLoad > 1.0 {
 277  			latencyLoad = 1.0
 278  		}
 279  		// Blend concurrency and latency for write load
 280  		metrics.WriteLoad = 0.5*metrics.WriteLoad + 0.5*latencyLoad
 281  	}
 282  
 283  	// Store latencies
 284  	metrics.QueryLatency = time.Duration(queryLatencyNs)
 285  	metrics.WriteLatency = time.Duration(writeLatencyNs)
 286  
 287  	// Update cached metrics
 288  	m.metricsLock.Lock()
 289  	m.cachedMetrics = metrics
 290  	m.metricsLock.Unlock()
 291  }
 292  
 293  // updateEmergencyMode manages the emergency mode state and throttle multiplier.
 294  // When memory exceeds the target:
 295  // - Enters emergency mode with 1.5x throttle multiplier (50% more aggressive)
 296  // - Every 10 seconds while still over limit, doubles the multiplier
 297  // - When memory returns under target, resets to normal
 298  func (m *Neo4jMonitor) updateEmergencyMode(memoryPressure float64) {
 299  	threshold := float64(m.emergencyThreshold.Load()) / 1000.0
 300  	forcedUntil := m.emergencyModeUntil.Load()
 301  	now := time.Now().UnixNano()
 302  
 303  	// Check if in forced emergency mode
 304  	if forcedUntil > now {
 305  		return // Stay in forced mode
 306  	}
 307  
 308  	// Check if memory exceeds threshold
 309  	if memoryPressure >= threshold {
 310  		if !m.inEmergencyMode.Load() {
 311  			// Entering emergency mode - start at 1.5x (50% more aggressive)
 312  			m.inEmergencyMode.Store(true)
 313  			m.throttleMultiplier.Store(150)
 314  			m.lastThrottleCheck.Store(now)
 315  			log.W.F("⚠️  Neo4j entering emergency mode: memory %.1f%% >= threshold %.1f%%, throttle 1.5x",
 316  				memoryPressure*100, threshold*100)
 317  			return
 318  		}
 319  
 320  		// Already in emergency mode - check if it's time to double throttling
 321  		lastCheck := m.lastThrottleCheck.Load()
 322  		elapsed := time.Duration(now - lastCheck)
 323  
 324  		if elapsed >= ThrottleCheckInterval {
 325  			// Double the throttle multiplier
 326  			currentMult := m.throttleMultiplier.Load()
 327  			newMult := currentMult * 2
 328  			if newMult > 1600 { // Cap at 16x to prevent overflow
 329  				newMult = 1600
 330  			}
 331  			m.throttleMultiplier.Store(newMult)
 332  			m.lastThrottleCheck.Store(now)
 333  			log.W.F("⚠️  Neo4j still over memory limit: %.1f%%, doubling throttle to %.1fx",
 334  				memoryPressure*100, float64(newMult)/100.0)
 335  		}
 336  	} else {
 337  		// Memory is under threshold
 338  		if m.inEmergencyMode.Load() {
 339  			m.inEmergencyMode.Store(false)
 340  			m.throttleMultiplier.Store(100) // Reset to 1.0x
 341  			log.I.F("✅ Neo4j exiting emergency mode: memory %.1f%% < threshold %.1f%%",
 342  				memoryPressure*100, threshold*100)
 343  		}
 344  	}
 345  }
 346  
 347  // IncrementActiveReads tracks an active read operation.
 348  // Call this when starting a read, and call the returned function when done.
 349  func (m *Neo4jMonitor) IncrementActiveReads() func() {
 350  	m.activeReads.Add(1)
 351  	return func() {
 352  		m.activeReads.Add(-1)
 353  	}
 354  }
 355  
 356  // IncrementActiveWrites tracks an active write operation.
 357  // Call this when starting a write, and call the returned function when done.
 358  func (m *Neo4jMonitor) IncrementActiveWrites() func() {
 359  	m.activeWrites.Add(1)
 360  	return func() {
 361  		m.activeWrites.Add(-1)
 362  	}
 363  }
 364  
 365  // GetConcurrencyStats returns current concurrency statistics for debugging.
 366  func (m *Neo4jMonitor) GetConcurrencyStats() (reads, writes int32, semUsed int) {
 367  	reads = m.activeReads.Load()
 368  	writes = m.activeWrites.Load()
 369  	if m.querySem != nil {
 370  		semUsed = len(m.querySem)
 371  	}
 372  	return
 373  }
 374  
 375  // CheckConnectivity performs a connectivity check to Neo4j.
 376  // This can be used to verify the database is responsive.
 377  func (m *Neo4jMonitor) CheckConnectivity(ctx context.Context) error {
 378  	if m.driver == nil {
 379  		return nil
 380  	}
 381  	return m.driver.VerifyConnectivity(ctx)
 382  }
 383