verify.mx raw

   1  package iskra
   2  
   3  import "bytes"
   4  
   5  type IRSignature struct {
   6  	ReturnType string
   7  	Params     []IRParam
   8  	FuncName   string
   9  }
  10  
  11  type IRParam struct {
  12  	Type string
  13  	Name string
  14  }
  15  
  16  func ParseIRSignature(defLine []byte) IRSignature {
  17  	sig := IRSignature{}
  18  	// "define hidden i1 @"pkg.Func"(i8 %c, ptr %context) ..."
  19  	rest := defLine
  20  
  21  	// Skip "define" and attributes
  22  	for len(rest) > 0 && rest[0] != '@' {
  23  		// find return type: the token before @
  24  		idx := bytes.IndexByte(rest, '@')
  25  		if idx < 0 {
  26  			return sig
  27  		}
  28  		before := bytes.TrimSpace(rest[:idx])
  29  		// return type is last space-separated token before @
  30  		sp := bytes.LastIndexByte(before, ' ')
  31  		if sp >= 0 {
  32  			sig.ReturnType = string(before[sp+1:])
  33  		} else {
  34  			sig.ReturnType = string(before)
  35  		}
  36  		rest = rest[idx:]
  37  		break
  38  	}
  39  
  40  	// Extract function name
  41  	if len(rest) > 0 && rest[0] == '@' {
  42  		rest = rest[1:]
  43  		if len(rest) > 0 && rest[0] == '"' {
  44  			rest = rest[1:]
  45  			end := bytes.IndexByte(rest, '"')
  46  			if end >= 0 {
  47  				sig.FuncName = string(rest[:end])
  48  				rest = rest[end+1:]
  49  			}
  50  		} else {
  51  			end := bytes.IndexByte(rest, '(')
  52  			if end >= 0 {
  53  				sig.FuncName = string(rest[:end])
  54  				rest = rest[end:]
  55  			}
  56  		}
  57  	}
  58  
  59  	// Parse parameters
  60  	pOpen := bytes.IndexByte(rest, '(')
  61  	if pOpen < 0 {
  62  		return sig
  63  	}
  64  	rest = rest[pOpen+1:]
  65  	pClose := bytes.IndexByte(rest, ')')
  66  	if pClose < 0 {
  67  		return sig
  68  	}
  69  	paramStr := rest[:pClose]
  70  
  71  	parts := bytes.Split(paramStr, []byte(","))
  72  	for _, p := range parts {
  73  		p = bytes.TrimSpace(p)
  74  		if len(p) == 0 {
  75  			continue
  76  		}
  77  		param := IRParam{}
  78  		// Strip attributes like "dereferenceable_or_null(N)"
  79  		p = stripParamAttrs(p)
  80  		sp := bytes.LastIndexByte(p, ' ')
  81  		if sp >= 0 {
  82  			param.Type = string(bytes.TrimSpace(p[:sp]))
  83  			param.Name = string(bytes.TrimSpace(p[sp+1:]))
  84  		} else {
  85  			param.Type = string(p)
  86  		}
  87  		sig.Params = append(sig.Params, param)
  88  	}
  89  	return sig
  90  }
  91  
  92  func stripParamAttrs(p []byte) []byte {
  93  	// Remove "dereferenceable_or_null(N)" and "dereferenceable(N)"
  94  	for {
  95  		idx := bytes.Index(p, []byte("dereferenceable"))
  96  		if idx < 0 {
  97  			break
  98  		}
  99  		end := idx + 15
 100  		if end < len(p) && p[end] == '_' {
 101  			// dereferenceable_or_null
 102  			nEnd := bytes.Index(p[idx:], []byte(")"))
 103  			if nEnd >= 0 {
 104  				end = idx + nEnd + 1
 105  			}
 106  		} else if end < len(p) && p[end] == '(' {
 107  			nEnd := bytes.IndexByte(p[end:], ')')
 108  			if nEnd >= 0 {
 109  				end = end + nEnd + 1
 110  			}
 111  		}
 112  		var rebuilt []byte
 113  		rebuilt = append(rebuilt, p[:idx]...)
 114  		rebuilt = append(rebuilt, p[end:]...)
 115  		p = rebuilt
 116  	}
 117  	// Collapse double spaces
 118  	for bytes.Contains(p, []byte("  ")) {
 119  		p = bytes.Replace(p, []byte("  "), []byte(" "), -1)
 120  	}
 121  	return bytes.TrimSpace(p)
 122  }
 123  
 124  type VerifyCase struct {
 125  	FuncName  string
 126  	Sig       IRSignature
 127  	Testable  bool
 128  	TestKind  string // "byte-exhaustive", "int-boundary", "void-null", "untestable"
 129  	LatticeIR []byte
 130  	RefIR     []byte
 131  }
 132  
 133  func ClassifyTestability(sig IRSignature) (string, bool) {
 134  	if len(sig.Params) == 0 {
 135  		return "untestable", false
 136  	}
 137  	// Last param is always ptr %context - ignore it
 138  	realParams := sig.Params
 139  	if len(realParams) > 0 && realParams[len(realParams)-1].Name == "%context" {
 140  		realParams = realParams[:len(realParams)-1]
 141  	}
 142  
 143  	if len(realParams) == 0 {
 144  		if sig.ReturnType == "i1" || sig.ReturnType == "i32" {
 145  			return "const-null", true
 146  		}
 147  		if sig.ReturnType == "void" {
 148  			return "void-null", true
 149  		}
 150  		return "untestable", false
 151  	}
 152  
 153  	if len(realParams) == 1 {
 154  		t := realParams[0].Type
 155  		if t == "i8" && (sig.ReturnType == "i1" || sig.ReturnType == "i32") {
 156  			return "byte-exhaustive", true
 157  		}
 158  		if t == "i32" && (sig.ReturnType == "i1" || sig.ReturnType == "i32") {
 159  			return "int-boundary", true
 160  		}
 161  	}
 162  
 163  	// Multi-scalar params
 164  	allScalar := true
 165  	for _, p := range realParams {
 166  		if p.Type != "i8" && p.Type != "i16" && p.Type != "i32" && p.Type != "i64" && p.Type != "i1" {
 167  			allScalar = false
 168  			break
 169  		}
 170  	}
 171  	if allScalar && (sig.ReturnType == "i1" || sig.ReturnType == "i32" || sig.ReturnType == "i64" || sig.ReturnType == "void") {
 172  		return "scalar-boundary", true
 173  	}
 174  
 175  	return "untestable", false
 176  }
 177  
 178  func GenerateTestMain(sig IRSignature, testKind string, needDecl bool) []byte {
 179  	var b []byte
 180  	b = append(b, "declare i32 @printf(ptr, ...)\n"...)
 181  	b = append(b, "@fmt_d = private unnamed_addr constant [4 x i8] c\"%d\\0A\\00\"\n"...)
 182  	b = append(b, "@fmt_ld = private unnamed_addr constant [5 x i8] c\"%ld\\0A\\00\"\n"...)
 183  
 184  	if needDecl {
 185  		b = append(b, "declare " | sig.ReturnType | " @\"" | sig.FuncName | "\"("...)
 186  		for i, p := range sig.Params {
 187  			if i > 0 {
 188  				b = append(b, ", "...)
 189  			}
 190  			b = append(b, p.Type...)
 191  		}
 192  		b = append(b, ")\n\n"...)
 193  	}
 194  
 195  	b = append(b, "define i32 @main() {\n"...)
 196  	b = append(b, "entry:\n"...)
 197  
 198  	switch testKind {
 199  	case "byte-exhaustive":
 200  		b = append(b, "  br label %loop\n"...)
 201  		b = append(b, "loop:\n"...)
 202  		b = append(b, "  %i = phi i32 [ 0, %entry ], [ %next, %loop ]\n"...)
 203  		b = append(b, "  %c = trunc i32 %i to i8\n"...)
 204  		b = append(b, "  %r = call " | sig.ReturnType | " @\"" | sig.FuncName | "\"(i8 %c, ptr null)\n"...)
 205  		if sig.ReturnType == "i1" {
 206  			b = append(b, "  %r32 = zext i1 %r to i32\n"...)
 207  			b = append(b, "  call i32 (ptr, ...) @printf(ptr @fmt_d, i32 %r32)\n"...)
 208  		} else {
 209  			b = append(b, "  call i32 (ptr, ...) @printf(ptr @fmt_d, i32 %r)\n"...)
 210  		}
 211  		b = append(b, "  %next = add i32 %i, 1\n"...)
 212  		b = append(b, "  %done = icmp eq i32 %next, 256\n"...)
 213  		b = append(b, "  br i1 %done, label %exit, label %loop\n"...)
 214  		b = append(b, "exit:\n"...)
 215  		b = append(b, "  ret i32 0\n"...)
 216  
 217  	case "int-boundary":
 218  		vals := []string{"0", "1", "-1", "127", "-128", "255", "2147483647", "-2147483648", "42", "100", "1000"}
 219  		for vi, v := range vals {
 220  			vn := "v" | intToStr(vi)
 221  			b = append(b, "  %" | vn | " = call " | sig.ReturnType | " @\"" | sig.FuncName | "\"(i32 " | v | ", ptr null)\n"...)
 222  			if sig.ReturnType == "i1" {
 223  				b = append(b, "  %" | vn | "e = zext i1 %" | vn | " to i32\n"...)
 224  				b = append(b, "  call i32 (ptr, ...) @printf(ptr @fmt_d, i32 %" | vn | "e)\n"...)
 225  			} else {
 226  				b = append(b, "  call i32 (ptr, ...) @printf(ptr @fmt_d, i32 %" | vn | ")\n"...)
 227  			}
 228  		}
 229  		b = append(b, "  ret i32 0\n"...)
 230  
 231  	case "const-null":
 232  		b = append(b, "  %r = call " | sig.ReturnType | " @\"" | sig.FuncName | "\"(ptr null)\n"...)
 233  		if sig.ReturnType == "i1" {
 234  			b = append(b, "  %r32 = zext i1 %r to i32\n"...)
 235  			b = append(b, "  call i32 (ptr, ...) @printf(ptr @fmt_d, i32 %r32)\n"...)
 236  		} else {
 237  			b = append(b, "  call i32 (ptr, ...) @printf(ptr @fmt_d, i32 %r)\n"...)
 238  		}
 239  		b = append(b, "  ret i32 0\n"...)
 240  
 241  	case "void-null":
 242  		b = append(b, "  call void @\"" | sig.FuncName | "\"(ptr null)\n"...)
 243  		b = append(b, "  ret i32 0\n"...)
 244  
 245  	default:
 246  		b = append(b, "  ret i32 0\n"...)
 247  	}
 248  
 249  	b = append(b, "}\n"...)
 250  	return b
 251  }
 252  
 253  func ReplaceFunctionInModule(module []byte, funcName string, newBody []byte) []byte {
 254  	// Find "define ... @"funcName"(...) ... {"
 255  	searchStr := "@\"" | funcName | "\""
 256  	defIdx := bytes.Index(module, []byte("define"))
 257  	for defIdx >= 0 {
 258  		nameIdx := bytes.Index(module[defIdx:], []byte(searchStr))
 259  		if nameIdx < 0 {
 260  			break
 261  		}
 262  		nameIdx += defIdx
 263  		// Verify this is within a define line (find the newline before)
 264  		lineStart := bytes.LastIndexByte(module[:nameIdx], '\n')
 265  		if lineStart < 0 {
 266  			lineStart = 0
 267  		} else {
 268  			lineStart++
 269  		}
 270  		line := module[lineStart:]
 271  		if !bytes.HasPrefix(bytes.TrimSpace(line), []byte("define ")) {
 272  			defIdx = nameIdx + len(searchStr)
 273  			continue
 274  		}
 275  
 276  		// Find the opening {
 277  		braceIdx := bytes.IndexByte(module[nameIdx:], '{')
 278  		if braceIdx < 0 {
 279  			break
 280  		}
 281  		funcStart := lineStart
 282  		bodyStart := nameIdx + braceIdx
 283  
 284  		// Find the closing } - match braces
 285  		depth := 1
 286  		pos := bodyStart + 1
 287  		for pos < len(module) && depth > 0 {
 288  			if module[pos] == '{' {
 289  				depth++
 290  			} else if module[pos] == '}' {
 291  				depth--
 292  			}
 293  			pos++
 294  		}
 295  		funcEnd := pos
 296  
 297  		// Build replacement
 298  		var result []byte
 299  		result = append(result, module[:funcStart]...)
 300  		result = append(result, newBody...)
 301  		result = append(result, '\n')
 302  		result = append(result, module[funcEnd:]...)
 303  		return result
 304  	}
 305  	return module
 306  }
 307  
 308  func ExtractFunctionFromModule(module []byte, funcName string) []byte {
 309  	searchStr := "@\"" | funcName | "\""
 310  	defIdx := 0
 311  	for defIdx < len(module) {
 312  		nextDef := bytes.Index(module[defIdx:], []byte("define"))
 313  		if nextDef < 0 {
 314  			break
 315  		}
 316  		defIdx += nextDef
 317  		nameIdx := bytes.Index(module[defIdx:], []byte(searchStr))
 318  		if nameIdx < 0 || nameIdx > 200 {
 319  			defIdx += 7
 320  			continue
 321  		}
 322  		nameIdx += defIdx
 323  
 324  		lineStart := bytes.LastIndexByte(module[:nameIdx], '\n')
 325  		if lineStart < 0 {
 326  			lineStart = 0
 327  		} else {
 328  			lineStart++
 329  		}
 330  
 331  		braceIdx := bytes.IndexByte(module[nameIdx:], '{')
 332  		if braceIdx < 0 {
 333  			break
 334  		}
 335  		bodyStart := nameIdx + braceIdx
 336  		depth := 1
 337  		pos := bodyStart + 1
 338  		for pos < len(module) && depth > 0 {
 339  			if module[pos] == '{' {
 340  				depth++
 341  			} else if module[pos] == '}' {
 342  				depth--
 343  			}
 344  			pos++
 345  		}
 346  		return module[lineStart:pos]
 347  	}
 348  	return nil
 349  }
 350