Blame view

scripts/patch_rerank_vllm_benchmark_config.py 3.27 KB
52ea6529   tangwang   性能测试:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
  #!/usr/bin/env python3
  """
  Surgically patch config/config.yaml:
    services.rerank.backend
    services.rerank.backends.qwen3_vllm.instruction_format
    services.rerank.backends.qwen3_vllm_score.instruction_format
  
  Preserves comments and unrelated lines. Used for benchmark matrix runs.
  """
  
  from __future__ import annotations
  
  import argparse
  import re
  import sys
  from pathlib import Path
  
  
  def _with_stripped_body(line: str) -> tuple[str, str]:
      """Return (body without end newline, newline suffix including '' if none)."""
      if line.endswith("\r\n"):
          return line[:-2], "\r\n"
      if line.endswith("\n"):
          return line[:-1], "\n"
      return line, ""
  
  
  def _patch_backend_in_rerank_block(lines: list[str], backend: str) -> None:
      in_rerank = False
      for i, line in enumerate(lines):
          if line.startswith("  rerank:"):
              in_rerank = True
              continue
          if in_rerank:
              if line.startswith("  ") and not line.startswith("    ") and line.strip():
                  in_rerank = False
                  continue
              body, nl = _with_stripped_body(line)
              m = re.match(r'^(\s*backend:\s*")[^"]+(".*)$', body)
              if m:
                  lines[i] = f'{m.group(1)}{backend}{m.group(2)}{nl}'
                  return
      raise RuntimeError("services.rerank.backend line not found")
  
  
  def _patch_instruction_format_under_backend(
      lines: list[str], section: str, fmt: str
  ) -> None:
      """section is 'qwen3_vllm' or 'qwen3_vllm_score' (first line is '      qwen3_vllm:')."""
      header = f"      {section}:"
      start = None
      for i, line in enumerate(lines):
          if line.rstrip() == header:
              start = i
              break
      if start is None:
          raise RuntimeError(f"section {section!r} not found")
  
      for j in range(start + 1, len(lines)):
          line = lines[j]
          body, nl = _with_stripped_body(line)
          if re.match(r"^      [a-zA-Z0-9_]+:\s*$", body):
              break
          m = re.match(r"^(\s*instruction_format:\s*)\S+", body)
          if m:
              lines[j] = f"{m.group(1)}{fmt}{nl}"
              return
      raise RuntimeError(f"instruction_format not found under {section!r}")
  
  
  def main() -> int:
      p = argparse.ArgumentParser()
      p.add_argument(
          "--config",
          type=Path,
          default=Path(__file__).resolve().parent.parent / "config" / "config.yaml",
      )
      p.add_argument("--backend", choices=("qwen3_vllm", "qwen3_vllm_score"), required=True)
      p.add_argument(
          "--instruction-format",
          dest="instruction_format",
          choices=("compact", "standard"),
          required=True,
      )
      args = p.parse_args()
      text = args.config.read_text(encoding="utf-8")
      lines = text.splitlines(keepends=True)
      if not lines:
          print("empty config", file=sys.stderr)
          return 2
      _patch_backend_in_rerank_block(lines, args.backend)
      _patch_instruction_format_under_backend(lines, "qwen3_vllm", args.instruction_format)
      _patch_instruction_format_under_backend(lines, "qwen3_vllm_score", args.instruction_format)
      args.config.write_text("".join(lines), encoding="utf-8")
      print(f"patched {args.config}: backend={args.backend} instruction_format={args.instruction_format} (both vLLM blocks)")
      return 0
  
  
  if __name__ == "__main__":
      raise SystemExit(main())