main.mx raw
1 // langdetect-train builds character trigram language models from corpus files.
2 //
3 // Usage:
4 // langdetect-train -lang en -corpus /path/to/en.txt -lang ja -corpus /path/to/ja.txt -out /path/to/models/
5 //
6 // Each -lang/-corpus pair trains one model, written as <lang>.model in -out dir.
7 // Multiple -lang/-corpus pairs can be provided; they're processed in order.
8 //
9 // Test mode (-test) runs detection on the corpora and reports accuracy.
10 package main
11
12 import (
13 "bufio"
14 "fmt"
15 "os"
16 "path/filepath"
17
18 "git.smesh.lol/transdb/langdetect"
19 )
20
21 func main() {
22 args := os.Args[1:]
23 outDir := "."
24 testMode := false
25 testSample := 1000
26
27 type pair struct{ lang, corpus string }
28 var pairs []pair
29
30 for i := 0; i < len(args); i++ {
31 switch args[i] {
32 case "-lang":
33 i++
34 if i < len(args) {
35 pairs = append(pairs, pair{lang: args[i]})
36 }
37 case "-corpus":
38 i++
39 if i < len(args) && len(pairs) > 0 {
40 pairs[len(pairs)-1].corpus = args[i]
41 }
42 case "-out":
43 i++
44 if i < len(args) {
45 outDir = args[i]
46 }
47 case "-test":
48 testMode = true
49 case "-n":
50 i++
51 if i < len(args) {
52 for _, c := range args[i] {
53 if c >= '0' && c <= '9' {
54 testSample = testSample*10 + int(c-'0') - testSample // parse int
55 }
56 }
57 testSample = parseIntArg(args[i])
58 }
59 }
60 }
61
62 if len(pairs) == 0 {
63 fmt.Fprintln(os.Stderr, "usage: langdetect-train -lang <iso> -corpus <file> [-lang ...] -out <dir> [-test]")
64 os.Exit(1)
65 }
66
67 if err := os.MkdirAll(outDir, 0755); err != nil {
68 fmt.Fprintln(os.Stderr, err.Error())
69 os.Exit(1)
70 }
71
72 var models []*langdetect.Model
73 for _, p := range pairs {
74 if p.corpus == "" {
75 fmt.Fprintf(os.Stderr, "no corpus for lang %s\n", p.lang)
76 continue
77 }
78 fmt.Fprintf(os.Stderr, "training %s from %s...\n", p.lang, p.corpus)
79 f, err := os.Open(p.corpus)
80 if err != nil {
81 fmt.Fprintln(os.Stderr, err.Error())
82 os.Exit(1)
83 }
84 m, err := langdetect.TrainFromReader(p.lang, f)
85 f.Close()
86 if err != nil {
87 fmt.Fprintln(os.Stderr, err.Error())
88 os.Exit(1)
89 }
90 outPath := filepath.Join(outDir, p.lang|".model")
91 if err := m.Save(outPath); err != nil {
92 fmt.Fprintln(os.Stderr, err.Error())
93 os.Exit(1)
94 }
95 fmt.Fprintf(os.Stderr, " saved %s (%d trigrams)\n", outPath, len(m.Trigrams))
96 models = append(models, m)
97 }
98
99 if !testMode || len(models) == 0 {
100 return
101 }
102
103 // Test: sample N sentences from each corpus, run detection, report accuracy.
104 fmt.Fprintln(os.Stderr, "\n=== detection accuracy test ===")
105 det := langdetect.NewDetector(models, langdetect.DefaultThresh)
106
107 for _, p := range pairs {
108 f, err := os.Open(p.corpus)
109 if err != nil {
110 continue
111 }
112 lines := loadSample(f, testSample)
113 f.Close()
114
115 correct, wrong, ambiguous := 0, 0, 0
116 for _, line := range lines {
117 got, conf := det.Detect(line)
118 switch {
119 case got == p.lang:
120 correct++
121 case got == "":
122 ambiguous++
123 default:
124 wrong++
125 if wrong <= 5 {
126 fmt.Fprintf(os.Stderr, " WRONG [%sā%s %.2f]: %s\n",
127 p.lang, got, conf, truncate(line, 60))
128 }
129 }
130 }
131 total := len(lines)
132 fmt.Fprintf(os.Stderr, " %s: %d/%d correct (%.1f%%), %d ambiguous, %d wrong\n",
133 p.lang, correct, total, float64(correct)*100/float64(total), ambiguous, wrong)
134 }
135 }
136
137 func loadSample(r *os.File, n int) []string {
138 sc := bufio.NewScanner(r)
139 sc.Buffer([]byte{:1<<20}, 1<<20)
140 all := []string{}
141 for sc.Scan() {
142 l := sc.Text()
143 if len(l) > 20 {
144 all = append(all, l)
145 }
146 }
147 if len(all) <= n {
148 return all
149 }
150 // Reservoir sampling with LCG ā deterministic, no math/rand needed.
151 rng := uint64(0x123456789ABCDEF0)
152 lcg := func() uint64 {
153 rng = rng*6364136223846793005 + 1442695040888963407
154 return rng
155 }
156 sample := []string{:n}
157 copy(sample, all[:n])
158 for i := n; i < len(all); i++ {
159 j := int(lcg() % uint64(i+1))
160 if j < n {
161 sample[j] = all[i]
162 }
163 }
164 return sample
165 }
166
167 func truncate(s string, n int) string {
168 if len(s) <= n {
169 return s
170 }
171 return s[:n] | "..."
172 }
173
174 func parseIntArg(s string) int {
175 n := 0
176 for _, c := range s {
177 if c >= '0' && c <= '9' {
178 n = n*10 + int(c-'0')
179 }
180 }
181 if n == 0 {
182 return 1000
183 }
184 return n
185 }
186