download_translation_models.py
4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3
"""Download local translation models declared in services.translation.capabilities."""
from __future__ import annotations
import argparse
import os
from pathlib import Path
import shutil
import subprocess
import sys
from typing import Iterable
from huggingface_hub import snapshot_download
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
from config.services_config import get_translation_config
LOCAL_BACKENDS = {"local_nllb", "local_marian"}
def iter_local_capabilities(selected: set[str] | None = None) -> Iterable[tuple[str, dict]]:
cfg = get_translation_config()
capabilities = cfg.get("capabilities", {}) if isinstance(cfg, dict) else {}
for name, capability in capabilities.items():
backend = str(capability.get("backend") or "").strip().lower()
if backend not in LOCAL_BACKENDS:
continue
if selected and name not in selected:
continue
yield name, capability
def _compute_ct2_output_dir(capability: dict) -> Path:
custom = str(capability.get("ct2_model_dir") or "").strip()
if custom:
return Path(custom).expanduser()
model_dir = Path(str(capability.get("model_dir") or "")).expanduser()
compute_type = str(capability.get("ct2_compute_type") or capability.get("torch_dtype") or "default").strip().lower()
normalized = compute_type.replace("_", "-")
return model_dir / f"ctranslate2-{normalized}"
def _resolve_converter_binary() -> str:
candidate = shutil.which("ct2-transformers-converter")
if candidate:
return candidate
venv_candidate = Path(sys.executable).absolute().parent / "ct2-transformers-converter"
if venv_candidate.exists():
return str(venv_candidate)
raise RuntimeError(
"ct2-transformers-converter was not found. "
"Install ctranslate2 in the active Python environment first."
)
def convert_to_ctranslate2(name: str, capability: dict) -> None:
model_id = str(capability.get("model_id") or "").strip()
model_dir = Path(str(capability.get("model_dir") or "")).expanduser()
model_source = str(model_dir if model_dir.exists() else model_id)
output_dir = _compute_ct2_output_dir(capability)
if (output_dir / "model.bin").exists():
print(f"[skip-convert] {name} -> {output_dir}")
return
quantization = str(
capability.get("ct2_conversion_quantization")
or capability.get("ct2_compute_type")
or capability.get("torch_dtype")
or "default"
).strip()
output_dir.parent.mkdir(parents=True, exist_ok=True)
print(f"[convert] {name} -> {output_dir} ({quantization})")
subprocess.run(
[
_resolve_converter_binary(),
"--model",
model_source,
"--output_dir",
str(output_dir),
"--quantization",
quantization,
],
check=True,
)
print(f"[converted] {name}")
def main() -> None:
parser = argparse.ArgumentParser(description="Download local translation models")
parser.add_argument("--all-local", action="store_true", help="Download all configured local translation models")
parser.add_argument("--models", nargs="*", default=[], help="Specific capability names to download")
parser.add_argument(
"--convert-ctranslate2",
action="store_true",
help="Also convert the downloaded Hugging Face models into CTranslate2 format",
)
args = parser.parse_args()
selected = {item.strip().lower() for item in args.models if item.strip()} or None
if not args.all_local and not selected:
parser.error("pass --all-local or --models <name> ...")
for name, capability in iter_local_capabilities(selected):
model_id = str(capability.get("model_id") or "").strip()
model_dir = Path(str(capability.get("model_dir") or "")).expanduser()
if not model_id or not model_dir:
raise ValueError(f"Capability '{name}' must define model_id and model_dir")
model_dir.parent.mkdir(parents=True, exist_ok=True)
print(f"[download] {name} -> {model_dir} ({model_id})")
snapshot_download(
repo_id=model_id,
local_dir=str(model_dir),
)
print(f"[done] {name}")
if args.convert_ctranslate2:
convert_to_ctranslate2(name, capability)
if __name__ == "__main__":
main()