sample.py raw

   1  #!/usr/bin/env python3
   2  """
   3  Quality sampling for round-trip test failures.
   4  Samples min(2%, 1000) pairs, pipes batches through the local `claude` CLI
   5  (no API credits — uses the Claude Code session).
   6  """
   7  import sys, os, random, re, subprocess, time
   8  
   9  FAIL_TSV   = os.environ.get("FAIL_TSV",   f"{os.environ['HOME']}/tmp/roundtrip_fails.tsv")
  10  OUT_REPORT = os.environ.get("OUT_REPORT", f"{os.environ['HOME']}/tmp/quality_report.txt")
  11  BATCH_SIZE = 40
  12  MAX_SAMPLE = 1000
  13  SAMPLE_PCT = 0.02
  14  BATCH_DELAY = 3  # seconds between batches to avoid rate limiting
  15  
  16  def load_fails(path):
  17      rows = []
  18      with open(path) as f:
  19          for i, line in enumerate(f):
  20              if i == 0:
  21                  continue
  22              parts = line.rstrip('\n').split('\t')
  23              if len(parts) < 4:
  24                  continue
  25              rows.append(tuple(parts[:4]))
  26      return rows
  27  
  28  def sample(rows):
  29      n = min(MAX_SAMPLE, int(len(rows) * SAMPLE_PCT))
  30      print(f"Total fails: {len(rows)}, sampling {n} ({SAMPLE_PCT*100:.0f}% or {MAX_SAMPLE} max)",
  31            file=sys.stderr)
  32      return random.sample(rows, n)
  33  
  34  def make_prompt(batch):
  35      lines = []
  36      for i, (direction, source, primary, backs) in enumerate(batch, 1):
  37          if direction.startswith("JA"):
  38              q = f"Is '{primary}' a valid English translation of Japanese '{source}'?"
  39          else:
  40              q = f"Is '{primary}' a valid Japanese translation of English '{source}'?"
  41          lines.append(f"{i}. {q}")
  42      return (
  43          "Rate each translation GOOD or BAD. GOOD = semantically valid even if not the only "
  44          "possible translation. BAD = wrong or misleading.\n"
  45          "One line per item: number, GOOD or BAD, brief reason.\n\n"
  46          + "\n".join(lines)
  47      )
  48  
  49  def parse_response(text, n):
  50      results = []
  51      for line in text.strip().split('\n'):
  52          m = re.match(r'(\d+)[.)]\s+(GOOD|BAD)', line.strip(), re.IGNORECASE)
  53          if m:
  54              results.append(m.group(2).upper())
  55      while len(results) < n:
  56          results.append("UNKNOWN")
  57      return results[:n]
  58  
  59  def call_claude(prompt):
  60      result = subprocess.run(
  61          ["claude", "--print", "--model", "claude-haiku-4-5-20251001"],
  62          input=prompt, capture_output=True, text=True, timeout=120
  63      )
  64      if result.returncode != 0:
  65          print(f"claude error: {result.stderr[:200]}", file=sys.stderr)
  66          return ""
  67      return result.stdout
  68  
  69  def main():
  70      random.seed(int(time.time()))
  71      rows = load_fails(FAIL_TSV)
  72      sampled = sample(rows)
  73  
  74      all_results = []
  75      batches = (len(sampled) + BATCH_SIZE - 1) // BATCH_SIZE
  76  
  77      for b in range(batches):
  78          batch = sampled[b*BATCH_SIZE:(b+1)*BATCH_SIZE]
  79          prompt = make_prompt(batch)
  80          response = call_claude(prompt)
  81          ratings = parse_response(response, len(batch))
  82          for (direction, source, primary, backs), rating in zip(batch, ratings):
  83              all_results.append((direction, source, primary, backs, rating))
  84  
  85          good  = sum(1 for r in all_results if r[4] == "GOOD")
  86          bad   = sum(1 for r in all_results if r[4] == "BAD")
  87          total = len(all_results)
  88          print(f"Batch {b+1}/{batches}: {good}/{total} GOOD ({good*100//max(total,1)}%)", file=sys.stderr)
  89  
  90          if b < batches - 1:
  91              time.sleep(BATCH_DELAY)
  92  
  93      good    = sum(1 for r in all_results if r[4] == "GOOD")
  94      bad     = sum(1 for r in all_results if r[4] == "BAD")
  95      unknown = sum(1 for r in all_results if r[4] == "UNKNOWN")
  96      total   = len(all_results)
  97  
  98      with open(OUT_REPORT, 'w') as f:
  99          f.write(f"Quality sample: {total} pairs assessed\n")
 100          f.write(f"  GOOD:    {good:5d} ({good*100//max(total,1)}%)\n")
 101          f.write(f"  BAD:     {bad:5d} ({bad*100//max(total,1)}%)\n")
 102          if unknown:
 103              f.write(f"  UNKNOWN: {unknown:5d}\n")
 104          f.write("\n--- BAD translations (sample) ---\n")
 105          count = 0
 106          for direction, source, primary, backs, rating in all_results:
 107              if rating == "BAD":
 108                  f.write(f"  [{direction}] {source} → {primary}  (back: {backs})\n")
 109                  count += 1
 110                  if count >= 50:
 111                      f.write("  ... (truncated)\n")
 112                      break
 113  
 114      print(f"\nQuality sample: {total} pairs assessed")
 115      print(f"  GOOD:      {good} ({good*100//max(total,1)}%)")
 116      print(f"  BAD:       {bad} ({bad*100//max(total,1)}%)")
 117      print(f"Full report: {OUT_REPORT}")
 118  
 119  if __name__ == "__main__":
 120      main()
 121