cooccur.mx raw

   1  package ingest
   2  
   3  import (
   4  	"bufio"
   5  	"fmt"
   6  	"io"
   7  	"os"
   8  	"slices"
   9  
  10  	"git.smesh.lol/iskradb/lattice"
  11  	"git.smesh.lol/transdb"
  12  )
  13  
  14  // CooccurConfig controls co-occurrence counting and PMI filtering.
  15  type CooccurConfig struct {
  16  	MinCooc             uint32  // minimum times a pair must co-occur (default 3)
  17  	PMIMin              float64 // minimum PMI score (default 2.0 ≈ 4× expected)
  18  	MaxPairs            int     // maximum new pairs to insert (0 = no limit)
  19  	MaxPairsPerSentence int     // skip sentences whose EN×JA product exceeds this (default 40)
  20  	JAWordlist          string  // path to pre-built JA wordlist file (empty = build from lattice)
  21  }
  22  
  23  func DefaultCooccurConfig() CooccurConfig {
  24  	return CooccurConfig{MinCooc: 3, PMIMin: 2.0, MaxPairs: 0, MaxPairsPerSentence: 40}
  25  }
  26  
  27  // ExtendFromSentences reads parallel EN/JA sentence files, computes PMI
  28  // co-occurrence scores, and inserts high-confidence pairs into db that
  29  // are not already covered by existing links.
  30  func ExtendFromSentences(db *DB, enPath, jaPath string, cfg CooccurConfig, verbose bool) (int, error) {
  31  	// Phase 1: build set of valid JA forms.
  32  	// Use pre-built wordlist file if provided (faster); otherwise scan lattice.
  33  	var validJA map[string]uint32
  34  	if cfg.JAWordlist != "" {
  35  		wl, err := transdb.LoadWordlist(cfg.JAWordlist)
  36  		if err != nil {
  37  			return 0, fmt.Errorf("load wordlist: %w", err)
  38  		}
  39  		validJA = wl
  40  		if verbose {
  41  			fmt.Fprintf(os.Stderr, "extend: %d valid JA forms from wordlist\n", len(validJA))
  42  		}
  43  	} else {
  44  		validJA = buildValidJASet(db)
  45  		if verbose {
  46  			fmt.Fprintf(os.Stderr, "extend: %d valid JA forms from lattice\n", len(validJA))
  47  		}
  48  	}
  49  
  50  	// Phase 2: stream sentence pairs, count co-occurrences.
  51  	cooc, enFreq, jaFreq, total, err := countCooc(db.Tree, enPath, jaPath, validJA, cfg.MaxPairsPerSentence, verbose)
  52  	if err != nil {
  53  		return 0, err
  54  	}
  55  	if verbose {
  56  		fmt.Fprintf(os.Stderr, "extend: %d sentence pairs, %d unique co-occurrence pairs\n",
  57  			total, len(cooc))
  58  	}
  59  
  60  	// Phase 3: score with PMI, collect high-scoring candidates.
  61  	type Candidate struct {
  62  		EN    string
  63  		ENCtx uint64 // 22-bit coord (cooccurrence axis only for now)
  64  		JA    string
  65  		PMI   float64
  66  	}
  67  	var candidates []Candidate
  68  	for pair, cnt := range cooc {
  69  		if cnt < cfg.MinCooc {
  70  			continue
  71  		}
  72  		en, enCtx, ja := splitPairCtx(pair)
  73  		pmi := pmiScore(cnt, enFreq[en], jaFreq[ja], uint32(total))
  74  		if pmi >= cfg.PMIMin {
  75  			candidates = append(candidates, Candidate{en, enCtx, ja, pmi})
  76  		}
  77  	}
  78  	// Sort descending by PMI.
  79  	slices.SortFunc(candidates, func(a, b Candidate) int {
  80  		if a.PMI > b.PMI {
  81  			return -1
  82  		}
  83  		if a.PMI < b.PMI {
  84  			return 1
  85  		}
  86  		return 0
  87  	})
  88  	if verbose {
  89  		fmt.Fprintf(os.Stderr, "extend: %d candidates above PMI %.1f\n", len(candidates), cfg.PMIMin)
  90  	}
  91  
  92  	// Phase 4: insert new pairs into lattice.
  93  	inserted := 0
  94  	for _, c := range candidates {
  95  		if cfg.MaxPairs > 0 && inserted >= cfg.MaxPairs {
  96  			break
  97  		}
  98  		if insertCooccurPair(db, c.EN, c.ENCtx, c.JA) {
  99  			inserted++
 100  		}
 101  	}
 102  
 103  	// Phase 5: accumulate corpus evidence counts in JA.DataLen.
 104  	// For each high-PMI (EN, JA) pair, find the JA record and increment its
 105  	// DataLen. JA.DataLen = total co-occurrence evidence across all EN partners.
 106  	// Stored on the JA record so every candidate in a rerank comparison carries
 107  	// its own evidence — challengers are not disadvantaged vs the current Link[0].
 108  	// Only inline JA records (DataFile==0, form ≤23 bytes) are counted.
 109  	// Accumulates across corpus re-runs.
 110  	confirmed := 0
 111  	for pair, cnt := range cooc {
 112  		if cnt < cfg.MinCooc {
 113  			continue
 114  		}
 115  		_, _, ja := splitPairCtx(pair)
 116  		jaKey := transdb.MakeKey(transdb.LangJA, 0, ja)
 117  		for _, b := range transdb.ActiveBranches {
 118  			jaRI := db.Tree.LookupRecIdx(lattice.Branch(b), jaKey)
 119  			if jaRI == lattice.NullRec {
 120  				continue
 121  			}
 122  			jaRec := db.Tree.GetRecord(jaRI)
 123  			if jaRec == nil || jaRec.DataFile != 0 {
 124  				break // overflow — DataLen is byte length, don't touch
 125  			}
 126  			if jaRec.DataLen < 0xFFFFFFFF {
 127  				jaRec.DataLen += cnt
 128  			}
 129  			confirmed++
 130  			break
 131  		}
 132  	}
 133  	if verbose && confirmed > 0 {
 134  		fmt.Fprintf(os.Stderr, "extend: %d JA records gained corpus evidence counts\n", confirmed)
 135  	}
 136  
 137  	return inserted, nil
 138  }
 139  
 140  // buildValidJASet collects all JA surface forms from the existing lattice
 141  // into a map[form]recIdx for fast substring matching.
 142  // Language is detected from the form content (JA = hiragana/katakana/CJK).
 143  func buildValidJASet(db *DB) map[string]uint32 {
 144  	valid := map[string]uint32{}
 145  	for recIdx := range db.Tree.RecKey {
 146  		rec := db.Tree.GetRecord(recIdx)
 147  		if rec == nil {
 148  			continue
 149  		}
 150  		form := transdb.FormFromInline(rec, db.StringPool)
 151  		if form != "" && transdb.Detect(form) == transdb.LangJA {
 152  			valid[form] = recIdx
 153  		}
 154  	}
 155  	return valid
 156  }
 157  
 158  // countCooc streams two parallel files line-by-line and counts
 159  // co-occurrences between EN tokens (with POS trigram context) and JA substrings.
 160  func countCooc(tree *lattice.Tree, enPath, jaPath string, validJA map[string]uint32, maxPairsPerSentence int, verbose bool) (
 161  	cooc map[string]uint32, enFreq map[string]uint32, jaFreq map[string]uint32, total int, err error) {
 162  
 163  	enF, err := os.Open(enPath)
 164  	if err != nil {
 165  		return nil, nil, nil, 0, fmt.Errorf("open %s: %w", enPath, err)
 166  	}
 167  	defer enF.Close()
 168  
 169  	jaF, err := os.Open(jaPath)
 170  	if err != nil {
 171  		return nil, nil, nil, 0, fmt.Errorf("open %s: %w", jaPath, err)
 172  	}
 173  	defer jaF.Close()
 174  
 175  	cooc = map[string]uint32{}
 176  	enFreq = map[string]uint32{}
 177  	jaFreq = map[string]uint32{}
 178  
 179  	enSc := bufio.NewScanner(enF)
 180  	jaSc := bufio.NewScanner(jaF)
 181  
 182  	// Prune every pruneInterval sentences: evict pairs seen < 2 times.
 183  	// More frequent = lower memory, slightly less recall on rare-but-valid pairs.
 184  	const pruneInterval = 10000
 185  
 186  	logInterval := 100000
 187  	for enSc.Scan() && jaSc.Scan() {
 188  		enLine := enSc.Text()
 189  		jaLine := jaSc.Text()
 190  		total++
 191  
 192  		if verbose && total%logInterval == 0 {
 193  			fmt.Fprintf(os.Stderr, "extend: processed %d sentence pairs... (cooc map: %d entries)\n",
 194  				total, len(cooc))
 195  		}
 196  
 197  		if total%pruneInterval == 0 {
 198  			before := len(cooc)
 199  			for k, v := range cooc {
 200  				if v < 2 {
 201  					delete(cooc, k)
 202  				}
 203  			}
 204  			if verbose {
 205  				fmt.Fprintf(os.Stderr, "extend: pruned cooc map %d→%d entries\n", before, len(cooc))
 206  			}
 207  		}
 208  
 209  		enToks := tokenizeENSentence(enLine)
 210  		jaToks := extractJATokens(jaLine, validJA)
 211  
 212  		// Skip degenerate pairs.
 213  		if len(enToks) == 0 || len(jaToks) == 0 {
 214  			continue
 215  		}
 216  		// Skip sentences whose cartesian product is too large.
 217  		if len(enToks)*len(jaToks) > maxPairsPerSentence {
 218  			continue
 219  		}
 220  
 221  		// Compute POS context for each EN token position (overlapping trigram window).
 222  		enPOS := []uint8{:len(enToks):len(enToks)}
 223  		for i, t := range enToks {
 224  			enPOS[i] = transdb.POSForWord(tree, transdb.LangEN, t)
 225  		}
 226  
 227  		// Dedup JA tokens within sentence.
 228  		jaSeen := map[string]bool{}
 229  		for _, t := range jaToks {
 230  			jaSeen[t] = true
 231  		}
 232  
 233  		// EN frequency: count unique words (not per-position to avoid inflation).
 234  		enCounted := map[string]bool{}
 235  		for _, t := range enToks {
 236  			if !enCounted[t] {
 237  				enFreq[t]++
 238  				enCounted[t] = true
 239  			}
 240  		}
 241  		for j := range jaSeen {
 242  			jaFreq[j]++
 243  		}
 244  
 245  		// Count co-occurrences per EN position (with overlapping trigram ctx).
 246  		enPosSeen := map[string]bool{} // dedup (enWord+ctx, jaWord) within sentence
 247  		for i, e := range enToks {
 248  			var prev, next uint8
 249  			if i > 0 {
 250  				prev = enPOS[i-1]
 251  			}
 252  			if i+1 < len(enPOS) {
 253  				next = enPOS[i+1]
 254  			}
 255  			// Cooccurrence axis: (prev_type, next_type) packed into coord.
 256  			// cur position (enPOS[i]) is implicit in the word's grammatical axis.
 257  			cooccur := transdb.CoordCooccur(prev, next)
 258  			ctx := transdb.PackCoord(0, 0, cooccur, 0, 0, 0, 0)
 259  			for j := range jaSeen {
 260  				k := joinPairCtx(ctx, e, j)
 261  				if !enPosSeen[k] {
 262  					enPosSeen[k] = true
 263  					cooc[k]++
 264  				}
 265  			}
 266  		}
 267  	}
 268  	if err = enSc.Err(); err != nil {
 269  		return nil, nil, nil, total, fmt.Errorf("scan %s: %w", enPath, err)
 270  	}
 271  	_ = io.EOF // suppress unused import warning
 272  	return cooc, enFreq, jaFreq, total, nil
 273  }
 274  
 275  // enStopWords are high-frequency English function words that co-occur with
 276  // everything and produce no translation signal — only noise and memory pressure.
 277  var enStopWords = map[string]bool{
 278  	"the": true, "a": true, "an": true, "is": true, "are": true,
 279  	"was": true, "were": true, "be": true, "been": true, "being": true,
 280  	"have": true, "has": true, "had": true, "do": true, "does": true,
 281  	"did": true, "will": true, "would": true, "could": true, "should": true,
 282  	"may": true, "might": true, "shall": true, "can": true,
 283  	"of": true, "in": true, "to": true, "for": true, "on": true,
 284  	"at": true, "by": true, "with": true, "from": true, "as": true,
 285  	"and": true, "or": true, "but": true, "not": true, "no": true,
 286  	"it": true, "its": true, "this": true, "that": true, "these": true,
 287  	"those": true, "he": true, "she": true, "we": true, "they": true,
 288  	"you": true, "me": true, "him": true, "her": true, "us": true,
 289  	"them": true, "my": true, "your": true, "his": true, "our": true,
 290  	"their": true, "what": true, "which": true, "who": true, "all": true,
 291  	"if": true, "so": true, "up": true, "out": true, "just": true,
 292  	"also": true, "than": true, "when": true, "where": true, "how": true,
 293  	"why": true, "about": true, "into": true, "then": true, "now": true,
 294  	"here": true, "there": true, "some": true, "any": true, "more": true,
 295  }
 296  
 297  // tokenizeENSentence splits an English subtitle line into lowercase tokens,
 298  // skipping stop words and tokens shorter than 3 chars.
 299  func tokenizeENSentence(line string) []string {
 300  	var tokens []string
 301  	var cur []byte
 302  	for i := 0; i < len(line); i++ {
 303  		c := line[i]
 304  		if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') {
 305  			if c >= 'A' && c <= 'Z' {
 306  				c += 32
 307  			}
 308  			cur = append(cur, c)
 309  		} else {
 310  			if len(cur) >= 3 {
 311  				tok := string(append([]byte(nil), cur...))
 312  				if !enStopWords[tok] {
 313  					tokens = append(tokens, tok)
 314  				}
 315  			}
 316  			cur = cur[:0]
 317  		}
 318  	}
 319  	if len(cur) >= 3 {
 320  		tok := string(append([]byte(nil), cur...))
 321  		if !enStopWords[tok] {
 322  			tokens = append(tokens, tok)
 323  		}
 324  	}
 325  	return tokens
 326  }
 327  
 328  // extractJATokens finds all substrings of line (up to 8 codepoints)
 329  // that exist as valid JA forms in the lattice.
 330  //
 331  // Uses byte-offset iteration because Moxie's range yields bytes not runes
 332  // and []rune(string) does not decode UTF-8.  CJK/kana are 3 bytes each;
 333  // utf8Start computes the correct byte-length per codepoint from the first byte.
 334  func extractJATokens(line string, validJA map[string]uint32) []string {
 335  	// Build byte offsets for each codepoint boundary.
 336  	offsets := []int{:0:len(line)/3 + 1}
 337  	i := 0
 338  	for i < len(line) {
 339  		offsets = append(offsets, i)
 340  		i += utf8CharLen(line[i])
 341  	}
 342  	offsets = append(offsets, len(line))
 343  
 344  	seen := map[string]bool{}
 345  	var tokens []string
 346  	maxCodepoints := 20
 347  	minCodepoints := 2 // skip single-char JA (particles: は, が, を, に, の…)
 348  	n := len(offsets) - 1 // number of codepoints
 349  	for start := 0; start < n; start++ {
 350  		for l := minCodepoints; l <= maxCodepoints && start+l <= n; l++ {
 351  			sub := line[offsets[start]:offsets[start+l]]
 352  			if _, ok := validJA[sub]; ok && !seen[sub] {
 353  				// Copy: sub is a slice of line which aliases the scanner buffer.
 354  				tokens = append(tokens, string(append([]byte(nil), []byte(sub)...)))
 355  				seen[sub] = true
 356  			}
 357  		}
 358  	}
 359  	return tokens
 360  }
 361  
 362  // utf8CharLen returns the byte length of the UTF-8 codepoint starting at b.
 363  func utf8CharLen(b byte) int {
 364  	switch {
 365  	case b < 0x80:
 366  		return 1
 367  	case b < 0xE0:
 368  		return 2
 369  	case b < 0xF0:
 370  		return 3
 371  	default:
 372  		return 4
 373  	}
 374  }
 375  
 376  // pmiScore computes pointwise mutual information (in bits / log2).
 377  // pmi = log2(P(x,y) / (P(x)*P(y))) = log2(cnt*N / freqX / freqY)
 378  func pmiScore(cnt, freqX, freqY, N uint32) float64 {
 379  	if freqX == 0 || freqY == 0 || N == 0 {
 380  		return 0
 381  	}
 382  	// Use log2 approximation via integer arithmetic converted to float.
 383  	num := float64(cnt) * float64(N)
 384  	den := float64(freqX) * float64(freqY)
 385  	if den == 0 {
 386  		return 0
 387  	}
 388  	return log2(num / den)
 389  }
 390  
 391  // log2 computes natural-log-based log2 using ln(x)/ln(2).
 392  func log2(x float64) float64 {
 393  	if x <= 0 {
 394  		return -999
 395  	}
 396  	// Integer-based approximation: count leading bits.
 397  	// For the PMI use case (x often in 1-1000 range), this is accurate enough.
 398  	// Use the series ln(x) ≈ 2*arctanh((x-1)/(x+1)) for x near 1.
 399  	// Better: implement as bit manipulation + correction.
 400  	// For simplicity, compute using a precomputed table of powers of 2.
 401  	result := 0.0
 402  	for x >= 2.0 {
 403  		x /= 2.0
 404  		result += 1.0
 405  	}
 406  	for x < 1.0 {
 407  		x *= 2.0
 408  		result -= 1.0
 409  	}
 410  	// x is now in [1, 2). Use linear approximation: log2(x) ≈ x - 1.
 411  	result += x - 1.0
 412  	return result
 413  }
 414  
 415  // joinPairCtx encodes (coord uint64, enWord, jaWord) as a cooc map key.
 416  // coord stored as 8 LE bytes. EN tokens are ASCII (≥0x61) so no ambiguity.
 417  func joinPairCtx(ctx uint64, en, ja string) string {
 418  	return string([]byte{
 419  		byte(ctx), byte(ctx >> 8), byte(ctx >> 16), byte(ctx >> 24),
 420  		byte(ctx >> 32), byte(ctx >> 40), byte(ctx >> 48), byte(ctx >> 56),
 421  	}) | en | "\x00" | ja
 422  }
 423  
 424  // splitPairCtx decodes a key produced by joinPairCtx.
 425  func splitPairCtx(pair string) (en string, ctx uint64, ja string) {
 426  	if len(pair) < 8 {
 427  		return "", 0, ""
 428  	}
 429  	ctx = uint64(pair[0]) | uint64(pair[1])<<8 | uint64(pair[2])<<16 | uint64(pair[3])<<24 |
 430  		uint64(pair[4])<<32 | uint64(pair[5])<<40 | uint64(pair[6])<<48 | uint64(pair[7])<<56
 431  	rest := pair[8:]
 432  	for i := 0; i < len(rest); i++ {
 433  		if rest[i] == 0 {
 434  			return rest[:i], ctx, rest[i+1:]
 435  		}
 436  	}
 437  	return rest, ctx, ""
 438  }
 439  
 440  // insertCooccurPair inserts a corpus-derived EN-JA translation link.
 441  // enCtx is the packed 3-position POS window. For ctx=0 (baseline) the
 442  // logic is symmetric: new EN record points to JA and vice versa.
 443  // For ctx≠0 (context entries), only the EN record is created pointing
 444  // to the existing JA record — JA links are not modified.
 445  // Returns true if something was inserted.
 446  func insertCooccurPair(db *DB, enWord string, enCtx uint64, jaWord string) bool {
 447  	enKey := transdb.MakeKey(transdb.LangEN, enCtx, enWord)
 448  	jaKey := transdb.MakeKey(transdb.LangJA, 0, jaWord)
 449  
 450  	// JA must exist in lattice.
 451  	jaRI := lattice.NullRec
 452  	for _, b := range transdb.ActiveBranches {
 453  		if ri := db.Tree.LookupRecIdx(lattice.Branch(b), jaKey); ri != lattice.NullRec {
 454  			jaRI = ri
 455  			break
 456  		}
 457  	}
 458  	if jaRI == lattice.NullRec {
 459  		return false
 460  	}
 461  
 462  	// EN record at this context key must not already exist.
 463  	for _, b := range transdb.ActiveBranches {
 464  		if db.Tree.LookupRecIdx(lattice.Branch(b), enKey) != lattice.NullRec {
 465  			return false
 466  		}
 467  	}
 468  
 469  	// Create EN record pointing to JA.
 470  	jaRec := db.Tree.GetRecord(jaRI)
 471  	if jaRec == nil {
 472  		return false
 473  	}
 474  	branch := lattice.Branch(transdb.POSFromBranch(jaRec.Branch))
 475  	var enRec lattice.Record
 476  	transdb.SetFormOnRecord(&enRec, enWord, &db.StringPool)
 477  	enRec.Branch = uint8(branch)
 478  	enRec.Link[0] = jaRI
 479  	db.Tree.InsertRec(branch, enKey, enRec)
 480  
 481  	// For ctx=0 new words: also wire JA→EN if JA has no primary EN link yet.
 482  	if enCtx == 0 {
 483  		jaRec = db.Tree.GetRecord(jaRI) // re-fetch after potential realloc
 484  		if jaRec != nil && jaRec.Link[0] == lattice.NullRec {
 485  			newEnRI := lattice.NullRec
 486  			for _, b := range transdb.ActiveBranches {
 487  				if ri := db.Tree.LookupRecIdx(lattice.Branch(b), enKey); ri != lattice.NullRec {
 488  					newEnRI = ri
 489  					break
 490  				}
 491  			}
 492  			if newEnRI != lattice.NullRec {
 493  				jaRec.Link[0] = newEnRI
 494  			}
 495  		}
 496  	}
 497  	return true
 498  }
 499