server_subs.mx raw
1 package server
2
3 import (
4 "bytes"
5 "time"
6
7 "smesh.lol/pkg/access"
8 "smesh.lol/pkg/nostr/envelope"
9 "smesh.lol/pkg/nostr/event"
10 "smesh.lol/pkg/nostr/filter"
11 )
12
13 // --- Subscription and Nostr message handling ---
14
15 func (s *Server) handleReq(fd int, msg []byte) {
16 c := s.conns[fd]
17 if c == nil {
18 return
19 }
20 _, rem, _ := envelope.Identify(msg)
21 filter.Tainted = false
22 var req envelope.Req
23 if _, err := req.Unmarshal(rem); err != nil {
24 return
25 }
26 if filter.Tainted {
27 c.jailed = true
28 return
29 }
30 id := string(req.Subscription)
31 cfg := s.cfg
32 authed := len(c.authedPubkey) > 0
33 wl := s.t.ConnIsWhitelisted(fd)
34
35 if cfg.AuthRequired && !authed && !wl {
36 cl := &envelope.Closed{
37 Subscription: []byte(id),
38 Reason: []byte("auth-required: authentication required"),
39 }
40 s.t.SendWS(fd, cl.Marshal(nil))
41 return
42 }
43
44 maxSubs := cfg.MaxSubscriptions
45 if maxSubs > 0 {
46 if _, exists := c.subs[id]; !exists && len(c.subs) >= maxSubs {
47 cl := &envelope.Closed{
48 Subscription: []byte(id),
49 Reason: []byte("error: too many subscriptions"),
50 }
51 s.t.SendWS(fd, cl.Marshal(nil))
52 return
53 }
54 }
55 filters := filter.S(*req.Filters)
56 reqCopy := []byte{:len(msg)}
57 copy(reqCopy, msg)
58 c.subs[id] = &sub{id: id, filters: filters, rawReq: reqCopy}
59 s.bcast.OnSubscribe(int32(fd), []byte(id), msg)
60
61 needsFilter := cfg.RelayURL != "" && !cfg.PrivilegedOpen && !wl
62 limit := cfg.QueryResultLimit
63 if limit <= 0 {
64 limit = 256
65 }
66
67 var events []*event.E
68 subID := req.Subscription
69 seen := map[string]bool{}
70 hasSearch := false
71 for _, f := range filters {
72 if len(f.Search) > 0 {
73 hasSearch = true
74 break
75 }
76 }
77 if hasSearch {
78 for _, f := range filters {
79 if len(f.Search) == 0 {
80 continue
81 }
82 results := s.eng.Search(f.Search, limit)
83 for _, ev := range results {
84 if seen[string(ev.ID)] {
85 continue
86 }
87 seen[string(ev.ID)] = true
88 if f.Matches(ev) {
89 events = append(events, ev)
90 }
91 }
92 }
93 } else {
94 for _, f := range filters {
95 evs, err := s.eng.QueryEvents(f)
96 if err != nil {
97 continue
98 }
99 for _, ev := range evs {
100 if seen[string(ev.ID)] {
101 continue
102 }
103 seen[string(ev.ID)] = true
104 events = append(events, ev)
105 }
106 }
107 }
108
109 count := 0
110 for _, ev := range events {
111 if count >= limit {
112 break
113 }
114 if needsFilter && !s.canSeeEventFD(fd, ev) {
115 continue
116 }
117 er := &envelope.EventResult{Subscription: subID, Event: ev}
118 if err := s.t.SendWSErr(fd, er.Marshal(nil)); err != nil {
119 s.t.CloseConn(fd)
120 return
121 }
122 count++
123 }
124 s.t.SendWS(fd, (&envelope.EOSE{Subscription: subID}).Marshal(nil))
125 }
126
127 func (s *Server) handleClose(fd int, msg []byte) {
128 c := s.conns[fd]
129 if c == nil {
130 return
131 }
132 _, rem, _ := envelope.Identify(msg)
133 var cl envelope.Close
134 if _, err := cl.Unmarshal(rem); err != nil {
135 return
136 }
137 delete(c.subs, string(cl.ID))
138 s.bcast.OnUnsubscribe(int32(fd), cl.ID)
139 }
140
141 func (s *Server) handleCount(fd int, msg []byte) {
142 c := s.conns[fd]
143 if c == nil {
144 return
145 }
146 if s.cfg.AuthRequired && len(c.authedPubkey) == 0 && !s.t.ConnIsWhitelisted(fd) {
147 return
148 }
149 _, rem, err := envelope.Identify(msg)
150 if err != nil {
151 return
152 }
153 filter.Tainted = false
154 var cr envelope.CountRequest
155 if _, err := cr.Unmarshal(rem); err != nil {
156 return
157 }
158 if filter.Tainted {
159 c.jailed = true
160 return
161 }
162 var total int
163 for _, f := range cr.Filters {
164 evs, err := s.eng.QueryEvents(f)
165 if err != nil {
166 continue
167 }
168 total += len(evs)
169 }
170 resp := (&envelope.CountResponse{
171 Subscription: cr.Subscription,
172 Count: total,
173 }).Marshal(nil)
174 s.t.SendWS(fd, resp)
175 }
176
177 func (s *Server) handleAuth(fd int, msg []byte) {
178 c := s.conns[fd]
179 if c == nil {
180 return
181 }
182 _, rem, _ := envelope.Identify(msg)
183 var auth envelope.AuthResponse
184 if _, err := auth.Unmarshal(rem); err != nil || auth.Event == nil {
185 ok := &envelope.OK{
186 EventID: []byte{:32},
187 OK: false,
188 Reason: []byte("error: failed to parse auth event"),
189 }
190 s.t.SendWS(fd, ok.Marshal(nil))
191 return
192 }
193 if auth.Event.Kind != 22242 {
194 s.authFail(fd, auth.Event.ID, "error: wrong event kind for auth")
195 return
196 }
197 ct := auth.Event.Tags.GetFirst([]byte("challenge"))
198 if ct == nil || !bytes.Equal(ct.Value(), c.challenge) {
199 s.authFail(fd, auth.Event.ID, "error: wrong challenge")
200 return
201 }
202 rt := auth.Event.Tags.GetFirst([]byte("relay"))
203 if rt == nil || len(rt.Value()) == 0 {
204 s.authFail(fd, auth.Event.ID, "error: missing relay tag")
205 return
206 }
207 if !relayURLMatch([]byte(s.cfg.RelayURL), rt.Value()) {
208 s.authFail(fd, auth.Event.ID, "error: relay URL mismatch")
209 return
210 }
211 now := time.Now().Unix()
212 if auth.Event.CreatedAt > now+600 || auth.Event.CreatedAt < now-600 {
213 s.authFail(fd, auth.Event.ID, "error: timestamp out of range")
214 return
215 }
216 valid, err := auth.Event.Verify()
217 if err != nil || !valid {
218 s.authFail(fd, auth.Event.ID, "error: invalid signature")
219 return
220 }
221 c.authedPubkey = []byte{:len(auth.Event.Pubkey)}
222 copy(c.authedPubkey, auth.Event.Pubkey)
223 s.bcast.OnAuth(int32(fd), c.authedPubkey)
224 ok := &envelope.OK{EventID: auth.Event.ID, OK: true}
225 s.t.SendWS(fd, ok.Marshal(nil))
226 }
227
228 func (s *Server) authFail(fd int, id []byte, reason string) {
229 ok := &envelope.OK{EventID: id, OK: false, Reason: []byte(reason)}
230 s.t.SendWS(fd, ok.Marshal(nil))
231 }
232
233 func (s *Server) canSeeEventFD(fd int, ev *event.E) bool {
234 c := s.conns[fd]
235 var authedPubkey []byte
236 if c != nil {
237 authedPubkey = c.authedPubkey
238 }
239 return access.CanSee(len(authedPubkey) > 0, authedPubkey, ev, s.cfg.NIP70Enforce, s.cfg.MarmotOpen)
240 }
241
242 func relayURLMatch(expected, found []byte) bool {
243 if bytes.Equal(expected, found) {
244 return true
245 }
246 e := makeCopy(expected)
247 f := makeCopy(found)
248 toLower(e)
249 toLower(f)
250 if len(e) > 0 && e[len(e)-1] == '/' {
251 e = e[:len(e)-1]
252 }
253 if len(f) > 0 && f[len(f)-1] == '/' {
254 f = f[:len(f)-1]
255 }
256 return bytes.Equal(stripScheme(e), stripScheme(f))
257 }
258
259 func stripScheme(u []byte) []byte {
260 if bytes.HasPrefix(u, []byte("wss://")) {
261 return u[6:]
262 }
263 if bytes.HasPrefix(u, []byte("ws://")) {
264 return u[5:]
265 }
266 return u
267 }
268