patch_rerank_vllm_benchmark_config.py 3.27 KB
#!/usr/bin/env python3
"""
Surgically patch config/config.yaml:
  services.rerank.backend
  services.rerank.backends.qwen3_vllm.instruction_format
  services.rerank.backends.qwen3_vllm_score.instruction_format

Preserves comments and unrelated lines. Used for benchmark matrix runs.
"""

from __future__ import annotations

import argparse
import re
import sys
from pathlib import Path


def _with_stripped_body(line: str) -> tuple[str, str]:
    """Return (body without end newline, newline suffix including '' if none)."""
    if line.endswith("\r\n"):
        return line[:-2], "\r\n"
    if line.endswith("\n"):
        return line[:-1], "\n"
    return line, ""


def _patch_backend_in_rerank_block(lines: list[str], backend: str) -> None:
    in_rerank = False
    for i, line in enumerate(lines):
        if line.startswith("  rerank:"):
            in_rerank = True
            continue
        if in_rerank:
            if line.startswith("  ") and not line.startswith("    ") and line.strip():
                in_rerank = False
                continue
            body, nl = _with_stripped_body(line)
            m = re.match(r'^(\s*backend:\s*")[^"]+(".*)$', body)
            if m:
                lines[i] = f'{m.group(1)}{backend}{m.group(2)}{nl}'
                return
    raise RuntimeError("services.rerank.backend line not found")


def _patch_instruction_format_under_backend(
    lines: list[str], section: str, fmt: str
) -> None:
    """section is 'qwen3_vllm' or 'qwen3_vllm_score' (first line is '      qwen3_vllm:')."""
    header = f"      {section}:"
    start = None
    for i, line in enumerate(lines):
        if line.rstrip() == header:
            start = i
            break
    if start is None:
        raise RuntimeError(f"section {section!r} not found")

    for j in range(start + 1, len(lines)):
        line = lines[j]
        body, nl = _with_stripped_body(line)
        if re.match(r"^      [a-zA-Z0-9_]+:\s*$", body):
            break
        m = re.match(r"^(\s*instruction_format:\s*)\S+", body)
        if m:
            lines[j] = f"{m.group(1)}{fmt}{nl}"
            return
    raise RuntimeError(f"instruction_format not found under {section!r}")


def main() -> int:
    p = argparse.ArgumentParser()
    p.add_argument(
        "--config",
        type=Path,
        default=Path(__file__).resolve().parent.parent / "config" / "config.yaml",
    )
    p.add_argument("--backend", choices=("qwen3_vllm", "qwen3_vllm_score"), required=True)
    p.add_argument(
        "--instruction-format",
        dest="instruction_format",
        choices=("compact", "standard"),
        required=True,
    )
    args = p.parse_args()
    text = args.config.read_text(encoding="utf-8")
    lines = text.splitlines(keepends=True)
    if not lines:
        print("empty config", file=sys.stderr)
        return 2
    _patch_backend_in_rerank_block(lines, args.backend)
    _patch_instruction_format_under_backend(lines, "qwen3_vllm", args.instruction_format)
    _patch_instruction_format_under_backend(lines, "qwen3_vllm_score", args.instruction_format)
    args.config.write_text("".join(lines), encoding="utf-8")
    print(f"patched {args.config}: backend={args.backend} instruction_format={args.instruction_format} (both vLLM blocks)")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())