classify.mx raw
1 package iskra
2
3 import "bytes"
4
5 type IRClass uint8
6
7 const (
8 ClassUnknown IRClass = 0
9 ClassByteEqual IRClass = 1
10 ClassBoundaryEq IRClass = 2
11 ClassBoundaryDiv IRClass = 3
12 ClassPerfDiv IRClass = 4
13 )
14
15 type ClassifyResult struct {
16 Class IRClass
17 NormMatch bool
18 InstrA int32
19 InstrB int32
20 BlocksA int32
21 BlocksB int32
22 CallsA int32
23 CallsB int32
24 Detail string
25 }
26
27 func (c ClassifyResult) ClassName() string {
28 switch c.Class {
29 case ClassByteEqual:
30 return "byte-equal"
31 case ClassBoundaryEq:
32 return "boundary-eq"
33 case ClassBoundaryDiv:
34 return "boundary-div"
35 case ClassPerfDiv:
36 return "perf-div"
37 default:
38 return "unknown"
39 }
40 }
41
42 // NormalizeIR strips debug metadata, renumbers SSA registers canonically,
43 // and removes optimization hint flags that don't change semantics.
44 func NormalizeIR(ir []byte) []byte {
45 lines := bytes.Split(ir, []byte("\n"))
46 var out []byte
47 ssaMap := map[string]string{}
48 ssaCounter := 0
49
50 for _, line := range lines {
51 trimmed := bytes.TrimSpace(line)
52
53 // Strip debug metadata lines (! = ..., #dbg_...)
54 if isDebugLine(trimmed) {
55 continue
56 }
57
58 // Strip debug refs inline
59 line = stripDebugRefs(line)
60
61 // Strip tbaa metadata refs
62 line = stripMetaRef(line, []byte("!tbaa !"))
63 line = stripMetaRef(line, []byte("!range !"))
64 line = stripMetaRef(line, []byte("!noalias !"))
65 line = stripMetaRef(line, []byte("!alias.scope !"))
66 line = stripMetaRef(line, []byte("!nonnull !"))
67 line = stripMetaRef(line, []byte("!dereferenceable !"))
68
69 // Strip optimization hint flags
70 line = stripFlag(line, []byte(" nsw"))
71 line = stripFlag(line, []byte(" nuw"))
72 line = stripFlag(line, []byte(" exact"))
73 line = stripFlag(line, []byte(" nnan"))
74 line = stripFlag(line, []byte(" ninf"))
75 line = stripFlag(line, []byte(" nsz"))
76 line = stripFlag(line, []byte(" arcp"))
77 line = stripFlag(line, []byte(" contract"))
78 line = stripFlag(line, []byte(" reassoc"))
79 line = stripFlag(line, []byte(" afn"))
80
81 // Renumber SSA registers: %N -> %_N (canonical)
82 line = renumberSSA(line, ssaMap, &ssaCounter)
83
84 // Normalize alignment: strip "align N"
85 line = stripAlignAnnotation(line)
86
87 out = append(out, bytes.TrimRight(line, " \t")...)
88 out = append(out, '\n')
89 }
90 return out
91 }
92
93 func stripMetaRef(line []byte, prefix []byte) []byte {
94 for {
95 idx := bytes.Index(line, prefix)
96 if idx < 0 {
97 return line
98 }
99 end := idx + len(prefix)
100 for end < len(line) && line[end] >= '0' && line[end] <= '9' {
101 end++
102 }
103 start := idx
104 if start > 0 && line[start-1] == ' ' {
105 start--
106 }
107 if start > 0 && line[start-1] == ',' {
108 start--
109 }
110 var rebuilt []byte
111 rebuilt = append(rebuilt, line[:start]...)
112 rebuilt = append(rebuilt, line[end:]...)
113 line = rebuilt
114 }
115 }
116
117 func stripFlag(line []byte, flag []byte) []byte {
118 for {
119 idx := bytes.Index(line, flag)
120 if idx < 0 {
121 return line
122 }
123 after := idx + len(flag)
124 if after < len(line) && isWordByte(line[after]) {
125 idx++
126 continue
127 }
128 if idx > 0 && isWordByte(line[idx-1]) {
129 idx++
130 line = line // can't skip, find next occurrence
131 next := bytes.Index(line[idx:], flag)
132 if next < 0 {
133 return line
134 }
135 idx = idx + next
136 after = idx + len(flag)
137 if after < len(line) && isWordByte(line[after]) {
138 continue
139 }
140 }
141 var rebuilt []byte
142 rebuilt = append(rebuilt, line[:idx]...)
143 rebuilt = append(rebuilt, line[after:]...)
144 line = rebuilt
145 }
146 }
147
148 func stripAlignAnnotation(line []byte) []byte {
149 for {
150 idx := bytes.Index(line, []byte(", align "))
151 if idx < 0 {
152 idx = bytes.Index(line, []byte(" align "))
153 if idx < 0 {
154 return line
155 }
156 }
157 start := idx
158 end := idx
159 if line[end] == ',' {
160 end++ // skip comma
161 }
162 for end < len(line) && line[end] == ' ' {
163 end++
164 }
165 if end+5 < len(line) && string(line[end:end+5]) == "align" {
166 end += 5
167 } else {
168 return line
169 }
170 for end < len(line) && line[end] == ' ' {
171 end++
172 }
173 for end < len(line) && line[end] >= '0' && line[end] <= '9' {
174 end++
175 }
176 var rebuilt []byte
177 rebuilt = append(rebuilt, line[:start]...)
178 rebuilt = append(rebuilt, line[end:]...)
179 line = rebuilt
180 }
181 }
182
183 func renumberSSA(line []byte, ssaMap map[string]string, counter *int) []byte {
184 var out []byte
185 i := 0
186 for i < len(line) {
187 if line[i] == '%' && i+1 < len(line) && isDigitByte(line[i+1]) {
188 start := i
189 i++
190 for i < len(line) && isDigitByte(line[i]) {
191 i++
192 }
193 orig := string(line[start:i])
194 mapped, ok := ssaMap[orig]
195 if !ok {
196 mapped = "%" | intToStr(*counter)
197 ssaMap[orig] = mapped
198 *counter++
199 }
200 out = append(out, mapped...)
201 } else {
202 out = append(out, line[i])
203 i++
204 }
205 }
206 return out
207 }
208
209 func intToStr(n int) string {
210 if n == 0 {
211 return "0"
212 }
213 var buf [10]byte
214 i := 9
215 for n > 0 {
216 buf[i] = byte('0' + n%10)
217 i--
218 n /= 10
219 }
220 return string(buf[i+1:])
221 }
222
223 // IRProfile extracts structural features from LLVM IR for cost comparison.
224 type IRProfile struct {
225 Instructions int32
226 Blocks int32
227 Calls int32
228 Phis int32
229 Loads int32
230 Stores int32
231 Branches int32
232 }
233
234 func ProfileIR(ir []byte) IRProfile {
235 p := IRProfile{}
236 lines := bytes.Split(ir, []byte("\n"))
237 inFunc := false
238 for _, line := range lines {
239 trimmed := bytes.TrimSpace(line)
240 if len(trimmed) == 0 {
241 continue
242 }
243 if bytes.HasPrefix(trimmed, []byte("define ")) {
244 inFunc = true
245 continue
246 }
247 if len(trimmed) == 1 && trimmed[0] == '}' {
248 inFunc = false
249 continue
250 }
251 if !inFunc {
252 continue
253 }
254 // Basic block label
255 if len(trimmed) > 0 && trimmed[len(trimmed)-1] == ':' && !bytes.HasPrefix(trimmed, []byte(" ")) {
256 p.Blocks++
257 continue
258 }
259 if isDebugLine(trimmed) {
260 continue
261 }
262 p.Instructions++
263 if bytes.Contains(trimmed, []byte(" call ")) || bytes.Contains(trimmed, []byte(" invoke ")) {
264 p.Calls++
265 }
266 if bytes.HasPrefix(trimmed, []byte("call ")) || bytes.HasPrefix(trimmed, []byte("invoke ")) {
267 p.Calls++
268 }
269 if bytes.Contains(trimmed, []byte(" = phi ")) {
270 p.Phis++
271 }
272 if bytes.Contains(trimmed, []byte(" = load ")) {
273 p.Loads++
274 }
275 if bytes.HasPrefix(trimmed, []byte("store ")) {
276 p.Stores++
277 }
278 if bytes.HasPrefix(trimmed, []byte("br ")) {
279 p.Branches++
280 }
281 }
282 return p
283 }
284
285 // ClassifyIRPair performs Phase A (normalize) and Phase B (structural diff)
286 // classification of two IR fragments.
287 func ClassifyIRPair(resultIR, actualIR []byte) ClassifyResult {
288 cr := ClassifyResult{}
289
290 // Phase A: normalize and compare
291 normResult := NormalizeIR(resultIR)
292 normActual := NormalizeIR(actualIR)
293
294 if bytes.Equal(normResult, normActual) {
295 cr.Class = ClassBoundaryEq
296 cr.NormMatch = true
297 cr.Detail = "matched after normalization"
298 return cr
299 }
300
301 // Phase A+: strip nil-check/safety blocks and compare
302 strippedResult := StripSafetyBlocks(normResult)
303 strippedActual := StripSafetyBlocks(normActual)
304 if bytes.Equal(strippedResult, strippedActual) {
305 cr.Class = ClassBoundaryEq
306 cr.NormMatch = true
307 cr.Detail = "matched after nil-check block stripping"
308 return cr
309 }
310
311 // Phase B: structural comparison
312 profA := ProfileIR(resultIR)
313 profB := ProfileIR(actualIR)
314 cr.InstrA = profA.Instructions
315 cr.InstrB = profB.Instructions
316 cr.BlocksA = profA.Blocks
317 cr.BlocksB = profB.Blocks
318 cr.CallsA = profA.Calls
319 cr.CallsB = profB.Calls
320
321 // Phase B determines the nature of the divergence for diagnostics
322 if profA.Calls != profB.Calls {
323 cr.Class = ClassBoundaryDiv
324 cr.Detail = "call count differs (type/template mismatch)"
325 return cr
326 }
327 if profA.Blocks != profB.Blocks {
328 cr.Class = ClassBoundaryDiv
329 cr.Detail = "block count differs (structural mismatch)"
330 return cr
331 }
332 instrDelta := profA.Instructions - profB.Instructions
333 if instrDelta < 0 {
334 instrDelta = -instrDelta
335 }
336 minInstr := profA.Instructions
337 if profB.Instructions < minInstr {
338 minInstr = profB.Instructions
339 }
340 if minInstr > 0 && instrDelta*100/minInstr > 10 {
341 cr.Class = ClassBoundaryDiv
342 cr.Detail = "instruction count diverges >10% (wrong template)"
343 return cr
344 }
345
346 cr.Class = ClassBoundaryDiv
347 cr.Detail = "structural divergence after normalization"
348 return cr
349 }
350
351 func StripSafetyBlocks(ir []byte) []byte {
352 lines := bytes.Split(ir, []byte("\n"))
353 var out []byte
354 skip := false
355 for _, line := range lines {
356 trimmed := bytes.TrimSpace(line)
357 // Check if this is a safety-check block label
358 // Labels look like "deref.next: ; preds = %entry" or just "deref.next:"
359 if isBlockLabel(trimmed) && isSafetyBlockLabel(trimmed) {
360 skip = true
361 continue
362 }
363 // A new non-safety block label ends the skip
364 if skip && isBlockLabel(trimmed) {
365 skip = false
366 }
367 // Opening brace of a function also ends skip
368 if skip && len(trimmed) > 0 && trimmed[0] == '}' {
369 skip = false
370 }
371 if skip {
372 continue
373 }
374 // Strip safety block references from branch targets and phi nodes
375 line = stripSafetyRefs(line)
376 out = append(out, line...)
377 out = append(out, '\n')
378 }
379 return out
380 }
381
382 func isBlockLabel(line []byte) bool {
383 // A block label has a colon before any semicolon, and no leading whitespace assignment
384 if len(line) == 0 {
385 return false
386 }
387 // Block labels don't start with % or whitespace in the define body context
388 if line[0] == '%' || line[0] == ' ' || line[0] == '\t' {
389 return false
390 }
391 // Must contain a colon
392 colon := bytes.IndexByte(line, ':')
393 return colon > 0
394 }
395
396 func isSafetyBlockLabel(label []byte) bool {
397 // Strip trailing ":" and any comment like "; preds = ..."
398 colon := bytes.IndexByte(label, ':')
399 if colon < 0 {
400 return false
401 }
402 name := label[:colon]
403 // Known safety-check block name patterns
404 prefixes := [][]byte{
405 []byte("deref.next"),
406 []byte("deref.throw"),
407 []byte("gep.next"),
408 []byte("gep.throw"),
409 []byte("store.next"),
410 []byte("store.throw"),
411 []byte("lookup.next"),
412 []byte("lookup.throw"),
413 []byte("slice.next"),
414 []byte("slice.throw"),
415 }
416 for _, p := range prefixes {
417 if bytes.HasPrefix(name, p) {
418 return true
419 }
420 }
421 return false
422 }
423
424 func stripSafetyRefs(line []byte) []byte {
425 // Remove references to safety blocks in phi nodes and branch instructions
426 // E.g. "[ true, %deref.next ]," or "label %gep.throw"
427 safetyPrefixes := [][]byte{
428 []byte("%deref.next"),
429 []byte("%deref.throw"),
430 []byte("%gep.next"),
431 []byte("%gep.throw"),
432 []byte("%store.next"),
433 []byte("%store.throw"),
434 []byte("%lookup.next"),
435 []byte("%lookup.throw"),
436 []byte("%slice.next"),
437 []byte("%slice.throw"),
438 }
439 for _, sp := range safetyPrefixes {
440 for bytes.Contains(line, sp) {
441 idx := bytes.Index(line, sp)
442 if idx < 0 {
443 break
444 }
445 // Find the enclosing context:
446 // In phi: "[ val, %block ]," - remove the whole bracket pair
447 // In br: "label %block" - remove "label %block"
448 bracketStart := -1
449 for j := idx - 1; j >= 0; j-- {
450 if line[j] == '[' {
451 bracketStart = j
452 break
453 }
454 if line[j] == ',' || line[j] == ';' {
455 break
456 }
457 }
458 if bracketStart >= 0 {
459 // Phi node entry: remove "[ val, %block ]" including trailing comma
460 bracketEnd := bytes.IndexByte(line[idx:], ']')
461 if bracketEnd >= 0 {
462 end := idx + bracketEnd + 1
463 // Skip trailing comma and space
464 for end < len(line) && (line[end] == ',' || line[end] == ' ') {
465 end++
466 }
467 // Also strip leading comma and space before bracket
468 start := bracketStart
469 if start > 0 && line[start-1] == ' ' {
470 start--
471 }
472 if start > 0 && line[start-1] == ',' {
473 start--
474 }
475 var rebuilt []byte
476 rebuilt = append(rebuilt, line[:start]...)
477 rebuilt = append(rebuilt, line[end:]...)
478 line = rebuilt
479 continue
480 }
481 }
482 // Branch target: "label %block" - remove
483 labelIdx := bytes.LastIndex(line[:idx], []byte("label "))
484 if labelIdx >= 0 {
485 end := idx + len(sp)
486 // Skip digits after the prefix (e.g., %deref.next2)
487 for end < len(line) && line[end] >= '0' && line[end] <= '9' {
488 end++
489 }
490 start := labelIdx
491 if start > 0 && line[start-1] == ' ' {
492 start--
493 }
494 if start > 0 && line[start-1] == ',' {
495 start--
496 }
497 var rebuilt []byte
498 rebuilt = append(rebuilt, line[:start]...)
499 rebuilt = append(rebuilt, line[end:]...)
500 line = rebuilt
501 continue
502 }
503 // Can't find context, just skip past this occurrence
504 break
505 }
506 }
507 return line
508 }
509