graph-thread.go raw
1 package neo4j
2
3 import (
4 "context"
5 "fmt"
6 "strings"
7
8 "next.orly.dev/pkg/nostr/encoders/hex"
9 "next.orly.dev/pkg/protocol/graph"
10 )
11
12 // TraverseThread performs BFS traversal of thread structure via e-tags.
13 // Starting from a seed event, it finds all replies/references at each depth.
14 //
15 // The traversal works bidirectionally using REFERENCES relationships:
16 // - Inbound: Events that reference the seed (replies, reactions, reposts)
17 // - Outbound: Events that the seed references (parents, quoted posts)
18 //
19 // Note: REFERENCES relationships are only created if the referenced event exists
20 // in the database at the time of saving. This means some references may be missing
21 // if events were stored out of order.
22 //
23 // Parameters:
24 // - seedEventID: The event ID to start traversal from
25 // - maxDepth: Maximum depth to traverse
26 // - direction: "both" (default), "inbound" (replies to seed), "outbound" (seed's references)
27 func (n *N) TraverseThread(seedEventID []byte, maxDepth int, direction string) (graph.GraphResultI, error) {
28 result := NewGraphResult()
29
30 if len(seedEventID) != 32 {
31 return result, fmt.Errorf("invalid event ID length: expected 32, got %d", len(seedEventID))
32 }
33
34 seedHex := strings.ToLower(hex.Enc(seedEventID))
35 ctx := context.Background()
36
37 // Normalize direction
38 if direction == "" {
39 direction = "both"
40 }
41
42 // Track visited events
43 visited := make(map[string]bool)
44 visited[seedHex] = true
45
46 // Process each depth level separately for BFS semantics
47 for depth := 1; depth <= maxDepth; depth++ {
48 newEventsAtDepth := 0
49
50 // Get events at current depth
51 visitedList := make([]string, 0, len(visited))
52 for id := range visited {
53 visitedList = append(visitedList, id)
54 }
55
56 // Process inbound references (events that reference the seed or its children)
57 if direction == "both" || direction == "inbound" {
58 inboundEvents, err := n.getInboundReferencesAtDepth(ctx, seedHex, depth, visitedList)
59 if err != nil {
60 n.Logger.Warningf("TraverseThread: error getting inbound refs at depth %d: %v", depth, err)
61 } else {
62 for _, eventID := range inboundEvents {
63 if !visited[eventID] {
64 visited[eventID] = true
65 result.AddEventAtDepth(eventID, depth)
66 newEventsAtDepth++
67 }
68 }
69 }
70 }
71
72 // Process outbound references (events that the seed or its children reference)
73 if direction == "both" || direction == "outbound" {
74 outboundEvents, err := n.getOutboundReferencesAtDepth(ctx, seedHex, depth, visitedList)
75 if err != nil {
76 n.Logger.Warningf("TraverseThread: error getting outbound refs at depth %d: %v", depth, err)
77 } else {
78 for _, eventID := range outboundEvents {
79 if !visited[eventID] {
80 visited[eventID] = true
81 result.AddEventAtDepth(eventID, depth)
82 newEventsAtDepth++
83 }
84 }
85 }
86 }
87
88 n.Logger.Debugf("TraverseThread: depth %d found %d new events", depth, newEventsAtDepth)
89
90 // Early termination if no new events found at this depth
91 if newEventsAtDepth == 0 {
92 break
93 }
94 }
95
96 n.Logger.Debugf("TraverseThread: completed with %d total events", result.TotalEvents)
97
98 return result, nil
99 }
100
101 // getInboundReferencesAtDepth finds events that reference the seed event at exactly the given depth.
102 // Uses variable-length path patterns to find events N hops away.
103 func (n *N) getInboundReferencesAtDepth(ctx context.Context, seedID string, depth int, visited []string) ([]string, error) {
104 // Query for events at exactly this depth that haven't been seen yet
105 // Direction: (referencing_event)-[:REFERENCES]->(seed)
106 // At depth 1: direct replies
107 // At depth 2: replies to replies, etc.
108 cypher := fmt.Sprintf(`
109 MATCH path = (ref:Event)-[:REFERENCES*%d]->(seed:Event {id: $seed})
110 WHERE ref.id <> $seed
111 AND NOT ref.id IN $visited
112 RETURN DISTINCT ref.id AS event_id
113 `, depth)
114
115 params := map[string]any{
116 "seed": seedID,
117 "visited": visited,
118 }
119
120 result, err := n.ExecuteRead(ctx, cypher, params)
121 if err != nil {
122 return nil, err
123 }
124
125 var events []string
126 for result.Next(ctx) {
127 record := result.Record()
128 eventID, ok := record.Values[0].(string)
129 if !ok || eventID == "" {
130 continue
131 }
132 events = append(events, strings.ToLower(eventID))
133 }
134
135 return events, nil
136 }
137
138 // getOutboundReferencesAtDepth finds events that the seed event references at exactly the given depth.
139 // Uses variable-length path patterns to find events N hops away.
140 func (n *N) getOutboundReferencesAtDepth(ctx context.Context, seedID string, depth int, visited []string) ([]string, error) {
141 // Query for events at exactly this depth that haven't been seen yet
142 // Direction: (seed)-[:REFERENCES]->(referenced_event)
143 // At depth 1: direct parents/quotes
144 // At depth 2: grandparents, etc.
145 cypher := fmt.Sprintf(`
146 MATCH path = (seed:Event {id: $seed})-[:REFERENCES*%d]->(ref:Event)
147 WHERE ref.id <> $seed
148 AND NOT ref.id IN $visited
149 RETURN DISTINCT ref.id AS event_id
150 `, depth)
151
152 params := map[string]any{
153 "seed": seedID,
154 "visited": visited,
155 }
156
157 result, err := n.ExecuteRead(ctx, cypher, params)
158 if err != nil {
159 return nil, err
160 }
161
162 var events []string
163 for result.Next(ctx) {
164 record := result.Record()
165 eventID, ok := record.Values[0].(string)
166 if !ok || eventID == "" {
167 continue
168 }
169 events = append(events, strings.ToLower(eventID))
170 }
171
172 return events, nil
173 }
174
175 // TraverseThreadFromHex is a convenience wrapper that accepts hex-encoded event ID.
176 func (n *N) TraverseThreadFromHex(seedEventIDHex string, maxDepth int, direction string) (*GraphResult, error) {
177 seedEventID, err := hex.Dec(seedEventIDHex)
178 if err != nil {
179 return nil, err
180 }
181 result, err := n.TraverseThread(seedEventID, maxDepth, direction)
182 if err != nil {
183 return nil, err
184 }
185 return result.(*GraphResult), nil
186 }
187
188 // GetThreadReplies finds all direct replies to an event.
189 // This is a convenience method that returns events at depth 1 with inbound direction.
190 func (n *N) GetThreadReplies(eventID []byte, kinds []uint16) (*GraphResult, error) {
191 result := NewGraphResult()
192
193 if len(eventID) != 32 {
194 return result, fmt.Errorf("invalid event ID length: expected 32, got %d", len(eventID))
195 }
196
197 eventIDHex := strings.ToLower(hex.Enc(eventID))
198 ctx := context.Background()
199
200 // Build kinds filter if specified
201 var kindsFilter string
202 params := map[string]any{
203 "eventId": eventIDHex,
204 }
205
206 if len(kinds) > 0 {
207 kindsInt := make([]int64, len(kinds))
208 for i, k := range kinds {
209 kindsInt[i] = int64(k)
210 }
211 params["kinds"] = kindsInt
212 kindsFilter = "AND reply.kind IN $kinds"
213 }
214
215 // Query for direct replies
216 cypher := fmt.Sprintf(`
217 MATCH (reply:Event)-[:REFERENCES]->(e:Event {id: $eventId})
218 WHERE true %s
219 RETURN reply.id AS event_id
220 ORDER BY reply.created_at DESC
221 `, kindsFilter)
222
223 queryResult, err := n.ExecuteRead(ctx, cypher, params)
224 if err != nil {
225 return result, fmt.Errorf("failed to query replies: %w", err)
226 }
227
228 for queryResult.Next(ctx) {
229 record := queryResult.Record()
230 replyID, ok := record.Values[0].(string)
231 if !ok || replyID == "" {
232 continue
233 }
234 result.AddEventAtDepth(strings.ToLower(replyID), 1)
235 }
236
237 return result, nil
238 }
239
240 // GetThreadParents finds events that a given event references (its parents/quotes).
241 func (n *N) GetThreadParents(eventID []byte) (*GraphResult, error) {
242 result := NewGraphResult()
243
244 if len(eventID) != 32 {
245 return result, fmt.Errorf("invalid event ID length: expected 32, got %d", len(eventID))
246 }
247
248 eventIDHex := strings.ToLower(hex.Enc(eventID))
249 ctx := context.Background()
250
251 params := map[string]any{
252 "eventId": eventIDHex,
253 }
254
255 // Query for events that this event references
256 cypher := `
257 MATCH (e:Event {id: $eventId})-[:REFERENCES]->(parent:Event)
258 RETURN parent.id AS event_id
259 ORDER BY parent.created_at ASC
260 `
261
262 queryResult, err := n.ExecuteRead(ctx, cypher, params)
263 if err != nil {
264 return result, fmt.Errorf("failed to query parents: %w", err)
265 }
266
267 for queryResult.Next(ctx) {
268 record := queryResult.Record()
269 parentID, ok := record.Values[0].(string)
270 if !ok || parentID == "" {
271 continue
272 }
273 result.AddEventAtDepth(strings.ToLower(parentID), 1)
274 }
275
276 return result, nil
277 }
278