pools.go raw
1 /* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
4 */
5
6 package device
7
8 import (
9 "sync"
10 )
11
12 type WaitPool struct {
13 pool sync.Pool
14 cond sync.Cond
15 lock sync.Mutex
16 count uint32 // Get calls not yet Put back
17 max uint32
18 }
19
20 func NewWaitPool(max uint32, new func() any) *WaitPool {
21 p := &WaitPool{pool: sync.Pool{New: new}, max: max}
22 p.cond = sync.Cond{L: &p.lock}
23 return p
24 }
25
26 func (p *WaitPool) Get() any {
27 if p.max != 0 {
28 p.lock.Lock()
29 for p.count >= p.max {
30 p.cond.Wait()
31 }
32 p.count++
33 p.lock.Unlock()
34 }
35 return p.pool.Get()
36 }
37
38 func (p *WaitPool) Put(x any) {
39 p.pool.Put(x)
40 if p.max == 0 {
41 return
42 }
43 p.lock.Lock()
44 defer p.lock.Unlock()
45 p.count--
46 p.cond.Signal()
47 }
48
49 func (device *Device) PopulatePools() {
50 device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
51 s := make([]*QueueInboundElement, 0, device.BatchSize())
52 return &QueueInboundElementsContainer{elems: s}
53 })
54 device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
55 s := make([]*QueueOutboundElement, 0, device.BatchSize())
56 return &QueueOutboundElementsContainer{elems: s}
57 })
58 device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
59 return new([MaxMessageSize]byte)
60 })
61 device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
62 return new(QueueInboundElement)
63 })
64 device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
65 return new(QueueOutboundElement)
66 })
67 }
68
69 func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
70 c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
71 c.Mutex = sync.Mutex{}
72 return c
73 }
74
75 func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
76 for i := range c.elems {
77 c.elems[i] = nil
78 }
79 c.elems = c.elems[:0]
80 device.pool.inboundElementsContainer.Put(c)
81 }
82
83 func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
84 c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
85 c.Mutex = sync.Mutex{}
86 return c
87 }
88
89 func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
90 for i := range c.elems {
91 c.elems[i] = nil
92 }
93 c.elems = c.elems[:0]
94 device.pool.outboundElementsContainer.Put(c)
95 }
96
97 func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
98 return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
99 }
100
101 func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
102 device.pool.messageBuffers.Put(msg)
103 }
104
105 func (device *Device) GetInboundElement() *QueueInboundElement {
106 return device.pool.inboundElements.Get().(*QueueInboundElement)
107 }
108
109 func (device *Device) PutInboundElement(elem *QueueInboundElement) {
110 elem.clearPointers()
111 device.pool.inboundElements.Put(elem)
112 }
113
114 func (device *Device) GetOutboundElement() *QueueOutboundElement {
115 return device.pool.outboundElements.Get().(*QueueOutboundElement)
116 }
117
118 func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
119 elem.clearPointers()
120 device.pool.outboundElements.Put(elem)
121 }
122