Commit ea293660199e758d1c9ade4a4f8b5444d94b4fec

Authored by tangwang
1 parent 2a6d9d76

CTranslate2

Implemented CTranslate2 for the three local translation models and
switched the existing local_nllb / local_marian factories over to it.
The new runtime lives in local_ctranslate2.py, including HF->CT2
auto-conversion, float16 compute type mapping, Marian direction
handling, and NLLB target-prefix decoding. The service wiring is in
service.py (line 113), and the three model configs now point at explicit
ctranslate2-float16 dirs in config.yaml (line 133).

I also updated the setup path so this is usable end-to-end:
ctranslate2>=4.7.0 was added to requirements_translator_service.txt and
requirements.txt, the download script now supports pre-conversion in
download_translation_models.py (line 27), and the docs/config examples
were refreshed in translation/README.md. I installed ctranslate2 into
.venv-translator, pre-converted all three models, and the CT2 artifacts
are now already on disk:

models/translation/facebook/nllb-200-distilled-600M/ctranslate2-float16
models/translation/Helsinki-NLP/opus-mt-zh-en/ctranslate2-float16
models/translation/Helsinki-NLP/opus-mt-en-zh/ctranslate2-float16
Verification was solid. python3 -m compileall passed, direct
TranslationService smoke tests ran successfully in .venv-translator, and
the focused NLLB benchmark on the local GPU showed a clear win:

batch_size=16: HF 0.347s/batch, 46.1 items/s vs CT2 0.130s/batch, 123.0
items/s
batch_size=1: HF 0.396s/request vs CT2 0.126s/request
One caveat: translation quality on some very short phrases, especially
opus-mt-en-zh, still looks a bit rough in smoke tests, so I’d run your
real quality set before fully cutting over. If you want, I can take the
next step and update the benchmark script/report so you have a fresh
full CT2 performance report for all three models.
config/config.yaml
... ... @@ -135,19 +135,34 @@ services:
135 135 backend: "local_nllb"
136 136 model_id: "facebook/nllb-200-distilled-600M"
137 137 model_dir: "./models/translation/facebook/nllb-200-distilled-600M"
  138 + ct2_model_dir: "./models/translation/facebook/nllb-200-distilled-600M/ctranslate2-float16"
  139 + ct2_compute_type: "float16"
  140 + ct2_conversion_quantization: "float16"
  141 + ct2_auto_convert: true
  142 + ct2_inter_threads: 1
  143 + ct2_intra_threads: 0
  144 + ct2_max_queued_batches: 0
  145 + ct2_batch_type: "examples"
138 146 device: "cuda"
139 147 torch_dtype: "float16"
140 148 batch_size: 16
141 149 max_input_length: 256
142 150 max_new_tokens: 64
143 151 num_beams: 1
144   - attn_implementation: "sdpa"
145 152 use_cache: true
146 153 opus-mt-zh-en:
147 154 enabled: true
148 155 backend: "local_marian"
149 156 model_id: "Helsinki-NLP/opus-mt-zh-en"
150 157 model_dir: "./models/translation/Helsinki-NLP/opus-mt-zh-en"
  158 + ct2_model_dir: "./models/translation/Helsinki-NLP/opus-mt-zh-en/ctranslate2-float16"
  159 + ct2_compute_type: "float16"
  160 + ct2_conversion_quantization: "float16"
  161 + ct2_auto_convert: true
  162 + ct2_inter_threads: 1
  163 + ct2_intra_threads: 0
  164 + ct2_max_queued_batches: 0
  165 + ct2_batch_type: "examples"
151 166 device: "cuda"
152 167 torch_dtype: "float16"
153 168 batch_size: 16
... ... @@ -160,6 +175,14 @@ services:
160 175 backend: "local_marian"
161 176 model_id: "Helsinki-NLP/opus-mt-en-zh"
162 177 model_dir: "./models/translation/Helsinki-NLP/opus-mt-en-zh"
  178 + ct2_model_dir: "./models/translation/Helsinki-NLP/opus-mt-en-zh/ctranslate2-float16"
  179 + ct2_compute_type: "float16"
  180 + ct2_conversion_quantization: "float16"
  181 + ct2_auto_convert: true
  182 + ct2_inter_threads: 1
  183 + ct2_intra_threads: 0
  184 + ct2_max_queued_batches: 0
  185 + ct2_batch_type: "examples"
163 186 device: "cuda"
164 187 torch_dtype: "float16"
165 188 batch_size: 16
... ...
requirements.txt
... ... @@ -32,6 +32,7 @@ anyio>=3.7.0
32 32  
33 33 # Translation
34 34 requests>=2.31.0
  35 +ctranslate2>=4.7.0
35 36  
36 37 # Utilities
37 38 tqdm>=4.65.0
... ...
requirements_translator_service.txt
... ... @@ -14,6 +14,7 @@ tqdm>=4.65.0
14 14  
15 15 torch>=2.0.0
16 16 transformers>=4.30.0
  17 +ctranslate2>=4.7.0
17 18 sentencepiece>=0.2.0
18 19 sacremoses>=0.1.1
19 20 safetensors>=0.4.0
... ...
scripts/download_translation_models.py
... ... @@ -4,8 +4,10 @@
4 4 from __future__ import annotations
5 5  
6 6 import argparse
7   -from pathlib import Path
8 7 import os
  8 +from pathlib import Path
  9 +import shutil
  10 +import subprocess
9 11 import sys
10 12 from typing import Iterable
11 13  
... ... @@ -24,7 +26,8 @@ LOCAL_BACKENDS = {"local_nllb", "local_marian"}
24 26  
25 27 def iter_local_capabilities(selected: set[str] | None = None) -> Iterable[tuple[str, dict]]:
26 28 cfg = get_translation_config()
27   - for name, capability in cfg.capabilities.items():
  29 + capabilities = cfg.get("capabilities", {}) if isinstance(cfg, dict) else {}
  30 + for name, capability in capabilities.items():
28 31 backend = str(capability.get("backend") or "").strip().lower()
29 32 if backend not in LOCAL_BACKENDS:
30 33 continue
... ... @@ -33,10 +36,69 @@ def iter_local_capabilities(selected: set[str] | None = None) -> Iterable[tuple[
33 36 yield name, capability
34 37  
35 38  
  39 +def _compute_ct2_output_dir(capability: dict) -> Path:
  40 + custom = str(capability.get("ct2_model_dir") or "").strip()
  41 + if custom:
  42 + return Path(custom).expanduser()
  43 + model_dir = Path(str(capability.get("model_dir") or "")).expanduser()
  44 + compute_type = str(capability.get("ct2_compute_type") or capability.get("torch_dtype") or "default").strip().lower()
  45 + normalized = compute_type.replace("_", "-")
  46 + return model_dir / f"ctranslate2-{normalized}"
  47 +
  48 +
  49 +def _resolve_converter_binary() -> str:
  50 + candidate = shutil.which("ct2-transformers-converter")
  51 + if candidate:
  52 + return candidate
  53 + venv_candidate = Path(sys.executable).absolute().parent / "ct2-transformers-converter"
  54 + if venv_candidate.exists():
  55 + return str(venv_candidate)
  56 + raise RuntimeError(
  57 + "ct2-transformers-converter was not found. "
  58 + "Install ctranslate2 in the active Python environment first."
  59 + )
  60 +
  61 +
  62 +def convert_to_ctranslate2(name: str, capability: dict) -> None:
  63 + model_id = str(capability.get("model_id") or "").strip()
  64 + model_dir = Path(str(capability.get("model_dir") or "")).expanduser()
  65 + model_source = str(model_dir if model_dir.exists() else model_id)
  66 + output_dir = _compute_ct2_output_dir(capability)
  67 + if (output_dir / "model.bin").exists():
  68 + print(f"[skip-convert] {name} -> {output_dir}")
  69 + return
  70 + quantization = str(
  71 + capability.get("ct2_conversion_quantization")
  72 + or capability.get("ct2_compute_type")
  73 + or capability.get("torch_dtype")
  74 + or "default"
  75 + ).strip()
  76 + output_dir.parent.mkdir(parents=True, exist_ok=True)
  77 + print(f"[convert] {name} -> {output_dir} ({quantization})")
  78 + subprocess.run(
  79 + [
  80 + _resolve_converter_binary(),
  81 + "--model",
  82 + model_source,
  83 + "--output_dir",
  84 + str(output_dir),
  85 + "--quantization",
  86 + quantization,
  87 + ],
  88 + check=True,
  89 + )
  90 + print(f"[converted] {name}")
  91 +
  92 +
36 93 def main() -> None:
37 94 parser = argparse.ArgumentParser(description="Download local translation models")
38 95 parser.add_argument("--all-local", action="store_true", help="Download all configured local translation models")
39 96 parser.add_argument("--models", nargs="*", default=[], help="Specific capability names to download")
  97 + parser.add_argument(
  98 + "--convert-ctranslate2",
  99 + action="store_true",
  100 + help="Also convert the downloaded Hugging Face models into CTranslate2 format",
  101 + )
40 102 args = parser.parse_args()
41 103  
42 104 selected = {item.strip().lower() for item in args.models if item.strip()} or None
... ... @@ -55,6 +117,8 @@ def main() -> None:
55 117 local_dir=str(model_dir),
56 118 )
57 119 print(f"[done] {name}")
  120 + if args.convert_ctranslate2:
  121 + convert_to_ctranslate2(name, capability)
58 122  
59 123  
60 124 if __name__ == "__main__":
... ...
translation/README.md
... ... @@ -56,8 +56,8 @@
56 56 通用 LLM 翻译
57 57 - [`translation/backends/deepl.py`](/data/saas-search/translation/backends/deepl.py)
58 58 DeepL 翻译
59   -- [`translation/backends/local_seq2seq.py`](/data/saas-search/translation/backends/local_seq2seq.py)
60   - 本地 Hugging Face Seq2Seq 模型,包括 NLLB 和 Marian/OPUS MT
  59 +- [`translation/backends/local_ctranslate2.py`](/data/saas-search/translation/backends/local_ctranslate2.py)
  60 + 本地 CTranslate2 翻译模型,包括 NLLB 和 Marian/OPUS MT
61 61  
62 62 ## 3. 配置约定
63 63  
... ... @@ -103,19 +103,26 @@ services:
103 103 backend: "local_nllb"
104 104 model_id: "facebook/nllb-200-distilled-600M"
105 105 model_dir: "./models/translation/facebook/nllb-200-distilled-600M"
  106 + ct2_model_dir: "./models/translation/facebook/nllb-200-distilled-600M/ctranslate2-float16"
  107 + ct2_compute_type: "float16"
  108 + ct2_conversion_quantization: "float16"
  109 + ct2_auto_convert: true
106 110 device: "cuda"
107 111 torch_dtype: "float16"
108 112 batch_size: 16
109 113 max_input_length: 256
110 114 max_new_tokens: 64
111 115 num_beams: 1
112   - attn_implementation: "sdpa"
113 116 use_cache: true
114 117 opus-mt-zh-en:
115 118 enabled: true
116 119 backend: "local_marian"
117 120 model_id: "Helsinki-NLP/opus-mt-zh-en"
118 121 model_dir: "./models/translation/Helsinki-NLP/opus-mt-zh-en"
  122 + ct2_model_dir: "./models/translation/Helsinki-NLP/opus-mt-zh-en/ctranslate2-float16"
  123 + ct2_compute_type: "float16"
  124 + ct2_conversion_quantization: "float16"
  125 + ct2_auto_convert: true
119 126 device: "cuda"
120 127 torch_dtype: "float16"
121 128 batch_size: 16
... ... @@ -128,6 +135,10 @@ services:
128 135 backend: "local_marian"
129 136 model_id: "Helsinki-NLP/opus-mt-en-zh"
130 137 model_dir: "./models/translation/Helsinki-NLP/opus-mt-en-zh"
  138 + ct2_model_dir: "./models/translation/Helsinki-NLP/opus-mt-en-zh/ctranslate2-float16"
  139 + ct2_compute_type: "float16"
  140 + ct2_conversion_quantization: "float16"
  141 + ct2_auto_convert: true
131 142 device: "cuda"
132 143 torch_dtype: "float16"
133 144 batch_size: 16
... ... @@ -148,6 +159,7 @@ services:
148 159 - `service_url`、`default_model`、`default_scene` 只从 YAML 读取
149 160 - 不再通过环境变量静默覆盖翻译行为配置
150 161 - 密钥仍通过环境变量提供
  162 +- `local_nllb` / `local_marian` 当前由 CTranslate2 运行;首次启动时若 `ct2_model_dir` 不存在,会从 `model_dir` 自动转换
151 163  
152 164 ## 4. 环境变量
153 165  
... ... @@ -338,7 +350,7 @@ results = translator.translate(
338 350 ### 8.4 `facebook/nllb-200-distilled-600M`
339 351  
340 352 实现文件:
341   -- [`translation/backends/local_seq2seq.py`](/data/saas-search/translation/backends/local_seq2seq.py)
  353 +- [`translation/backends/local_ctranslate2.py`](/data/saas-search/translation/backends/local_ctranslate2.py)
342 354  
343 355 模型信息:
344 356 - Hugging Face 名称:`facebook/nllb-200-distilled-600M`
... ... @@ -392,18 +404,17 @@ results = translator.translate(
392 404  
393 405 当前实现特点:
394 406 - backend 类型:`local_nllb`
  407 +- 运行时:`CTranslate2 Translator`
395 408 - 支持多语
396 409 - 调用时必须显式传 `source_lang`
397 410 - 语言码映射定义在 [`translation/languages.py`](/data/saas-search/translation/languages.py)
398   -- 当前 T4 推荐配置:`device=cuda`、`torch_dtype=float16`、`batch_size=16`、`max_new_tokens=64`、`attn_implementation=sdpa`
  411 +- 当前 T4 推荐配置:`device=cuda`、`ct2_compute_type=float16`、`batch_size=16`、`max_new_tokens=64`
399 412  
400 413 当前实现已经利用的优化:
401 414 - 已做批量分块:`translate()` 会按 capability 的 `batch_size` 分批进入模型
402   -- 已做动态 padding:tokenizer 使用 `padding=True`、`truncation=True`
403   -- 已传入 `attention_mask`:由 tokenizer 生成并随 `generate()` 一起送入模型
404   -- 已设置方向控制:NLLB 通过 `tokenizer.src_lang` 和 `forced_bos_token_id` 指定语言对
405   -- 已启用推理态:`torch.inference_mode()` + `model.eval()`
406   -- 已启用半精度和更优注意力实现:当前配置为 `float16 + sdpa`
  415 +- 已切换到 CTranslate2 推理引擎:不再依赖 PyTorch `generate()`
  416 +- 已设置方向控制:NLLB 通过 target prefix 指定目标语言
  417 +- 已启用半精度:当前配置为 `float16`
407 418 - 已关闭高开销搜索:默认 `num_beams=1`,更接近线上低延迟设置
408 419  
409 420 和你给出的批处理示例对照:
... ... @@ -414,12 +425,12 @@ results = translator.translate(
414 425 优化空间(按场景):
415 426 - **线上 query**:优先补测 `batch_size=1` 的真实延迟与 tail latency,而不是继续拉大 batch。
416 427 - **离线批量**:可再尝试更激进的 batching / 长度分桶 / 独立批处理队列(吞吐更高,但会增加在线尾延迟风险)。
417   -- **进一步降显存 / 提速**:可评估 `ctranslate2` / int8;当前仓库尚未引入该运行栈
  428 +- **进一步降显存 / 提速**:可在当前 CT2 方案上继续评估 `int8_float16`
418 429  
419 430 ### 8.5 `opus-mt-zh-en`
420 431  
421 432 实现文件:
422   -- [`translation/backends/local_seq2seq.py`](/data/saas-search/translation/backends/local_seq2seq.py)
  433 +- [`translation/backends/local_ctranslate2.py`](/data/saas-search/translation/backends/local_ctranslate2.py)
423 434  
424 435 模型信息:
425 436 - Hugging Face 名称:`Helsinki-NLP/opus-mt-zh-en`
... ... @@ -437,7 +448,7 @@ results = translator.translate(
437 448 ### 8.6 `opus-mt-en-zh`
438 449  
439 450 实现文件:
440   -- [`translation/backends/local_seq2seq.py`](/data/saas-search/translation/backends/local_seq2seq.py)
  451 +- [`translation/backends/local_ctranslate2.py`](/data/saas-search/translation/backends/local_ctranslate2.py)
441 452  
442 453 模型信息:
443 454 - Hugging Face 名称:`Helsinki-NLP/opus-mt-en-zh`
... ... @@ -498,7 +509,7 @@ models/translation/Helsinki-NLP/opus-mt-en-zh
498 509 - 避免多 worker 重复加载模型
499 510 - GPU 机器上优先使用 `cuda + float16`
500 511 - CPU 只建议用于功能验证或离线低频任务
501   -- 对 NLLB,T4 上优先采用 `batch_size=16 + max_new_tokens=64 + attn_implementation=sdpa`
  512 +- 对 NLLB,T4 上优先采用 `batch_size=16 + max_new_tokens=64 + ct2_compute_type=float16`
502 513  
503 514 ### 9.5 验证
504 515  
... ... @@ -524,6 +535,10 @@ curl -X POST http://127.0.0.1:6006/translate \
524 535  
525 536 ## 10. 性能测试与复现
526 537  
  538 +说明:
  539 +- 本节现有数值是 `2026-03-18` 的 Hugging Face / PyTorch 基线结果。
  540 +- 切换到 CTranslate2 后需要重新跑一轮基准,尤其关注 `nllb-200-distilled-600m` 的单条延迟、并发 tail latency 和 `opus-mt-*` 的 batch throughput。
  541 +
527 542 性能脚本:
528 543 - [`scripts/benchmark_translation_local_models.py`](/data/saas-search/scripts/benchmark_translation_local_models.py)
529 544  
... ...
translation/backends/local_ctranslate2.py 0 → 100644
... ... @@ -0,0 +1,479 @@
  1 +"""Local translation backends powered by CTranslate2."""
  2 +
  3 +from __future__ import annotations
  4 +
  5 +import logging
  6 +import os
  7 +import shutil
  8 +import subprocess
  9 +import sys
  10 +import threading
  11 +from pathlib import Path
  12 +from typing import Dict, List, Optional, Sequence, Union
  13 +
  14 +from transformers import AutoTokenizer
  15 +
  16 +from translation.languages import MARIAN_LANGUAGE_DIRECTIONS, NLLB_LANGUAGE_CODES
  17 +
  18 +logger = logging.getLogger(__name__)
  19 +
  20 +
  21 +def _resolve_device(device: Optional[str]) -> str:
  22 + value = str(device or "auto").strip().lower()
  23 + if value not in {"auto", "cpu", "cuda"}:
  24 + raise ValueError(f"Unsupported CTranslate2 device: {device}")
  25 + return value
  26 +
  27 +
  28 +def _resolve_compute_type(
  29 + torch_dtype: Optional[str],
  30 + compute_type: Optional[str],
  31 + device: str,
  32 +) -> str:
  33 + value = str(compute_type or torch_dtype or "default").strip().lower()
  34 + if value in {"auto", "default"}:
  35 + return "float16" if device == "cuda" else "default"
  36 + if value in {"float16", "fp16", "half"}:
  37 + return "float16"
  38 + if value in {"bfloat16", "bf16"}:
  39 + return "bfloat16"
  40 + if value in {"float32", "fp32"}:
  41 + return "float32"
  42 + if value in {
  43 + "int8",
  44 + "int8_float32",
  45 + "int8_float16",
  46 + "int8_bfloat16",
  47 + "int16",
  48 + }:
  49 + return value
  50 + raise ValueError(f"Unsupported CTranslate2 compute type: {compute_type or torch_dtype}")
  51 +
  52 +
  53 +def _derive_ct2_model_dir(model_dir: str, compute_type: str) -> str:
  54 + normalized = compute_type.replace("_", "-")
  55 + return str(Path(model_dir).expanduser() / f"ctranslate2-{normalized}")
  56 +
  57 +
  58 +def _resolve_converter_binary() -> str:
  59 + candidate = shutil.which("ct2-transformers-converter")
  60 + if candidate:
  61 + return candidate
  62 + venv_candidate = Path(sys.executable).absolute().parent / "ct2-transformers-converter"
  63 + if venv_candidate.exists():
  64 + return str(venv_candidate)
  65 + raise RuntimeError(
  66 + "ct2-transformers-converter was not found. "
  67 + "Ensure ctranslate2 is installed in the active translator environment."
  68 + )
  69 +
  70 +
  71 +class LocalCTranslate2TranslationBackend:
  72 + """Base backend for local CTranslate2 translation models."""
  73 +
  74 + def __init__(
  75 + self,
  76 + *,
  77 + name: str,
  78 + model_id: str,
  79 + model_dir: str,
  80 + device: str,
  81 + torch_dtype: str,
  82 + batch_size: int,
  83 + max_input_length: int,
  84 + max_new_tokens: int,
  85 + num_beams: int,
  86 + ct2_model_dir: Optional[str] = None,
  87 + ct2_compute_type: Optional[str] = None,
  88 + ct2_auto_convert: bool = True,
  89 + ct2_conversion_quantization: Optional[str] = None,
  90 + ct2_inter_threads: int = 1,
  91 + ct2_intra_threads: int = 0,
  92 + ct2_max_queued_batches: int = 0,
  93 + ct2_batch_type: str = "examples",
  94 + ) -> None:
  95 + self.model = name
  96 + self.model_id = model_id
  97 + self.model_dir = model_dir
  98 + self.device = _resolve_device(device)
  99 + self.compute_type = _resolve_compute_type(torch_dtype, ct2_compute_type, self.device)
  100 + self.batch_size = int(batch_size)
  101 + self.max_input_length = int(max_input_length)
  102 + self.max_new_tokens = int(max_new_tokens)
  103 + self.num_beams = int(num_beams)
  104 + self.ct2_model_dir = str(ct2_model_dir or _derive_ct2_model_dir(model_dir, self.compute_type))
  105 + self.ct2_auto_convert = bool(ct2_auto_convert)
  106 + self.ct2_conversion_quantization = _resolve_compute_type(
  107 + torch_dtype,
  108 + ct2_conversion_quantization or self.compute_type,
  109 + self.device,
  110 + )
  111 + self.ct2_inter_threads = int(ct2_inter_threads)
  112 + self.ct2_intra_threads = int(ct2_intra_threads)
  113 + self.ct2_max_queued_batches = int(ct2_max_queued_batches)
  114 + self.ct2_batch_type = str(ct2_batch_type or "examples").strip().lower()
  115 + if self.ct2_batch_type not in {"examples", "tokens"}:
  116 + raise ValueError(f"Unsupported CTranslate2 batch type: {ct2_batch_type}")
  117 + self._tokenizer_lock = threading.Lock()
  118 + self._load_runtime()
  119 +
  120 + @property
  121 + def supports_batch(self) -> bool:
  122 + return True
  123 +
  124 + def _tokenizer_source(self) -> str:
  125 + return self.model_dir if os.path.exists(self.model_dir) else self.model_id
  126 +
  127 + def _model_source(self) -> str:
  128 + return self.model_dir if os.path.exists(self.model_dir) else self.model_id
  129 +
  130 + def _tokenizer_kwargs(self) -> Dict[str, object]:
  131 + return {}
  132 +
  133 + def _translator_kwargs(self) -> Dict[str, object]:
  134 + return {
  135 + "device": self.device,
  136 + "compute_type": self.compute_type,
  137 + "inter_threads": self.ct2_inter_threads,
  138 + "intra_threads": self.ct2_intra_threads,
  139 + "max_queued_batches": self.ct2_max_queued_batches,
  140 + }
  141 +
  142 + def _load_runtime(self) -> None:
  143 + try:
  144 + import ctranslate2
  145 + except ImportError as exc:
  146 + raise RuntimeError(
  147 + "CTranslate2 is required for local Marian/NLLB translation. "
  148 + "Install the translator service dependencies again after adding ctranslate2."
  149 + ) from exc
  150 +
  151 + tokenizer_source = self._tokenizer_source()
  152 + model_source = self._model_source()
  153 + self._ensure_converted_model(model_source)
  154 + logger.info(
  155 + "Loading CTranslate2 translation model | name=%s ct2_model_dir=%s tokenizer=%s device=%s compute_type=%s",
  156 + self.model,
  157 + self.ct2_model_dir,
  158 + tokenizer_source,
  159 + self.device,
  160 + self.compute_type,
  161 + )
  162 + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_source, **self._tokenizer_kwargs())
  163 + self.translator = ctranslate2.Translator(self.ct2_model_dir, **self._translator_kwargs())
  164 + if self.tokenizer.pad_token is None and self.tokenizer.eos_token is not None:
  165 + self.tokenizer.pad_token = self.tokenizer.eos_token
  166 +
  167 + def _ensure_converted_model(self, model_source: str) -> None:
  168 + ct2_path = Path(self.ct2_model_dir).expanduser()
  169 + if (ct2_path / "model.bin").exists():
  170 + return
  171 + if not self.ct2_auto_convert:
  172 + raise FileNotFoundError(
  173 + f"CTranslate2 model not found for '{self.model}': {ct2_path}. "
  174 + "Enable ct2_auto_convert or pre-convert the model."
  175 + )
  176 +
  177 + ct2_path.parent.mkdir(parents=True, exist_ok=True)
  178 + converter = _resolve_converter_binary()
  179 + logger.info(
  180 + "Converting translation model to CTranslate2 | name=%s source=%s output=%s quantization=%s",
  181 + self.model,
  182 + model_source,
  183 + ct2_path,
  184 + self.ct2_conversion_quantization,
  185 + )
  186 + try:
  187 + subprocess.run(
  188 + [
  189 + converter,
  190 + "--model",
  191 + model_source,
  192 + "--output_dir",
  193 + str(ct2_path),
  194 + "--quantization",
  195 + self.ct2_conversion_quantization,
  196 + ],
  197 + check=True,
  198 + stdout=subprocess.PIPE,
  199 + stderr=subprocess.PIPE,
  200 + text=True,
  201 + )
  202 + except subprocess.CalledProcessError as exc:
  203 + stderr = exc.stderr.strip()
  204 + raise RuntimeError(
  205 + f"Failed to convert model '{self.model}' to CTranslate2: {stderr or exc}"
  206 + ) from exc
  207 +
  208 + def _normalize_texts(self, text: Union[str, Sequence[str]]) -> List[str]:
  209 + if isinstance(text, str):
  210 + return [text]
  211 + return ["" if item is None else str(item) for item in text]
  212 +
  213 + def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None:
  214 + del source_lang, target_lang
  215 +
  216 + def _encode_source_tokens(
  217 + self,
  218 + texts: List[str],
  219 + source_lang: Optional[str],
  220 + target_lang: str,
  221 + ) -> List[List[str]]:
  222 + del source_lang, target_lang
  223 + with self._tokenizer_lock:
  224 + encoded = self.tokenizer(
  225 + texts,
  226 + truncation=True,
  227 + max_length=self.max_input_length,
  228 + padding=False,
  229 + )
  230 + input_ids = encoded["input_ids"]
  231 + return [self.tokenizer.convert_ids_to_tokens(ids) for ids in input_ids]
  232 +
  233 + def _target_prefixes(
  234 + self,
  235 + count: int,
  236 + source_lang: Optional[str],
  237 + target_lang: str,
  238 + ) -> Optional[List[Optional[List[str]]]]:
  239 + del count, source_lang, target_lang
  240 + return None
  241 +
  242 + def _postprocess_hypothesis(
  243 + self,
  244 + tokens: List[str],
  245 + source_lang: Optional[str],
  246 + target_lang: str,
  247 + ) -> List[str]:
  248 + del source_lang, target_lang
  249 + return tokens
  250 +
  251 + def _decode_tokens(self, tokens: List[str]) -> Optional[str]:
  252 + token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
  253 + text = self.tokenizer.decode(token_ids, skip_special_tokens=True).strip()
  254 + return text or None
  255 +
  256 + def _translate_batch(
  257 + self,
  258 + texts: List[str],
  259 + target_lang: str,
  260 + source_lang: Optional[str] = None,
  261 + ) -> List[Optional[str]]:
  262 + self._validate_languages(source_lang, target_lang)
  263 + source_tokens = self._encode_source_tokens(texts, source_lang, target_lang)
  264 + target_prefix = self._target_prefixes(len(source_tokens), source_lang, target_lang)
  265 + results = self.translator.translate_batch(
  266 + source_tokens,
  267 + target_prefix=target_prefix,
  268 + max_batch_size=self.batch_size,
  269 + batch_type=self.ct2_batch_type,
  270 + beam_size=self.num_beams,
  271 + max_input_length=self.max_input_length,
  272 + max_decoding_length=self.max_new_tokens,
  273 + )
  274 + outputs: List[Optional[str]] = []
  275 + for result in results:
  276 + hypothesis = result.hypotheses[0] if result.hypotheses else []
  277 + processed = self._postprocess_hypothesis(hypothesis, source_lang, target_lang)
  278 + outputs.append(self._decode_tokens(processed))
  279 + return outputs
  280 +
  281 + def translate(
  282 + self,
  283 + text: Union[str, Sequence[str]],
  284 + target_lang: str,
  285 + source_lang: Optional[str] = None,
  286 + scene: Optional[str] = None,
  287 + ) -> Union[Optional[str], List[Optional[str]]]:
  288 + del scene
  289 + is_single = isinstance(text, str)
  290 + texts = self._normalize_texts(text)
  291 + outputs: List[Optional[str]] = []
  292 + for start in range(0, len(texts), self.batch_size):
  293 + chunk = texts[start:start + self.batch_size]
  294 + if not any(item.strip() for item in chunk):
  295 + outputs.extend([None if not item.strip() else item for item in chunk]) # type: ignore[list-item]
  296 + continue
  297 + outputs.extend(self._translate_batch(chunk, target_lang=target_lang, source_lang=source_lang))
  298 + return outputs[0] if is_single else outputs
  299 +
  300 +
  301 +class MarianCTranslate2TranslationBackend(LocalCTranslate2TranslationBackend):
  302 + """Local backend for Marian/OPUS MT models on CTranslate2."""
  303 +
  304 + def __init__(
  305 + self,
  306 + *,
  307 + name: str,
  308 + model_id: str,
  309 + model_dir: str,
  310 + device: str,
  311 + torch_dtype: str,
  312 + batch_size: int,
  313 + max_input_length: int,
  314 + max_new_tokens: int,
  315 + num_beams: int,
  316 + source_langs: Sequence[str],
  317 + target_langs: Sequence[str],
  318 + ct2_model_dir: Optional[str] = None,
  319 + ct2_compute_type: Optional[str] = None,
  320 + ct2_auto_convert: bool = True,
  321 + ct2_conversion_quantization: Optional[str] = None,
  322 + ct2_inter_threads: int = 1,
  323 + ct2_intra_threads: int = 0,
  324 + ct2_max_queued_batches: int = 0,
  325 + ct2_batch_type: str = "examples",
  326 + ) -> None:
  327 + self.source_langs = {str(lang).strip().lower() for lang in source_langs if str(lang).strip()}
  328 + self.target_langs = {str(lang).strip().lower() for lang in target_langs if str(lang).strip()}
  329 + super().__init__(
  330 + name=name,
  331 + model_id=model_id,
  332 + model_dir=model_dir,
  333 + device=device,
  334 + torch_dtype=torch_dtype,
  335 + batch_size=batch_size,
  336 + max_input_length=max_input_length,
  337 + max_new_tokens=max_new_tokens,
  338 + num_beams=num_beams,
  339 + ct2_model_dir=ct2_model_dir,
  340 + ct2_compute_type=ct2_compute_type,
  341 + ct2_auto_convert=ct2_auto_convert,
  342 + ct2_conversion_quantization=ct2_conversion_quantization,
  343 + ct2_inter_threads=ct2_inter_threads,
  344 + ct2_intra_threads=ct2_intra_threads,
  345 + ct2_max_queued_batches=ct2_max_queued_batches,
  346 + ct2_batch_type=ct2_batch_type,
  347 + )
  348 +
  349 + def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None:
  350 + src = str(source_lang or "").strip().lower()
  351 + tgt = str(target_lang or "").strip().lower()
  352 + if self.source_langs and src not in self.source_langs:
  353 + raise ValueError(
  354 + f"Model '{self.model}' only supports source languages: {sorted(self.source_langs)}"
  355 + )
  356 + if self.target_langs and tgt not in self.target_langs:
  357 + raise ValueError(
  358 + f"Model '{self.model}' only supports target languages: {sorted(self.target_langs)}"
  359 + )
  360 +
  361 +
  362 +class NLLBCTranslate2TranslationBackend(LocalCTranslate2TranslationBackend):
  363 + """Local backend for NLLB models on CTranslate2."""
  364 +
  365 + def __init__(
  366 + self,
  367 + *,
  368 + name: str,
  369 + model_id: str,
  370 + model_dir: str,
  371 + device: str,
  372 + torch_dtype: str,
  373 + batch_size: int,
  374 + max_input_length: int,
  375 + max_new_tokens: int,
  376 + num_beams: int,
  377 + language_codes: Optional[Dict[str, str]] = None,
  378 + ct2_model_dir: Optional[str] = None,
  379 + ct2_compute_type: Optional[str] = None,
  380 + ct2_auto_convert: bool = True,
  381 + ct2_conversion_quantization: Optional[str] = None,
  382 + ct2_inter_threads: int = 1,
  383 + ct2_intra_threads: int = 0,
  384 + ct2_max_queued_batches: int = 0,
  385 + ct2_batch_type: str = "examples",
  386 + ) -> None:
  387 + overrides = language_codes or {}
  388 + self.language_codes = {
  389 + **NLLB_LANGUAGE_CODES,
  390 + **{str(k).strip().lower(): str(v).strip() for k, v in overrides.items() if str(k).strip()},
  391 + }
  392 + self._tokenizers_by_source: Dict[str, object] = {}
  393 + super().__init__(
  394 + name=name,
  395 + model_id=model_id,
  396 + model_dir=model_dir,
  397 + device=device,
  398 + torch_dtype=torch_dtype,
  399 + batch_size=batch_size,
  400 + max_input_length=max_input_length,
  401 + max_new_tokens=max_new_tokens,
  402 + num_beams=num_beams,
  403 + ct2_model_dir=ct2_model_dir,
  404 + ct2_compute_type=ct2_compute_type,
  405 + ct2_auto_convert=ct2_auto_convert,
  406 + ct2_conversion_quantization=ct2_conversion_quantization,
  407 + ct2_inter_threads=ct2_inter_threads,
  408 + ct2_intra_threads=ct2_intra_threads,
  409 + ct2_max_queued_batches=ct2_max_queued_batches,
  410 + ct2_batch_type=ct2_batch_type,
  411 + )
  412 +
  413 + def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None:
  414 + src = str(source_lang or "").strip().lower()
  415 + tgt = str(target_lang or "").strip().lower()
  416 + if not src:
  417 + raise ValueError(f"Model '{self.model}' requires source_lang")
  418 + if src not in self.language_codes:
  419 + raise ValueError(f"Unsupported NLLB source language: {source_lang}")
  420 + if tgt not in self.language_codes:
  421 + raise ValueError(f"Unsupported NLLB target language: {target_lang}")
  422 +
  423 + def _get_tokenizer_for_source(self, source_lang: str):
  424 + src_code = self.language_codes[source_lang]
  425 + with self._tokenizer_lock:
  426 + tokenizer = self._tokenizers_by_source.get(src_code)
  427 + if tokenizer is None:
  428 + tokenizer = AutoTokenizer.from_pretrained(self._tokenizer_source(), src_lang=src_code)
  429 + if tokenizer.pad_token is None and tokenizer.eos_token is not None:
  430 + tokenizer.pad_token = tokenizer.eos_token
  431 + self._tokenizers_by_source[src_code] = tokenizer
  432 + return tokenizer
  433 +
  434 + def _encode_source_tokens(
  435 + self,
  436 + texts: List[str],
  437 + source_lang: Optional[str],
  438 + target_lang: str,
  439 + ) -> List[List[str]]:
  440 + del target_lang
  441 + source_key = str(source_lang or "").strip().lower()
  442 + tokenizer = self._get_tokenizer_for_source(source_key)
  443 + encoded = tokenizer(
  444 + texts,
  445 + truncation=True,
  446 + max_length=self.max_input_length,
  447 + padding=False,
  448 + )
  449 + input_ids = encoded["input_ids"]
  450 + return [tokenizer.convert_ids_to_tokens(ids) for ids in input_ids]
  451 +
  452 + def _target_prefixes(
  453 + self,
  454 + count: int,
  455 + source_lang: Optional[str],
  456 + target_lang: str,
  457 + ) -> Optional[List[Optional[List[str]]]]:
  458 + del source_lang
  459 + tgt_code = self.language_codes[str(target_lang).strip().lower()]
  460 + return [[tgt_code] for _ in range(count)]
  461 +
  462 + def _postprocess_hypothesis(
  463 + self,
  464 + tokens: List[str],
  465 + source_lang: Optional[str],
  466 + target_lang: str,
  467 + ) -> List[str]:
  468 + del source_lang
  469 + tgt_code = self.language_codes[str(target_lang).strip().lower()]
  470 + if tokens and tokens[0] == tgt_code:
  471 + return tokens[1:]
  472 + return tokens
  473 +
  474 +
  475 +def get_marian_language_direction(model_name: str) -> tuple[str, str]:
  476 + direction = MARIAN_LANGUAGE_DIRECTIONS.get(model_name)
  477 + if direction is None:
  478 + raise ValueError(f"Translation capability '{model_name}' is not registered with Marian language directions")
  479 + return direction
... ...
translation/service.py
... ... @@ -111,9 +111,9 @@ class TranslationService:
111 111 )
112 112  
113 113 def _create_local_nllb_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol:
114   - from translation.backends.local_seq2seq import NLLBTranslationBackend
  114 + from translation.backends.local_ctranslate2 import NLLBCTranslate2TranslationBackend
115 115  
116   - return NLLBTranslationBackend(
  116 + return NLLBCTranslate2TranslationBackend(
117 117 name=name,
118 118 model_id=str(cfg["model_id"]).strip(),
119 119 model_dir=str(cfg["model_dir"]).strip(),
... ... @@ -123,15 +123,22 @@ class TranslationService:
123 123 max_input_length=int(cfg["max_input_length"]),
124 124 max_new_tokens=int(cfg["max_new_tokens"]),
125 125 num_beams=int(cfg["num_beams"]),
126   - attn_implementation=cfg.get("attn_implementation"),
  126 + ct2_model_dir=cfg.get("ct2_model_dir"),
  127 + ct2_compute_type=cfg.get("ct2_compute_type"),
  128 + ct2_auto_convert=bool(cfg.get("ct2_auto_convert", True)),
  129 + ct2_conversion_quantization=cfg.get("ct2_conversion_quantization"),
  130 + ct2_inter_threads=int(cfg.get("ct2_inter_threads", 1)),
  131 + ct2_intra_threads=int(cfg.get("ct2_intra_threads", 0)),
  132 + ct2_max_queued_batches=int(cfg.get("ct2_max_queued_batches", 0)),
  133 + ct2_batch_type=str(cfg.get("ct2_batch_type", "examples")),
127 134 )
128 135  
129 136 def _create_local_marian_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol:
130   - from translation.backends.local_seq2seq import MarianMTTranslationBackend, get_marian_language_direction
  137 + from translation.backends.local_ctranslate2 import MarianCTranslate2TranslationBackend, get_marian_language_direction
131 138  
132 139 source_lang, target_lang = get_marian_language_direction(name)
133 140  
134   - return MarianMTTranslationBackend(
  141 + return MarianCTranslate2TranslationBackend(
135 142 name=name,
136 143 model_id=str(cfg["model_id"]).strip(),
137 144 model_dir=str(cfg["model_dir"]).strip(),
... ... @@ -143,7 +150,14 @@ class TranslationService:
143 150 num_beams=int(cfg["num_beams"]),
144 151 source_langs=[source_lang],
145 152 target_langs=[target_lang],
146   - attn_implementation=cfg.get("attn_implementation"),
  153 + ct2_model_dir=cfg.get("ct2_model_dir"),
  154 + ct2_compute_type=cfg.get("ct2_compute_type"),
  155 + ct2_auto_convert=bool(cfg.get("ct2_auto_convert", True)),
  156 + ct2_conversion_quantization=cfg.get("ct2_conversion_quantization"),
  157 + ct2_inter_threads=int(cfg.get("ct2_inter_threads", 1)),
  158 + ct2_intra_threads=int(cfg.get("ct2_intra_threads", 0)),
  159 + ct2_max_queued_batches=int(cfg.get("ct2_max_queued_batches", 0)),
  160 + ct2_batch_type=str(cfg.get("ct2_batch_type", "examples")),
147 161 )
148 162  
149 163 @property
... ...