download_translation_models.py 3.85 KB
#!/usr/bin/env python3
"""Download local translation models declared in services.translation.capabilities."""

from __future__ import annotations

import argparse
import os
from pathlib import Path
import sys
from typing import Iterable

from huggingface_hub import snapshot_download

PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")

from config.services_config import get_translation_config
from translation.ct2_conversion import convert_transformers_model


LOCAL_BACKENDS = {"local_nllb", "local_marian"}


def iter_local_capabilities(selected: set[str] | None = None) -> Iterable[tuple[str, dict]]:
    cfg = get_translation_config()
    capabilities = cfg.get("capabilities", {}) if isinstance(cfg, dict) else {}
    for name, capability in capabilities.items():
        backend = str(capability.get("backend") or "").strip().lower()
        if backend not in LOCAL_BACKENDS:
            continue
        if selected and name not in selected:
            continue
        yield name, capability


def _compute_ct2_output_dir(capability: dict) -> Path:
    custom = str(capability.get("ct2_model_dir") or "").strip()
    if custom:
        return Path(custom).expanduser()
    model_dir = Path(str(capability.get("model_dir") or "")).expanduser()
    compute_type = str(capability.get("ct2_compute_type") or capability.get("torch_dtype") or "default").strip().lower()
    normalized = compute_type.replace("_", "-")
    return model_dir / f"ctranslate2-{normalized}"


def convert_to_ctranslate2(name: str, capability: dict) -> None:
    model_id = str(capability.get("model_id") or "").strip()
    model_dir = Path(str(capability.get("model_dir") or "")).expanduser()
    model_source = str(model_dir if model_dir.exists() else model_id)
    output_dir = _compute_ct2_output_dir(capability)
    if (output_dir / "model.bin").exists():
        print(f"[skip-convert] {name} -> {output_dir}")
        return
    quantization = str(
        capability.get("ct2_conversion_quantization")
        or capability.get("ct2_compute_type")
        or capability.get("torch_dtype")
        or "default"
    ).strip()
    output_dir.parent.mkdir(parents=True, exist_ok=True)
    print(f"[convert] {name} -> {output_dir} ({quantization})")
    convert_transformers_model(model_source, str(output_dir), quantization)
    print(f"[converted] {name}")


def main() -> None:
    parser = argparse.ArgumentParser(description="Download local translation models")
    parser.add_argument("--all-local", action="store_true", help="Download all configured local translation models")
    parser.add_argument("--models", nargs="*", default=[], help="Specific capability names to download")
    parser.add_argument(
        "--convert-ctranslate2",
        action="store_true",
        help="Also convert the downloaded Hugging Face models into CTranslate2 format",
    )
    args = parser.parse_args()

    selected = {item.strip().lower() for item in args.models if item.strip()} or None
    if not args.all_local and not selected:
        parser.error("pass --all-local or --models <name> ...")

    for name, capability in iter_local_capabilities(selected):
        model_id = str(capability.get("model_id") or "").strip()
        model_dir = Path(str(capability.get("model_dir") or "")).expanduser()
        if not model_id or not model_dir:
            raise ValueError(f"Capability '{name}' must define model_id and model_dir")
        model_dir.parent.mkdir(parents=True, exist_ok=True)
        print(f"[download] {name} -> {model_dir} ({model_id})")
        snapshot_download(
            repo_id=model_id,
            local_dir=str(model_dir),
        )
        print(f"[done] {name}")
        if args.convert_ctranslate2:
            convert_to_ctranslate2(name, capability)


if __name__ == "__main__":
    main()