pools.go raw

   1  /* SPDX-License-Identifier: MIT
   2   *
   3   * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
   4   */
   5  
   6  package device
   7  
   8  import (
   9  	"sync"
  10  )
  11  
  12  type WaitPool struct {
  13  	pool  sync.Pool
  14  	cond  sync.Cond
  15  	lock  sync.Mutex
  16  	count uint32 // Get calls not yet Put back
  17  	max   uint32
  18  }
  19  
  20  func NewWaitPool(max uint32, new func() any) *WaitPool {
  21  	p := &WaitPool{pool: sync.Pool{New: new}, max: max}
  22  	p.cond = sync.Cond{L: &p.lock}
  23  	return p
  24  }
  25  
  26  func (p *WaitPool) Get() any {
  27  	if p.max != 0 {
  28  		p.lock.Lock()
  29  		for p.count >= p.max {
  30  			p.cond.Wait()
  31  		}
  32  		p.count++
  33  		p.lock.Unlock()
  34  	}
  35  	return p.pool.Get()
  36  }
  37  
  38  func (p *WaitPool) Put(x any) {
  39  	p.pool.Put(x)
  40  	if p.max == 0 {
  41  		return
  42  	}
  43  	p.lock.Lock()
  44  	defer p.lock.Unlock()
  45  	p.count--
  46  	p.cond.Signal()
  47  }
  48  
  49  func (device *Device) PopulatePools() {
  50  	device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
  51  		s := make([]*QueueInboundElement, 0, device.BatchSize())
  52  		return &QueueInboundElementsContainer{elems: s}
  53  	})
  54  	device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
  55  		s := make([]*QueueOutboundElement, 0, device.BatchSize())
  56  		return &QueueOutboundElementsContainer{elems: s}
  57  	})
  58  	device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
  59  		return new([MaxMessageSize]byte)
  60  	})
  61  	device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
  62  		return new(QueueInboundElement)
  63  	})
  64  	device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
  65  		return new(QueueOutboundElement)
  66  	})
  67  }
  68  
  69  func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
  70  	c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
  71  	c.Mutex = sync.Mutex{}
  72  	return c
  73  }
  74  
  75  func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
  76  	for i := range c.elems {
  77  		c.elems[i] = nil
  78  	}
  79  	c.elems = c.elems[:0]
  80  	device.pool.inboundElementsContainer.Put(c)
  81  }
  82  
  83  func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
  84  	c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
  85  	c.Mutex = sync.Mutex{}
  86  	return c
  87  }
  88  
  89  func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
  90  	for i := range c.elems {
  91  		c.elems[i] = nil
  92  	}
  93  	c.elems = c.elems[:0]
  94  	device.pool.outboundElementsContainer.Put(c)
  95  }
  96  
  97  func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
  98  	return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
  99  }
 100  
 101  func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
 102  	device.pool.messageBuffers.Put(msg)
 103  }
 104  
 105  func (device *Device) GetInboundElement() *QueueInboundElement {
 106  	return device.pool.inboundElements.Get().(*QueueInboundElement)
 107  }
 108  
 109  func (device *Device) PutInboundElement(elem *QueueInboundElement) {
 110  	elem.clearPointers()
 111  	device.pool.inboundElements.Put(elem)
 112  }
 113  
 114  func (device *Device) GetOutboundElement() *QueueOutboundElement {
 115  	return device.pool.outboundElements.Get().(*QueueOutboundElement)
 116  }
 117  
 118  func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
 119  	elem.clearPointers()
 120  	device.pool.outboundElements.Put(elem)
 121  }
 122