graph-thread.go raw

   1  package neo4j
   2  
   3  import (
   4  	"context"
   5  	"fmt"
   6  	"strings"
   7  
   8  	"next.orly.dev/pkg/nostr/encoders/hex"
   9  	"next.orly.dev/pkg/protocol/graph"
  10  )
  11  
  12  // TraverseThread performs BFS traversal of thread structure via e-tags.
  13  // Starting from a seed event, it finds all replies/references at each depth.
  14  //
  15  // The traversal works bidirectionally using REFERENCES relationships:
  16  //   - Inbound: Events that reference the seed (replies, reactions, reposts)
  17  //   - Outbound: Events that the seed references (parents, quoted posts)
  18  //
  19  // Note: REFERENCES relationships are only created if the referenced event exists
  20  // in the database at the time of saving. This means some references may be missing
  21  // if events were stored out of order.
  22  //
  23  // Parameters:
  24  //   - seedEventID: The event ID to start traversal from
  25  //   - maxDepth: Maximum depth to traverse
  26  //   - direction: "both" (default), "inbound" (replies to seed), "outbound" (seed's references)
  27  func (n *N) TraverseThread(seedEventID []byte, maxDepth int, direction string) (graph.GraphResultI, error) {
  28  	result := NewGraphResult()
  29  
  30  	if len(seedEventID) != 32 {
  31  		return result, fmt.Errorf("invalid event ID length: expected 32, got %d", len(seedEventID))
  32  	}
  33  
  34  	seedHex := strings.ToLower(hex.Enc(seedEventID))
  35  	ctx := context.Background()
  36  
  37  	// Normalize direction
  38  	if direction == "" {
  39  		direction = "both"
  40  	}
  41  
  42  	// Track visited events
  43  	visited := make(map[string]bool)
  44  	visited[seedHex] = true
  45  
  46  	// Process each depth level separately for BFS semantics
  47  	for depth := 1; depth <= maxDepth; depth++ {
  48  		newEventsAtDepth := 0
  49  
  50  		// Get events at current depth
  51  		visitedList := make([]string, 0, len(visited))
  52  		for id := range visited {
  53  			visitedList = append(visitedList, id)
  54  		}
  55  
  56  		// Process inbound references (events that reference the seed or its children)
  57  		if direction == "both" || direction == "inbound" {
  58  			inboundEvents, err := n.getInboundReferencesAtDepth(ctx, seedHex, depth, visitedList)
  59  			if err != nil {
  60  				n.Logger.Warningf("TraverseThread: error getting inbound refs at depth %d: %v", depth, err)
  61  			} else {
  62  				for _, eventID := range inboundEvents {
  63  					if !visited[eventID] {
  64  						visited[eventID] = true
  65  						result.AddEventAtDepth(eventID, depth)
  66  						newEventsAtDepth++
  67  					}
  68  				}
  69  			}
  70  		}
  71  
  72  		// Process outbound references (events that the seed or its children reference)
  73  		if direction == "both" || direction == "outbound" {
  74  			outboundEvents, err := n.getOutboundReferencesAtDepth(ctx, seedHex, depth, visitedList)
  75  			if err != nil {
  76  				n.Logger.Warningf("TraverseThread: error getting outbound refs at depth %d: %v", depth, err)
  77  			} else {
  78  				for _, eventID := range outboundEvents {
  79  					if !visited[eventID] {
  80  						visited[eventID] = true
  81  						result.AddEventAtDepth(eventID, depth)
  82  						newEventsAtDepth++
  83  					}
  84  				}
  85  			}
  86  		}
  87  
  88  		n.Logger.Debugf("TraverseThread: depth %d found %d new events", depth, newEventsAtDepth)
  89  
  90  		// Early termination if no new events found at this depth
  91  		if newEventsAtDepth == 0 {
  92  			break
  93  		}
  94  	}
  95  
  96  	n.Logger.Debugf("TraverseThread: completed with %d total events", result.TotalEvents)
  97  
  98  	return result, nil
  99  }
 100  
 101  // getInboundReferencesAtDepth finds events that reference the seed event at exactly the given depth.
 102  // Uses variable-length path patterns to find events N hops away.
 103  func (n *N) getInboundReferencesAtDepth(ctx context.Context, seedID string, depth int, visited []string) ([]string, error) {
 104  	// Query for events at exactly this depth that haven't been seen yet
 105  	// Direction: (referencing_event)-[:REFERENCES]->(seed)
 106  	// At depth 1: direct replies
 107  	// At depth 2: replies to replies, etc.
 108  	cypher := fmt.Sprintf(`
 109  		MATCH path = (ref:Event)-[:REFERENCES*%d]->(seed:Event {id: $seed})
 110  		WHERE ref.id <> $seed
 111  		  AND NOT ref.id IN $visited
 112  		RETURN DISTINCT ref.id AS event_id
 113  	`, depth)
 114  
 115  	params := map[string]any{
 116  		"seed":    seedID,
 117  		"visited": visited,
 118  	}
 119  
 120  	result, err := n.ExecuteRead(ctx, cypher, params)
 121  	if err != nil {
 122  		return nil, err
 123  	}
 124  
 125  	var events []string
 126  	for result.Next(ctx) {
 127  		record := result.Record()
 128  		eventID, ok := record.Values[0].(string)
 129  		if !ok || eventID == "" {
 130  			continue
 131  		}
 132  		events = append(events, strings.ToLower(eventID))
 133  	}
 134  
 135  	return events, nil
 136  }
 137  
 138  // getOutboundReferencesAtDepth finds events that the seed event references at exactly the given depth.
 139  // Uses variable-length path patterns to find events N hops away.
 140  func (n *N) getOutboundReferencesAtDepth(ctx context.Context, seedID string, depth int, visited []string) ([]string, error) {
 141  	// Query for events at exactly this depth that haven't been seen yet
 142  	// Direction: (seed)-[:REFERENCES]->(referenced_event)
 143  	// At depth 1: direct parents/quotes
 144  	// At depth 2: grandparents, etc.
 145  	cypher := fmt.Sprintf(`
 146  		MATCH path = (seed:Event {id: $seed})-[:REFERENCES*%d]->(ref:Event)
 147  		WHERE ref.id <> $seed
 148  		  AND NOT ref.id IN $visited
 149  		RETURN DISTINCT ref.id AS event_id
 150  	`, depth)
 151  
 152  	params := map[string]any{
 153  		"seed":    seedID,
 154  		"visited": visited,
 155  	}
 156  
 157  	result, err := n.ExecuteRead(ctx, cypher, params)
 158  	if err != nil {
 159  		return nil, err
 160  	}
 161  
 162  	var events []string
 163  	for result.Next(ctx) {
 164  		record := result.Record()
 165  		eventID, ok := record.Values[0].(string)
 166  		if !ok || eventID == "" {
 167  			continue
 168  		}
 169  		events = append(events, strings.ToLower(eventID))
 170  	}
 171  
 172  	return events, nil
 173  }
 174  
 175  // TraverseThreadFromHex is a convenience wrapper that accepts hex-encoded event ID.
 176  func (n *N) TraverseThreadFromHex(seedEventIDHex string, maxDepth int, direction string) (*GraphResult, error) {
 177  	seedEventID, err := hex.Dec(seedEventIDHex)
 178  	if err != nil {
 179  		return nil, err
 180  	}
 181  	result, err := n.TraverseThread(seedEventID, maxDepth, direction)
 182  	if err != nil {
 183  		return nil, err
 184  	}
 185  	return result.(*GraphResult), nil
 186  }
 187  
 188  // GetThreadReplies finds all direct replies to an event.
 189  // This is a convenience method that returns events at depth 1 with inbound direction.
 190  func (n *N) GetThreadReplies(eventID []byte, kinds []uint16) (*GraphResult, error) {
 191  	result := NewGraphResult()
 192  
 193  	if len(eventID) != 32 {
 194  		return result, fmt.Errorf("invalid event ID length: expected 32, got %d", len(eventID))
 195  	}
 196  
 197  	eventIDHex := strings.ToLower(hex.Enc(eventID))
 198  	ctx := context.Background()
 199  
 200  	// Build kinds filter if specified
 201  	var kindsFilter string
 202  	params := map[string]any{
 203  		"eventId": eventIDHex,
 204  	}
 205  
 206  	if len(kinds) > 0 {
 207  		kindsInt := make([]int64, len(kinds))
 208  		for i, k := range kinds {
 209  			kindsInt[i] = int64(k)
 210  		}
 211  		params["kinds"] = kindsInt
 212  		kindsFilter = "AND reply.kind IN $kinds"
 213  	}
 214  
 215  	// Query for direct replies
 216  	cypher := fmt.Sprintf(`
 217  		MATCH (reply:Event)-[:REFERENCES]->(e:Event {id: $eventId})
 218  		WHERE true %s
 219  		RETURN reply.id AS event_id
 220  		ORDER BY reply.created_at DESC
 221  	`, kindsFilter)
 222  
 223  	queryResult, err := n.ExecuteRead(ctx, cypher, params)
 224  	if err != nil {
 225  		return result, fmt.Errorf("failed to query replies: %w", err)
 226  	}
 227  
 228  	for queryResult.Next(ctx) {
 229  		record := queryResult.Record()
 230  		replyID, ok := record.Values[0].(string)
 231  		if !ok || replyID == "" {
 232  			continue
 233  		}
 234  		result.AddEventAtDepth(strings.ToLower(replyID), 1)
 235  	}
 236  
 237  	return result, nil
 238  }
 239  
 240  // GetThreadParents finds events that a given event references (its parents/quotes).
 241  func (n *N) GetThreadParents(eventID []byte) (*GraphResult, error) {
 242  	result := NewGraphResult()
 243  
 244  	if len(eventID) != 32 {
 245  		return result, fmt.Errorf("invalid event ID length: expected 32, got %d", len(eventID))
 246  	}
 247  
 248  	eventIDHex := strings.ToLower(hex.Enc(eventID))
 249  	ctx := context.Background()
 250  
 251  	params := map[string]any{
 252  		"eventId": eventIDHex,
 253  	}
 254  
 255  	// Query for events that this event references
 256  	cypher := `
 257  		MATCH (e:Event {id: $eventId})-[:REFERENCES]->(parent:Event)
 258  		RETURN parent.id AS event_id
 259  		ORDER BY parent.created_at ASC
 260  	`
 261  
 262  	queryResult, err := n.ExecuteRead(ctx, cypher, params)
 263  	if err != nil {
 264  		return result, fmt.Errorf("failed to query parents: %w", err)
 265  	}
 266  
 267  	for queryResult.Next(ctx) {
 268  		record := queryResult.Record()
 269  		parentID, ok := record.Values[0].(string)
 270  		if !ok || parentID == "" {
 271  			continue
 272  		}
 273  		result.AddEventAtDepth(strings.ToLower(parentID), 1)
 274  	}
 275  
 276  	return result, nil
 277  }
 278