neo4j_monitor.go raw
1 package ratelimit
2
3 import (
4 "context"
5 "sync"
6 "sync/atomic"
7 "time"
8
9 "github.com/neo4j/neo4j-go-driver/v5/neo4j"
10 "next.orly.dev/pkg/lol/log"
11 "next.orly.dev/pkg/interfaces/loadmonitor"
12 )
13
14 // Neo4jMonitor implements loadmonitor.Monitor for Neo4j database.
15 // Since Neo4j driver doesn't expose detailed metrics, we track:
16 // - Memory pressure via actual RSS (not Go runtime)
17 // - Query concurrency via the semaphore
18 // - Latency via recording
19 //
20 // This monitor implements aggressive memory-based limiting:
21 // When memory exceeds the target, it applies 50% more aggressive throttling.
22 // It rechecks every 10 seconds and doubles the throttling multiplier until
23 // memory returns under target.
24 type Neo4jMonitor struct {
25 driver neo4j.DriverWithContext
26 querySem chan struct{} // Reference to the query semaphore
27
28 // Target memory for pressure calculation
29 targetMemoryBytes atomic.Uint64
30
31 // Emergency mode configuration
32 emergencyThreshold atomic.Uint64 // stored as threshold * 1000 (e.g., 1500 = 1.5)
33 emergencyModeUntil atomic.Int64 // Unix nano when forced emergency mode ends
34 inEmergencyMode atomic.Bool
35
36 // Aggressive throttling multiplier for Neo4j
37 // Starts at 1.5 (50% more aggressive), doubles every 10 seconds while over limit
38 throttleMultiplier atomic.Uint64 // stored as multiplier * 100 (e.g., 150 = 1.5x)
39 lastThrottleCheck atomic.Int64 // Unix nano timestamp
40
41 // Latency tracking with exponential moving average
42 queryLatencyNs atomic.Int64
43 writeLatencyNs atomic.Int64
44 latencyAlpha float64 // EMA coefficient (default 0.1)
45
46 // Concurrency tracking
47 activeReads atomic.Int32
48 activeWrites atomic.Int32
49 maxConcurrency int
50
51 // Cached metrics (updated by background goroutine)
52 metricsLock sync.RWMutex
53 cachedMetrics loadmonitor.Metrics
54
55 // Background collection
56 stopChan chan struct{}
57 stopped chan struct{}
58 interval time.Duration
59 }
60
61 // Compile-time checks for interface implementation
62 var _ loadmonitor.Monitor = (*Neo4jMonitor)(nil)
63 var _ loadmonitor.EmergencyModeMonitor = (*Neo4jMonitor)(nil)
64
65 // ThrottleCheckInterval is how often to recheck memory and adjust throttling
66 const ThrottleCheckInterval = 10 * time.Second
67
68 // NewNeo4jMonitor creates a new Neo4j load monitor.
69 // The querySem should be the same semaphore used for limiting concurrent queries.
70 // maxConcurrency is the maximum concurrent query limit (typically 10).
71 func NewNeo4jMonitor(
72 driver neo4j.DriverWithContext,
73 querySem chan struct{},
74 maxConcurrency int,
75 updateInterval time.Duration,
76 ) *Neo4jMonitor {
77 if updateInterval <= 0 {
78 updateInterval = 100 * time.Millisecond
79 }
80 if maxConcurrency <= 0 {
81 maxConcurrency = 10
82 }
83
84 m := &Neo4jMonitor{
85 driver: driver,
86 querySem: querySem,
87 maxConcurrency: maxConcurrency,
88 latencyAlpha: 0.1, // 10% new, 90% old for smooth EMA
89 stopChan: make(chan struct{}),
90 stopped: make(chan struct{}),
91 interval: updateInterval,
92 }
93
94 // Set a default target (1.5GB)
95 m.targetMemoryBytes.Store(1500 * 1024 * 1024)
96
97 // Default emergency threshold: 100% of target (same as target for Neo4j)
98 m.emergencyThreshold.Store(1000)
99
100 // Start with 1.0x multiplier (no throttling)
101 m.throttleMultiplier.Store(100)
102
103 return m
104 }
105
106 // SetEmergencyThreshold sets the memory threshold above which emergency mode is triggered.
107 // threshold is a fraction, e.g., 1.0 = 100% of target memory.
108 func (m *Neo4jMonitor) SetEmergencyThreshold(threshold float64) {
109 m.emergencyThreshold.Store(uint64(threshold * 1000))
110 }
111
112 // GetEmergencyThreshold returns the current emergency threshold as a fraction.
113 func (m *Neo4jMonitor) GetEmergencyThreshold() float64 {
114 return float64(m.emergencyThreshold.Load()) / 1000.0
115 }
116
117 // ForceEmergencyMode manually triggers emergency mode for a duration.
118 func (m *Neo4jMonitor) ForceEmergencyMode(duration time.Duration) {
119 m.emergencyModeUntil.Store(time.Now().Add(duration).UnixNano())
120 m.inEmergencyMode.Store(true)
121 m.throttleMultiplier.Store(150) // Start at 1.5x
122 log.W.F("⚠️ Neo4j emergency mode forced for %v", duration)
123 }
124
125 // GetThrottleMultiplier returns the current throttle multiplier.
126 // Returns a value >= 1.0, where 1.0 = no extra throttling, 1.5 = 50% more aggressive, etc.
127 func (m *Neo4jMonitor) GetThrottleMultiplier() float64 {
128 return float64(m.throttleMultiplier.Load()) / 100.0
129 }
130
131 // GetMetrics returns the current load metrics.
132 func (m *Neo4jMonitor) GetMetrics() loadmonitor.Metrics {
133 m.metricsLock.RLock()
134 defer m.metricsLock.RUnlock()
135 return m.cachedMetrics
136 }
137
138 // RecordQueryLatency records a query latency sample using exponential moving average.
139 func (m *Neo4jMonitor) RecordQueryLatency(latency time.Duration) {
140 ns := latency.Nanoseconds()
141 for {
142 old := m.queryLatencyNs.Load()
143 if old == 0 {
144 if m.queryLatencyNs.CompareAndSwap(0, ns) {
145 return
146 }
147 continue
148 }
149 // EMA: new = alpha * sample + (1-alpha) * old
150 newVal := int64(m.latencyAlpha*float64(ns) + (1-m.latencyAlpha)*float64(old))
151 if m.queryLatencyNs.CompareAndSwap(old, newVal) {
152 return
153 }
154 }
155 }
156
157 // RecordWriteLatency records a write latency sample using exponential moving average.
158 func (m *Neo4jMonitor) RecordWriteLatency(latency time.Duration) {
159 ns := latency.Nanoseconds()
160 for {
161 old := m.writeLatencyNs.Load()
162 if old == 0 {
163 if m.writeLatencyNs.CompareAndSwap(0, ns) {
164 return
165 }
166 continue
167 }
168 // EMA: new = alpha * sample + (1-alpha) * old
169 newVal := int64(m.latencyAlpha*float64(ns) + (1-m.latencyAlpha)*float64(old))
170 if m.writeLatencyNs.CompareAndSwap(old, newVal) {
171 return
172 }
173 }
174 }
175
176 // SetMemoryTarget sets the target memory limit in bytes.
177 func (m *Neo4jMonitor) SetMemoryTarget(bytes uint64) {
178 m.targetMemoryBytes.Store(bytes)
179 }
180
181 // Start begins background metric collection.
182 func (m *Neo4jMonitor) Start() <-chan struct{} {
183 go m.collectLoop()
184 return m.stopped
185 }
186
187 // Stop halts background metric collection.
188 func (m *Neo4jMonitor) Stop() {
189 close(m.stopChan)
190 <-m.stopped
191 }
192
193 // collectLoop periodically collects metrics.
194 func (m *Neo4jMonitor) collectLoop() {
195 defer close(m.stopped)
196
197 ticker := time.NewTicker(m.interval)
198 defer ticker.Stop()
199
200 for {
201 select {
202 case <-m.stopChan:
203 return
204 case <-ticker.C:
205 m.updateMetrics()
206 }
207 }
208 }
209
210 // updateMetrics collects current metrics and manages aggressive throttling.
211 func (m *Neo4jMonitor) updateMetrics() {
212 metrics := loadmonitor.Metrics{
213 Timestamp: time.Now(),
214 }
215
216 // Use RSS-based memory pressure (actual physical memory, not Go runtime)
217 procMem := ReadProcessMemoryStats()
218 physicalMemBytes := procMem.PhysicalMemoryBytes()
219 metrics.PhysicalMemoryMB = physicalMemBytes / (1024 * 1024)
220
221 targetBytes := m.targetMemoryBytes.Load()
222 if targetBytes > 0 {
223 // Use actual physical memory (RSS - shared) for pressure calculation
224 metrics.MemoryPressure = float64(physicalMemBytes) / float64(targetBytes)
225 }
226
227 // Check and update emergency mode with aggressive throttling
228 m.updateEmergencyMode(metrics.MemoryPressure)
229 metrics.InEmergencyMode = m.inEmergencyMode.Load()
230
231 // Calculate load from semaphore usage
232 // querySem is a buffered channel - count how many slots are taken
233 if m.querySem != nil {
234 usedSlots := len(m.querySem)
235 concurrencyLoad := float64(usedSlots) / float64(m.maxConcurrency)
236 if concurrencyLoad > 1.0 {
237 concurrencyLoad = 1.0
238 }
239 // Both read and write use the same semaphore
240 metrics.WriteLoad = concurrencyLoad
241 metrics.ReadLoad = concurrencyLoad
242 }
243
244 // Apply throttle multiplier to loads when in emergency mode
245 // This makes the PID controller think load is higher, causing more throttling
246 if metrics.InEmergencyMode {
247 multiplier := m.GetThrottleMultiplier()
248 metrics.WriteLoad = metrics.WriteLoad * multiplier
249 if metrics.WriteLoad > 1.0 {
250 metrics.WriteLoad = 1.0
251 }
252 metrics.ReadLoad = metrics.ReadLoad * multiplier
253 if metrics.ReadLoad > 1.0 {
254 metrics.ReadLoad = 1.0
255 }
256 }
257
258 // Add latency-based load adjustment
259 // High latency indicates the database is struggling
260 queryLatencyNs := m.queryLatencyNs.Load()
261 writeLatencyNs := m.writeLatencyNs.Load()
262
263 // Consider > 500ms query latency as concerning
264 const latencyThresholdNs = 500 * 1e6 // 500ms
265 if queryLatencyNs > 0 {
266 latencyLoad := float64(queryLatencyNs) / float64(latencyThresholdNs)
267 if latencyLoad > 1.0 {
268 latencyLoad = 1.0
269 }
270 // Blend concurrency and latency for read load
271 metrics.ReadLoad = 0.5*metrics.ReadLoad + 0.5*latencyLoad
272 }
273
274 if writeLatencyNs > 0 {
275 latencyLoad := float64(writeLatencyNs) / float64(latencyThresholdNs)
276 if latencyLoad > 1.0 {
277 latencyLoad = 1.0
278 }
279 // Blend concurrency and latency for write load
280 metrics.WriteLoad = 0.5*metrics.WriteLoad + 0.5*latencyLoad
281 }
282
283 // Store latencies
284 metrics.QueryLatency = time.Duration(queryLatencyNs)
285 metrics.WriteLatency = time.Duration(writeLatencyNs)
286
287 // Update cached metrics
288 m.metricsLock.Lock()
289 m.cachedMetrics = metrics
290 m.metricsLock.Unlock()
291 }
292
293 // updateEmergencyMode manages the emergency mode state and throttle multiplier.
294 // When memory exceeds the target:
295 // - Enters emergency mode with 1.5x throttle multiplier (50% more aggressive)
296 // - Every 10 seconds while still over limit, doubles the multiplier
297 // - When memory returns under target, resets to normal
298 func (m *Neo4jMonitor) updateEmergencyMode(memoryPressure float64) {
299 threshold := float64(m.emergencyThreshold.Load()) / 1000.0
300 forcedUntil := m.emergencyModeUntil.Load()
301 now := time.Now().UnixNano()
302
303 // Check if in forced emergency mode
304 if forcedUntil > now {
305 return // Stay in forced mode
306 }
307
308 // Check if memory exceeds threshold
309 if memoryPressure >= threshold {
310 if !m.inEmergencyMode.Load() {
311 // Entering emergency mode - start at 1.5x (50% more aggressive)
312 m.inEmergencyMode.Store(true)
313 m.throttleMultiplier.Store(150)
314 m.lastThrottleCheck.Store(now)
315 log.W.F("⚠️ Neo4j entering emergency mode: memory %.1f%% >= threshold %.1f%%, throttle 1.5x",
316 memoryPressure*100, threshold*100)
317 return
318 }
319
320 // Already in emergency mode - check if it's time to double throttling
321 lastCheck := m.lastThrottleCheck.Load()
322 elapsed := time.Duration(now - lastCheck)
323
324 if elapsed >= ThrottleCheckInterval {
325 // Double the throttle multiplier
326 currentMult := m.throttleMultiplier.Load()
327 newMult := currentMult * 2
328 if newMult > 1600 { // Cap at 16x to prevent overflow
329 newMult = 1600
330 }
331 m.throttleMultiplier.Store(newMult)
332 m.lastThrottleCheck.Store(now)
333 log.W.F("⚠️ Neo4j still over memory limit: %.1f%%, doubling throttle to %.1fx",
334 memoryPressure*100, float64(newMult)/100.0)
335 }
336 } else {
337 // Memory is under threshold
338 if m.inEmergencyMode.Load() {
339 m.inEmergencyMode.Store(false)
340 m.throttleMultiplier.Store(100) // Reset to 1.0x
341 log.I.F("✅ Neo4j exiting emergency mode: memory %.1f%% < threshold %.1f%%",
342 memoryPressure*100, threshold*100)
343 }
344 }
345 }
346
347 // IncrementActiveReads tracks an active read operation.
348 // Call this when starting a read, and call the returned function when done.
349 func (m *Neo4jMonitor) IncrementActiveReads() func() {
350 m.activeReads.Add(1)
351 return func() {
352 m.activeReads.Add(-1)
353 }
354 }
355
356 // IncrementActiveWrites tracks an active write operation.
357 // Call this when starting a write, and call the returned function when done.
358 func (m *Neo4jMonitor) IncrementActiveWrites() func() {
359 m.activeWrites.Add(1)
360 return func() {
361 m.activeWrites.Add(-1)
362 }
363 }
364
365 // GetConcurrencyStats returns current concurrency statistics for debugging.
366 func (m *Neo4jMonitor) GetConcurrencyStats() (reads, writes int32, semUsed int) {
367 reads = m.activeReads.Load()
368 writes = m.activeWrites.Load()
369 if m.querySem != nil {
370 semUsed = len(m.querySem)
371 }
372 return
373 }
374
375 // CheckConnectivity performs a connectivity check to Neo4j.
376 // This can be used to verify the database is responsive.
377 func (m *Neo4jMonitor) CheckConnectivity(ctx context.Context) error {
378 if m.driver == nil {
379 return nil
380 }
381 return m.driver.VerifyConnectivity(ctx)
382 }
383