#!/usr/bin/env python3 """ Single-request rerank latency probe using real title lines (e.g. 1.8w export). Randomly samples N titles from a text file (one title per line), POSTs to the rerank HTTP API, prints wall-clock latency. Supports multiple N values (comma-separated) and multiple repeats per N. Example: source activate.sh python scripts/benchmark_reranker_random_titles.py 386 python scripts/benchmark_reranker_random_titles.py 40,80,100 python scripts/benchmark_reranker_random_titles.py 40,80,100 --repeat 3 --seed 42 RERANK_BASE=http://127.0.0.1:6007 python scripts/benchmark_reranker_random_titles.py 200 """ from __future__ import annotations import argparse import json import os import random import statistics import sys import time from pathlib import Path from typing import List, Optional, Tuple import httpx def _load_titles(path: Path) -> List[str]: lines: List[str] = [] with path.open(encoding="utf-8", errors="replace") as f: for line in f: s = line.strip() if s: lines.append(s) return lines def _parse_doc_counts(s: str) -> List[int]: parts = [p.strip() for p in s.split(",") if p.strip()] if not parts: raise ValueError("empty doc-count list") out: List[int] = [] for p in parts: v = int(p, 10) if v <= 0: raise ValueError(f"doc count must be positive, got {v}") out.append(v) return out def _do_rerank( client: httpx.Client, url: str, query: str, docs: List[str], *, top_n: int, normalize: bool, ) -> Tuple[bool, int, float, Optional[int], str]: payload: dict = {"query": query, "docs": docs, "normalize": normalize} if top_n > 0: payload["top_n"] = top_n body = json.dumps(payload, ensure_ascii=False) headers = {"Content-Type": "application/json"} t0 = time.perf_counter() try: resp = client.post(url, content=body.encode("utf-8"), headers=headers) except httpx.HTTPError: raise elapsed_ms = (time.perf_counter() - t0) * 1000.0 text = resp.text or "" ok = resp.status_code == 200 scores_len: Optional[int] = None if ok: try: data = resp.json() sc = data.get("scores") if isinstance(sc, list): scores_len = len(sc) except json.JSONDecodeError: scores_len = None return ok, resp.status_code, elapsed_ms, scores_len, text def main() -> int: parser = argparse.ArgumentParser( description="POST /rerank with N random titles from a file and print latency." ) parser.add_argument( "n", type=str, metavar="N[,N,...]", help="Doc counts: one integer or comma-separated list, e.g. 40,80,100.", ) parser.add_argument( "--repeat", type=int, default=3, help="Number of runs per doc count (default: 3).", ) parser.add_argument( "--titles-file", type=Path, default=Path(os.environ.get("RERANK_TITLE_FILE", "/home/ubuntu/rerank_test/titles.1.8w")), help="Path to newline-separated titles (default: %(default)s or env RERANK_TITLE_FILE).", ) parser.add_argument( "--url", type=str, default=os.environ.get("RERANK_BASE", "http://127.0.0.1:6007").rstrip("/") + "/rerank", help="Full rerank URL (default: $RERANK_BASE/rerank or http://127.0.0.1:6007/rerank).", ) parser.add_argument( "--query", type=str, default="健身女生T恤短袖", help="Rerank query string.", ) parser.add_argument( "--seed", type=int, default=None, help="RNG base seed; each (n, run) uses a derived seed when set (optional).", ) parser.add_argument( "--top-n", type=int, default=0, help="If > 0, include top_n in JSON body (omit field when 0).", ) parser.add_argument( "--no-normalize", action="store_true", help="Send normalize=false (default: normalize=true).", ) parser.add_argument( "--timeout", type=float, default=float(os.environ.get("RERANK_TIMEOUT_SEC", "240")), help="HTTP timeout seconds.", ) parser.add_argument( "--print-body-preview", action="store_true", help="Print first ~500 chars of response body on success (last run only).", ) args = parser.parse_args() try: doc_counts = _parse_doc_counts(args.n) except ValueError as exc: print(f"error: invalid N list {args.n!r}: {exc}", file=sys.stderr) return 2 repeat = int(args.repeat) if repeat <= 0: print("error: --repeat must be positive", file=sys.stderr) return 2 if not args.titles_file.is_file(): print(f"error: titles file not found: {args.titles_file}", file=sys.stderr) return 2 titles = _load_titles(args.titles_file) max_n = max(doc_counts) if len(titles) < max_n: print( f"error: file has only {len(titles)} non-empty lines, need at least {max_n}", file=sys.stderr, ) return 2 top_n = int(args.top_n) normalize = not args.no_normalize any_fail = False summary: dict[int, List[float]] = {n: [] for n in doc_counts} with httpx.Client(timeout=args.timeout) as client: for n in doc_counts: for run_idx in range(repeat): if args.seed is not None: random.seed(args.seed + n * 10_000 + run_idx) docs = random.sample(titles, n) try: ok, status, elapsed_ms, scores_len, text = _do_rerank( client, args.url, args.query, docs, top_n=top_n, normalize=normalize, ) except httpx.HTTPError as exc: print( f"n={n} run={run_idx + 1}/{repeat} error: request failed: {exc}", file=sys.stderr, ) any_fail = True continue if ok: summary[n].append(elapsed_ms) else: any_fail = True print( f"n={n} run={run_idx + 1}/{repeat} status={status} " f"latency_ms={elapsed_ms:.2f} scores={scores_len if scores_len is not None else 'n/a'}" ) if args.print_body_preview and text and run_idx == repeat - 1 and n == doc_counts[-1]: preview = text[:500] + ("…" if len(text) > 500 else "") print(preview) for n in doc_counts: lat = summary[n] if not lat: print(f"summary n={n} runs=0 (all failed)") continue avg = statistics.mean(lat) lo, hi = min(lat), max(lat) extra = "" if len(lat) >= 2: extra = f" stdev_ms={statistics.stdev(lat):.2f}" print( f"summary n={n} runs={len(lat)} min_ms={lo:.2f} max_ms={hi:.2f} avg_ms={avg:.2f}{extra}" ) return 1 if any_fail else 0 if __name__ == "__main__": raise SystemExit(main())