graph-thread.go raw

   1  //go:build !(js && wasm)
   2  
   3  package database
   4  
   5  import (
   6  	"next.orly.dev/pkg/lol/log"
   7  	"next.orly.dev/pkg/database/indexes/types"
   8  	"next.orly.dev/pkg/nostr/encoders/hex"
   9  )
  10  
  11  // TraverseThread performs BFS traversal of thread structure via e-tags.
  12  // Starting from a seed event, it finds all replies/references at each depth.
  13  //
  14  // The traversal works bidirectionally:
  15  // - Forward: Events that the seed references (parents, quoted posts)
  16  // - Backward: Events that reference the seed (replies, reactions, reposts)
  17  //
  18  // Parameters:
  19  // - seedEventID: The event ID to start traversal from
  20  // - maxDepth: Maximum depth to traverse
  21  // - direction: "both" (default), "inbound" (replies to seed), "outbound" (seed's references)
  22  func (d *D) TraverseThread(seedEventID []byte, maxDepth int, direction string) (*GraphResult, error) {
  23  	result := NewGraphResult()
  24  
  25  	if len(seedEventID) != 32 {
  26  		return result, ErrEventNotFound
  27  	}
  28  
  29  	// Get seed event serial
  30  	seedSerial, err := d.GetSerialById(seedEventID)
  31  	if err != nil {
  32  		log.D.F("TraverseThread: seed event not in database: %s", hex.Enc(seedEventID))
  33  		return result, nil
  34  	}
  35  
  36  	// Normalize direction
  37  	if direction == "" {
  38  		direction = "both"
  39  	}
  40  
  41  	// Track visited events
  42  	visited := make(map[uint64]bool)
  43  	visited[seedSerial.Get()] = true
  44  
  45  	// Current frontier
  46  	currentFrontier := []*types.Uint40{seedSerial}
  47  
  48  	consecutiveEmptyDepths := 0
  49  
  50  	for currentDepth := 1; currentDepth <= maxDepth; currentDepth++ {
  51  		var nextFrontier []*types.Uint40
  52  		newEventsAtDepth := 0
  53  
  54  		for _, eventSerial := range currentFrontier {
  55  			// Get inbound references (events that reference this event)
  56  			if direction == "both" || direction == "inbound" {
  57  				inboundSerials, err := d.GetReferencingEvents(eventSerial, nil)
  58  				if err != nil {
  59  					log.D.F("TraverseThread: error getting inbound refs for serial %d: %v", eventSerial.Get(), err)
  60  				} else {
  61  					for _, refSerial := range inboundSerials {
  62  						if visited[refSerial.Get()] {
  63  							continue
  64  						}
  65  						visited[refSerial.Get()] = true
  66  
  67  						eventIDHex, err := d.GetEventIDFromSerial(refSerial)
  68  						if err != nil {
  69  							continue
  70  						}
  71  
  72  						result.AddEventAtDepth(eventIDHex, currentDepth)
  73  						newEventsAtDepth++
  74  						nextFrontier = append(nextFrontier, refSerial)
  75  					}
  76  				}
  77  			}
  78  
  79  			// Get outbound references (events this event references)
  80  			if direction == "both" || direction == "outbound" {
  81  				outboundSerials, err := d.GetETagsFromEventSerial(eventSerial)
  82  				if err != nil {
  83  					log.D.F("TraverseThread: error getting outbound refs for serial %d: %v", eventSerial.Get(), err)
  84  				} else {
  85  					for _, refSerial := range outboundSerials {
  86  						if visited[refSerial.Get()] {
  87  							continue
  88  						}
  89  						visited[refSerial.Get()] = true
  90  
  91  						eventIDHex, err := d.GetEventIDFromSerial(refSerial)
  92  						if err != nil {
  93  							continue
  94  						}
  95  
  96  						result.AddEventAtDepth(eventIDHex, currentDepth)
  97  						newEventsAtDepth++
  98  						nextFrontier = append(nextFrontier, refSerial)
  99  					}
 100  				}
 101  			}
 102  		}
 103  
 104  		log.T.F("TraverseThread: depth %d found %d new events", currentDepth, newEventsAtDepth)
 105  
 106  		if newEventsAtDepth == 0 {
 107  			consecutiveEmptyDepths++
 108  			if consecutiveEmptyDepths >= 2 {
 109  				break
 110  			}
 111  		} else {
 112  			consecutiveEmptyDepths = 0
 113  		}
 114  
 115  		currentFrontier = nextFrontier
 116  	}
 117  
 118  	log.D.F("TraverseThread: completed with %d total events", result.TotalEvents)
 119  
 120  	return result, nil
 121  }
 122  
 123  // TraverseThreadFromHex is a convenience wrapper that accepts hex-encoded event ID.
 124  func (d *D) TraverseThreadFromHex(seedEventIDHex string, maxDepth int, direction string) (*GraphResult, error) {
 125  	seedEventID, err := hex.Dec(seedEventIDHex)
 126  	if err != nil {
 127  		return nil, err
 128  	}
 129  	return d.TraverseThread(seedEventID, maxDepth, direction)
 130  }
 131  
 132  // GetThreadReplies finds all direct replies to an event.
 133  // This is a convenience method that returns events at depth 1 with inbound direction.
 134  func (d *D) GetThreadReplies(eventID []byte, kinds []uint16) (*GraphResult, error) {
 135  	result := NewGraphResult()
 136  
 137  	if len(eventID) != 32 {
 138  		return result, ErrEventNotFound
 139  	}
 140  
 141  	eventSerial, err := d.GetSerialById(eventID)
 142  	if err != nil {
 143  		return result, nil
 144  	}
 145  
 146  	// Get events that reference this event
 147  	replySerials, err := d.GetReferencingEvents(eventSerial, kinds)
 148  	if err != nil {
 149  		return nil, err
 150  	}
 151  
 152  	for _, replySerial := range replySerials {
 153  		eventIDHex, err := d.GetEventIDFromSerial(replySerial)
 154  		if err != nil {
 155  			continue
 156  		}
 157  		result.AddEventAtDepth(eventIDHex, 1)
 158  	}
 159  
 160  	return result, nil
 161  }
 162  
 163  // GetThreadParents finds events that a given event references (its parents/quotes).
 164  func (d *D) GetThreadParents(eventID []byte) (*GraphResult, error) {
 165  	result := NewGraphResult()
 166  
 167  	if len(eventID) != 32 {
 168  		return result, ErrEventNotFound
 169  	}
 170  
 171  	eventSerial, err := d.GetSerialById(eventID)
 172  	if err != nil {
 173  		return result, nil
 174  	}
 175  
 176  	// Get events that this event references
 177  	parentSerials, err := d.GetETagsFromEventSerial(eventSerial)
 178  	if err != nil {
 179  		return nil, err
 180  	}
 181  
 182  	for _, parentSerial := range parentSerials {
 183  		eventIDHex, err := d.GetEventIDFromSerial(parentSerial)
 184  		if err != nil {
 185  			continue
 186  		}
 187  		result.AddEventAtDepth(eventIDHex, 1)
 188  	}
 189  
 190  	return result, nil
 191  }
 192