follows.mx raw

   1  package acl
   2  
   3  import (
   4  	"bytes"
   5  	"time"
   6  
   7  	"smesh.lol/pkg/grapevine"
   8  	"smesh.lol/pkg/store"
   9  )
  10  
  11  // Follows is an ACL that allows writes only from pubkeys followed
  12  // by the configured admin set. Follow lists (kind 3) are fetched
  13  // from the store at a configurable frequency.
  14  type Follows struct {
  15  	store      *store.Engine
  16  	admins     [][]byte
  17  	followed   map[string]bool
  18  	lastRefresh int64
  19  	freqSec    int
  20  }
  21  
  22  func NewFollows(s *store.Engine, adminHexPubkeys []string, freqSec int) *Follows {
  23  	admins := [][]byte{:0:len(adminHexPubkeys)}
  24  	for _, h := range adminHexPubkeys {
  25  		if pk := hexDec(h); len(pk) == 32 {
  26  			admins = append(admins, pk)
  27  		}
  28  	}
  29  	f := &Follows{
  30  		store:    s,
  31  		admins:   admins,
  32  		followed: map[string]bool{},
  33  		freqSec:  freqSec,
  34  	}
  35  	f.refresh()
  36  	return f
  37  }
  38  
  39  func (f *Follows) AllowWrite(pubkey []byte, _ uint16) bool {
  40  	f.maybeRefresh()
  41  	for _, a := range f.admins {
  42  		if bytes.Equal(a, pubkey) {
  43  			return true
  44  		}
  45  	}
  46  	return f.followed[string(pubkey)]
  47  }
  48  
  49  func (f *Follows) AllowRead([]byte) bool { return true }
  50  
  51  func (f *Follows) IsFollowed(pubkey []byte) bool {
  52  	f.maybeRefresh()
  53  	return f.followed[string(pubkey)]
  54  }
  55  
  56  func (f *Follows) maybeRefresh() {
  57  	now := time.Now().Unix()
  58  	if now-f.lastRefresh < int64(f.freqSec) {
  59  		return
  60  	}
  61  	f.refresh()
  62  }
  63  
  64  func (f *Follows) refresh() {
  65  	f.lastRefresh = time.Now().Unix()
  66  	m := map[string]bool{}
  67  	for _, admin := range f.admins {
  68  		m[string(admin)] = true
  69  		for _, pk := range grapevine.GetFollows(f.store, admin) {
  70  			m[string(pk)] = true
  71  		}
  72  	}
  73  	f.followed = m
  74  }
  75  
  76  func hexDec(s string) []byte {
  77  	if len(s)%2 != 0 {
  78  		return nil
  79  	}
  80  	b := []byte{:len(s) / 2}
  81  	for i := 0; i < len(b); i++ {
  82  		hi := unhex(s[i*2])
  83  		lo := unhex(s[i*2+1])
  84  		if hi == 0xff || lo == 0xff {
  85  			return nil
  86  		}
  87  		b[i] = hi<<4 | lo
  88  	}
  89  	return b
  90  }
  91  
  92  func unhex(c byte) byte {
  93  	switch {
  94  	case c >= '0' && c <= '9':
  95  		return c - '0'
  96  	case c >= 'a' && c <= 'f':
  97  		return c - 'a' + 10
  98  	case c >= 'A' && c <= 'F':
  99  		return c - 'A' + 10
 100  	}
 101  	return 0xff
 102  }
 103