benchmark_translation_local_models_focus.py 9.17 KB
#!/usr/bin/env python3
"""Focused translation benchmark for two stress scenarios on local CT2 models."""

from __future__ import annotations

import argparse
import copy
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List

PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from config.services_config import get_translation_config
from scripts.benchmark_translation_local_models import (
    SCENARIOS,
    benchmark_concurrency_case,
    benchmark_serial_case,
    build_environment_info,
    ensure_cuda_stats_reset,
    load_texts,
)
from translation.service import TranslationService

DEFAULT_HIGH_BATCH_SIZES = [32, 64, 128]
DEFAULT_HIGH_CONCURRENCIES = [8, 16, 32, 64]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Focused benchmark for local CT2 translation models")
    parser.add_argument("--csv-path", default="products_analyzed.csv", help="Benchmark dataset CSV path")
    parser.add_argument(
        "--output-dir",
        default="perf_reports/20260318/translation_local_models_ct2_focus",
        help="Directory for JSON/Markdown focused reports",
    )
    parser.add_argument(
        "--high-batch-sizes",
        default="32,64,128",
        help="Comma-separated batch sizes for the high-batch/low-concurrency scenario",
    )
    parser.add_argument(
        "--high-concurrencies",
        default="8,16,32,64",
        help="Comma-separated concurrency levels for the high-concurrency/low-batch scenario",
    )
    parser.add_argument(
        "--high-batch-rows",
        type=int,
        default=512,
        help="Rows used for the high-batch/low-concurrency scenario",
    )
    parser.add_argument(
        "--high-concurrency-requests",
        type=int,
        default=32,
        help="Requests per high-concurrency/low-batch case",
    )
    parser.add_argument("--warmup-batches", type=int, default=1, help="Warmup batches before measuring")
    return parser.parse_args()


def parse_csv_ints(raw: str) -> List[int]:
    values: List[int] = []
    for item in raw.split(","):
        stripped = item.strip()
        if not stripped:
            continue
        value = int(stripped)
        if value <= 0:
            raise ValueError(f"Expected positive integer, got {value}")
        values.append(value)
    if not values:
        raise ValueError("Parsed empty integer list")
    return values


def build_variant_config(model: str, overrides: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str, Any]]:
    config = copy.deepcopy(get_translation_config())
    for name, cfg in config["capabilities"].items():
        cfg["enabled"] = name == model
        cfg["use_cache"] = False
    config["default_model"] = model
    capability = config["capabilities"][model]
    capability.update(overrides)
    config["capabilities"][model] = capability
    return config, capability


def render_markdown(report: Dict[str, Any]) -> str:
    lines = [
        "# Local Translation Model Focused Benchmark",
        "",
        f"- Generated at: `{report['generated_at']}`",
        f"- Python: `{report['environment']['python']}`",
        f"- Torch: `{report['environment']['torch']}`",
        f"- Transformers: `{report['environment']['transformers']}`",
        f"- CUDA: `{report['environment']['cuda_available']}`",
    ]
    if report["environment"]["gpu_name"]:
        lines.append(f"- GPU: `{report['environment']['gpu_name']}` ({report['environment']['gpu_total_mem_gb']} GiB)")
    lines.extend(
        [
            "",
            "## Scope",
            "",
            "- Scenario 1: high batch size + low concurrency",
            "- Scenario 2: high concurrency + low batch size",
            "- Variants in this report:",
        ]
    )
    for variant in report["variants"]:
        lines.append(f"  - `{variant['name']}`: `{variant['overrides']}`")

    for scenario in report["scenarios"]:
        lines.extend(
            [
                "",
                f"## {scenario['name']}",
                "",
                f"- Direction: `{scenario['source_lang']} -> {scenario['target_lang']}`",
                f"- Column: `{scenario['column']}`",
            ]
        )
        for variant in scenario["variants"]:
            lines.extend(
                [
                    "",
                    f"### Variant `{variant['name']}`",
                    "",
                    "| Scenario | Setting | Items/s | Req p95 ms | Avg req ms |",
                    "|---|---|---:|---:|---:|",
                ]
            )
            for row in variant["high_batch_low_concurrency"]:
                lines.append(
                    f"| high-batch/low-concurrency | batch={row['batch_size']}, concurrency=1 | "
                    f"{row['items_per_second']} | {row['request_latency_p95_ms']} | {row['avg_request_latency_ms']} |"
                )
            for row in variant["high_concurrency_low_batch"]:
                lines.append(
                    f"| high-concurrency/low-batch | batch=1, concurrency={row['concurrency']} | "
                    f"{row['items_per_second']} | {row['request_latency_p95_ms']} | {row['avg_request_latency_ms']} |"
                )
    return "\n".join(lines) + "\n"


def main() -> None:
    args = parse_args()
    csv_path = (PROJECT_ROOT / args.csv_path).resolve() if not Path(args.csv_path).is_absolute() else Path(args.csv_path)
    output_dir = (PROJECT_ROOT / args.output_dir).resolve() if not Path(args.output_dir).is_absolute() else Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    high_batch_sizes = parse_csv_ints(args.high_batch_sizes)
    high_concurrencies = parse_csv_ints(args.high_concurrencies)

    variants = [
        {"name": "ct2_default", "overrides": {}},
        {
            "name": "ct2_tuned_t4",
            "overrides": {
                "ct2_inter_threads": 2,
                "ct2_max_queued_batches": 16,
                "ct2_batch_type": "examples",
            },
        },
    ]

    report: Dict[str, Any] = {
        "generated_at": datetime.now().isoformat(timespec="seconds"),
        "environment": build_environment_info(),
        "csv_path": str(csv_path),
        "variants": variants,
        "scenarios": [],
    }

    largest_batch = max(high_batch_sizes)
    high_batch_rows = max(args.high_batch_rows, largest_batch)

    for scenario in SCENARIOS:
        scenario_entry = dict(scenario)
        scenario_entry["variants"] = []
        batch_texts = load_texts(csv_path, scenario["column"], high_batch_rows)
        conc_needed = max(high_concurrencies) * args.high_concurrency_requests
        conc_texts = load_texts(csv_path, scenario["column"], conc_needed)

        for variant in variants:
            print(f"[start] {scenario['name']} | {variant['name']}", flush=True)
            config, capability = build_variant_config(scenario["model"], variant["overrides"])
            ensure_cuda_stats_reset()
            service = TranslationService(config)
            backend = service.get_backend(scenario["model"])

            high_batch_results = []
            for batch_size in high_batch_sizes:
                high_batch_results.append(
                    benchmark_serial_case(
                        service=service,
                        backend=backend,
                        scenario=scenario,
                        capability=capability,
                        texts=batch_texts[: max(batch_size, high_batch_rows)],
                        batch_size=batch_size,
                        warmup_batches=args.warmup_batches,
                    )
                )

            high_concurrency_results = []
            for concurrency in high_concurrencies:
                high_concurrency_results.append(
                    benchmark_concurrency_case(
                        service=service,
                        backend=backend,
                        scenario=scenario,
                        capability=capability,
                        texts=conc_texts,
                        batch_size=1,
                        concurrency=concurrency,
                        requests_per_case=args.high_concurrency_requests,
                        warmup_batches=args.warmup_batches,
                    )
                )

            scenario_entry["variants"].append(
                {
                    "name": variant["name"],
                    "overrides": variant["overrides"],
                    "high_batch_low_concurrency": high_batch_results,
                    "high_concurrency_low_batch": high_concurrency_results,
                }
            )
            print(f"[done] {scenario['name']} | {variant['name']}", flush=True)

        report["scenarios"].append(scenario_entry)

    stamp = datetime.now().strftime("%H%M%S")
    json_path = output_dir / f"translation_local_models_focus_{stamp}.json"
    md_path = output_dir / f"translation_local_models_focus_{stamp}.md"
    json_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
    md_path.write_text(render_markdown(report), encoding="utf-8")
    print(f"JSON report: {json_path}")
    print(f"Markdown report: {md_path}")


if __name__ == "__main__":
    main()