benchmark_nllb_t4_tuning.py 11.8 KB
#!/usr/bin/env python3
"""Focused NLLB T4 tuning benchmark for product-name translation."""

from __future__ import annotations

import argparse
import copy
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Tuple

PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from config.services_config import get_translation_config
from scripts.benchmark_translation_local_models import (
    benchmark_concurrency_case,
    benchmark_serial_case,
    build_environment_info,
    ensure_cuda_stats_reset,
    load_texts,
)
from translation.service import TranslationService


SCENARIOS = [
    {
        "name": "nllb zh->en",
        "model": "nllb-200-distilled-600m",
        "source_lang": "zh",
        "target_lang": "en",
        "column": "title_cn",
        "scene": "sku_name",
    },
    {
        "name": "nllb en->zh",
        "model": "nllb-200-distilled-600m",
        "source_lang": "en",
        "target_lang": "zh",
        "column": "title",
        "scene": "sku_name",
    },
]

VARIANTS = [
    {
        "name": "ct2_default_fixed64",
        "description": "Original CT2 default",
        "overrides": {
            "ct2_inter_threads": 1,
            "ct2_max_queued_batches": 0,
            "ct2_batch_type": "examples",
            "max_new_tokens": 64,
        },
    },
    {
        "name": "ct2_prev_t4_fixed64",
        "description": "Previous T4 tuning result",
        "overrides": {
            "ct2_inter_threads": 2,
            "ct2_max_queued_batches": 16,
            "ct2_batch_type": "examples",
            "max_new_tokens": 64,
        },
    },
    {
        "name": "ct2_best_t4_dynamic",
        "description": "Recommended T4 profile after this round",
        "overrides": {
            "ct2_inter_threads": 4,
            "ct2_max_queued_batches": 32,
            "ct2_batch_type": "examples",
            "max_new_tokens": 64,
            "ct2_decoding_length_mode": "source",
            "ct2_decoding_length_extra": 8,
            "ct2_decoding_length_min": 32,
        },
    },
    {
        "name": "ct2_fixed48_experiment",
        "description": "High-gain experiment with truncation risk",
        "overrides": {
            "ct2_inter_threads": 3,
            "ct2_max_queued_batches": 16,
            "ct2_batch_type": "examples",
            "max_new_tokens": 48,
        },
    },
]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Focused NLLB T4 tuning benchmark")
    parser.add_argument("--csv-path", default="products_analyzed.csv", help="Benchmark dataset CSV path")
    parser.add_argument(
        "--output-dir",
        default="perf_reports/20260318/nllb_t4_product_names_ct2",
        help="Directory for JSON/Markdown reports",
    )
    parser.add_argument("--batch-size", type=int, default=64, help="Batch size for the bulk scenario")
    parser.add_argument("--batch-items", type=int, default=256, help="Rows used for the bulk scenario")
    parser.add_argument("--concurrency", type=int, default=64, help="Concurrency for the online scenario")
    parser.add_argument(
        "--requests-per-case",
        type=int,
        default=24,
        help="Requests per worker in the online scenario",
    )
    parser.add_argument("--quality-samples", type=int, default=100, help="Rows used for quality spot-checks")
    parser.add_argument("--warmup-batches", type=int, default=1, help="Warmup batches before measuring")
    return parser.parse_args()


def build_service(model: str, overrides: Dict[str, Any]) -> Tuple[TranslationService, Dict[str, Any]]:
    config = copy.deepcopy(get_translation_config())
    for name, cfg in config["capabilities"].items():
        cfg["enabled"] = name == model
        cfg["use_cache"] = False
    config["default_model"] = model
    capability = config["capabilities"][model]
    capability.update(overrides)
    return TranslationService(config), capability


def build_quality_reference_overrides(overrides: Dict[str, Any]) -> Dict[str, Any]:
    reference = dict(overrides)
    reference.pop("ct2_decoding_length_mode", None)
    reference.pop("ct2_decoding_length_extra", None)
    reference.pop("ct2_decoding_length_min", None)
    reference["max_new_tokens"] = max(64, int(reference.get("max_new_tokens", 64)))
    return reference


def summarize_quality(reference_outputs: List[Any], candidate_outputs: List[Any], texts: List[str]) -> Dict[str, Any]:
    same = 0
    diffs: List[Dict[str, str]] = []
    for text, ref_output, candidate_output in zip(texts, reference_outputs, candidate_outputs):
        if ref_output == candidate_output:
            same += 1
            continue
        if len(diffs) < 3:
            diffs.append(
                {
                    "input": text,
                    "candidate": "" if candidate_output is None else str(candidate_output),
                    "reference": "" if ref_output is None else str(ref_output),
                }
            )
    return {
        "same": same,
        "total": len(texts),
        "changed": len(texts) - same,
        "sample_diffs": diffs,
    }


def render_markdown(report: Dict[str, Any]) -> str:
    lines = [
        "# NLLB T4 Product-Name Tuning",
        "",
        f"- Generated at: `{report['generated_at']}`",
        f"- Python: `{report['environment']['python']}`",
        f"- Torch: `{report['environment']['torch']}`",
        f"- Transformers: `{report['environment']['transformers']}`",
        f"- CUDA: `{report['environment']['cuda_available']}`",
    ]
    if report["environment"]["gpu_name"]:
        lines.append(f"- GPU: `{report['environment']['gpu_name']}` ({report['environment']['gpu_total_mem_gb']} GiB)")
    lines.extend(
        [
            "",
            "## Scope",
            "",
            f"- Bulk scenario: `batch={report['config']['batch_size']}, concurrency=1`",
            f"- Online scenario: `batch=1, concurrency={report['config']['concurrency']}`",
            f"- Online requests per worker: `{report['config']['requests_per_case']}`",
            f"- Quality spot-check samples: `{report['config']['quality_samples']}`",
            "",
            "## Variants",
            "",
        ]
    )
    for variant in report["variants"]:
        lines.append(f"- `{variant['name']}`: {variant['description']} -> `{variant['overrides']}`")

    for scenario in report["scenarios"]:
        lines.extend(
            [
                "",
                f"## {scenario['name']}",
                "",
                "| Variant | Bulk items/s | Bulk p95 ms | Online items/s | Online p95 ms | Quality same/total |",
                "|---|---:|---:|---:|---:|---:|",
            ]
        )
        for variant in scenario["variants"]:
            quality = variant["quality_vs_reference"]
            lines.append(
                f"| {variant['name']} | {variant['bulk']['items_per_second']} | {variant['bulk']['request_latency_p95_ms']} | "
                f"{variant['online']['items_per_second']} | {variant['online']['request_latency_p95_ms']} | "
                f"{quality['same']}/{quality['total']} |"
            )
        for variant in scenario["variants"]:
            quality = variant["quality_vs_reference"]
            if not quality["sample_diffs"]:
                continue
            lines.extend(
                [
                    "",
                    f"### Quality Notes: {variant['name']}",
                    "",
                ]
            )
            for diff in quality["sample_diffs"]:
                lines.append(f"- Input: `{diff['input']}`")
                lines.append(f"- Candidate: `{diff['candidate']}`")
                lines.append(f"- Reference: `{diff['reference']}`")
                lines.append("")

    return "\n".join(lines).rstrip() + "\n"


def main() -> None:
    args = parse_args()
    csv_path = (PROJECT_ROOT / args.csv_path).resolve() if not Path(args.csv_path).is_absolute() else Path(args.csv_path)
    output_dir = (PROJECT_ROOT / args.output_dir).resolve() if not Path(args.output_dir).is_absolute() else Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    report: Dict[str, Any] = {
        "generated_at": datetime.now().isoformat(timespec="seconds"),
        "environment": build_environment_info(),
        "config": {
            "csv_path": str(csv_path),
            "batch_size": args.batch_size,
            "batch_items": args.batch_items,
            "concurrency": args.concurrency,
            "requests_per_case": args.requests_per_case,
            "quality_samples": args.quality_samples,
        },
        "variants": VARIANTS,
        "scenarios": [],
    }

    for scenario in SCENARIOS:
        batch_texts = load_texts(csv_path, scenario["column"], args.batch_items)
        online_texts = load_texts(csv_path, scenario["column"], args.concurrency * args.requests_per_case)
        quality_texts = load_texts(csv_path, scenario["column"], args.quality_samples)

        scenario_report = dict(scenario)
        scenario_report["variants"] = []
        for variant in VARIANTS:
            print(f"[start] {scenario['name']} | {variant['name']}", flush=True)
            ensure_cuda_stats_reset()
            service, capability = build_service(scenario["model"], variant["overrides"])
            backend = service.get_backend(scenario["model"])
            bulk = benchmark_serial_case(
                service=service,
                backend=backend,
                scenario=scenario,
                capability=capability,
                texts=batch_texts,
                batch_size=args.batch_size,
                warmup_batches=args.warmup_batches,
            )
            online = benchmark_concurrency_case(
                service=service,
                backend=backend,
                scenario=scenario,
                capability=capability,
                texts=online_texts,
                batch_size=1,
                concurrency=args.concurrency,
                requests_per_case=args.requests_per_case,
                warmup_batches=args.warmup_batches,
            )
            quality_reference_overrides = build_quality_reference_overrides(variant["overrides"])
            reference_service, _ = build_service(scenario["model"], quality_reference_overrides)
            reference_outputs = reference_service.translate(
                quality_texts,
                source_lang=scenario["source_lang"],
                target_lang=scenario["target_lang"],
                model=scenario["model"],
                scene=scenario["scene"],
            )
            candidate_outputs = service.translate(
                quality_texts,
                source_lang=scenario["source_lang"],
                target_lang=scenario["target_lang"],
                model=scenario["model"],
                scene=scenario["scene"],
            )
            scenario_report["variants"].append(
                {
                    "name": variant["name"],
                    "description": variant["description"],
                    "overrides": variant["overrides"],
                    "quality_reference_overrides": quality_reference_overrides,
                    "bulk": bulk,
                    "online": online,
                    "quality_vs_reference": summarize_quality(reference_outputs, candidate_outputs, quality_texts),
                }
            )
        report["scenarios"].append(scenario_report)

    timestamp = datetime.now().strftime("%H%M%S")
    json_path = output_dir / f"nllb_t4_tuning_{timestamp}.json"
    md_path = output_dir / f"nllb_t4_tuning_{timestamp}.md"
    json_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
    md_path.write_text(render_markdown(report), encoding="utf-8")
    print(f"JSON_REPORT={json_path}")
    print(f"MARKDOWN_REPORT={md_path}")


if __name__ == "__main__":
    main()