store.go raw

   1  package marmot
   2  
   3  import (
   4  	"encoding/hex"
   5  	"encoding/json"
   6  	"os"
   7  	"path/filepath"
   8  	"sync"
   9  )
  10  
  11  // GroupStore persists MLS group state so conversations survive restarts.
  12  type GroupStore interface {
  13  	SaveGroup(groupID, state []byte) error
  14  	LoadGroup(groupID []byte) ([]byte, error)
  15  	ListGroups() ([][]byte, error)
  16  	DeleteGroup(groupID []byte) error
  17  }
  18  
  19  // FileGroupStore implements GroupStore using one file per group in a directory.
  20  type FileGroupStore struct {
  21  	dir string
  22  	mu  sync.Mutex
  23  }
  24  
  25  // NewFileGroupStore creates a file-backed group store. The directory is created
  26  // if it does not exist.
  27  func NewFileGroupStore(dir string) (*FileGroupStore, error) {
  28  	if err := os.MkdirAll(dir, 0700); err != nil {
  29  		return nil, err
  30  	}
  31  	return &FileGroupStore{dir: dir}, nil
  32  }
  33  
  34  func (s *FileGroupStore) path(groupID []byte) string {
  35  	return filepath.Join(s.dir, hex.EncodeToString(groupID)+".json")
  36  }
  37  
  38  func (s *FileGroupStore) SaveGroup(groupID, state []byte) error {
  39  	s.mu.Lock()
  40  	defer s.mu.Unlock()
  41  	return os.WriteFile(s.path(groupID), state, 0600)
  42  }
  43  
  44  func (s *FileGroupStore) LoadGroup(groupID []byte) ([]byte, error) {
  45  	s.mu.Lock()
  46  	defer s.mu.Unlock()
  47  	return os.ReadFile(s.path(groupID))
  48  }
  49  
  50  func (s *FileGroupStore) ListGroups() ([][]byte, error) {
  51  	s.mu.Lock()
  52  	defer s.mu.Unlock()
  53  
  54  	entries, err := os.ReadDir(s.dir)
  55  	if err != nil {
  56  		return nil, err
  57  	}
  58  	var ids [][]byte
  59  	for _, e := range entries {
  60  		if e.IsDir() {
  61  			continue
  62  		}
  63  		name := e.Name()
  64  		if filepath.Ext(name) != ".json" {
  65  			continue
  66  		}
  67  		base := name[:len(name)-5] // strip .json
  68  		id, err := hex.DecodeString(base)
  69  		if err != nil {
  70  			continue
  71  		}
  72  		ids = append(ids, id)
  73  	}
  74  	return ids, nil
  75  }
  76  
  77  func (s *FileGroupStore) DeleteGroup(groupID []byte) error {
  78  	s.mu.Lock()
  79  	defer s.mu.Unlock()
  80  	return os.Remove(s.path(groupID))
  81  }
  82  
  83  // MemoryGroupStore implements GroupStore in memory. Useful for testing.
  84  type MemoryGroupStore struct {
  85  	mu     sync.Mutex
  86  	groups map[string][]byte
  87  }
  88  
  89  func NewMemoryGroupStore() *MemoryGroupStore {
  90  	return &MemoryGroupStore{groups: make(map[string][]byte)}
  91  }
  92  
  93  func (s *MemoryGroupStore) SaveGroup(groupID, state []byte) error {
  94  	s.mu.Lock()
  95  	defer s.mu.Unlock()
  96  	cp := make([]byte, len(state))
  97  	copy(cp, state)
  98  	s.groups[string(groupID)] = cp
  99  	return nil
 100  }
 101  
 102  func (s *MemoryGroupStore) LoadGroup(groupID []byte) ([]byte, error) {
 103  	s.mu.Lock()
 104  	defer s.mu.Unlock()
 105  	data, ok := s.groups[string(groupID)]
 106  	if !ok {
 107  		return nil, os.ErrNotExist
 108  	}
 109  	cp := make([]byte, len(data))
 110  	copy(cp, data)
 111  	return cp, nil
 112  }
 113  
 114  func (s *MemoryGroupStore) ListGroups() ([][]byte, error) {
 115  	s.mu.Lock()
 116  	defer s.mu.Unlock()
 117  	ids := make([][]byte, 0, len(s.groups))
 118  	for k := range s.groups {
 119  		ids = append(ids, []byte(k))
 120  	}
 121  	return ids, nil
 122  }
 123  
 124  func (s *MemoryGroupStore) DeleteGroup(groupID []byte) error {
 125  	s.mu.Lock()
 126  	defer s.mu.Unlock()
 127  	delete(s.groups, string(groupID))
 128  	return nil
 129  }
 130  
 131  // groupStateSerialized is the JSON structure for persisting group state.
 132  type groupStateSerialized struct {
 133  	GroupID  []byte `json:"group_id"`
 134  	PeerPub  []byte `json:"peer_pub"`
 135  	MLSState []byte `json:"mls_state"`
 136  }
 137  
 138  func marshalGroupState(gs *GroupState) ([]byte, error) {
 139  	return json.Marshal(&groupStateSerialized{
 140  		GroupID:  gs.GroupID,
 141  		PeerPub:  gs.PeerPub,
 142  		MLSState: gs.mlsBytes,
 143  	})
 144  }
 145  
 146  func unmarshalGroupState(data []byte) (*groupStateSerialized, error) {
 147  	var s groupStateSerialized
 148  	if err := json.Unmarshal(data, &s); err != nil {
 149  		return nil, err
 150  	}
 151  	return &s, nil
 152  }
 153