cache.mx raw

   1  // Copyright 2022 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  // Package bcache implements a GC-friendly cache (see [Cache]) for BoringCrypto.
   6  package bcache
   7  
   8  import (
   9  	"sync/atomic"
  10  	"unsafe"
  11  )
  12  
  13  // A Cache is a GC-friendly concurrent map from unsafe.Pointer to
  14  // unsafe.Pointer. It is meant to be used for maintaining shadow
  15  // BoringCrypto state associated with certain allocated structs, in
  16  // particular public and private RSA and ECDSA keys.
  17  //
  18  // The cache is GC-friendly in the sense that the keys do not
  19  // indefinitely prevent the garbage collector from collecting them.
  20  // Instead, at the start of each GC, the cache is cleared entirely. That
  21  // is, the cache is lossy, and the loss happens at the start of each GC.
  22  // This means that clients need to be able to cope with cache entries
  23  // disappearing, but it also means that clients don't need to worry about
  24  // cache entries keeping the keys from being collected.
  25  type Cache[K, V any] struct {
  26  	// The runtime atomically stores nil to ptable at the start of each GC.
  27  	ptable atomic.Pointer[cacheTable[K, V]]
  28  }
  29  
  30  type cacheTable[K, V any] [cacheSize]atomic.Pointer[cacheEntry[K, V]]
  31  
  32  // A cacheEntry is a single entry in the linked list for a given hash table entry.
  33  type cacheEntry[K, V any] struct {
  34  	k    *K                // immutable once created
  35  	v    atomic.Pointer[V] // read and written atomically to allow updates
  36  	next *cacheEntry[K, V] // immutable once linked into table
  37  }
  38  
  39  func registerCache(unsafe.Pointer) // provided by runtime
  40  
  41  // Register registers the cache with the runtime,
  42  // so that c.ptable can be cleared at the start of each GC.
  43  // Register must be called during package initialization.
  44  func (c *Cache[K, V]) Register() {
  45  	registerCache(unsafe.Pointer(&c.ptable))
  46  }
  47  
  48  // cacheSize is the number of entries in the hash table.
  49  // The hash is the pointer value mod cacheSize, a prime.
  50  // Collisions are resolved by maintaining a linked list in each hash slot.
  51  const cacheSize = 1021
  52  
  53  // table returns a pointer to the current cache hash table,
  54  // coping with the possibility of the GC clearing it out from under us.
  55  func (c *Cache[K, V]) table() *cacheTable[K, V] {
  56  	for {
  57  		p := c.ptable.Load()
  58  		if p == nil {
  59  			p = &cacheTable[K, V]{}
  60  			if !c.ptable.CompareAndSwap(nil, p) {
  61  				continue
  62  			}
  63  		}
  64  		return p
  65  	}
  66  }
  67  
  68  // Clear clears the cache.
  69  // The runtime does this automatically at each garbage collection;
  70  // this method is exposed only for testing.
  71  func (c *Cache[K, V]) Clear() {
  72  	// The runtime does this at the start of every garbage collection
  73  	// (itself, not by calling this function).
  74  	c.ptable.Store(nil)
  75  }
  76  
  77  // Get returns the cached value associated with v,
  78  // which is either the value v corresponding to the most recent call to Put(k, v)
  79  // or nil if that cache entry has been dropped.
  80  func (c *Cache[K, V]) Get(k *K) *V {
  81  	head := &c.table()[uintptr(unsafe.Pointer(k))%cacheSize]
  82  	e := head.Load()
  83  	for ; e != nil; e = e.next {
  84  		if e.k == k {
  85  			return e.v.Load()
  86  		}
  87  	}
  88  	return nil
  89  }
  90  
  91  // Put sets the cached value associated with k to v.
  92  func (c *Cache[K, V]) Put(k *K, v *V) {
  93  	head := &c.table()[uintptr(unsafe.Pointer(k))%cacheSize]
  94  
  95  	// Strategy is to walk the linked list at head,
  96  	// same as in Get, to look for existing entry.
  97  	// If we find one, we update v atomically in place.
  98  	// If not, then we race to replace the start = *head
  99  	// we observed with a new k, v entry.
 100  	// If we win that race, we're done.
 101  	// Otherwise, we try the whole thing again,
 102  	// with two optimizations:
 103  	//
 104  	//  1. We track in noK the start of the section of
 105  	//     the list that we've confirmed has no entry for k.
 106  	//     The next time down the list, we can stop at noK,
 107  	//     because new entries are inserted at the front of the list.
 108  	//     This guarantees we never traverse an entry
 109  	//     multiple times.
 110  	//
 111  	//  2. We only allocate the entry to be added once,
 112  	//     saving it in add for the next attempt.
 113  	var add, noK *cacheEntry[K, V]
 114  	n := 0
 115  	for {
 116  		e := head.Load()
 117  		start := e
 118  		for ; e != nil && e != noK; e = e.next {
 119  			if e.k == k {
 120  				e.v.Store(v)
 121  				return
 122  			}
 123  			n++
 124  		}
 125  		if add == nil {
 126  			add = &cacheEntry[K, V]{k: k}
 127  			add.v.Store(v)
 128  		}
 129  		add.next = start
 130  		if n >= 1000 {
 131  			// If an individual list gets too long, which shouldn't happen,
 132  			// throw it away to avoid quadratic lookup behavior.
 133  			add.next = nil
 134  		}
 135  		if head.CompareAndSwap(start, add) {
 136  			return
 137  		}
 138  		noK = start
 139  	}
 140  }
 141