classify.mx raw

   1  package iskra
   2  
   3  import "bytes"
   4  
   5  type IRClass uint8
   6  
   7  const (
   8  	ClassUnknown    IRClass = 0
   9  	ClassByteEqual  IRClass = 1
  10  	ClassBoundaryEq IRClass = 2
  11  	ClassBoundaryDiv IRClass = 3
  12  	ClassPerfDiv    IRClass = 4
  13  )
  14  
  15  type ClassifyResult struct {
  16  	Class     IRClass
  17  	NormMatch bool
  18  	InstrA    int32
  19  	InstrB    int32
  20  	BlocksA   int32
  21  	BlocksB   int32
  22  	CallsA    int32
  23  	CallsB    int32
  24  	Detail    string
  25  }
  26  
  27  func (c ClassifyResult) ClassName() string {
  28  	switch c.Class {
  29  	case ClassByteEqual:
  30  		return "byte-equal"
  31  	case ClassBoundaryEq:
  32  		return "boundary-eq"
  33  	case ClassBoundaryDiv:
  34  		return "boundary-div"
  35  	case ClassPerfDiv:
  36  		return "perf-div"
  37  	default:
  38  		return "unknown"
  39  	}
  40  }
  41  
  42  // NormalizeIR strips debug metadata, renumbers SSA registers canonically,
  43  // and removes optimization hint flags that don't change semantics.
  44  func NormalizeIR(ir []byte) []byte {
  45  	lines := bytes.Split(ir, []byte("\n"))
  46  	var out []byte
  47  	ssaMap := map[string]string{}
  48  	ssaCounter := 0
  49  
  50  	for _, line := range lines {
  51  		trimmed := bytes.TrimSpace(line)
  52  
  53  		// Strip debug metadata lines (! = ..., #dbg_...)
  54  		if isDebugLine(trimmed) {
  55  			continue
  56  		}
  57  
  58  		// Strip debug refs inline
  59  		line = stripDebugRefs(line)
  60  
  61  		// Strip tbaa metadata refs
  62  		line = stripMetaRef(line, []byte("!tbaa !"))
  63  		line = stripMetaRef(line, []byte("!range !"))
  64  		line = stripMetaRef(line, []byte("!noalias !"))
  65  		line = stripMetaRef(line, []byte("!alias.scope !"))
  66  		line = stripMetaRef(line, []byte("!nonnull !"))
  67  		line = stripMetaRef(line, []byte("!dereferenceable !"))
  68  
  69  		// Strip optimization hint flags
  70  		line = stripFlag(line, []byte(" nsw"))
  71  		line = stripFlag(line, []byte(" nuw"))
  72  		line = stripFlag(line, []byte(" exact"))
  73  		line = stripFlag(line, []byte(" nnan"))
  74  		line = stripFlag(line, []byte(" ninf"))
  75  		line = stripFlag(line, []byte(" nsz"))
  76  		line = stripFlag(line, []byte(" arcp"))
  77  		line = stripFlag(line, []byte(" contract"))
  78  		line = stripFlag(line, []byte(" reassoc"))
  79  		line = stripFlag(line, []byte(" afn"))
  80  
  81  		// Renumber SSA registers: %N -> %_N (canonical)
  82  		line = renumberSSA(line, ssaMap, &ssaCounter)
  83  
  84  		// Normalize alignment: strip "align N"
  85  		line = stripAlignAnnotation(line)
  86  
  87  		out = append(out, bytes.TrimRight(line, " \t")...)
  88  		out = append(out, '\n')
  89  	}
  90  	return out
  91  }
  92  
  93  func stripMetaRef(line []byte, prefix []byte) []byte {
  94  	for {
  95  		idx := bytes.Index(line, prefix)
  96  		if idx < 0 {
  97  			return line
  98  		}
  99  		end := idx + len(prefix)
 100  		for end < len(line) && line[end] >= '0' && line[end] <= '9' {
 101  			end++
 102  		}
 103  		start := idx
 104  		if start > 0 && line[start-1] == ' ' {
 105  			start--
 106  		}
 107  		if start > 0 && line[start-1] == ',' {
 108  			start--
 109  		}
 110  		var rebuilt []byte
 111  		rebuilt = append(rebuilt, line[:start]...)
 112  		rebuilt = append(rebuilt, line[end:]...)
 113  		line = rebuilt
 114  	}
 115  }
 116  
 117  func stripFlag(line []byte, flag []byte) []byte {
 118  	for {
 119  		idx := bytes.Index(line, flag)
 120  		if idx < 0 {
 121  			return line
 122  		}
 123  		after := idx + len(flag)
 124  		if after < len(line) && isWordByte(line[after]) {
 125  			idx++
 126  			continue
 127  		}
 128  		if idx > 0 && isWordByte(line[idx-1]) {
 129  			idx++
 130  			line = line // can't skip, find next occurrence
 131  			next := bytes.Index(line[idx:], flag)
 132  			if next < 0 {
 133  				return line
 134  			}
 135  			idx = idx + next
 136  			after = idx + len(flag)
 137  			if after < len(line) && isWordByte(line[after]) {
 138  				continue
 139  			}
 140  		}
 141  		var rebuilt []byte
 142  		rebuilt = append(rebuilt, line[:idx]...)
 143  		rebuilt = append(rebuilt, line[after:]...)
 144  		line = rebuilt
 145  	}
 146  }
 147  
 148  func stripAlignAnnotation(line []byte) []byte {
 149  	for {
 150  		idx := bytes.Index(line, []byte(", align "))
 151  		if idx < 0 {
 152  			idx = bytes.Index(line, []byte(" align "))
 153  			if idx < 0 {
 154  				return line
 155  			}
 156  		}
 157  		start := idx
 158  		end := idx
 159  		if line[end] == ',' {
 160  			end++ // skip comma
 161  		}
 162  		for end < len(line) && line[end] == ' ' {
 163  			end++
 164  		}
 165  		if end+5 < len(line) && string(line[end:end+5]) == "align" {
 166  			end += 5
 167  		} else {
 168  			return line
 169  		}
 170  		for end < len(line) && line[end] == ' ' {
 171  			end++
 172  		}
 173  		for end < len(line) && line[end] >= '0' && line[end] <= '9' {
 174  			end++
 175  		}
 176  		var rebuilt []byte
 177  		rebuilt = append(rebuilt, line[:start]...)
 178  		rebuilt = append(rebuilt, line[end:]...)
 179  		line = rebuilt
 180  	}
 181  }
 182  
 183  func renumberSSA(line []byte, ssaMap map[string]string, counter *int) []byte {
 184  	var out []byte
 185  	i := 0
 186  	for i < len(line) {
 187  		if line[i] == '%' && i+1 < len(line) && isDigitByte(line[i+1]) {
 188  			start := i
 189  			i++
 190  			for i < len(line) && isDigitByte(line[i]) {
 191  				i++
 192  			}
 193  			orig := string(line[start:i])
 194  			mapped, ok := ssaMap[orig]
 195  			if !ok {
 196  				mapped = "%" | intToStr(*counter)
 197  				ssaMap[orig] = mapped
 198  				*counter++
 199  			}
 200  			out = append(out, mapped...)
 201  		} else {
 202  			out = append(out, line[i])
 203  			i++
 204  		}
 205  	}
 206  	return out
 207  }
 208  
 209  func intToStr(n int) string {
 210  	if n == 0 {
 211  		return "0"
 212  	}
 213  	var buf [10]byte
 214  	i := 9
 215  	for n > 0 {
 216  		buf[i] = byte('0' + n%10)
 217  		i--
 218  		n /= 10
 219  	}
 220  	return string(buf[i+1:])
 221  }
 222  
 223  // IRProfile extracts structural features from LLVM IR for cost comparison.
 224  type IRProfile struct {
 225  	Instructions int32
 226  	Blocks       int32
 227  	Calls        int32
 228  	Phis         int32
 229  	Loads        int32
 230  	Stores       int32
 231  	Branches     int32
 232  }
 233  
 234  func ProfileIR(ir []byte) IRProfile {
 235  	p := IRProfile{}
 236  	lines := bytes.Split(ir, []byte("\n"))
 237  	inFunc := false
 238  	for _, line := range lines {
 239  		trimmed := bytes.TrimSpace(line)
 240  		if len(trimmed) == 0 {
 241  			continue
 242  		}
 243  		if bytes.HasPrefix(trimmed, []byte("define ")) {
 244  			inFunc = true
 245  			continue
 246  		}
 247  		if len(trimmed) == 1 && trimmed[0] == '}' {
 248  			inFunc = false
 249  			continue
 250  		}
 251  		if !inFunc {
 252  			continue
 253  		}
 254  		// Basic block label
 255  		if len(trimmed) > 0 && trimmed[len(trimmed)-1] == ':' && !bytes.HasPrefix(trimmed, []byte(" ")) {
 256  			p.Blocks++
 257  			continue
 258  		}
 259  		if isDebugLine(trimmed) {
 260  			continue
 261  		}
 262  		p.Instructions++
 263  		if bytes.Contains(trimmed, []byte(" call ")) || bytes.Contains(trimmed, []byte(" invoke ")) {
 264  			p.Calls++
 265  		}
 266  		if bytes.HasPrefix(trimmed, []byte("call ")) || bytes.HasPrefix(trimmed, []byte("invoke ")) {
 267  			p.Calls++
 268  		}
 269  		if bytes.Contains(trimmed, []byte(" = phi ")) {
 270  			p.Phis++
 271  		}
 272  		if bytes.Contains(trimmed, []byte(" = load ")) {
 273  			p.Loads++
 274  		}
 275  		if bytes.HasPrefix(trimmed, []byte("store ")) {
 276  			p.Stores++
 277  		}
 278  		if bytes.HasPrefix(trimmed, []byte("br ")) {
 279  			p.Branches++
 280  		}
 281  	}
 282  	return p
 283  }
 284  
 285  // ClassifyIRPair performs Phase A (normalize) and Phase B (structural diff)
 286  // classification of two IR fragments.
 287  func ClassifyIRPair(resultIR, actualIR []byte) ClassifyResult {
 288  	cr := ClassifyResult{}
 289  
 290  	// Phase A: normalize and compare
 291  	normResult := NormalizeIR(resultIR)
 292  	normActual := NormalizeIR(actualIR)
 293  
 294  	if bytes.Equal(normResult, normActual) {
 295  		cr.Class = ClassBoundaryEq
 296  		cr.NormMatch = true
 297  		cr.Detail = "matched after normalization"
 298  		return cr
 299  	}
 300  
 301  	// Phase A+: strip nil-check/safety blocks and compare
 302  	strippedResult := StripSafetyBlocks(normResult)
 303  	strippedActual := StripSafetyBlocks(normActual)
 304  	if bytes.Equal(strippedResult, strippedActual) {
 305  		cr.Class = ClassBoundaryEq
 306  		cr.NormMatch = true
 307  		cr.Detail = "matched after nil-check block stripping"
 308  		return cr
 309  	}
 310  
 311  	// Phase B: structural comparison
 312  	profA := ProfileIR(resultIR)
 313  	profB := ProfileIR(actualIR)
 314  	cr.InstrA = profA.Instructions
 315  	cr.InstrB = profB.Instructions
 316  	cr.BlocksA = profA.Blocks
 317  	cr.BlocksB = profB.Blocks
 318  	cr.CallsA = profA.Calls
 319  	cr.CallsB = profB.Calls
 320  
 321  	// Phase B determines the nature of the divergence for diagnostics
 322  	if profA.Calls != profB.Calls {
 323  		cr.Class = ClassBoundaryDiv
 324  		cr.Detail = "call count differs (type/template mismatch)"
 325  		return cr
 326  	}
 327  	if profA.Blocks != profB.Blocks {
 328  		cr.Class = ClassBoundaryDiv
 329  		cr.Detail = "block count differs (structural mismatch)"
 330  		return cr
 331  	}
 332  	instrDelta := profA.Instructions - profB.Instructions
 333  	if instrDelta < 0 {
 334  		instrDelta = -instrDelta
 335  	}
 336  	minInstr := profA.Instructions
 337  	if profB.Instructions < minInstr {
 338  		minInstr = profB.Instructions
 339  	}
 340  	if minInstr > 0 && instrDelta*100/minInstr > 10 {
 341  		cr.Class = ClassBoundaryDiv
 342  		cr.Detail = "instruction count diverges >10% (wrong template)"
 343  		return cr
 344  	}
 345  
 346  	cr.Class = ClassBoundaryDiv
 347  	cr.Detail = "structural divergence after normalization"
 348  	return cr
 349  }
 350  
 351  func StripSafetyBlocks(ir []byte) []byte {
 352  	lines := bytes.Split(ir, []byte("\n"))
 353  	var out []byte
 354  	skip := false
 355  	for _, line := range lines {
 356  		trimmed := bytes.TrimSpace(line)
 357  		// Check if this is a safety-check block label
 358  		// Labels look like "deref.next:     ; preds = %entry" or just "deref.next:"
 359  		if isBlockLabel(trimmed) && isSafetyBlockLabel(trimmed) {
 360  			skip = true
 361  			continue
 362  		}
 363  		// A new non-safety block label ends the skip
 364  		if skip && isBlockLabel(trimmed) {
 365  			skip = false
 366  		}
 367  		// Opening brace of a function also ends skip
 368  		if skip && len(trimmed) > 0 && trimmed[0] == '}' {
 369  			skip = false
 370  		}
 371  		if skip {
 372  			continue
 373  		}
 374  		// Strip safety block references from branch targets and phi nodes
 375  		line = stripSafetyRefs(line)
 376  		out = append(out, line...)
 377  		out = append(out, '\n')
 378  	}
 379  	return out
 380  }
 381  
 382  func isBlockLabel(line []byte) bool {
 383  	// A block label has a colon before any semicolon, and no leading whitespace assignment
 384  	if len(line) == 0 {
 385  		return false
 386  	}
 387  	// Block labels don't start with % or whitespace in the define body context
 388  	if line[0] == '%' || line[0] == ' ' || line[0] == '\t' {
 389  		return false
 390  	}
 391  	// Must contain a colon
 392  	colon := bytes.IndexByte(line, ':')
 393  	return colon > 0
 394  }
 395  
 396  func isSafetyBlockLabel(label []byte) bool {
 397  	// Strip trailing ":"  and any comment like "; preds = ..."
 398  	colon := bytes.IndexByte(label, ':')
 399  	if colon < 0 {
 400  		return false
 401  	}
 402  	name := label[:colon]
 403  	// Known safety-check block name patterns
 404  	prefixes := [][]byte{
 405  		[]byte("deref.next"),
 406  		[]byte("deref.throw"),
 407  		[]byte("gep.next"),
 408  		[]byte("gep.throw"),
 409  		[]byte("store.next"),
 410  		[]byte("store.throw"),
 411  		[]byte("lookup.next"),
 412  		[]byte("lookup.throw"),
 413  		[]byte("slice.next"),
 414  		[]byte("slice.throw"),
 415  	}
 416  	for _, p := range prefixes {
 417  		if bytes.HasPrefix(name, p) {
 418  			return true
 419  		}
 420  	}
 421  	return false
 422  }
 423  
 424  func stripSafetyRefs(line []byte) []byte {
 425  	// Remove references to safety blocks in phi nodes and branch instructions
 426  	// E.g. "[ true, %deref.next ]," or "label %gep.throw"
 427  	safetyPrefixes := [][]byte{
 428  		[]byte("%deref.next"),
 429  		[]byte("%deref.throw"),
 430  		[]byte("%gep.next"),
 431  		[]byte("%gep.throw"),
 432  		[]byte("%store.next"),
 433  		[]byte("%store.throw"),
 434  		[]byte("%lookup.next"),
 435  		[]byte("%lookup.throw"),
 436  		[]byte("%slice.next"),
 437  		[]byte("%slice.throw"),
 438  	}
 439  	for _, sp := range safetyPrefixes {
 440  		for bytes.Contains(line, sp) {
 441  			idx := bytes.Index(line, sp)
 442  			if idx < 0 {
 443  				break
 444  			}
 445  			// Find the enclosing context:
 446  			// In phi: "[ val, %block ]," - remove the whole bracket pair
 447  			// In br: "label %block" - remove "label %block"
 448  			bracketStart := -1
 449  			for j := idx - 1; j >= 0; j-- {
 450  				if line[j] == '[' {
 451  					bracketStart = j
 452  					break
 453  				}
 454  				if line[j] == ',' || line[j] == ';' {
 455  					break
 456  				}
 457  			}
 458  			if bracketStart >= 0 {
 459  				// Phi node entry: remove "[ val, %block ]" including trailing comma
 460  				bracketEnd := bytes.IndexByte(line[idx:], ']')
 461  				if bracketEnd >= 0 {
 462  					end := idx + bracketEnd + 1
 463  					// Skip trailing comma and space
 464  					for end < len(line) && (line[end] == ',' || line[end] == ' ') {
 465  						end++
 466  					}
 467  					// Also strip leading comma and space before bracket
 468  					start := bracketStart
 469  					if start > 0 && line[start-1] == ' ' {
 470  						start--
 471  					}
 472  					if start > 0 && line[start-1] == ',' {
 473  						start--
 474  					}
 475  					var rebuilt []byte
 476  					rebuilt = append(rebuilt, line[:start]...)
 477  					rebuilt = append(rebuilt, line[end:]...)
 478  					line = rebuilt
 479  					continue
 480  				}
 481  			}
 482  			// Branch target: "label %block" - remove
 483  			labelIdx := bytes.LastIndex(line[:idx], []byte("label "))
 484  			if labelIdx >= 0 {
 485  				end := idx + len(sp)
 486  				// Skip digits after the prefix (e.g., %deref.next2)
 487  				for end < len(line) && line[end] >= '0' && line[end] <= '9' {
 488  					end++
 489  				}
 490  				start := labelIdx
 491  				if start > 0 && line[start-1] == ' ' {
 492  					start--
 493  				}
 494  				if start > 0 && line[start-1] == ',' {
 495  					start--
 496  				}
 497  				var rebuilt []byte
 498  				rebuilt = append(rebuilt, line[:start]...)
 499  				rebuilt = append(rebuilt, line[end:]...)
 500  				line = rebuilt
 501  				continue
 502  			}
 503  			// Can't find context, just skip past this occurrence
 504  			break
 505  		}
 506  	}
 507  	return line
 508  }
 509