model.mx raw

   1  // Package langdetect implements corpus-trained character trigram language
   2  // identification. Models are built from raw text corpora and stored as
   3  // compact TSV files. Detection uses cosine similarity between the input's
   4  // trigram profile and each stored model; the highest-scoring language wins
   5  // if it clears the confidence threshold.
   6  //
   7  // Trigram choice: character-level (not byte-level) trigrams capture
   8  // orthographic patterns that are language-specific regardless of script.
   9  // 300 trigrams per model covers >95% of the frequency mass for any language.
  10  package langdetect
  11  
  12  import (
  13  	"bufio"
  14  	"fmt"
  15  	"io"
  16  	"os"
  17  	"strconv"
  18  )
  19  
  20  const (
  21  	modelVersion   = "v1"
  22  	maxTrigrams    = 300
  23  	minInputChars  = 5   // minimum codepoints to attempt detection (3 gives one trigram)
  24  	DefaultThresh  = 0.90
  25  )
  26  
  27  // Model holds a trigram frequency profile for one language.
  28  type Model struct {
  29  	Lang      string            // ISO 639-1 code
  30  	Trigrams  map[string]float64 // trigram → normalized frequency
  31  }
  32  
  33  // TrainFromReader builds a Model by reading text from r.
  34  // Counts all 3-codepoint trigrams, keeps the top maxTrigrams by frequency,
  35  // normalizes so values sum to 1.0.
  36  func TrainFromReader(lang string, r io.Reader) (*Model, error) {
  37  	counts := map[string]uint64{}
  38  	total := uint64(0)
  39  
  40  	sc := bufio.NewScanner(r)
  41  	sc.Buffer([]byte{:1<<20}, 1<<20)
  42  	for sc.Scan() {
  43  		line := sc.Text()
  44  		extractTrigrams(line, func(t string) {
  45  			counts[t]++
  46  			total++
  47  		})
  48  	}
  49  	if err := sc.Err(); err != nil {
  50  		return nil, err
  51  	}
  52  	if total == 0 {
  53  		return nil, fmt.Errorf("langdetect: no trigrams found in corpus for %s", lang)
  54  	}
  55  
  56  	// Sort by frequency descending, keep top maxTrigrams.
  57  	type kv struct {
  58  		k string
  59  		v uint64
  60  	}
  61  	var pairs []kv
  62  	for k, v := range counts {
  63  		pairs = append(pairs, kv{k, v})
  64  	}
  65  	// Sort descending by count (insertion sort — stable, works for any size).
  66  	for i := 1; i < len(pairs); i++ {
  67  		key := pairs[i]
  68  		j := i - 1
  69  		for j >= 0 && pairs[j].v < key.v {
  70  			pairs[j+1] = pairs[j]
  71  			j--
  72  		}
  73  		pairs[j+1] = key
  74  	}
  75  	if len(pairs) > maxTrigrams {
  76  		pairs = pairs[:maxTrigrams]
  77  	}
  78  
  79  	// Compute sum of kept frequencies for normalization.
  80  	sum := uint64(0)
  81  	for _, p := range pairs {
  82  		sum += p.v
  83  	}
  84  
  85  	m := &Model{
  86  		Lang:     lang,
  87  		Trigrams: map[string]float64{},
  88  	}
  89  	for _, p := range pairs {
  90  		m.Trigrams[p.k] = float64(p.v) / float64(sum)
  91  	}
  92  	return m, nil
  93  }
  94  
  95  // Save writes the model to path in TSV format:
  96  //   # langmodel v1
  97  //   # lang: <iso>
  98  //   <trigram>\t<frequency>
  99  //   ...
 100  func (m *Model) Save(path string) error {
 101  	f, err := os.Create(path)
 102  	if err != nil {
 103  		return err
 104  	}
 105  	defer f.Close()
 106  	bw := bufio.NewWriter(f)
 107  	fmt.Fprintf(bw, "# langmodel %s\n# lang: %s\n", modelVersion, m.Lang)
 108  	// Sort for deterministic output.
 109  	type kv struct{ k string; v float64 }
 110  	var pairs []kv
 111  	for k, v := range m.Trigrams {
 112  		pairs = append(pairs, kv{k, v})
 113  	}
 114  	for i := 1; i < len(pairs); i++ {
 115  		key := pairs[i]
 116  		j := i - 1
 117  		for j >= 0 && pairs[j].v < key.v {
 118  			pairs[j+1] = pairs[j]
 119  			j--
 120  		}
 121  		pairs[j+1] = key
 122  	}
 123  	for _, p := range pairs {
 124  		fmt.Fprintf(bw, "%s\t%.8f\n", p.k, p.v)
 125  	}
 126  	return bw.Flush()
 127  }
 128  
 129  // Load reads a model from a TSV file produced by Save.
 130  func Load(path string) (*Model, error) {
 131  	f, err := os.Open(path)
 132  	if err != nil {
 133  		return nil, err
 134  	}
 135  	defer f.Close()
 136  
 137  	m := &Model{Trigrams: map[string]float64{}}
 138  	sc := bufio.NewScanner(f)
 139  	for sc.Scan() {
 140  		line := sc.Text()
 141  		const langPrefix = "# lang: "
 142  		if len(line) >= len(langPrefix) && line[:len(langPrefix)] == langPrefix {
 143  			m.Lang = line[len(langPrefix):]
 144  			continue
 145  		}
 146  		if len(line) > 0 && line[0] == '#' {
 147  			continue
 148  		}
 149  		tab := -1
 150  		for i := 0; i < len(line); i++ {
 151  			if line[i] == '\t' {
 152  				tab = i
 153  				break
 154  			}
 155  		}
 156  		if tab < 0 {
 157  			continue
 158  		}
 159  		freq, err := strconv.ParseFloat(line[tab+1:], 64)
 160  		if err != nil {
 161  			continue
 162  		}
 163  		m.Trigrams[line[:tab]] = freq
 164  	}
 165  	if m.Lang == "" {
 166  		return nil, fmt.Errorf("langdetect: missing lang header in %s", path)
 167  	}
 168  	return m, sc.Err()
 169  }
 170  
 171  // Detector holds a set of language models and performs identification.
 172  type Detector struct {
 173  	Models    []*Model
 174  	Threshold float64 // minimum cosine similarity to report a match (default 0.90)
 175  }
 176  
 177  // NewDetector creates a Detector with the given models and threshold.
 178  func NewDetector(models []*Model, threshold float64) *Detector {
 179  	if threshold <= 0 {
 180  		threshold = DefaultThresh
 181  	}
 182  	return &Detector{Models: models, Threshold: threshold}
 183  }
 184  
 185  // Detect returns the ISO language code and confidence [0,1] for text.
 186  // Returns ("", 0) if the text is too short or no model clears the threshold.
 187  func (d *Detector) Detect(text string) (lang string, confidence float64) {
 188  	// Fast path: count codepoints.
 189  	ncp := 0
 190  	for i := 0; i < len(text); {
 191  		i += utf8CharLen(text[i])
 192  		ncp++
 193  	}
 194  	if ncp < minInputChars {
 195  		return "", 0
 196  	}
 197  
 198  	// Build input trigram profile.
 199  	inputCounts := map[string]uint64{}
 200  	inputTotal := uint64(0)
 201  	extractTrigrams(text, func(t string) {
 202  		inputCounts[t]++
 203  		inputTotal++
 204  	})
 205  	if inputTotal == 0 {
 206  		return "", 0
 207  	}
 208  
 209  	// Compute cosine similarity against each model.
 210  	// cosine(A, B) = dot(A, B) / (|A| * |B|)
 211  	// Since model vectors are already normalized (sum=1 approximates L1 norm),
 212  	// we use dot product directly as a proxy — good enough for top-K trigrams.
 213  	// Score each model by dot product of input trigram profile vs model.
 214  	// For disjoint scripts (EN/JA) the wrong-language score is near zero.
 215  	type scored struct {
 216  		lang  string
 217  		score float64
 218  	}
 219  	var scores []scored
 220  	scoreSum := 0.0
 221  	for _, m := range d.Models {
 222  		dot := 0.0
 223  		for t, modelFreq := range m.Trigrams {
 224  			if cnt, ok := inputCounts[t]; ok {
 225  				dot += (float64(cnt) / float64(inputTotal)) * modelFreq
 226  			}
 227  		}
 228  		scores = append(scores, scored{m.Lang, dot})
 229  		scoreSum += dot
 230  	}
 231  
 232  	// Find best.
 233  	bestIdx := 0
 234  	for i, s := range scores {
 235  		if s.score > scores[bestIdx].score {
 236  			bestIdx = i
 237  		}
 238  	}
 239  
 240  	// Relative confidence: best / total. Near 1.0 when one language
 241  	// dominates (disjoint scripts); near 1/n when ambiguous.
 242  	if scoreSum == 0 {
 243  		return "", 0
 244  	}
 245  	confidence = scores[bestIdx].score / scoreSum
 246  	if confidence < d.Threshold {
 247  		return "", confidence
 248  	}
 249  	return scores[bestIdx].lang, confidence
 250  }
 251  
 252  // extractTrigrams calls fn for every 3-codepoint sequence in text.
 253  // Uses byte-offset iteration (not []rune) for Moxie compatibility.
 254  func extractTrigrams(text string, fn func(string)) {
 255  	// Build codepoint byte offsets.
 256  	offsets := []int{:0:len(text)/3+1}
 257  	i := 0
 258  	for i < len(text) {
 259  		offsets = append(offsets, i)
 260  		i += utf8CharLen(text[i])
 261  	}
 262  	offsets = append(offsets, len(text))
 263  
 264  	n := len(offsets) - 1
 265  	for start := 0; start+3 <= n; start++ {
 266  		// Copy bytes: trigram may be a slice of a scanner buffer that
 267  		// gets reused on the next Scan() call. Map keys must own their bytes.
 268  		raw := text[offsets[start]:offsets[start+3]]
 269  		t := string(append([]byte(nil), raw...))
 270  		fn(t)
 271  	}
 272  }
 273  
 274  // utf8CharLen returns byte length of the UTF-8 codepoint starting at b.
 275  func utf8CharLen(b byte) int {
 276  	switch {
 277  	case b < 0x80:
 278  		return 1
 279  	case b < 0xE0:
 280  		return 2
 281  	case b < 0xF0:
 282  		return 3
 283  	default:
 284  		return 4
 285  	}
 286  }
 287