benchgen.mx raw

   1  package iskra
   2  
   3  import (
   4  	"bytes"
   5  	"os"
   6  )
   7  
   8  type BenchFunc struct {
   9  	Name       string
  10  	RecvType   string
  11  	ParamTypes []string
  12  	ParamNames []string
  13  	ResultTypes []string
  14  	IsMethod   bool
  15  	Class      BenchClass
  16  	SizeDriver int32
  17  }
  18  
  19  const benchItersSimple = "10000000"
  20  const benchItersComplex = "100000"
  21  
  22  func EmitBenchFile(pkgName string, importPath string, funcs []BenchFunc) []byte {
  23  	var out []byte
  24  	out = append(out, "// Generated by iskra bench-gen\n"...)
  25  	out = append(out, "package main\n\n"...)
  26  	out = append(out, "import (\n"...)
  27  	out = append(out, "\t\"fmt\"\n"...)
  28  	out = append(out, "\t\"time\"\n"...)
  29  	if importPath != "" {
  30  		out = append(out, "\t\"" | importPath | "\"\n"...)
  31  	}
  32  	out = append(out, ")\n\n"...)
  33  	out = append(out, "const Ns = " | benchItersSimple | "\n"...)
  34  	out = append(out, "const Nc = " | benchItersComplex | "\n\n"...)
  35  	out = append(out, "var sink int64\n\n"...)
  36  
  37  	out = append(out, "func main() {\n"...)
  38  	for _, f := range funcs {
  39  		if f.Class == BenchSkip {
  40  			continue
  41  		}
  42  		if f.SizeDriver >= 0 {
  43  			emitSizedBench(&out, f, pkgName)
  44  		} else {
  45  			emitSimpleBench(&out, f, pkgName)
  46  		}
  47  	}
  48  	out = append(out, "\t_ = sink\n"...)
  49  	out = append(out, "}\n"...)
  50  	return out
  51  }
  52  
  53  func iterVar(f BenchFunc) string {
  54  	if f.Class == BenchComplex {
  55  		return "Nc"
  56  	}
  57  	return "Ns"
  58  }
  59  
  60  func iterLabel(f BenchFunc) string {
  61  	if f.Class == BenchComplex {
  62  		return benchItersComplex
  63  	}
  64  	return benchItersSimple
  65  }
  66  
  67  func emitSimpleBench(out *[]byte, f BenchFunc, pkg string) {
  68  	label := benchLabel(f, "")
  69  	*out = append(*out, "\t{\n"...)
  70  	if f.IsMethod && f.RecvType != "" {
  71  		emitRecvSetup(out, f.RecvType, pkg)
  72  	}
  73  	emitParamSetup(out, f.ParamTypes, f.ParamNames, -1, 0, pkg)
  74  	emitTimedLoop(out, f, label, pkg)
  75  	*out = append(*out, "\t}\n"...)
  76  }
  77  
  78  func emitSizedBench(out *[]byte, f BenchFunc, pkg string) {
  79  	sizes := []int32{1, 10, 100, 1000}
  80  	for _, sz := range sizes {
  81  		label := benchLabel(f, intToStr(int(sz)))
  82  		*out = append(*out, "\t{\n"...)
  83  		if f.IsMethod && f.RecvType != "" {
  84  			emitRecvSetup(out, f.RecvType, pkg)
  85  		}
  86  		emitParamSetup(out, f.ParamTypes, f.ParamNames, f.SizeDriver, sz, pkg)
  87  		emitTimedLoop(out, f, label, pkg)
  88  		*out = append(*out, "\t}\n"...)
  89  	}
  90  }
  91  
  92  func benchLabel(f BenchFunc, sizeSuffix string) string {
  93  	name := f.Name
  94  	if f.IsMethod && f.RecvType != "" {
  95  		rt := f.RecvType
  96  		if len(rt) > 0 && rt[0] == '*' {
  97  			rt = rt[1:]
  98  		}
  99  		name = rt | "." | f.Name
 100  	}
 101  	if sizeSuffix != "" {
 102  		return name | "/N=" | sizeSuffix
 103  	}
 104  	return name
 105  }
 106  
 107  func emitRecvSetup(out *[]byte, recvType string, pkg string) {
 108  	clean := recvType
 109  	isPtr := false
 110  	if len(clean) > 0 && clean[0] == '*' {
 111  		isPtr = true
 112  		clean = clean[1:]
 113  	}
 114  	qualified := pkg | "." | clean
 115  	if isPtr {
 116  		*out = append(*out, "\t\trecv := &" | qualified | "{}\n"...)
 117  	} else {
 118  		*out = append(*out, "\t\tvar recv " | qualified | "\n"...)
 119  	}
 120  }
 121  
 122  func emitParamSetup(out *[]byte, types []string, names []string, sizeDriver int32, size int32, pkg string) {
 123  	for i, typ := range types {
 124  		varName := "p" | intToStr(i)
 125  		if i < len(names) && names[i] != "" {
 126  			varName = names[i]
 127  		}
 128  		qtyp := qualifyType(typ, pkg)
 129  		if int32(i) == sizeDriver {
 130  			emitSizedParam(out, varName, qtyp, size)
 131  		} else {
 132  			emitDefaultParam(out, varName, qtyp)
 133  		}
 134  	}
 135  }
 136  
 137  func emitDefaultParam(out *[]byte, name string, typ string) {
 138  	switch typ {
 139  	case "int", "int32":
 140  		*out = append(*out, "\t\t" | name | " := 42\n"...)
 141  	case "int64":
 142  		*out = append(*out, "\t\t" | name | " := int64(42)\n"...)
 143  	case "uint", "uint32":
 144  		*out = append(*out, "\t\t" | name | " := uint32(42)\n"...)
 145  	case "uint64":
 146  		*out = append(*out, "\t\t" | name | " := uint64(42)\n"...)
 147  	case "uint8", "byte":
 148  		*out = append(*out, "\t\t" | name | " := byte('a')\n"...)
 149  	case "rune":
 150  		*out = append(*out, "\t\t" | name | " := rune('a')\n"...)
 151  	case "float64":
 152  		*out = append(*out, "\t\t" | name | " := 3.14\n"...)
 153  	case "float32":
 154  		*out = append(*out, "\t\t" | name | " := float32(3.14)\n"...)
 155  	case "bool":
 156  		*out = append(*out, "\t\t" | name | " := true\n"...)
 157  	case "string", "[]byte":
 158  		*out = append(*out, "\t\t" | name | " := []byte(\"hello, world\")\n"...)
 159  	case "[][]byte":
 160  		*out = append(*out, "\t\t" | name | " := [][]byte{[]byte(\"hello\"), []byte(\"world\")}\n"...)
 161  	case "error":
 162  		*out = append(*out, "\t\tvar " | name | " error\n"...)
 163  	default:
 164  		if len(typ) > 2 && typ[:2] == "[]" {
 165  			*out = append(*out, "\t\t" | name | " := " | typ | "{}\n"...)
 166  		} else if len(typ) > 0 && typ[0] == '*' {
 167  			*out = append(*out, "\t\t" | name | " := &" | typ[1:] | "{}\n"...)
 168  		} else {
 169  			*out = append(*out, "\t\tvar " | name | " " | typ | "\n"...)
 170  		}
 171  	}
 172  }
 173  
 174  func emitSizedParam(out *[]byte, name string, typ string, size int32) {
 175  	szStr := intToStr(int(size))
 176  	switch typ {
 177  	case "int", "int32":
 178  		*out = append(*out, "\t\t" | name | " := " | szStr | "\n"...)
 179  	case "int64":
 180  		*out = append(*out, "\t\t" | name | " := int64(" | szStr | ")\n"...)
 181  	case "uint", "uint32":
 182  		*out = append(*out, "\t\t" | name | " := uint32(" | szStr | ")\n"...)
 183  	case "uint64":
 184  		*out = append(*out, "\t\t" | name | " := uint64(" | szStr | ")\n"...)
 185  	case "float64":
 186  		*out = append(*out, "\t\t" | name | " := float64(" | szStr | ")\n"...)
 187  	case "float32":
 188  		*out = append(*out, "\t\t" | name | " := float32(" | szStr | ")\n"...)
 189  	case "string", "[]byte":
 190  		*out = append(*out, "\t\t" | name | " := []byte{:" | szStr | "}\n"...)
 191  		*out = append(*out, "\t\tfor j := range " | name | " { " | name | "[j] = byte('a' + j%26) }\n"...)
 192  	default:
 193  		if len(typ) > 2 && typ[:2] == "[]" {
 194  			*out = append(*out, "\t\t" | name | " := " | typ | "{:" | szStr | "}\n"...)
 195  		} else {
 196  			*out = append(*out, "\t\t" | name | " := " | szStr | "\n"...)
 197  		}
 198  	}
 199  }
 200  
 201  func emitTimedLoop(out *[]byte, f BenchFunc, label string, pkg string) {
 202  	nVar := iterVar(f)
 203  	*out = append(*out, "\t\tt0 := time.Now()\n"...)
 204  	*out = append(*out, "\t\tfor j := 0; j < " | nVar | "; j++ {\n"...)
 205  
 206  	*out = append(*out, "\t\t\t"...)
 207  	nResults := len(f.ResultTypes)
 208  	if nResults == 1 {
 209  		*out = append(*out, "_ = "...)
 210  	} else if nResults > 1 {
 211  		for ri := 0; ri < nResults; ri++ {
 212  			if ri > 0 {
 213  				*out = append(*out, ", "...)
 214  			}
 215  			*out = append(*out, "_"...)
 216  		}
 217  		*out = append(*out, " = "...)
 218  	}
 219  
 220  	if f.IsMethod && f.RecvType != "" {
 221  		*out = append(*out, "recv." | f.Name | "("...)
 222  	} else {
 223  		*out = append(*out, pkg | "." | f.Name | "("...)
 224  	}
 225  
 226  	for i := range f.ParamTypes {
 227  		if i > 0 {
 228  			*out = append(*out, ", "...)
 229  		}
 230  		varName := "p" | intToStr(i)
 231  		if i < len(f.ParamNames) && f.ParamNames[i] != "" {
 232  			varName = f.ParamNames[i]
 233  		}
 234  		*out = append(*out, varName...)
 235  	}
 236  	*out = append(*out, ")\n"...)
 237  	*out = append(*out, "\t\t\tsink++\n"...)
 238  	*out = append(*out, "\t\t}\n"...)
 239  
 240  	nLabel := iterLabel(f)
 241  	*out = append(*out, "\t\telapsed := time.Since(t0).Nanoseconds()\n"...)
 242  	*out = append(*out, "\t\tnsop := elapsed / int64(" | nVar | ")\n"...)
 243  	*out = append(*out, "\t\tsubns := (elapsed * 100 / int64(" | nVar | ")) % 100\n"...)
 244  	*out = append(*out, "\t\tfmt.Println(\"" | label | "\\t" | nLabel | "\\t\" | fmt.Sprint(nsop) | \".\" | fmt.Sprint(subns) | \" ns/op\")\n"...)
 245  }
 246  
 247  func qualifyType(typ string, pkg string) string {
 248  	if len(typ) == 0 || pkg == "" {
 249  		return typ
 250  	}
 251  	// Slice of custom type: []Foo -> []pkg.Foo
 252  	if len(typ) > 2 && typ[:2] == "[]" {
 253  		inner := typ[2:]
 254  		if needsQualification(inner) {
 255  			return "[]" | pkg | "." | inner
 256  		}
 257  		return typ
 258  	}
 259  	// Pointer to custom type: *Foo -> *pkg.Foo
 260  	if typ[0] == '*' {
 261  		inner := typ[1:]
 262  		if needsQualification(inner) {
 263  			return "*" | pkg | "." | inner
 264  		}
 265  		return typ
 266  	}
 267  	if needsQualification(typ) {
 268  		return pkg | "." | typ
 269  	}
 270  	return typ
 271  }
 272  
 273  func needsQualification(name string) bool {
 274  	if len(name) == 0 {
 275  		return false
 276  	}
 277  	if name[0] < 'A' || name[0] > 'Z' {
 278  		return false
 279  	}
 280  	switch name {
 281  	case "Reader", "Writer", "ReadWriter":
 282  		return true
 283  	}
 284  	return true
 285  }
 286  
 287  func benchFuncName(f BenchFunc) string {
 288  	name := f.Name
 289  	if !f.IsMethod || f.RecvType == "" {
 290  		return name
 291  	}
 292  	rt := f.RecvType
 293  	if len(rt) > 0 && rt[0] == '*' {
 294  		rt = rt[1:]
 295  	}
 296  	return rt | "_" | name
 297  }
 298  
 299  func ExtractBenchFuncsFlat(t *Tree) []BenchFunc {
 300  	var funcs []BenchFunc
 301  	seen := map[string]bool{}
 302  
 303  	for i := range t.RecMeta {
 304  		meta := &t.RecMeta[i]
 305  		if meta.StageTag != StageAST {
 306  			continue
 307  		}
 308  		if meta.Kind != KindFunc && meta.Kind != KindMethod {
 309  			continue
 310  		}
 311  		astContent := t.GetContent(uint32(i))
 312  		if len(astContent) == 0 {
 313  			continue
 314  		}
 315  
 316  		rec := t.db.GetRecord(uint32(i))
 317  		if rec == nil {
 318  			continue
 319  		}
 320  		fullName := FormFromRecord(rec, t.StringPool)
 321  		funcName := unqualifiedName(fullName)
 322  
 323  		if len(funcName) == 0 || funcName[0] < 'A' || funcName[0] > 'Z' {
 324  			continue
 325  		}
 326  
 327  		st := ExtractSymbols(string(astContent))
 328  		cls := ClassifyBenchCost(string(astContent), funcName)
 329  
 330  		paramTypes := expandParamTypes(st.ParamTypes, extractParamNamesFromAST(astContent))
 331  		if hasUnsynthesizableParam(paramTypes) {
 332  			continue
 333  		}
 334  		paramNames := extractParamNamesFromAST(astContent)
 335  		sanitizeParamNames(paramNames)
 336  
 337  		if meta.Kind == KindMethod && st.RecvType != "" {
 338  			rt := st.RecvType
 339  			if len(rt) > 0 && rt[0] == '*' {
 340  				rt = rt[1:]
 341  			}
 342  			if len(rt) == 0 || rt[0] < 'A' || rt[0] > 'Z' {
 343  				continue
 344  			}
 345  		}
 346  
 347  		resultTypes := expandResultTypes(st.ResultTypes, astContent)
 348  
 349  		bf := BenchFunc{
 350  			Name:        funcName,
 351  			RecvType:    st.RecvType,
 352  			ParamTypes:  paramTypes,
 353  			ParamNames:  paramNames,
 354  			ResultTypes: resultTypes,
 355  			IsMethod:    meta.Kind == KindMethod,
 356  			Class:       cls,
 357  			SizeDriver:  detectSizeDriver(astContent, paramTypes),
 358  		}
 359  		key := benchFuncName(bf)
 360  		if seen[key] {
 361  			continue
 362  		}
 363  		seen[key] = true
 364  		funcs = append(funcs, bf)
 365  	}
 366  	return funcs
 367  }
 368  
 369  func hasUnsynthesizableParam(types []string) bool {
 370  	for _, t := range types {
 371  		if bytes.Contains([]byte(t), []byte("func(")) {
 372  			return true
 373  		}
 374  		if len(t) >= 3 && t[:3] == "..." {
 375  			return true
 376  		}
 377  		if t == "interface{}" || t == "any" {
 378  			return true
 379  		}
 380  		if bytes.Contains([]byte(t), []byte("interface{")) {
 381  			return true
 382  		}
 383  		// Skip any io.* types (interfaces that can't be defaulted)
 384  		if len(t) >= 3 && t[:3] == "io." {
 385  			return true
 386  		}
 387  			// Skip cross-package types with dots
 388  		if bytes.ContainsAny([]byte(t), ".") {
 389  			return true
 390  		}
 391  		// Skip well-known interface/abstract types
 392  		switch t {
 393  		case "Reader", "Writer", "ReadWriter", "ReadCloser",
 394  			"WriteCloser", "ReadWriteCloser", "Closer",
 395  			"ByteReader", "ByteWriter", "ByteScanner",
 396  			"RuneReader", "RuneScanner", "Seeker",
 397  			"ReadSeeker", "WriteSeeker", "ReadWriteSeeker",
 398  			"ReaderFrom", "WriterTo", "ReaderAt", "WriterAt",
 399  			"Block", "AEAD", "Hash", "Image", "Model":
 400  			return true
 401  		}
 402  	}
 403  	return false
 404  }
 405  
 406  func sanitizeParamNames(names []string) {
 407  	for i, n := range names {
 408  		if n == "b" || n == "i" || n == "j" {
 409  			names[i] = n | "0"
 410  		}
 411  	}
 412  }
 413  
 414  // expandParamTypes aligns types with names when the AST groups
 415  // multiple names under one type (e.g. "err,target error" -> 1 type, 2 names).
 416  func expandParamTypes(types []string, names []string) []string {
 417  	if len(types) == 0 || len(names) <= len(types) {
 418  		return types
 419  	}
 420  	expanded := []string{:0:len(names)}
 421  	ti := 0
 422  	ni := 0
 423  	for ti < len(types) && ni < len(names) {
 424  		remaining := len(names) - ni
 425  		typesLeft := len(types) - ti
 426  		count := remaining - typesLeft + 1
 427  		if count < 1 {
 428  			count = 1
 429  		}
 430  		for c := 0; c < count && ni < len(names); c++ {
 431  			expanded = append(expanded, types[ti])
 432  			ni++
 433  		}
 434  		ti++
 435  	}
 436  	return expanded
 437  }
 438  
 439  func expandResultTypes(types []string, astDump []byte) []string {
 440  	resultNames := extractResultNamesFromAST(astDump)
 441  	if len(resultNames) <= len(types) {
 442  		return types
 443  	}
 444  	return expandParamTypes(types, resultNames)
 445  }
 446  
 447  func extractResultNamesFromAST(astDump []byte) []string {
 448  	lines := bytes.Split(astDump, []byte("\n"))
 449  	inResults := false
 450  	var names []string
 451  	for _, line := range lines {
 452  		trimmed := bytes.TrimSpace(line)
 453  		if bytes.Equal(trimmed, []byte("Results")) {
 454  			inResults = true
 455  			continue
 456  		}
 457  		if inResults {
 458  			if len(trimmed) == 0 || (trimmed[0] != ' ' && !bytes.HasPrefix(line, []byte("    "))) {
 459  				if !bytes.HasPrefix(trimmed, []byte("  ")) {
 460  					break
 461  				}
 462  			}
 463  			parts := bytes.Fields(trimmed)
 464  			if len(parts) >= 2 {
 465  				namesPart := parts[0]
 466  				for _, n := range bytes.Split(namesPart, []byte(",")) {
 467  					names = append(names, string(n))
 468  				}
 469  			} else if len(parts) == 1 {
 470  				names = append(names, "_")
 471  			}
 472  		}
 473  	}
 474  	return names
 475  }
 476  
 477  func extractParamNamesFromAST(astDump []byte) []string {
 478  	lines := bytes.Split(astDump, []byte("\n"))
 479  	inParams := false
 480  	var names []string
 481  	for _, line := range lines {
 482  		trimmed := bytes.TrimSpace(line)
 483  		if bytes.Equal(trimmed, []byte("Params")) {
 484  			inParams = true
 485  			continue
 486  		}
 487  		if inParams {
 488  			if len(trimmed) == 0 || (trimmed[0] != ' ' && !bytes.HasPrefix(line, []byte("    "))) {
 489  				if !bytes.HasPrefix(trimmed, []byte("  ")) {
 490  					break
 491  				}
 492  			}
 493  			parts := bytes.Fields(trimmed)
 494  			if len(parts) >= 2 {
 495  				namesPart := parts[0]
 496  				for _, n := range bytes.Split(namesPart, []byte(",")) {
 497  					names = append(names, string(n))
 498  				}
 499  			}
 500  		}
 501  	}
 502  	return names
 503  }
 504  
 505  // ImportPathFromManifest extracts the Go import path from a corpus
 506  // manifest.csv by parsing IR filenames like "unicode_utf8_RuneBytes.O0.ll"
 507  // into "unicode/utf8". Returns pkgName, importPath.
 508  func ImportPathFromManifest(manifestPath string) (string, string) {
 509  	data, err := os.ReadFile(manifestPath)
 510  	if err != nil {
 511  		return "", ""
 512  	}
 513  	lines := bytes.Split(data, []byte("\n"))
 514  	for _, line := range lines {
 515  		fields := bytes.Split(line, []byte("\t"))
 516  		if len(fields) < 6 {
 517  			continue
 518  		}
 519  		kind := string(fields[1])
 520  		if kind != "func" {
 521  			continue
 522  		}
 523  		irFile := string(fields[5])
 524  		if len(irFile) == 0 {
 525  			continue
 526  		}
 527  		funcName := string(fields[2])
 528  		// IR filename: crypto_aes_NewCipher.O0.ll
 529  		// Strip ".O0.ll" suffix, then strip "_FuncName" suffix to get path prefix
 530  		base := irFile
 531  		if dotIdx := bytes.IndexByte([]byte(base), '.'); dotIdx >= 0 {
 532  			base = base[:dotIdx]
 533  		}
 534  		// Strip _FuncName suffix
 535  		if len(base) > len(funcName)+1 {
 536  			prefix := base[:len(base)-len(funcName)-1]
 537  			// Replace _ with / to reconstruct import path
 538  			importPath := bytes.Replace([]byte(prefix), []byte("_"), []byte("/"), -1)
 539  			// Package name is last segment
 540  			pkgName := prefix
 541  			if lastUnderscore := bytes.LastIndexByte([]byte(prefix), '_'); lastUnderscore >= 0 {
 542  				pkgName = prefix[lastUnderscore+1:]
 543  			}
 544  			return pkgName, string(importPath)
 545  		}
 546  	}
 547  	return "", ""
 548  }
 549  
 550  func detectSizeDriver(astDump []byte, paramTypes []string) int32 {
 551  	lines := bytes.Split(astDump, []byte("\n"))
 552  	paramNames := extractParamNamesFromAST(astDump)
 553  
 554  	for _, line := range lines {
 555  		trimmed := bytes.TrimSpace(line)
 556  		if !bytes.HasPrefix(trimmed, []byte("Range")) && !bytes.HasPrefix(trimmed, []byte("For")) {
 557  			continue
 558  		}
 559  		for i, pn := range paramNames {
 560  			if bytes.Contains(trimmed, []byte(pn)) {
 561  				return int32(i)
 562  			}
 563  		}
 564  	}
 565  	for i, pt := range paramTypes {
 566  		if pt == "[]byte" || pt == "string" {
 567  			return int32(i)
 568  		}
 569  	}
 570  	return -1
 571  }
 572