marmot.go raw
1 package marmot
2
3 import (
4 "context"
5 "fmt"
6 "sync"
7
8 "next.orly.dev/pkg/nostr/encoders/event"
9 "next.orly.dev/pkg/nostr/encoders/filter"
10 "next.orly.dev/pkg/nostr/encoders/hex"
11 "next.orly.dev/pkg/nostr/encoders/kind"
12 "next.orly.dev/pkg/nostr/encoders/tag"
13 "next.orly.dev/pkg/nostr/interfaces/signer"
14 "github.com/emersion/go-mls"
15 "next.orly.dev/pkg/lol/log"
16 )
17
18 // RelayConnection abstracts the relay interface so the Marmot client can be
19 // used with any relay transport (WebSocket, in-process channel, mock).
20 type RelayConnection interface {
21 Publish(ctx context.Context, ev *event.E) error
22 Subscribe(ctx context.Context, ff *filter.S) (EventStream, error)
23 }
24
25 // EventStream delivers events from a subscription. Close stops delivery.
26 type EventStream interface {
27 Events() <-chan *event.E
28 Close()
29 }
30
31 // DMHandler is called when an incoming DM is decrypted.
32 type DMHandler func(senderPub []byte, plaintext []byte)
33
34 // Client manages Marmot DM conversations. It holds MLS group state for
35 // active 1:1 conversations and handles the lifecycle of key packages,
36 // welcomes, and encrypted messages.
37 type Client struct {
38 sign signer.I
39 store GroupStore
40 relay RelayConnection
41 onDM DMHandler
42 kpp *mls.KeyPairPackage // our current key pair package
43 groups map[string]*GroupState
44 mu sync.RWMutex
45 }
46
47 // NewClient creates a Marmot client. The signer provides identity and
48 // signing. The store persists group state. The relay handles event transport.
49 func NewClient(sign signer.I, store GroupStore, relay RelayConnection) (*Client, error) {
50 kpp, err := GenerateKeyPackage(sign)
51 if err != nil {
52 return nil, fmt.Errorf("generate key package: %w", err)
53 }
54
55 c := &Client{
56 sign: sign,
57 store: store,
58 relay: relay,
59 kpp: kpp,
60 groups: make(map[string]*GroupState),
61 }
62
63 // Load persisted groups
64 ids, err := store.ListGroups()
65 if err == nil {
66 for _, id := range ids {
67 data, err := store.LoadGroup(id)
68 if err != nil {
69 log.W.F("failed to load group %x: %v", id, err)
70 continue
71 }
72 gs, err := unmarshalGroupState(data)
73 if err != nil {
74 log.W.F("failed to unmarshal group %x: %v", id, err)
75 continue
76 }
77 // We store the serialized state but can't re-hydrate the
78 // mls.Group from bytes with the current go-mls API.
79 // For now, groups are re-established on restart via welcome
80 // re-exchange. Store the metadata so we know about them.
81 c.groups[string(gs.GroupID)] = &GroupState{
82 GroupID: gs.GroupID,
83 PeerPub: gs.PeerPub,
84 mlsBytes: gs.MLSState,
85 }
86 }
87 }
88
89 return c, nil
90 }
91
92 // OnDM registers a handler for incoming decrypted DMs.
93 func (c *Client) OnDM(handler DMHandler) {
94 c.onDM = handler
95 }
96
97 // PublishKeyPackage publishes our MLS key package as a kind 443 event so
98 // peers can create DM groups with us.
99 func (c *Client) PublishKeyPackage(ctx context.Context) error {
100 ev, err := KeyPackageToEvent(c.kpp, c.sign)
101 if err != nil {
102 return err
103 }
104 return c.relay.Publish(ctx, ev)
105 }
106
107 // SendDM sends an encrypted DM to the given recipient. If no group exists,
108 // it fetches the recipient's key package, creates a group, and sends a
109 // welcome. Then it encrypts and publishes the message.
110 func (c *Client) SendDM(ctx context.Context, recipientPub []byte, plaintext []byte) error {
111 groupID := DMGroupID(c.sign.Pub(), recipientPub)
112
113 c.mu.RLock()
114 gs, ok := c.groups[string(groupID)]
115 c.mu.RUnlock()
116
117 if !ok || gs.group == nil {
118 // Need to establish a new group
119 var err error
120 gs, err = c.establishGroup(ctx, recipientPub)
121 if err != nil {
122 return fmt.Errorf("establish group: %w", err)
123 }
124 }
125
126 ciphertext, err := gs.Encrypt(plaintext)
127 if err != nil {
128 return fmt.Errorf("encrypt: %w", err)
129 }
130
131 ev, err := MessageToEvent(groupID, ciphertext, c.sign)
132 if err != nil {
133 return err
134 }
135
136 return c.relay.Publish(ctx, ev)
137 }
138
139 // establishGroup fetches the peer's key package and creates a DM group.
140 func (c *Client) establishGroup(ctx context.Context, peerPub []byte) (*GroupState, error) {
141 // Fetch the peer's latest key package (kind 443)
142 f := filter.New()
143 f.Kinds = kind.NewS(kind.New(KindKeyPackage))
144 f.Authors = &tag.T{T: [][]byte{peerPub}}
145 limit := uint(1)
146 f.Limit = &limit
147
148 stream, err := c.relay.Subscribe(ctx, filter.NewS(f))
149 if err != nil {
150 return nil, fmt.Errorf("subscribe for key package: %w", err)
151 }
152 defer stream.Close()
153
154 // Wait for one event
155 var peerKPEvent *event.E
156 select {
157 case ev := <-stream.Events():
158 peerKPEvent = ev
159 case <-ctx.Done():
160 return nil, ctx.Err()
161 }
162
163 if peerKPEvent == nil {
164 return nil, fmt.Errorf("no key package found for %s", hex.Enc(peerPub))
165 }
166
167 peerKP, err := EventToKeyPackage(peerKPEvent)
168 if err != nil {
169 return nil, fmt.Errorf("parse peer key package: %w", err)
170 }
171
172 gs, welcome, _, err := CreateDMGroup(c.kpp, peerKP, c.sign.Pub(), peerPub)
173 if err != nil {
174 return nil, fmt.Errorf("create DM group: %w", err)
175 }
176
177 // Send the welcome as a gift-wrapped event
178 wrapEv, err := WelcomeToGiftWrap(welcome, peerPub, c.sign)
179 if err != nil {
180 return nil, fmt.Errorf("gift wrap welcome: %w", err)
181 }
182 if err := c.relay.Publish(ctx, wrapEv); err != nil {
183 return nil, fmt.Errorf("publish welcome: %w", err)
184 }
185
186 // Store the group
187 c.mu.Lock()
188 c.groups[string(gs.GroupID)] = gs
189 c.mu.Unlock()
190
191 c.persistGroup(gs)
192
193 return gs, nil
194 }
195
196 // HandleEvent processes an incoming event. Call this from the subscription loop.
197 func (c *Client) HandleEvent(ctx context.Context, ev *event.E) error {
198 switch ev.Kind {
199 case KindGiftWrap:
200 return c.handleWelcome(ctx, ev)
201 case KindGroupMessage:
202 return c.handleGroupMessage(ctx, ev)
203 default:
204 return nil
205 }
206 }
207
208 func (c *Client) handleWelcome(ctx context.Context, ev *event.E) error {
209 welcome, err := UnwrapWelcome(ev, c.sign)
210 if err != nil {
211 return fmt.Errorf("unwrap welcome: %w", err)
212 }
213
214 senderPub := ev.Pubkey
215 gs, err := JoinDMGroup(welcome, c.kpp, senderPub)
216 if err != nil {
217 return fmt.Errorf("join DM group: %w", err)
218 }
219
220 // Derive the group ID
221 gs.GroupID = DMGroupID(c.sign.Pub(), senderPub)
222
223 c.mu.Lock()
224 c.groups[string(gs.GroupID)] = gs
225 c.mu.Unlock()
226
227 c.persistGroup(gs)
228
229 log.I.F("joined DM group with %s", hex.Enc(senderPub))
230 return nil
231 }
232
233 func (c *Client) handleGroupMessage(ctx context.Context, ev *event.E) error {
234 groupID, ciphertext, err := EventToMessage(ev)
235 if err != nil {
236 return err
237 }
238
239 c.mu.RLock()
240 gs, ok := c.groups[string(groupID)]
241 c.mu.RUnlock()
242
243 if !ok || gs.group == nil {
244 return fmt.Errorf("unknown group %x", groupID)
245 }
246
247 plaintext, err := gs.Decrypt(ciphertext)
248 if err != nil {
249 return fmt.Errorf("decrypt: %w", err)
250 }
251
252 if c.onDM != nil {
253 c.onDM(ev.Pubkey, plaintext)
254 }
255
256 return nil
257 }
258
259 func (c *Client) persistGroup(gs *GroupState) {
260 data, err := marshalGroupState(gs)
261 if err != nil {
262 log.W.F("failed to marshal group state: %v", err)
263 return
264 }
265 if err := c.store.SaveGroup(gs.GroupID, data); err != nil {
266 log.W.F("failed to persist group state: %v", err)
267 }
268 }
269
270 // SubscriptionFilter returns the filter for subscribing to events relevant
271 // to this client (key packages, welcomes, and group messages addressed to us).
272 func (c *Client) SubscriptionFilter() *filter.S {
273 f := filter.New()
274 f.Kinds = kind.NewS(
275 kind.New(KindGiftWrap),
276 kind.New(KindGroupMessage),
277 )
278 f.Tags = tag.NewS(
279 tag.NewFromAny("#p", hex.Enc(c.sign.Pub())),
280 )
281 return filter.NewS(f)
282 }
283