patch_rerank_vllm_benchmark_config.py
3.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python3
"""
Surgically patch config/config.yaml:
services.rerank.backend
services.rerank.backends.qwen3_vllm.instruction_format
services.rerank.backends.qwen3_vllm_score.instruction_format
Preserves comments and unrelated lines. Used for benchmark matrix runs.
"""
from __future__ import annotations
import argparse
import re
import sys
from pathlib import Path
def _with_stripped_body(line: str) -> tuple[str, str]:
"""Return (body without end newline, newline suffix including '' if none)."""
if line.endswith("\r\n"):
return line[:-2], "\r\n"
if line.endswith("\n"):
return line[:-1], "\n"
return line, ""
def _patch_backend_in_rerank_block(lines: list[str], backend: str) -> None:
in_rerank = False
for i, line in enumerate(lines):
if line.startswith(" rerank:"):
in_rerank = True
continue
if in_rerank:
if line.startswith(" ") and not line.startswith(" ") and line.strip():
in_rerank = False
continue
body, nl = _with_stripped_body(line)
m = re.match(r'^(\s*backend:\s*")[^"]+(".*)$', body)
if m:
lines[i] = f'{m.group(1)}{backend}{m.group(2)}{nl}'
return
raise RuntimeError("services.rerank.backend line not found")
def _patch_instruction_format_under_backend(
lines: list[str], section: str, fmt: str
) -> None:
"""section is 'qwen3_vllm' or 'qwen3_vllm_score' (first line is ' qwen3_vllm:')."""
header = f" {section}:"
start = None
for i, line in enumerate(lines):
if line.rstrip() == header:
start = i
break
if start is None:
raise RuntimeError(f"section {section!r} not found")
for j in range(start + 1, len(lines)):
line = lines[j]
body, nl = _with_stripped_body(line)
if re.match(r"^ [a-zA-Z0-9_]+:\s*$", body):
break
m = re.match(r"^(\s*instruction_format:\s*)\S+", body)
if m:
lines[j] = f"{m.group(1)}{fmt}{nl}"
return
raise RuntimeError(f"instruction_format not found under {section!r}")
def main() -> int:
p = argparse.ArgumentParser()
p.add_argument(
"--config",
type=Path,
default=Path(__file__).resolve().parent.parent / "config" / "config.yaml",
)
p.add_argument("--backend", choices=("qwen3_vllm", "qwen3_vllm_score"), required=True)
p.add_argument(
"--instruction-format",
dest="instruction_format",
choices=("compact", "standard"),
required=True,
)
args = p.parse_args()
text = args.config.read_text(encoding="utf-8")
lines = text.splitlines(keepends=True)
if not lines:
print("empty config", file=sys.stderr)
return 2
_patch_backend_in_rerank_block(lines, args.backend)
_patch_instruction_format_under_backend(lines, "qwen3_vllm", args.instruction_format)
_patch_instruction_format_under_backend(lines, "qwen3_vllm_score", args.instruction_format)
args.config.write_text("".join(lines), encoding="utf-8")
print(f"patched {args.config}: backend={args.backend} instruction_format={args.instruction_format} (both vLLM blocks)")
return 0
if __name__ == "__main__":
raise SystemExit(main())