walk.mx raw

   1  package iskra
   2  
   3  import "git.smesh.lol/iskradb/lattice"
   4  
   5  // WalkKind classifies the intent of a walk step.
   6  type WalkKind uint8
   7  
   8  const (
   9  	WalkTranslate WalkKind = 0 // width stays constant
  10  	WalkElaborate WalkKind = 1 // width grows toward MaxWidth
  11  	WalkNarrow    WalkKind = 2 // width shrinks toward MinWidth
  12  )
  13  
  14  // WalkState is a single position in an in-progress lattice walk.
  15  type WalkState struct {
  16  	Key    lattice.Key // current lattice position
  17  	Domain uint8       // current domain
  18  	Coord  uint64      // accumulated coord context
  19  	Word   string      // surface form at this position
  20  	Score  uint32      // accumulated confidence (higher = better)
  21  	Depth  uint8       // steps taken
  22  }
  23  
  24  // WalkCandidate is a scored next-step from WalkStep.
  25  type WalkCandidate struct {
  26  	State  WalkState
  27  	Weight uint32 // bigram weight that produced this candidate
  28  }
  29  
  30  // WalkStep computes the set of candidate next positions from a current state.
  31  // Uses BigramIdx for O(1) lookup of continuations from state.Word, then scores
  32  // each by bigram weight with coord relaxation. Returns candidates ordered by
  33  // descending weight, capped at maxCandidates.
  34  func WalkStep(t *Tree, state WalkState, maxCandidates int32) []WalkCandidate {
  35  	idxEntries := t.BigramIdx[state.Word]
  36  	if len(idxEntries) == 0 {
  37  		return nil
  38  	}
  39  
  40  	var candidates []WalkCandidate
  41  	prefix := state.Word | "|"
  42  
  43  	for _, ri := range idxEntries {
  44  		if int32(ri) >= len(t.RecMeta) {
  45  			continue
  46  		}
  47  		meta := &t.RecMeta[ri]
  48  		if meta.Count == 0 {
  49  			continue
  50  		}
  51  		rec := t.db.GetRecord(ri)
  52  		if rec == nil {
  53  			continue
  54  		}
  55  		form := FormFromRecord(rec, t.StringPool)
  56  		if len(form) <= len(prefix) {
  57  			continue
  58  		}
  59  		nextWord := form[len(prefix):]
  60  
  61  		w, matchedCoord := BigramWeightRelaxed(t, state.Domain, state.Coord, state.Word, nextWord)
  62  		if w == 0 {
  63  			w = meta.Count
  64  			matchedCoord = 0
  65  		}
  66  
  67  		nextKey := MakeKey(state.Domain, matchedCoord, nextWord)
  68  		candidates = append(candidates, WalkCandidate{
  69  			State: WalkState{
  70  				Key:    nextKey,
  71  				Domain: state.Domain,
  72  				Coord:  matchedCoord,
  73  				Word:   nextWord,
  74  				Score:  state.Score + w,
  75  				Depth:  state.Depth + 1,
  76  			},
  77  			Weight: w,
  78  		})
  79  	}
  80  
  81  	sortCandidates(candidates)
  82  	if maxCandidates > 0 && len(candidates) > maxCandidates {
  83  		candidates = candidates[:maxCandidates]
  84  	}
  85  	return candidates
  86  }
  87  
  88  // WalkStepCrossDomain performs a walk step that crosses from srcDomain to dstDomain.
  89  // This is the "translate" step: find the best match for state.Word in dstDomain
  90  // using coord-relaxation and cross-domain links.
  91  func WalkStepCrossDomain(t *Tree, state WalkState, dstDomain uint8) WalkState {
  92  	src := NewSubLattice(t.db, t.StringPool, state.Domain, t.Reg)
  93  	dst := NewSubLattice(t.db, t.StringPool, dstDomain, t.Reg)
  94  	translated := Translate(src, dst, state.Word, state.Coord)
  95  	if translated == "" {
  96  		translated = state.Word
  97  	}
  98  	newKey := MakeKey(dstDomain, state.Coord, translated)
  99  	return WalkState{
 100  		Key:    newKey,
 101  		Domain: dstDomain,
 102  		Coord:  state.Coord,
 103  		Word:   translated,
 104  		Score:  state.Score,
 105  		Depth:  state.Depth + 1,
 106  	}
 107  }
 108  
 109  // Beam holds the top-K active walks in progress.
 110  type Beam struct {
 111  	Walks    []WalkState
 112  	Width    int32
 113  	MaxWidth int32
 114  	MinWidth int32
 115  	Kind     WalkKind
 116  }
 117  
 118  // NewBeam creates a beam with the initial state.
 119  func NewBeam(initial WalkState, width int32, kind WalkKind) *Beam {
 120  	return &Beam{
 121  		Walks:    []WalkState{initial},
 122  		Width:    width,
 123  		MaxWidth: width * 4,
 124  		MinWidth: 1,
 125  		Kind:     kind,
 126  	}
 127  }
 128  
 129  // Step advances all walks in the beam by one position.
 130  // For WalkElaborate: width grows by 1 each step (up to MaxWidth).
 131  // For WalkNarrow: width shrinks by 1 each step (down to MinWidth).
 132  // For WalkTranslate: width stays constant.
 133  func (b *Beam) Step(t *Tree) {
 134  	var allCandidates []WalkCandidate
 135  
 136  	for _, walk := range b.Walks {
 137  		candidates := WalkStep(t, walk, b.Width)
 138  		allCandidates = append(allCandidates, candidates...)
 139  	}
 140  
 141  	if len(allCandidates) == 0 {
 142  		return
 143  	}
 144  
 145  	sortCandidates(allCandidates)
 146  
 147  	switch b.Kind {
 148  	case WalkElaborate:
 149  		if b.Width < b.MaxWidth {
 150  			b.Width++
 151  		}
 152  	case WalkNarrow:
 153  		if b.Width > b.MinWidth {
 154  			b.Width--
 155  		}
 156  	}
 157  
 158  	limit := b.Width
 159  	if len(allCandidates) < limit {
 160  		limit = len(allCandidates)
 161  	}
 162  	b.Walks = []WalkState{:0:limit}
 163  	for i := 0; i < limit; i++ {
 164  		b.Walks = append(b.Walks, allCandidates[i].State)
 165  	}
 166  }
 167  
 168  // Best returns the highest-scoring walk in the beam.
 169  func (b *Beam) Best() WalkState {
 170  	if len(b.Walks) == 0 {
 171  		return WalkState{}
 172  	}
 173  	best := b.Walks[0]
 174  	for i := 1; i < len(b.Walks); i++ {
 175  		if b.Walks[i].Score > best.Score {
 176  			best = b.Walks[i]
 177  		}
 178  	}
 179  	return best
 180  }
 181  
 182  // Run executes the beam for maxSteps iterations, returning the best final state.
 183  func (b *Beam) Run(t *Tree, maxSteps int32) WalkState {
 184  	for step := 0; step < maxSteps; step++ {
 185  		b.Step(t)
 186  		if len(b.Walks) == 0 {
 187  			break
 188  		}
 189  	}
 190  	return b.Best()
 191  }
 192  
 193  // sortCandidates orders by descending weight (highest first).
 194  func sortCandidates(cs []WalkCandidate) {
 195  	for i := 1; i < len(cs); i++ {
 196  		j := i
 197  		for j > 0 && cs[j].Weight > cs[j-1].Weight {
 198  			cs[j], cs[j-1] = cs[j-1], cs[j]
 199  			j--
 200  		}
 201  	}
 202  }
 203  
 204