#!/usr/bin/env python3 """Benchmark a single long-text translation request for local models.""" from __future__ import annotations import argparse import copy import json import logging import statistics import time from pathlib import Path import torch PROJECT_ROOT = Path(__file__).resolve().parent.parent import sys if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from config.services_config import get_translation_config # noqa: E402 from translation.service import TranslationService # noqa: E402 from translation.text_splitter import compute_safe_input_token_limit # noqa: E402 def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Benchmark a long-text translation request") parser.add_argument("--model", default="nllb-200-distilled-600m") parser.add_argument("--source-lang", default="zh") parser.add_argument("--target-lang", default="en") parser.add_argument("--scene", default="sku_name") parser.add_argument("--source-md", default="docs/DEVELOPER_GUIDE.md") parser.add_argument("--paragraph-min-chars", type=int, default=250) parser.add_argument("--target-doc-chars", type=int, default=4500) parser.add_argument("--min-doc-chars", type=int, default=2400) parser.add_argument("--runs", type=int, default=3) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--ct2-inter-threads", type=int, default=4) parser.add_argument("--ct2-max-queued-batches", type=int, default=32) parser.add_argument("--ct2-batch-type", default="examples") parser.add_argument("--max-new-tokens", type=int, default=64) parser.add_argument("--ct2-decoding-length-mode", default="source") parser.add_argument("--ct2-decoding-length-extra", type=int, default=8) parser.add_argument("--ct2-decoding-length-min", type=int, default=32) return parser.parse_args() def build_long_document(args: argparse.Namespace) -> str: source_path = (PROJECT_ROOT / args.source_md).resolve() text = source_path.read_text(encoding="utf-8") paragraphs = [] for raw in text.split("\n\n"): normalized = " ".join(line.strip() for line in raw.splitlines() if line.strip()) if len(normalized) >= args.paragraph_min_chars and not normalized.startswith("```"): paragraphs.append(normalized) parts = [] total = 0 for paragraph in paragraphs: parts.append(paragraph) total += len(paragraph) + 2 if total >= args.target_doc_chars: break document = "\n\n".join(parts) if len(document) < args.min_doc_chars: raise ValueError( f"Prepared long document is too short: {len(document)} chars < {args.min_doc_chars}" ) return document def build_service(args: argparse.Namespace) -> TranslationService: config = copy.deepcopy(get_translation_config()) for name, capability in config["capabilities"].items(): capability["enabled"] = name == args.model capability = config["capabilities"][args.model] capability["use_cache"] = False capability["batch_size"] = args.batch_size capability["ct2_inter_threads"] = args.ct2_inter_threads capability["ct2_max_queued_batches"] = args.ct2_max_queued_batches capability["ct2_batch_type"] = args.ct2_batch_type capability["max_new_tokens"] = args.max_new_tokens capability["ct2_decoding_length_mode"] = args.ct2_decoding_length_mode capability["ct2_decoding_length_extra"] = args.ct2_decoding_length_extra capability["ct2_decoding_length_min"] = args.ct2_decoding_length_min config["default_model"] = args.model return TranslationService(config) def percentile(values: list[float], p: float) -> float: if not values: return 0.0 ordered = sorted(values) if len(ordered) == 1: return float(ordered[0]) index = min(len(ordered) - 1, max(0, round((len(ordered) - 1) * p))) return float(ordered[index]) def main() -> None: args = parse_args() logging.getLogger().setLevel(logging.WARNING) document = build_long_document(args) load_started = time.perf_counter() service = build_service(args) backend = service.get_backend(args.model) load_seconds = time.perf_counter() - load_started safe_input_limit = compute_safe_input_token_limit( max_input_length=backend.max_input_length, max_new_tokens=backend.max_new_tokens, decoding_length_mode=backend.ct2_decoding_length_mode, decoding_length_extra=backend.ct2_decoding_length_extra, ) segments = backend._split_text_if_needed( document, target_lang=args.target_lang, source_lang=args.source_lang, ) # Warm up once before measurements. _ = service.translate( document, source_lang=args.source_lang, target_lang=args.target_lang, model=args.model, scene=args.scene, ) if torch.cuda.is_available(): torch.cuda.synchronize() latencies_ms: list[float] = [] output_chars = 0 for _ in range(args.runs): started = time.perf_counter() output = service.translate( document, source_lang=args.source_lang, target_lang=args.target_lang, model=args.model, scene=args.scene, ) if torch.cuda.is_available(): torch.cuda.synchronize() latencies_ms.append((time.perf_counter() - started) * 1000) output_chars += len(output or "") total_seconds = sum(latencies_ms) / 1000.0 payload = { "model": args.model, "source_lang": args.source_lang, "target_lang": args.target_lang, "doc_chars": len(document), "runs": args.runs, "load_seconds": round(load_seconds, 3), "batch_size": backend.batch_size, "ct2_inter_threads": backend.ct2_inter_threads, "ct2_max_queued_batches": backend.ct2_max_queued_batches, "ct2_batch_type": backend.ct2_batch_type, "max_new_tokens": backend.max_new_tokens, "ct2_decoding_length_mode": backend.ct2_decoding_length_mode, "ct2_decoding_length_extra": backend.ct2_decoding_length_extra, "ct2_decoding_length_min": backend.ct2_decoding_length_min, "safe_input_limit": safe_input_limit, "segment_count": len(segments), "segment_char_lengths": { "min": min(len(segment) for segment in segments), "max": max(len(segment) for segment in segments), "avg": round(statistics.fmean(len(segment) for segment in segments), 1), }, "latency_avg_ms": round(statistics.fmean(latencies_ms), 2), "latency_p50_ms": round(percentile(latencies_ms, 0.50), 2), "latency_p95_ms": round(percentile(latencies_ms, 0.95), 2), "latency_max_ms": round(max(latencies_ms), 2), "input_chars_per_second": round((len(document) * args.runs) / total_seconds, 2), "output_chars_per_second": round(output_chars / total_seconds, 2), } print(json.dumps(payload, ensure_ascii=False)) if __name__ == "__main__": main()