Blame view

scripts/translation/download_translation_models.py 3.85 KB
0fd2f875   tangwang   translate
1
2
3
4
5
6
  #!/usr/bin/env python3
  """Download local translation models declared in services.translation.capabilities."""
  
  from __future__ import annotations
  
  import argparse
0fd2f875   tangwang   translate
7
  import os
ea293660   tangwang   CTranslate2
8
  from pathlib import Path
0fd2f875   tangwang   translate
9
10
11
12
13
  import sys
  from typing import Iterable
  
  from huggingface_hub import snapshot_download
  
32e9b30c   tangwang   scripts/ 根目录主要保留启...
14
  PROJECT_ROOT = Path(__file__).resolve().parents[2]
0fd2f875   tangwang   translate
15
16
17
18
19
  if str(PROJECT_ROOT) not in sys.path:
      sys.path.insert(0, str(PROJECT_ROOT))
  os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
  
  from config.services_config import get_translation_config
f07947a5   tangwang   Improve portabili...
20
  from translation.ct2_conversion import convert_transformers_model
0fd2f875   tangwang   translate
21
22
23
24
25
26
27
  
  
  LOCAL_BACKENDS = {"local_nllb", "local_marian"}
  
  
  def iter_local_capabilities(selected: set[str] | None = None) -> Iterable[tuple[str, dict]]:
      cfg = get_translation_config()
ea293660   tangwang   CTranslate2
28
29
      capabilities = cfg.get("capabilities", {}) if isinstance(cfg, dict) else {}
      for name, capability in capabilities.items():
0fd2f875   tangwang   translate
30
31
32
33
34
35
36
37
          backend = str(capability.get("backend") or "").strip().lower()
          if backend not in LOCAL_BACKENDS:
              continue
          if selected and name not in selected:
              continue
          yield name, capability
  
  
ea293660   tangwang   CTranslate2
38
39
40
41
42
43
44
45
46
47
  def _compute_ct2_output_dir(capability: dict) -> Path:
      custom = str(capability.get("ct2_model_dir") or "").strip()
      if custom:
          return Path(custom).expanduser()
      model_dir = Path(str(capability.get("model_dir") or "")).expanduser()
      compute_type = str(capability.get("ct2_compute_type") or capability.get("torch_dtype") or "default").strip().lower()
      normalized = compute_type.replace("_", "-")
      return model_dir / f"ctranslate2-{normalized}"
  
  
ea293660   tangwang   CTranslate2
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
  def convert_to_ctranslate2(name: str, capability: dict) -> None:
      model_id = str(capability.get("model_id") or "").strip()
      model_dir = Path(str(capability.get("model_dir") or "")).expanduser()
      model_source = str(model_dir if model_dir.exists() else model_id)
      output_dir = _compute_ct2_output_dir(capability)
      if (output_dir / "model.bin").exists():
          print(f"[skip-convert] {name} -> {output_dir}")
          return
      quantization = str(
          capability.get("ct2_conversion_quantization")
          or capability.get("ct2_compute_type")
          or capability.get("torch_dtype")
          or "default"
      ).strip()
      output_dir.parent.mkdir(parents=True, exist_ok=True)
      print(f"[convert] {name} -> {output_dir} ({quantization})")
f07947a5   tangwang   Improve portabili...
64
      convert_transformers_model(model_source, str(output_dir), quantization)
ea293660   tangwang   CTranslate2
65
66
67
      print(f"[converted] {name}")
  
  
0fd2f875   tangwang   translate
68
69
70
71
  def main() -> None:
      parser = argparse.ArgumentParser(description="Download local translation models")
      parser.add_argument("--all-local", action="store_true", help="Download all configured local translation models")
      parser.add_argument("--models", nargs="*", default=[], help="Specific capability names to download")
ea293660   tangwang   CTranslate2
72
73
74
75
76
      parser.add_argument(
          "--convert-ctranslate2",
          action="store_true",
          help="Also convert the downloaded Hugging Face models into CTranslate2 format",
      )
0fd2f875   tangwang   translate
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
      args = parser.parse_args()
  
      selected = {item.strip().lower() for item in args.models if item.strip()} or None
      if not args.all_local and not selected:
          parser.error("pass --all-local or --models <name> ...")
  
      for name, capability in iter_local_capabilities(selected):
          model_id = str(capability.get("model_id") or "").strip()
          model_dir = Path(str(capability.get("model_dir") or "")).expanduser()
          if not model_id or not model_dir:
              raise ValueError(f"Capability '{name}' must define model_id and model_dir")
          model_dir.parent.mkdir(parents=True, exist_ok=True)
          print(f"[download] {name} -> {model_dir} ({model_id})")
          snapshot_download(
              repo_id=model_id,
              local_dir=str(model_dir),
          )
          print(f"[done] {name}")
ea293660   tangwang   CTranslate2
95
96
          if args.convert_ctranslate2:
              convert_to_ctranslate2(name, capability)
0fd2f875   tangwang   translate
97
98
99
100
  
  
  if __name__ == "__main__":
      main()