Blame view

scripts/download_translation_models.py 4.45 KB
0fd2f875   tangwang   translate
1
2
3
4
5
6
  #!/usr/bin/env python3
  """Download local translation models declared in services.translation.capabilities."""
  
  from __future__ import annotations
  
  import argparse
0fd2f875   tangwang   translate
7
  import os
ea293660   tangwang   CTranslate2
8
9
10
  from pathlib import Path
  import shutil
  import subprocess
0fd2f875   tangwang   translate
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
  import sys
  from typing import Iterable
  
  from huggingface_hub import snapshot_download
  
  PROJECT_ROOT = Path(__file__).resolve().parent.parent
  if str(PROJECT_ROOT) not in sys.path:
      sys.path.insert(0, str(PROJECT_ROOT))
  os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
  
  from config.services_config import get_translation_config
  
  
  LOCAL_BACKENDS = {"local_nllb", "local_marian"}
  
  
  def iter_local_capabilities(selected: set[str] | None = None) -> Iterable[tuple[str, dict]]:
      cfg = get_translation_config()
ea293660   tangwang   CTranslate2
29
30
      capabilities = cfg.get("capabilities", {}) if isinstance(cfg, dict) else {}
      for name, capability in capabilities.items():
0fd2f875   tangwang   translate
31
32
33
34
35
36
37
38
          backend = str(capability.get("backend") or "").strip().lower()
          if backend not in LOCAL_BACKENDS:
              continue
          if selected and name not in selected:
              continue
          yield name, capability
  
  
ea293660   tangwang   CTranslate2
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  def _compute_ct2_output_dir(capability: dict) -> Path:
      custom = str(capability.get("ct2_model_dir") or "").strip()
      if custom:
          return Path(custom).expanduser()
      model_dir = Path(str(capability.get("model_dir") or "")).expanduser()
      compute_type = str(capability.get("ct2_compute_type") or capability.get("torch_dtype") or "default").strip().lower()
      normalized = compute_type.replace("_", "-")
      return model_dir / f"ctranslate2-{normalized}"
  
  
  def _resolve_converter_binary() -> str:
      candidate = shutil.which("ct2-transformers-converter")
      if candidate:
          return candidate
      venv_candidate = Path(sys.executable).absolute().parent / "ct2-transformers-converter"
      if venv_candidate.exists():
          return str(venv_candidate)
      raise RuntimeError(
          "ct2-transformers-converter was not found. "
          "Install ctranslate2 in the active Python environment first."
      )
  
  
  def convert_to_ctranslate2(name: str, capability: dict) -> None:
      model_id = str(capability.get("model_id") or "").strip()
      model_dir = Path(str(capability.get("model_dir") or "")).expanduser()
      model_source = str(model_dir if model_dir.exists() else model_id)
      output_dir = _compute_ct2_output_dir(capability)
      if (output_dir / "model.bin").exists():
          print(f"[skip-convert] {name} -> {output_dir}")
          return
      quantization = str(
          capability.get("ct2_conversion_quantization")
          or capability.get("ct2_compute_type")
          or capability.get("torch_dtype")
          or "default"
      ).strip()
      output_dir.parent.mkdir(parents=True, exist_ok=True)
      print(f"[convert] {name} -> {output_dir} ({quantization})")
      subprocess.run(
          [
              _resolve_converter_binary(),
              "--model",
              model_source,
              "--output_dir",
              str(output_dir),
              "--quantization",
              quantization,
          ],
          check=True,
      )
      print(f"[converted] {name}")
  
  
0fd2f875   tangwang   translate
93
94
95
96
  def main() -> None:
      parser = argparse.ArgumentParser(description="Download local translation models")
      parser.add_argument("--all-local", action="store_true", help="Download all configured local translation models")
      parser.add_argument("--models", nargs="*", default=[], help="Specific capability names to download")
ea293660   tangwang   CTranslate2
97
98
99
100
101
      parser.add_argument(
          "--convert-ctranslate2",
          action="store_true",
          help="Also convert the downloaded Hugging Face models into CTranslate2 format",
      )
0fd2f875   tangwang   translate
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
      args = parser.parse_args()
  
      selected = {item.strip().lower() for item in args.models if item.strip()} or None
      if not args.all_local and not selected:
          parser.error("pass --all-local or --models <name> ...")
  
      for name, capability in iter_local_capabilities(selected):
          model_id = str(capability.get("model_id") or "").strip()
          model_dir = Path(str(capability.get("model_dir") or "")).expanduser()
          if not model_id or not model_dir:
              raise ValueError(f"Capability '{name}' must define model_id and model_dir")
          model_dir.parent.mkdir(parents=True, exist_ok=True)
          print(f"[download] {name} -> {model_dir} ({model_id})")
          snapshot_download(
              repo_id=model_id,
              local_dir=str(model_dir),
          )
          print(f"[done] {name}")
ea293660   tangwang   CTranslate2
120
121
          if args.convert_ctranslate2:
              convert_to_ctranslate2(name, capability)
0fd2f875   tangwang   translate
122
123
124
125
  
  
  if __name__ == "__main__":
      main()