benchmark_reranker_random_titles.py 9.94 KB
#!/usr/bin/env python3
"""
Single-request rerank latency probe using real title lines (e.g. 1.8w export).

Randomly samples N titles from a text file (one title per line), POSTs to the
rerank HTTP API, prints wall-clock latency.

Supports multiple N values (comma-separated) and multiple repeats per N.
Each invocation runs 3 warmup requests with n=400 first; those are not timed for summaries.

Example:
  source activate.sh
  python scripts/benchmark_reranker_random_titles.py 386
  python scripts/benchmark_reranker_random_titles.py 40,80,100
  python scripts/benchmark_reranker_random_titles.py 40,80,100 --repeat 3 --seed 42
  RERANK_BASE=http://127.0.0.1:6007 python scripts/benchmark_reranker_random_titles.py 200
"""

from __future__ import annotations

import argparse
import json
import os
import random
import statistics
import sys
import time
from pathlib import Path
from typing import List, Optional, Tuple

import httpx


def _load_titles(path: Path) -> List[str]:
    lines: List[str] = []
    with path.open(encoding="utf-8", errors="replace") as f:
        for line in f:
            s = line.strip()
            if s:
                lines.append(s)
    return lines


def _parse_doc_counts(s: str) -> List[int]:
    parts = [p.strip() for p in s.split(",") if p.strip()]
    if not parts:
        raise ValueError("empty doc-count list")
    out: List[int] = []
    for p in parts:
        v = int(p, 10)
        if v <= 0:
            raise ValueError(f"doc count must be positive, got {v}")
        out.append(v)
    return out


def _do_rerank(
    client: httpx.Client,
    url: str,
    query: str,
    docs: List[str],
    *,
    top_n: int,
    normalize: bool,
) -> Tuple[bool, int, float, Optional[int], str]:
    payload: dict = {"query": query, "docs": docs, "normalize": normalize}
    if top_n > 0:
        payload["top_n"] = top_n
    body = json.dumps(payload, ensure_ascii=False)
    headers = {"Content-Type": "application/json"}
    t0 = time.perf_counter()
    try:
        resp = client.post(url, content=body.encode("utf-8"), headers=headers)
    except httpx.HTTPError:
        raise
    elapsed_ms = (time.perf_counter() - t0) * 1000.0
    text = resp.text or ""
    ok = resp.status_code == 200
    scores_len: Optional[int] = None
    if ok:
        try:
            data = resp.json()
            sc = data.get("scores")
            if isinstance(sc, list):
                scores_len = len(sc)
        except json.JSONDecodeError:
            scores_len = None
    return ok, resp.status_code, elapsed_ms, scores_len, text


def main() -> int:
    parser = argparse.ArgumentParser(
        description="POST /rerank with N random titles from a file and print latency."
    )
    parser.add_argument(
        "n",
        type=str,
        metavar="N[,N,...]",
        help="Doc counts: one integer or comma-separated list, e.g. 40,80,100.",
    )
    parser.add_argument(
        "--repeat",
        type=int,
        default=3,
        help="Number of runs per doc count (default: 3).",
    )
    parser.add_argument(
        "--titles-file",
        type=Path,
        default=Path(os.environ.get("RERANK_TITLE_FILE", "/home/ubuntu/rerank_test/titles.1.8w")),
        help="Path to newline-separated titles (default: %(default)s or env RERANK_TITLE_FILE).",
    )
    parser.add_argument(
        "--url",
        type=str,
        default=os.environ.get("RERANK_BASE", "http://127.0.0.1:6007").rstrip("/") + "/rerank",
        help="Full rerank URL (default: $RERANK_BASE/rerank or http://127.0.0.1:6007/rerank).",
    )
    parser.add_argument(
        "--query",
        type=str,
        default="健身女生T恤短袖",
        help="Rerank query string.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="RNG base seed; each (n, run) uses a derived seed when set (optional).",
    )
    parser.add_argument(
        "--top-n",
        type=int,
        default=0,
        help="If > 0, include top_n in JSON body (omit field when 0).",
    )
    parser.add_argument(
        "--no-normalize",
        action="store_true",
        help="Send normalize=false (default: normalize=true).",
    )
    parser.add_argument(
        "--timeout",
        type=float,
        default=float(os.environ.get("RERANK_TIMEOUT_SEC", "240")),
        help="HTTP timeout seconds.",
    )
    parser.add_argument(
        "--print-body-preview",
        action="store_true",
        help="Print first ~500 chars of response body on success (last run only).",
    )
    parser.add_argument(
        "--tag",
        type=str,
        default=os.environ.get("BENCH_TAG", ""),
        help="Optional label stored in --json-summary-out (default: env BENCH_TAG or empty).",
    )
    parser.add_argument(
        "--json-summary-out",
        type=Path,
        default=None,
        help="Write one JSON object with per-n latencies and aggregates for downstream tables.",
    )
    parser.add_argument(
        "--quiet-runs",
        action="store_true",
        help="Suppress per-run lines; still prints warmup lines and text summaries.",
    )
    args = parser.parse_args()

    try:
        doc_counts = _parse_doc_counts(args.n)
    except ValueError as exc:
        print(f"error: invalid N list {args.n!r}: {exc}", file=sys.stderr)
        return 2

    repeat = int(args.repeat)
    if repeat <= 0:
        print("error: --repeat must be positive", file=sys.stderr)
        return 2

    if not args.titles_file.is_file():
        print(f"error: titles file not found: {args.titles_file}", file=sys.stderr)
        return 2

    titles = _load_titles(args.titles_file)
    warmup_n = 400
    warmup_runs = 3
    max_n = max(max(doc_counts), warmup_n)
    if len(titles) < max_n:
        print(
            f"error: file has only {len(titles)} non-empty lines, need at least {max_n}",
            file=sys.stderr,
        )
        return 2

    top_n = int(args.top_n)
    normalize = not args.no_normalize
    any_fail = False
    summary: dict[int, List[float]] = {n: [] for n in doc_counts}

    with httpx.Client(timeout=args.timeout) as client:
        for w in range(warmup_runs):
            if args.seed is not None:
                random.seed(args.seed + 8_000_000 + w)
            docs_w = random.sample(titles, warmup_n)
            try:
                ok_w, status_w, _elapsed_w, scores_len_w, _text_w = _do_rerank(
                    client,
                    args.url,
                    args.query,
                    docs_w,
                    top_n=top_n,
                    normalize=normalize,
                )
            except httpx.HTTPError as exc:
                print(
                    f"warmup n={warmup_n} {w + 1}/{warmup_runs} error: request failed: {exc}",
                    file=sys.stderr,
                )
                any_fail = True
                continue
            if not ok_w:
                any_fail = True
            print(
                f"warmup n={warmup_n} {w + 1}/{warmup_runs} status={status_w} "
                f"scores={scores_len_w if scores_len_w is not None else 'n/a'} (not timed)"
            )

        for n in doc_counts:
            for run_idx in range(repeat):
                if args.seed is not None:
                    random.seed(args.seed + n * 10_000 + run_idx)
                docs = random.sample(titles, n)
                try:
                    ok, status, elapsed_ms, scores_len, text = _do_rerank(
                        client,
                        args.url,
                        args.query,
                        docs,
                        top_n=top_n,
                        normalize=normalize,
                    )
                except httpx.HTTPError as exc:
                    print(
                        f"n={n} run={run_idx + 1}/{repeat} error: request failed: {exc}",
                        file=sys.stderr,
                    )
                    any_fail = True
                    continue

                if ok:
                    summary[n].append(elapsed_ms)
                else:
                    any_fail = True

                if not args.quiet_runs:
                    print(
                        f"n={n} run={run_idx + 1}/{repeat} status={status} "
                        f"latency_ms={elapsed_ms:.2f} scores={scores_len if scores_len is not None else 'n/a'}"
                    )
                if args.print_body_preview and text and run_idx == repeat - 1 and n == doc_counts[-1]:
                    preview = text[:500] + ("…" if len(text) > 500 else "")
                    print(preview)

    for n in doc_counts:
        lat = summary[n]
        if not lat:
            print(f"summary n={n} runs=0 (all failed)")
            continue
        avg = statistics.mean(lat)
        lo, hi = min(lat), max(lat)
        extra = ""
        if len(lat) >= 2:
            extra = f" stdev_ms={statistics.stdev(lat):.2f}"
        print(
            f"summary n={n} runs={len(lat)} min_ms={lo:.2f} max_ms={hi:.2f} avg_ms={avg:.2f}{extra}"
        )

    if args.json_summary_out is not None:
        per_n: dict = {}
        for n in doc_counts:
            lat = summary[n]
            row: dict = {"values_ms": lat, "runs": len(lat)}
            if lat:
                row["mean_ms"] = statistics.mean(lat)
                row["min_ms"] = min(lat)
                row["max_ms"] = max(lat)
                if len(lat) >= 2:
                    row["stdev_ms"] = statistics.stdev(lat)
            per_n[str(n)] = row
        out_obj = {
            "tag": args.tag or None,
            "doc_counts": doc_counts,
            "repeat": repeat,
            "url": args.url,
            "per_n": per_n,
            "failed": bool(any_fail),
        }
        args.json_summary_out.parent.mkdir(parents=True, exist_ok=True)
        args.json_summary_out.write_text(
            json.dumps(out_obj, ensure_ascii=False, indent=2) + "\n",
            encoding="utf-8",
        )
        print(f"wrote json summary -> {args.json_summary_out}")

    return 1 if any_fail else 0


if __name__ == "__main__":
    raise SystemExit(main())