baidu_translate_benchmark.py 7.61 KB
import argparse
import json
import os
import random
import statistics
import time
from dataclasses import dataclass

import requests


def _now() -> float:
    return time.perf_counter()


def _mask(s: str, keep: int = 3) -> str:
    if not s:
        return ""
    if len(s) <= keep * 2:
        return "*" * len(s)
    return f"{s[:keep]}***{s[-keep:]}"


def get_access_token(api_key: str, secret_key: str, timeout_s: float) -> str:
    url = "https://aip.baidubce.com/oauth/2.0/token"
    params = {
        "grant_type": "client_credentials",
        "client_id": api_key,
        "client_secret": secret_key,
    }
    r = requests.post(url, params=params, timeout=timeout_s)
    r.raise_for_status()
    data = r.json()
    token = data.get("access_token")
    if not token:
        raise RuntimeError(f"no access_token in response: {data}")
    return str(token)


def translate_one(
    *,
    access_token: str,
    q: str,
    from_lang: str,
    to_lang: str,
    timeout_s: float,
) -> dict:
    url = (
        "https://aip.baidubce.com/rpc/2.0/mt/texttrans/v1"
        + "?access_token="
        + access_token
    )
    payload = json.dumps({"from": from_lang, "to": to_lang, "q": q}, ensure_ascii=False)
    headers = {"Content-Type": "application/json", "Accept": "application/json"}
    r = requests.post(url, headers=headers, data=payload.encode("utf-8"), timeout=timeout_s)
    r.raise_for_status()
    return r.json()


def random_english_samples(n: int, seed: int | None) -> list[str]:
    rng = random.Random(seed)

    starters = [
        "Hello",
        "Please",
        "Could you",
        "I wonder if you can",
        "Let's",
        "We should",
        "It's important to",
        "Don't forget to",
        "I need to",
        "Can we",
    ]
    verbs = [
        "check",
        "compare",
        "translate",
        "summarize",
        "review",
        "optimize",
        "measure",
        "confirm",
        "fix",
        "verify",
    ]
    objects = [
        "the latest order status",
        "this short sentence",
        "the product title and description",
        "our search results",
        "the API response payload",
        "the translation quality",
        "the page load time",
        "a few random examples",
        "the error message",
        "the performance metrics",
    ]
    endings = [
        "today.",
        "right now.",
        "before we ship.",
        "in the next minute.",
        "without changing the meaning.",
        "as soon as possible.",
        "for a quick smoke test.",
        "with a clear output.",
        "and report the timing.",
        "to ensure everything works.",
    ]

    samples: list[str] = []
    while len(samples) < n:
        s = f"{rng.choice(starters)} {rng.choice(verbs)} {rng.choice(objects)} {rng.choice(endings)}"
        # Add a little variation
        if rng.random() < 0.25:
            s = s.replace("the ", "our ", 1)
        if rng.random() < 0.20:
            s = s.replace(".", "!")
        samples.append(s)
    return samples


@dataclass(frozen=True)
class Timing:
    text: str
    seconds: float
    ok: bool
    error: str | None
    sample_out: str | None


def main() -> int:
    p = argparse.ArgumentParser(
        description="Benchmark Baidu texttrans latency with random samples."
    )
    p.add_argument("--n", type=int, default=15, help="Number of samples (10-20 suggested).")
    p.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.")
    p.add_argument("--from", dest="from_lang", default="en", help="Source language code.")
    p.add_argument("--to", dest="to_lang", default="zh", help="Target language code.")
    p.add_argument("--timeout", type=float, default=20.0, help="HTTP timeout seconds.")
    p.add_argument(
        "--print-first",
        type=int,
        default=5,
        help="Print first N translations for sanity check.",
    )
    args = p.parse_args()

    if args.n <= 0:
        raise SystemExit("--n must be positive")

    api_key = os.environ.get("BAIDU_API_KEY", "").strip()
    secret_key = os.environ.get("BAIDU_SECRET_KEY", "").strip()
    if not api_key or not secret_key:
        raise SystemExit(
            "Missing BAIDU_API_KEY / BAIDU_SECRET_KEY env vars. "
            "Example:\n"
            "  export BAIDU_API_KEY='...'\n"
            "  export BAIDU_SECRET_KEY='...'\n"
        )

    print("Baidu credentials:")
    print(f"  BAIDU_API_KEY={_mask(api_key)}")
    print(f"  BAIDU_SECRET_KEY={_mask(secret_key)}")
    print()

    t0 = _now()
    token = get_access_token(api_key, secret_key, args.timeout)
    token_s = _now() - t0
    print(f"Access token acquired in {token_s:.3f}s")
    print()

    samples = random_english_samples(args.n, args.seed)
    timings: list[Timing] = []

    for i, text in enumerate(samples, start=1):
        start = _now()
        try:
            out = translate_one(
                access_token=token,
                q=text,
                from_lang=args.from_lang,
                to_lang=args.to_lang,
                timeout_s=args.timeout,
            )
            elapsed = _now() - start

            # Try to extract a plausible translation field, but keep whole JSON as fallback.
            translated: str | None = None
            if isinstance(out, dict):
                result = out.get("result")
                if isinstance(result, dict):
                    translated = result.get("trans_result")
                if isinstance(translated, list) and translated:
                    # Baidu often returns list of dicts with dst/src
                    first = translated[0]
                    if isinstance(first, dict):
                        translated = first.get("dst")
            timings.append(
                Timing(
                    text=text,
                    seconds=elapsed,
                    ok=True,
                    error=None,
                    sample_out=translated if isinstance(translated, str) else None,
                )
            )
        except Exception as e:  # noqa: BLE001 - benchmark script, keep it simple
            elapsed = _now() - start
            timings.append(Timing(text=text, seconds=elapsed, ok=False, error=str(e), sample_out=None))

        print(f"[{i:02d}/{len(samples):02d}] {timings[-1].seconds:.3f}s {'OK' if timings[-1].ok else 'ERR'}")

    ok = [t for t in timings if t.ok]
    errs = [t for t in timings if not t.ok]
    latencies = [t.seconds for t in ok]

    print()
    print("=== Summary ===")
    print(f"token_time_s: {token_s:.3f}")
    print(f"requests_total: {len(timings)}")
    print(f"requests_ok: {len(ok)}")
    print(f"requests_err: {len(errs)}")
    if latencies:
        print(f"total_time_s: {sum(latencies):.3f}")
        print(f"avg_s: {statistics.mean(latencies):.3f}")
        print(f"p50_s: {statistics.median(latencies):.3f}")
        print(f"min_s: {min(latencies):.3f}")
        print(f"max_s: {max(latencies):.3f}")
    else:
        print("No successful requests; cannot compute latency stats.")

    if args.print_first > 0:
        print()
        print("=== Samples (first N) ===")
        shown = 0
        for t in ok:
            print(f"- IN:  {t.text}")
            if t.sample_out:
                print(f"  OUT: {t.sample_out}")
            else:
                print("  OUT: (could not parse translation field; see raw JSON by editing script if needed)")
            shown += 1
            if shown >= args.print_first:
                break

    if errs:
        print()
        print("=== Errors (first 3) ===")
        for t in errs[:3]:
            print(f"- {t.seconds:.3f}s: {t.error}")

    return 0 if not errs else 2


if __name__ == "__main__":
    raise SystemExit(main())