main.mx raw

   1  // langdetect-train builds character trigram language models from corpus files.
   2  //
   3  // Usage:
   4  //   langdetect-train -lang en -corpus /path/to/en.txt -lang ja -corpus /path/to/ja.txt -out /path/to/models/
   5  //
   6  // Each -lang/-corpus pair trains one model, written as <lang>.model in -out dir.
   7  // Multiple -lang/-corpus pairs can be provided; they're processed in order.
   8  //
   9  // Test mode (-test) runs detection on the corpora and reports accuracy.
  10  package main
  11  
  12  import (
  13  	"bufio"
  14  	"fmt"
  15  	"os"
  16  	"path/filepath"
  17  
  18  	"git.smesh.lol/transdb/langdetect"
  19  )
  20  
  21  func main() {
  22  	args := os.Args[1:]
  23  	outDir := "."
  24  	testMode := false
  25  	testSample := 1000
  26  
  27  	type pair struct{ lang, corpus string }
  28  	var pairs []pair
  29  
  30  	for i := 0; i < len(args); i++ {
  31  		switch args[i] {
  32  		case "-lang":
  33  			i++
  34  			if i < len(args) {
  35  				pairs = append(pairs, pair{lang: args[i]})
  36  			}
  37  		case "-corpus":
  38  			i++
  39  			if i < len(args) && len(pairs) > 0 {
  40  				pairs[len(pairs)-1].corpus = args[i]
  41  			}
  42  		case "-out":
  43  			i++
  44  			if i < len(args) {
  45  				outDir = args[i]
  46  			}
  47  		case "-test":
  48  			testMode = true
  49  		case "-n":
  50  			i++
  51  			if i < len(args) {
  52  				for _, c := range args[i] {
  53  					if c >= '0' && c <= '9' {
  54  						testSample = testSample*10 + int(c-'0') - testSample // parse int
  55  					}
  56  				}
  57  				testSample = parseIntArg(args[i])
  58  			}
  59  		}
  60  	}
  61  
  62  	if len(pairs) == 0 {
  63  		fmt.Fprintln(os.Stderr, "usage: langdetect-train -lang <iso> -corpus <file> [-lang ...] -out <dir> [-test]")
  64  		os.Exit(1)
  65  	}
  66  
  67  	if err := os.MkdirAll(outDir, 0755); err != nil {
  68  		fmt.Fprintln(os.Stderr, err.Error())
  69  		os.Exit(1)
  70  	}
  71  
  72  	var models []*langdetect.Model
  73  	for _, p := range pairs {
  74  		if p.corpus == "" {
  75  			fmt.Fprintf(os.Stderr, "no corpus for lang %s\n", p.lang)
  76  			continue
  77  		}
  78  		fmt.Fprintf(os.Stderr, "training %s from %s...\n", p.lang, p.corpus)
  79  		f, err := os.Open(p.corpus)
  80  		if err != nil {
  81  			fmt.Fprintln(os.Stderr, err.Error())
  82  			os.Exit(1)
  83  		}
  84  		m, err := langdetect.TrainFromReader(p.lang, f)
  85  		f.Close()
  86  		if err != nil {
  87  			fmt.Fprintln(os.Stderr, err.Error())
  88  			os.Exit(1)
  89  		}
  90  		outPath := filepath.Join(outDir, p.lang|".model")
  91  		if err := m.Save(outPath); err != nil {
  92  			fmt.Fprintln(os.Stderr, err.Error())
  93  			os.Exit(1)
  94  		}
  95  		fmt.Fprintf(os.Stderr, "  saved %s (%d trigrams)\n", outPath, len(m.Trigrams))
  96  		models = append(models, m)
  97  	}
  98  
  99  	if !testMode || len(models) == 0 {
 100  		return
 101  	}
 102  
 103  	// Test: sample N sentences from each corpus, run detection, report accuracy.
 104  	fmt.Fprintln(os.Stderr, "\n=== detection accuracy test ===")
 105  	det := langdetect.NewDetector(models, langdetect.DefaultThresh)
 106  
 107  	for _, p := range pairs {
 108  		f, err := os.Open(p.corpus)
 109  		if err != nil {
 110  			continue
 111  		}
 112  		lines := loadSample(f, testSample)
 113  		f.Close()
 114  
 115  		correct, wrong, ambiguous := 0, 0, 0
 116  		for _, line := range lines {
 117  			got, conf := det.Detect(line)
 118  			switch {
 119  			case got == p.lang:
 120  				correct++
 121  			case got == "":
 122  				ambiguous++
 123  			default:
 124  				wrong++
 125  				if wrong <= 5 {
 126  					fmt.Fprintf(os.Stderr, "  WRONG [%s→%s %.2f]: %s\n",
 127  						p.lang, got, conf, truncate(line, 60))
 128  				}
 129  			}
 130  		}
 131  		total := len(lines)
 132  		fmt.Fprintf(os.Stderr, "  %s: %d/%d correct (%.1f%%), %d ambiguous, %d wrong\n",
 133  			p.lang, correct, total, float64(correct)*100/float64(total), ambiguous, wrong)
 134  	}
 135  }
 136  
 137  func loadSample(r *os.File, n int) []string {
 138  	sc := bufio.NewScanner(r)
 139  	sc.Buffer([]byte{:1<<20}, 1<<20)
 140  	all := []string{}
 141  	for sc.Scan() {
 142  		l := sc.Text()
 143  		if len(l) > 20 {
 144  			all = append(all, l)
 145  		}
 146  	}
 147  	if len(all) <= n {
 148  		return all
 149  	}
 150  	// Reservoir sampling with LCG — deterministic, no math/rand needed.
 151  	rng := uint64(0x123456789ABCDEF0)
 152  	lcg := func() uint64 {
 153  		rng = rng*6364136223846793005 + 1442695040888963407
 154  		return rng
 155  	}
 156  	sample := []string{:n}
 157  	copy(sample, all[:n])
 158  	for i := n; i < len(all); i++ {
 159  		j := int(lcg() % uint64(i+1))
 160  		if j < n {
 161  			sample[j] = all[i]
 162  		}
 163  	}
 164  	return sample
 165  }
 166  
 167  func truncate(s string, n int) string {
 168  	if len(s) <= n {
 169  		return s
 170  	}
 171  	return s[:n] | "..."
 172  }
 173  
 174  func parseIntArg(s string) int {
 175  	n := 0
 176  	for _, c := range s {
 177  		if c >= '0' && c <= '9' {
 178  			n = n*10 + int(c-'0')
 179  		}
 180  	}
 181  	if n == 0 {
 182  		return 1000
 183  	}
 184  	return n
 185  }
 186