benchmark_translation_longtext_single.py 7.01 KB
#!/usr/bin/env python3
"""Benchmark a single long-text translation request for local models."""

from __future__ import annotations

import argparse
import copy
import json
import logging
import statistics
import time
from pathlib import Path

import torch

PROJECT_ROOT = Path(__file__).resolve().parent.parent

import sys

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from config.services_config import get_translation_config  # noqa: E402
from translation.service import TranslationService  # noqa: E402
from translation.text_splitter import compute_safe_input_token_limit  # noqa: E402


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Benchmark a long-text translation request")
    parser.add_argument("--model", default="nllb-200-distilled-600m")
    parser.add_argument("--source-lang", default="zh")
    parser.add_argument("--target-lang", default="en")
    parser.add_argument("--scene", default="sku_name")
    parser.add_argument("--source-md", default="docs/DEVELOPER_GUIDE.md")
    parser.add_argument("--paragraph-min-chars", type=int, default=250)
    parser.add_argument("--target-doc-chars", type=int, default=4500)
    parser.add_argument("--min-doc-chars", type=int, default=2400)
    parser.add_argument("--runs", type=int, default=3)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--ct2-inter-threads", type=int, default=4)
    parser.add_argument("--ct2-max-queued-batches", type=int, default=32)
    parser.add_argument("--ct2-batch-type", default="examples")
    parser.add_argument("--max-new-tokens", type=int, default=64)
    parser.add_argument("--ct2-decoding-length-mode", default="source")
    parser.add_argument("--ct2-decoding-length-extra", type=int, default=8)
    parser.add_argument("--ct2-decoding-length-min", type=int, default=32)
    return parser.parse_args()


def build_long_document(args: argparse.Namespace) -> str:
    source_path = (PROJECT_ROOT / args.source_md).resolve()
    text = source_path.read_text(encoding="utf-8")
    paragraphs = []
    for raw in text.split("\n\n"):
        normalized = " ".join(line.strip() for line in raw.splitlines() if line.strip())
        if len(normalized) >= args.paragraph_min_chars and not normalized.startswith("```"):
            paragraphs.append(normalized)

    parts = []
    total = 0
    for paragraph in paragraphs:
        parts.append(paragraph)
        total += len(paragraph) + 2
        if total >= args.target_doc_chars:
            break
    document = "\n\n".join(parts)
    if len(document) < args.min_doc_chars:
        raise ValueError(
            f"Prepared long document is too short: {len(document)} chars < {args.min_doc_chars}"
        )
    return document


def build_service(args: argparse.Namespace) -> TranslationService:
    config = copy.deepcopy(get_translation_config())
    for name, capability in config["capabilities"].items():
        capability["enabled"] = name == args.model

    capability = config["capabilities"][args.model]
    capability["use_cache"] = False
    capability["batch_size"] = args.batch_size
    capability["ct2_inter_threads"] = args.ct2_inter_threads
    capability["ct2_max_queued_batches"] = args.ct2_max_queued_batches
    capability["ct2_batch_type"] = args.ct2_batch_type
    capability["max_new_tokens"] = args.max_new_tokens
    capability["ct2_decoding_length_mode"] = args.ct2_decoding_length_mode
    capability["ct2_decoding_length_extra"] = args.ct2_decoding_length_extra
    capability["ct2_decoding_length_min"] = args.ct2_decoding_length_min
    config["default_model"] = args.model
    return TranslationService(config)


def percentile(values: list[float], p: float) -> float:
    if not values:
        return 0.0
    ordered = sorted(values)
    if len(ordered) == 1:
        return float(ordered[0])
    index = min(len(ordered) - 1, max(0, round((len(ordered) - 1) * p)))
    return float(ordered[index])


def main() -> None:
    args = parse_args()
    logging.getLogger().setLevel(logging.WARNING)

    document = build_long_document(args)
    load_started = time.perf_counter()
    service = build_service(args)
    backend = service.get_backend(args.model)
    load_seconds = time.perf_counter() - load_started

    safe_input_limit = compute_safe_input_token_limit(
        max_input_length=backend.max_input_length,
        max_new_tokens=backend.max_new_tokens,
        decoding_length_mode=backend.ct2_decoding_length_mode,
        decoding_length_extra=backend.ct2_decoding_length_extra,
    )
    segments = backend._split_text_if_needed(
        document,
        target_lang=args.target_lang,
        source_lang=args.source_lang,
    )

    # Warm up once before measurements.
    _ = service.translate(
        document,
        source_lang=args.source_lang,
        target_lang=args.target_lang,
        model=args.model,
        scene=args.scene,
    )
    if torch.cuda.is_available():
        torch.cuda.synchronize()

    latencies_ms: list[float] = []
    output_chars = 0
    for _ in range(args.runs):
        started = time.perf_counter()
        output = service.translate(
            document,
            source_lang=args.source_lang,
            target_lang=args.target_lang,
            model=args.model,
            scene=args.scene,
        )
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        latencies_ms.append((time.perf_counter() - started) * 1000)
        output_chars += len(output or "")

    total_seconds = sum(latencies_ms) / 1000.0
    payload = {
        "model": args.model,
        "source_lang": args.source_lang,
        "target_lang": args.target_lang,
        "doc_chars": len(document),
        "runs": args.runs,
        "load_seconds": round(load_seconds, 3),
        "batch_size": backend.batch_size,
        "ct2_inter_threads": backend.ct2_inter_threads,
        "ct2_max_queued_batches": backend.ct2_max_queued_batches,
        "ct2_batch_type": backend.ct2_batch_type,
        "max_new_tokens": backend.max_new_tokens,
        "ct2_decoding_length_mode": backend.ct2_decoding_length_mode,
        "ct2_decoding_length_extra": backend.ct2_decoding_length_extra,
        "ct2_decoding_length_min": backend.ct2_decoding_length_min,
        "safe_input_limit": safe_input_limit,
        "segment_count": len(segments),
        "segment_char_lengths": {
            "min": min(len(segment) for segment in segments),
            "max": max(len(segment) for segment in segments),
            "avg": round(statistics.fmean(len(segment) for segment in segments), 1),
        },
        "latency_avg_ms": round(statistics.fmean(latencies_ms), 2),
        "latency_p50_ms": round(percentile(latencies_ms, 0.50), 2),
        "latency_p95_ms": round(percentile(latencies_ms, 0.95), 2),
        "latency_max_ms": round(max(latencies_ms), 2),
        "input_chars_per_second": round((len(document) * args.runs) / total_seconds, 2),
        "output_chars_per_second": round(output_chars / total_seconds, 2),
    }
    print(json.dumps(payload, ensure_ascii=False))


if __name__ == "__main__":
    main()