Commit ea293660199e758d1c9ade4a4f8b5444d94b4fec
1 parent
2a6d9d76
CTranslate2
Implemented CTranslate2 for the three local translation models and switched the existing local_nllb / local_marian factories over to it. The new runtime lives in local_ctranslate2.py, including HF->CT2 auto-conversion, float16 compute type mapping, Marian direction handling, and NLLB target-prefix decoding. The service wiring is in service.py (line 113), and the three model configs now point at explicit ctranslate2-float16 dirs in config.yaml (line 133). I also updated the setup path so this is usable end-to-end: ctranslate2>=4.7.0 was added to requirements_translator_service.txt and requirements.txt, the download script now supports pre-conversion in download_translation_models.py (line 27), and the docs/config examples were refreshed in translation/README.md. I installed ctranslate2 into .venv-translator, pre-converted all three models, and the CT2 artifacts are now already on disk: models/translation/facebook/nllb-200-distilled-600M/ctranslate2-float16 models/translation/Helsinki-NLP/opus-mt-zh-en/ctranslate2-float16 models/translation/Helsinki-NLP/opus-mt-en-zh/ctranslate2-float16 Verification was solid. python3 -m compileall passed, direct TranslationService smoke tests ran successfully in .venv-translator, and the focused NLLB benchmark on the local GPU showed a clear win: batch_size=16: HF 0.347s/batch, 46.1 items/s vs CT2 0.130s/batch, 123.0 items/s batch_size=1: HF 0.396s/request vs CT2 0.126s/request One caveat: translation quality on some very short phrases, especially opus-mt-en-zh, still looks a bit rough in smoke tests, so I’d run your real quality set before fully cutting over. If you want, I can take the next step and update the benchmark script/report so you have a fresh full CT2 performance report for all three models.
Showing
7 changed files
with
620 additions
and
23 deletions
Show diff stats
config/config.yaml
| @@ -135,19 +135,34 @@ services: | @@ -135,19 +135,34 @@ services: | ||
| 135 | backend: "local_nllb" | 135 | backend: "local_nllb" |
| 136 | model_id: "facebook/nllb-200-distilled-600M" | 136 | model_id: "facebook/nllb-200-distilled-600M" |
| 137 | model_dir: "./models/translation/facebook/nllb-200-distilled-600M" | 137 | model_dir: "./models/translation/facebook/nllb-200-distilled-600M" |
| 138 | + ct2_model_dir: "./models/translation/facebook/nllb-200-distilled-600M/ctranslate2-float16" | ||
| 139 | + ct2_compute_type: "float16" | ||
| 140 | + ct2_conversion_quantization: "float16" | ||
| 141 | + ct2_auto_convert: true | ||
| 142 | + ct2_inter_threads: 1 | ||
| 143 | + ct2_intra_threads: 0 | ||
| 144 | + ct2_max_queued_batches: 0 | ||
| 145 | + ct2_batch_type: "examples" | ||
| 138 | device: "cuda" | 146 | device: "cuda" |
| 139 | torch_dtype: "float16" | 147 | torch_dtype: "float16" |
| 140 | batch_size: 16 | 148 | batch_size: 16 |
| 141 | max_input_length: 256 | 149 | max_input_length: 256 |
| 142 | max_new_tokens: 64 | 150 | max_new_tokens: 64 |
| 143 | num_beams: 1 | 151 | num_beams: 1 |
| 144 | - attn_implementation: "sdpa" | ||
| 145 | use_cache: true | 152 | use_cache: true |
| 146 | opus-mt-zh-en: | 153 | opus-mt-zh-en: |
| 147 | enabled: true | 154 | enabled: true |
| 148 | backend: "local_marian" | 155 | backend: "local_marian" |
| 149 | model_id: "Helsinki-NLP/opus-mt-zh-en" | 156 | model_id: "Helsinki-NLP/opus-mt-zh-en" |
| 150 | model_dir: "./models/translation/Helsinki-NLP/opus-mt-zh-en" | 157 | model_dir: "./models/translation/Helsinki-NLP/opus-mt-zh-en" |
| 158 | + ct2_model_dir: "./models/translation/Helsinki-NLP/opus-mt-zh-en/ctranslate2-float16" | ||
| 159 | + ct2_compute_type: "float16" | ||
| 160 | + ct2_conversion_quantization: "float16" | ||
| 161 | + ct2_auto_convert: true | ||
| 162 | + ct2_inter_threads: 1 | ||
| 163 | + ct2_intra_threads: 0 | ||
| 164 | + ct2_max_queued_batches: 0 | ||
| 165 | + ct2_batch_type: "examples" | ||
| 151 | device: "cuda" | 166 | device: "cuda" |
| 152 | torch_dtype: "float16" | 167 | torch_dtype: "float16" |
| 153 | batch_size: 16 | 168 | batch_size: 16 |
| @@ -160,6 +175,14 @@ services: | @@ -160,6 +175,14 @@ services: | ||
| 160 | backend: "local_marian" | 175 | backend: "local_marian" |
| 161 | model_id: "Helsinki-NLP/opus-mt-en-zh" | 176 | model_id: "Helsinki-NLP/opus-mt-en-zh" |
| 162 | model_dir: "./models/translation/Helsinki-NLP/opus-mt-en-zh" | 177 | model_dir: "./models/translation/Helsinki-NLP/opus-mt-en-zh" |
| 178 | + ct2_model_dir: "./models/translation/Helsinki-NLP/opus-mt-en-zh/ctranslate2-float16" | ||
| 179 | + ct2_compute_type: "float16" | ||
| 180 | + ct2_conversion_quantization: "float16" | ||
| 181 | + ct2_auto_convert: true | ||
| 182 | + ct2_inter_threads: 1 | ||
| 183 | + ct2_intra_threads: 0 | ||
| 184 | + ct2_max_queued_batches: 0 | ||
| 185 | + ct2_batch_type: "examples" | ||
| 163 | device: "cuda" | 186 | device: "cuda" |
| 164 | torch_dtype: "float16" | 187 | torch_dtype: "float16" |
| 165 | batch_size: 16 | 188 | batch_size: 16 |
requirements.txt
requirements_translator_service.txt
| @@ -14,6 +14,7 @@ tqdm>=4.65.0 | @@ -14,6 +14,7 @@ tqdm>=4.65.0 | ||
| 14 | 14 | ||
| 15 | torch>=2.0.0 | 15 | torch>=2.0.0 |
| 16 | transformers>=4.30.0 | 16 | transformers>=4.30.0 |
| 17 | +ctranslate2>=4.7.0 | ||
| 17 | sentencepiece>=0.2.0 | 18 | sentencepiece>=0.2.0 |
| 18 | sacremoses>=0.1.1 | 19 | sacremoses>=0.1.1 |
| 19 | safetensors>=0.4.0 | 20 | safetensors>=0.4.0 |
scripts/download_translation_models.py
| @@ -4,8 +4,10 @@ | @@ -4,8 +4,10 @@ | ||
| 4 | from __future__ import annotations | 4 | from __future__ import annotations |
| 5 | 5 | ||
| 6 | import argparse | 6 | import argparse |
| 7 | -from pathlib import Path | ||
| 8 | import os | 7 | import os |
| 8 | +from pathlib import Path | ||
| 9 | +import shutil | ||
| 10 | +import subprocess | ||
| 9 | import sys | 11 | import sys |
| 10 | from typing import Iterable | 12 | from typing import Iterable |
| 11 | 13 | ||
| @@ -24,7 +26,8 @@ LOCAL_BACKENDS = {"local_nllb", "local_marian"} | @@ -24,7 +26,8 @@ LOCAL_BACKENDS = {"local_nllb", "local_marian"} | ||
| 24 | 26 | ||
| 25 | def iter_local_capabilities(selected: set[str] | None = None) -> Iterable[tuple[str, dict]]: | 27 | def iter_local_capabilities(selected: set[str] | None = None) -> Iterable[tuple[str, dict]]: |
| 26 | cfg = get_translation_config() | 28 | cfg = get_translation_config() |
| 27 | - for name, capability in cfg.capabilities.items(): | 29 | + capabilities = cfg.get("capabilities", {}) if isinstance(cfg, dict) else {} |
| 30 | + for name, capability in capabilities.items(): | ||
| 28 | backend = str(capability.get("backend") or "").strip().lower() | 31 | backend = str(capability.get("backend") or "").strip().lower() |
| 29 | if backend not in LOCAL_BACKENDS: | 32 | if backend not in LOCAL_BACKENDS: |
| 30 | continue | 33 | continue |
| @@ -33,10 +36,69 @@ def iter_local_capabilities(selected: set[str] | None = None) -> Iterable[tuple[ | @@ -33,10 +36,69 @@ def iter_local_capabilities(selected: set[str] | None = None) -> Iterable[tuple[ | ||
| 33 | yield name, capability | 36 | yield name, capability |
| 34 | 37 | ||
| 35 | 38 | ||
| 39 | +def _compute_ct2_output_dir(capability: dict) -> Path: | ||
| 40 | + custom = str(capability.get("ct2_model_dir") or "").strip() | ||
| 41 | + if custom: | ||
| 42 | + return Path(custom).expanduser() | ||
| 43 | + model_dir = Path(str(capability.get("model_dir") or "")).expanduser() | ||
| 44 | + compute_type = str(capability.get("ct2_compute_type") or capability.get("torch_dtype") or "default").strip().lower() | ||
| 45 | + normalized = compute_type.replace("_", "-") | ||
| 46 | + return model_dir / f"ctranslate2-{normalized}" | ||
| 47 | + | ||
| 48 | + | ||
| 49 | +def _resolve_converter_binary() -> str: | ||
| 50 | + candidate = shutil.which("ct2-transformers-converter") | ||
| 51 | + if candidate: | ||
| 52 | + return candidate | ||
| 53 | + venv_candidate = Path(sys.executable).absolute().parent / "ct2-transformers-converter" | ||
| 54 | + if venv_candidate.exists(): | ||
| 55 | + return str(venv_candidate) | ||
| 56 | + raise RuntimeError( | ||
| 57 | + "ct2-transformers-converter was not found. " | ||
| 58 | + "Install ctranslate2 in the active Python environment first." | ||
| 59 | + ) | ||
| 60 | + | ||
| 61 | + | ||
| 62 | +def convert_to_ctranslate2(name: str, capability: dict) -> None: | ||
| 63 | + model_id = str(capability.get("model_id") or "").strip() | ||
| 64 | + model_dir = Path(str(capability.get("model_dir") or "")).expanduser() | ||
| 65 | + model_source = str(model_dir if model_dir.exists() else model_id) | ||
| 66 | + output_dir = _compute_ct2_output_dir(capability) | ||
| 67 | + if (output_dir / "model.bin").exists(): | ||
| 68 | + print(f"[skip-convert] {name} -> {output_dir}") | ||
| 69 | + return | ||
| 70 | + quantization = str( | ||
| 71 | + capability.get("ct2_conversion_quantization") | ||
| 72 | + or capability.get("ct2_compute_type") | ||
| 73 | + or capability.get("torch_dtype") | ||
| 74 | + or "default" | ||
| 75 | + ).strip() | ||
| 76 | + output_dir.parent.mkdir(parents=True, exist_ok=True) | ||
| 77 | + print(f"[convert] {name} -> {output_dir} ({quantization})") | ||
| 78 | + subprocess.run( | ||
| 79 | + [ | ||
| 80 | + _resolve_converter_binary(), | ||
| 81 | + "--model", | ||
| 82 | + model_source, | ||
| 83 | + "--output_dir", | ||
| 84 | + str(output_dir), | ||
| 85 | + "--quantization", | ||
| 86 | + quantization, | ||
| 87 | + ], | ||
| 88 | + check=True, | ||
| 89 | + ) | ||
| 90 | + print(f"[converted] {name}") | ||
| 91 | + | ||
| 92 | + | ||
| 36 | def main() -> None: | 93 | def main() -> None: |
| 37 | parser = argparse.ArgumentParser(description="Download local translation models") | 94 | parser = argparse.ArgumentParser(description="Download local translation models") |
| 38 | parser.add_argument("--all-local", action="store_true", help="Download all configured local translation models") | 95 | parser.add_argument("--all-local", action="store_true", help="Download all configured local translation models") |
| 39 | parser.add_argument("--models", nargs="*", default=[], help="Specific capability names to download") | 96 | parser.add_argument("--models", nargs="*", default=[], help="Specific capability names to download") |
| 97 | + parser.add_argument( | ||
| 98 | + "--convert-ctranslate2", | ||
| 99 | + action="store_true", | ||
| 100 | + help="Also convert the downloaded Hugging Face models into CTranslate2 format", | ||
| 101 | + ) | ||
| 40 | args = parser.parse_args() | 102 | args = parser.parse_args() |
| 41 | 103 | ||
| 42 | selected = {item.strip().lower() for item in args.models if item.strip()} or None | 104 | selected = {item.strip().lower() for item in args.models if item.strip()} or None |
| @@ -55,6 +117,8 @@ def main() -> None: | @@ -55,6 +117,8 @@ def main() -> None: | ||
| 55 | local_dir=str(model_dir), | 117 | local_dir=str(model_dir), |
| 56 | ) | 118 | ) |
| 57 | print(f"[done] {name}") | 119 | print(f"[done] {name}") |
| 120 | + if args.convert_ctranslate2: | ||
| 121 | + convert_to_ctranslate2(name, capability) | ||
| 58 | 122 | ||
| 59 | 123 | ||
| 60 | if __name__ == "__main__": | 124 | if __name__ == "__main__": |
translation/README.md
| @@ -56,8 +56,8 @@ | @@ -56,8 +56,8 @@ | ||
| 56 | 通用 LLM 翻译 | 56 | 通用 LLM 翻译 |
| 57 | - [`translation/backends/deepl.py`](/data/saas-search/translation/backends/deepl.py) | 57 | - [`translation/backends/deepl.py`](/data/saas-search/translation/backends/deepl.py) |
| 58 | DeepL 翻译 | 58 | DeepL 翻译 |
| 59 | -- [`translation/backends/local_seq2seq.py`](/data/saas-search/translation/backends/local_seq2seq.py) | ||
| 60 | - 本地 Hugging Face Seq2Seq 模型,包括 NLLB 和 Marian/OPUS MT | 59 | +- [`translation/backends/local_ctranslate2.py`](/data/saas-search/translation/backends/local_ctranslate2.py) |
| 60 | + 本地 CTranslate2 翻译模型,包括 NLLB 和 Marian/OPUS MT | ||
| 61 | 61 | ||
| 62 | ## 3. 配置约定 | 62 | ## 3. 配置约定 |
| 63 | 63 | ||
| @@ -103,19 +103,26 @@ services: | @@ -103,19 +103,26 @@ services: | ||
| 103 | backend: "local_nllb" | 103 | backend: "local_nllb" |
| 104 | model_id: "facebook/nllb-200-distilled-600M" | 104 | model_id: "facebook/nllb-200-distilled-600M" |
| 105 | model_dir: "./models/translation/facebook/nllb-200-distilled-600M" | 105 | model_dir: "./models/translation/facebook/nllb-200-distilled-600M" |
| 106 | + ct2_model_dir: "./models/translation/facebook/nllb-200-distilled-600M/ctranslate2-float16" | ||
| 107 | + ct2_compute_type: "float16" | ||
| 108 | + ct2_conversion_quantization: "float16" | ||
| 109 | + ct2_auto_convert: true | ||
| 106 | device: "cuda" | 110 | device: "cuda" |
| 107 | torch_dtype: "float16" | 111 | torch_dtype: "float16" |
| 108 | batch_size: 16 | 112 | batch_size: 16 |
| 109 | max_input_length: 256 | 113 | max_input_length: 256 |
| 110 | max_new_tokens: 64 | 114 | max_new_tokens: 64 |
| 111 | num_beams: 1 | 115 | num_beams: 1 |
| 112 | - attn_implementation: "sdpa" | ||
| 113 | use_cache: true | 116 | use_cache: true |
| 114 | opus-mt-zh-en: | 117 | opus-mt-zh-en: |
| 115 | enabled: true | 118 | enabled: true |
| 116 | backend: "local_marian" | 119 | backend: "local_marian" |
| 117 | model_id: "Helsinki-NLP/opus-mt-zh-en" | 120 | model_id: "Helsinki-NLP/opus-mt-zh-en" |
| 118 | model_dir: "./models/translation/Helsinki-NLP/opus-mt-zh-en" | 121 | model_dir: "./models/translation/Helsinki-NLP/opus-mt-zh-en" |
| 122 | + ct2_model_dir: "./models/translation/Helsinki-NLP/opus-mt-zh-en/ctranslate2-float16" | ||
| 123 | + ct2_compute_type: "float16" | ||
| 124 | + ct2_conversion_quantization: "float16" | ||
| 125 | + ct2_auto_convert: true | ||
| 119 | device: "cuda" | 126 | device: "cuda" |
| 120 | torch_dtype: "float16" | 127 | torch_dtype: "float16" |
| 121 | batch_size: 16 | 128 | batch_size: 16 |
| @@ -128,6 +135,10 @@ services: | @@ -128,6 +135,10 @@ services: | ||
| 128 | backend: "local_marian" | 135 | backend: "local_marian" |
| 129 | model_id: "Helsinki-NLP/opus-mt-en-zh" | 136 | model_id: "Helsinki-NLP/opus-mt-en-zh" |
| 130 | model_dir: "./models/translation/Helsinki-NLP/opus-mt-en-zh" | 137 | model_dir: "./models/translation/Helsinki-NLP/opus-mt-en-zh" |
| 138 | + ct2_model_dir: "./models/translation/Helsinki-NLP/opus-mt-en-zh/ctranslate2-float16" | ||
| 139 | + ct2_compute_type: "float16" | ||
| 140 | + ct2_conversion_quantization: "float16" | ||
| 141 | + ct2_auto_convert: true | ||
| 131 | device: "cuda" | 142 | device: "cuda" |
| 132 | torch_dtype: "float16" | 143 | torch_dtype: "float16" |
| 133 | batch_size: 16 | 144 | batch_size: 16 |
| @@ -148,6 +159,7 @@ services: | @@ -148,6 +159,7 @@ services: | ||
| 148 | - `service_url`、`default_model`、`default_scene` 只从 YAML 读取 | 159 | - `service_url`、`default_model`、`default_scene` 只从 YAML 读取 |
| 149 | - 不再通过环境变量静默覆盖翻译行为配置 | 160 | - 不再通过环境变量静默覆盖翻译行为配置 |
| 150 | - 密钥仍通过环境变量提供 | 161 | - 密钥仍通过环境变量提供 |
| 162 | +- `local_nllb` / `local_marian` 当前由 CTranslate2 运行;首次启动时若 `ct2_model_dir` 不存在,会从 `model_dir` 自动转换 | ||
| 151 | 163 | ||
| 152 | ## 4. 环境变量 | 164 | ## 4. 环境变量 |
| 153 | 165 | ||
| @@ -338,7 +350,7 @@ results = translator.translate( | @@ -338,7 +350,7 @@ results = translator.translate( | ||
| 338 | ### 8.4 `facebook/nllb-200-distilled-600M` | 350 | ### 8.4 `facebook/nllb-200-distilled-600M` |
| 339 | 351 | ||
| 340 | 实现文件: | 352 | 实现文件: |
| 341 | -- [`translation/backends/local_seq2seq.py`](/data/saas-search/translation/backends/local_seq2seq.py) | 353 | +- [`translation/backends/local_ctranslate2.py`](/data/saas-search/translation/backends/local_ctranslate2.py) |
| 342 | 354 | ||
| 343 | 模型信息: | 355 | 模型信息: |
| 344 | - Hugging Face 名称:`facebook/nllb-200-distilled-600M` | 356 | - Hugging Face 名称:`facebook/nllb-200-distilled-600M` |
| @@ -392,18 +404,17 @@ results = translator.translate( | @@ -392,18 +404,17 @@ results = translator.translate( | ||
| 392 | 404 | ||
| 393 | 当前实现特点: | 405 | 当前实现特点: |
| 394 | - backend 类型:`local_nllb` | 406 | - backend 类型:`local_nllb` |
| 407 | +- 运行时:`CTranslate2 Translator` | ||
| 395 | - 支持多语 | 408 | - 支持多语 |
| 396 | - 调用时必须显式传 `source_lang` | 409 | - 调用时必须显式传 `source_lang` |
| 397 | - 语言码映射定义在 [`translation/languages.py`](/data/saas-search/translation/languages.py) | 410 | - 语言码映射定义在 [`translation/languages.py`](/data/saas-search/translation/languages.py) |
| 398 | -- 当前 T4 推荐配置:`device=cuda`、`torch_dtype=float16`、`batch_size=16`、`max_new_tokens=64`、`attn_implementation=sdpa` | 411 | +- 当前 T4 推荐配置:`device=cuda`、`ct2_compute_type=float16`、`batch_size=16`、`max_new_tokens=64` |
| 399 | 412 | ||
| 400 | 当前实现已经利用的优化: | 413 | 当前实现已经利用的优化: |
| 401 | - 已做批量分块:`translate()` 会按 capability 的 `batch_size` 分批进入模型 | 414 | - 已做批量分块:`translate()` 会按 capability 的 `batch_size` 分批进入模型 |
| 402 | -- 已做动态 padding:tokenizer 使用 `padding=True`、`truncation=True` | ||
| 403 | -- 已传入 `attention_mask`:由 tokenizer 生成并随 `generate()` 一起送入模型 | ||
| 404 | -- 已设置方向控制:NLLB 通过 `tokenizer.src_lang` 和 `forced_bos_token_id` 指定语言对 | ||
| 405 | -- 已启用推理态:`torch.inference_mode()` + `model.eval()` | ||
| 406 | -- 已启用半精度和更优注意力实现:当前配置为 `float16 + sdpa` | 415 | +- 已切换到 CTranslate2 推理引擎:不再依赖 PyTorch `generate()` |
| 416 | +- 已设置方向控制:NLLB 通过 target prefix 指定目标语言 | ||
| 417 | +- 已启用半精度:当前配置为 `float16` | ||
| 407 | - 已关闭高开销搜索:默认 `num_beams=1`,更接近线上低延迟设置 | 418 | - 已关闭高开销搜索:默认 `num_beams=1`,更接近线上低延迟设置 |
| 408 | 419 | ||
| 409 | 和你给出的批处理示例对照: | 420 | 和你给出的批处理示例对照: |
| @@ -414,12 +425,12 @@ results = translator.translate( | @@ -414,12 +425,12 @@ results = translator.translate( | ||
| 414 | 优化空间(按场景): | 425 | 优化空间(按场景): |
| 415 | - **线上 query**:优先补测 `batch_size=1` 的真实延迟与 tail latency,而不是继续拉大 batch。 | 426 | - **线上 query**:优先补测 `batch_size=1` 的真实延迟与 tail latency,而不是继续拉大 batch。 |
| 416 | - **离线批量**:可再尝试更激进的 batching / 长度分桶 / 独立批处理队列(吞吐更高,但会增加在线尾延迟风险)。 | 427 | - **离线批量**:可再尝试更激进的 batching / 长度分桶 / 独立批处理队列(吞吐更高,但会增加在线尾延迟风险)。 |
| 417 | -- **进一步降显存 / 提速**:可评估 `ctranslate2` / int8;当前仓库尚未引入该运行栈。 | 428 | +- **进一步降显存 / 提速**:可在当前 CT2 方案上继续评估 `int8_float16`。 |
| 418 | 429 | ||
| 419 | ### 8.5 `opus-mt-zh-en` | 430 | ### 8.5 `opus-mt-zh-en` |
| 420 | 431 | ||
| 421 | 实现文件: | 432 | 实现文件: |
| 422 | -- [`translation/backends/local_seq2seq.py`](/data/saas-search/translation/backends/local_seq2seq.py) | 433 | +- [`translation/backends/local_ctranslate2.py`](/data/saas-search/translation/backends/local_ctranslate2.py) |
| 423 | 434 | ||
| 424 | 模型信息: | 435 | 模型信息: |
| 425 | - Hugging Face 名称:`Helsinki-NLP/opus-mt-zh-en` | 436 | - Hugging Face 名称:`Helsinki-NLP/opus-mt-zh-en` |
| @@ -437,7 +448,7 @@ results = translator.translate( | @@ -437,7 +448,7 @@ results = translator.translate( | ||
| 437 | ### 8.6 `opus-mt-en-zh` | 448 | ### 8.6 `opus-mt-en-zh` |
| 438 | 449 | ||
| 439 | 实现文件: | 450 | 实现文件: |
| 440 | -- [`translation/backends/local_seq2seq.py`](/data/saas-search/translation/backends/local_seq2seq.py) | 451 | +- [`translation/backends/local_ctranslate2.py`](/data/saas-search/translation/backends/local_ctranslate2.py) |
| 441 | 452 | ||
| 442 | 模型信息: | 453 | 模型信息: |
| 443 | - Hugging Face 名称:`Helsinki-NLP/opus-mt-en-zh` | 454 | - Hugging Face 名称:`Helsinki-NLP/opus-mt-en-zh` |
| @@ -498,7 +509,7 @@ models/translation/Helsinki-NLP/opus-mt-en-zh | @@ -498,7 +509,7 @@ models/translation/Helsinki-NLP/opus-mt-en-zh | ||
| 498 | - 避免多 worker 重复加载模型 | 509 | - 避免多 worker 重复加载模型 |
| 499 | - GPU 机器上优先使用 `cuda + float16` | 510 | - GPU 机器上优先使用 `cuda + float16` |
| 500 | - CPU 只建议用于功能验证或离线低频任务 | 511 | - CPU 只建议用于功能验证或离线低频任务 |
| 501 | -- 对 NLLB,T4 上优先采用 `batch_size=16 + max_new_tokens=64 + attn_implementation=sdpa` | 512 | +- 对 NLLB,T4 上优先采用 `batch_size=16 + max_new_tokens=64 + ct2_compute_type=float16` |
| 502 | 513 | ||
| 503 | ### 9.5 验证 | 514 | ### 9.5 验证 |
| 504 | 515 | ||
| @@ -524,6 +535,10 @@ curl -X POST http://127.0.0.1:6006/translate \ | @@ -524,6 +535,10 @@ curl -X POST http://127.0.0.1:6006/translate \ | ||
| 524 | 535 | ||
| 525 | ## 10. 性能测试与复现 | 536 | ## 10. 性能测试与复现 |
| 526 | 537 | ||
| 538 | +说明: | ||
| 539 | +- 本节现有数值是 `2026-03-18` 的 Hugging Face / PyTorch 基线结果。 | ||
| 540 | +- 切换到 CTranslate2 后需要重新跑一轮基准,尤其关注 `nllb-200-distilled-600m` 的单条延迟、并发 tail latency 和 `opus-mt-*` 的 batch throughput。 | ||
| 541 | + | ||
| 527 | 性能脚本: | 542 | 性能脚本: |
| 528 | - [`scripts/benchmark_translation_local_models.py`](/data/saas-search/scripts/benchmark_translation_local_models.py) | 543 | - [`scripts/benchmark_translation_local_models.py`](/data/saas-search/scripts/benchmark_translation_local_models.py) |
| 529 | 544 |
| @@ -0,0 +1,479 @@ | @@ -0,0 +1,479 @@ | ||
| 1 | +"""Local translation backends powered by CTranslate2.""" | ||
| 2 | + | ||
| 3 | +from __future__ import annotations | ||
| 4 | + | ||
| 5 | +import logging | ||
| 6 | +import os | ||
| 7 | +import shutil | ||
| 8 | +import subprocess | ||
| 9 | +import sys | ||
| 10 | +import threading | ||
| 11 | +from pathlib import Path | ||
| 12 | +from typing import Dict, List, Optional, Sequence, Union | ||
| 13 | + | ||
| 14 | +from transformers import AutoTokenizer | ||
| 15 | + | ||
| 16 | +from translation.languages import MARIAN_LANGUAGE_DIRECTIONS, NLLB_LANGUAGE_CODES | ||
| 17 | + | ||
| 18 | +logger = logging.getLogger(__name__) | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +def _resolve_device(device: Optional[str]) -> str: | ||
| 22 | + value = str(device or "auto").strip().lower() | ||
| 23 | + if value not in {"auto", "cpu", "cuda"}: | ||
| 24 | + raise ValueError(f"Unsupported CTranslate2 device: {device}") | ||
| 25 | + return value | ||
| 26 | + | ||
| 27 | + | ||
| 28 | +def _resolve_compute_type( | ||
| 29 | + torch_dtype: Optional[str], | ||
| 30 | + compute_type: Optional[str], | ||
| 31 | + device: str, | ||
| 32 | +) -> str: | ||
| 33 | + value = str(compute_type or torch_dtype or "default").strip().lower() | ||
| 34 | + if value in {"auto", "default"}: | ||
| 35 | + return "float16" if device == "cuda" else "default" | ||
| 36 | + if value in {"float16", "fp16", "half"}: | ||
| 37 | + return "float16" | ||
| 38 | + if value in {"bfloat16", "bf16"}: | ||
| 39 | + return "bfloat16" | ||
| 40 | + if value in {"float32", "fp32"}: | ||
| 41 | + return "float32" | ||
| 42 | + if value in { | ||
| 43 | + "int8", | ||
| 44 | + "int8_float32", | ||
| 45 | + "int8_float16", | ||
| 46 | + "int8_bfloat16", | ||
| 47 | + "int16", | ||
| 48 | + }: | ||
| 49 | + return value | ||
| 50 | + raise ValueError(f"Unsupported CTranslate2 compute type: {compute_type or torch_dtype}") | ||
| 51 | + | ||
| 52 | + | ||
| 53 | +def _derive_ct2_model_dir(model_dir: str, compute_type: str) -> str: | ||
| 54 | + normalized = compute_type.replace("_", "-") | ||
| 55 | + return str(Path(model_dir).expanduser() / f"ctranslate2-{normalized}") | ||
| 56 | + | ||
| 57 | + | ||
| 58 | +def _resolve_converter_binary() -> str: | ||
| 59 | + candidate = shutil.which("ct2-transformers-converter") | ||
| 60 | + if candidate: | ||
| 61 | + return candidate | ||
| 62 | + venv_candidate = Path(sys.executable).absolute().parent / "ct2-transformers-converter" | ||
| 63 | + if venv_candidate.exists(): | ||
| 64 | + return str(venv_candidate) | ||
| 65 | + raise RuntimeError( | ||
| 66 | + "ct2-transformers-converter was not found. " | ||
| 67 | + "Ensure ctranslate2 is installed in the active translator environment." | ||
| 68 | + ) | ||
| 69 | + | ||
| 70 | + | ||
| 71 | +class LocalCTranslate2TranslationBackend: | ||
| 72 | + """Base backend for local CTranslate2 translation models.""" | ||
| 73 | + | ||
| 74 | + def __init__( | ||
| 75 | + self, | ||
| 76 | + *, | ||
| 77 | + name: str, | ||
| 78 | + model_id: str, | ||
| 79 | + model_dir: str, | ||
| 80 | + device: str, | ||
| 81 | + torch_dtype: str, | ||
| 82 | + batch_size: int, | ||
| 83 | + max_input_length: int, | ||
| 84 | + max_new_tokens: int, | ||
| 85 | + num_beams: int, | ||
| 86 | + ct2_model_dir: Optional[str] = None, | ||
| 87 | + ct2_compute_type: Optional[str] = None, | ||
| 88 | + ct2_auto_convert: bool = True, | ||
| 89 | + ct2_conversion_quantization: Optional[str] = None, | ||
| 90 | + ct2_inter_threads: int = 1, | ||
| 91 | + ct2_intra_threads: int = 0, | ||
| 92 | + ct2_max_queued_batches: int = 0, | ||
| 93 | + ct2_batch_type: str = "examples", | ||
| 94 | + ) -> None: | ||
| 95 | + self.model = name | ||
| 96 | + self.model_id = model_id | ||
| 97 | + self.model_dir = model_dir | ||
| 98 | + self.device = _resolve_device(device) | ||
| 99 | + self.compute_type = _resolve_compute_type(torch_dtype, ct2_compute_type, self.device) | ||
| 100 | + self.batch_size = int(batch_size) | ||
| 101 | + self.max_input_length = int(max_input_length) | ||
| 102 | + self.max_new_tokens = int(max_new_tokens) | ||
| 103 | + self.num_beams = int(num_beams) | ||
| 104 | + self.ct2_model_dir = str(ct2_model_dir or _derive_ct2_model_dir(model_dir, self.compute_type)) | ||
| 105 | + self.ct2_auto_convert = bool(ct2_auto_convert) | ||
| 106 | + self.ct2_conversion_quantization = _resolve_compute_type( | ||
| 107 | + torch_dtype, | ||
| 108 | + ct2_conversion_quantization or self.compute_type, | ||
| 109 | + self.device, | ||
| 110 | + ) | ||
| 111 | + self.ct2_inter_threads = int(ct2_inter_threads) | ||
| 112 | + self.ct2_intra_threads = int(ct2_intra_threads) | ||
| 113 | + self.ct2_max_queued_batches = int(ct2_max_queued_batches) | ||
| 114 | + self.ct2_batch_type = str(ct2_batch_type or "examples").strip().lower() | ||
| 115 | + if self.ct2_batch_type not in {"examples", "tokens"}: | ||
| 116 | + raise ValueError(f"Unsupported CTranslate2 batch type: {ct2_batch_type}") | ||
| 117 | + self._tokenizer_lock = threading.Lock() | ||
| 118 | + self._load_runtime() | ||
| 119 | + | ||
| 120 | + @property | ||
| 121 | + def supports_batch(self) -> bool: | ||
| 122 | + return True | ||
| 123 | + | ||
| 124 | + def _tokenizer_source(self) -> str: | ||
| 125 | + return self.model_dir if os.path.exists(self.model_dir) else self.model_id | ||
| 126 | + | ||
| 127 | + def _model_source(self) -> str: | ||
| 128 | + return self.model_dir if os.path.exists(self.model_dir) else self.model_id | ||
| 129 | + | ||
| 130 | + def _tokenizer_kwargs(self) -> Dict[str, object]: | ||
| 131 | + return {} | ||
| 132 | + | ||
| 133 | + def _translator_kwargs(self) -> Dict[str, object]: | ||
| 134 | + return { | ||
| 135 | + "device": self.device, | ||
| 136 | + "compute_type": self.compute_type, | ||
| 137 | + "inter_threads": self.ct2_inter_threads, | ||
| 138 | + "intra_threads": self.ct2_intra_threads, | ||
| 139 | + "max_queued_batches": self.ct2_max_queued_batches, | ||
| 140 | + } | ||
| 141 | + | ||
| 142 | + def _load_runtime(self) -> None: | ||
| 143 | + try: | ||
| 144 | + import ctranslate2 | ||
| 145 | + except ImportError as exc: | ||
| 146 | + raise RuntimeError( | ||
| 147 | + "CTranslate2 is required for local Marian/NLLB translation. " | ||
| 148 | + "Install the translator service dependencies again after adding ctranslate2." | ||
| 149 | + ) from exc | ||
| 150 | + | ||
| 151 | + tokenizer_source = self._tokenizer_source() | ||
| 152 | + model_source = self._model_source() | ||
| 153 | + self._ensure_converted_model(model_source) | ||
| 154 | + logger.info( | ||
| 155 | + "Loading CTranslate2 translation model | name=%s ct2_model_dir=%s tokenizer=%s device=%s compute_type=%s", | ||
| 156 | + self.model, | ||
| 157 | + self.ct2_model_dir, | ||
| 158 | + tokenizer_source, | ||
| 159 | + self.device, | ||
| 160 | + self.compute_type, | ||
| 161 | + ) | ||
| 162 | + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_source, **self._tokenizer_kwargs()) | ||
| 163 | + self.translator = ctranslate2.Translator(self.ct2_model_dir, **self._translator_kwargs()) | ||
| 164 | + if self.tokenizer.pad_token is None and self.tokenizer.eos_token is not None: | ||
| 165 | + self.tokenizer.pad_token = self.tokenizer.eos_token | ||
| 166 | + | ||
| 167 | + def _ensure_converted_model(self, model_source: str) -> None: | ||
| 168 | + ct2_path = Path(self.ct2_model_dir).expanduser() | ||
| 169 | + if (ct2_path / "model.bin").exists(): | ||
| 170 | + return | ||
| 171 | + if not self.ct2_auto_convert: | ||
| 172 | + raise FileNotFoundError( | ||
| 173 | + f"CTranslate2 model not found for '{self.model}': {ct2_path}. " | ||
| 174 | + "Enable ct2_auto_convert or pre-convert the model." | ||
| 175 | + ) | ||
| 176 | + | ||
| 177 | + ct2_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 178 | + converter = _resolve_converter_binary() | ||
| 179 | + logger.info( | ||
| 180 | + "Converting translation model to CTranslate2 | name=%s source=%s output=%s quantization=%s", | ||
| 181 | + self.model, | ||
| 182 | + model_source, | ||
| 183 | + ct2_path, | ||
| 184 | + self.ct2_conversion_quantization, | ||
| 185 | + ) | ||
| 186 | + try: | ||
| 187 | + subprocess.run( | ||
| 188 | + [ | ||
| 189 | + converter, | ||
| 190 | + "--model", | ||
| 191 | + model_source, | ||
| 192 | + "--output_dir", | ||
| 193 | + str(ct2_path), | ||
| 194 | + "--quantization", | ||
| 195 | + self.ct2_conversion_quantization, | ||
| 196 | + ], | ||
| 197 | + check=True, | ||
| 198 | + stdout=subprocess.PIPE, | ||
| 199 | + stderr=subprocess.PIPE, | ||
| 200 | + text=True, | ||
| 201 | + ) | ||
| 202 | + except subprocess.CalledProcessError as exc: | ||
| 203 | + stderr = exc.stderr.strip() | ||
| 204 | + raise RuntimeError( | ||
| 205 | + f"Failed to convert model '{self.model}' to CTranslate2: {stderr or exc}" | ||
| 206 | + ) from exc | ||
| 207 | + | ||
| 208 | + def _normalize_texts(self, text: Union[str, Sequence[str]]) -> List[str]: | ||
| 209 | + if isinstance(text, str): | ||
| 210 | + return [text] | ||
| 211 | + return ["" if item is None else str(item) for item in text] | ||
| 212 | + | ||
| 213 | + def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: | ||
| 214 | + del source_lang, target_lang | ||
| 215 | + | ||
| 216 | + def _encode_source_tokens( | ||
| 217 | + self, | ||
| 218 | + texts: List[str], | ||
| 219 | + source_lang: Optional[str], | ||
| 220 | + target_lang: str, | ||
| 221 | + ) -> List[List[str]]: | ||
| 222 | + del source_lang, target_lang | ||
| 223 | + with self._tokenizer_lock: | ||
| 224 | + encoded = self.tokenizer( | ||
| 225 | + texts, | ||
| 226 | + truncation=True, | ||
| 227 | + max_length=self.max_input_length, | ||
| 228 | + padding=False, | ||
| 229 | + ) | ||
| 230 | + input_ids = encoded["input_ids"] | ||
| 231 | + return [self.tokenizer.convert_ids_to_tokens(ids) for ids in input_ids] | ||
| 232 | + | ||
| 233 | + def _target_prefixes( | ||
| 234 | + self, | ||
| 235 | + count: int, | ||
| 236 | + source_lang: Optional[str], | ||
| 237 | + target_lang: str, | ||
| 238 | + ) -> Optional[List[Optional[List[str]]]]: | ||
| 239 | + del count, source_lang, target_lang | ||
| 240 | + return None | ||
| 241 | + | ||
| 242 | + def _postprocess_hypothesis( | ||
| 243 | + self, | ||
| 244 | + tokens: List[str], | ||
| 245 | + source_lang: Optional[str], | ||
| 246 | + target_lang: str, | ||
| 247 | + ) -> List[str]: | ||
| 248 | + del source_lang, target_lang | ||
| 249 | + return tokens | ||
| 250 | + | ||
| 251 | + def _decode_tokens(self, tokens: List[str]) -> Optional[str]: | ||
| 252 | + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) | ||
| 253 | + text = self.tokenizer.decode(token_ids, skip_special_tokens=True).strip() | ||
| 254 | + return text or None | ||
| 255 | + | ||
| 256 | + def _translate_batch( | ||
| 257 | + self, | ||
| 258 | + texts: List[str], | ||
| 259 | + target_lang: str, | ||
| 260 | + source_lang: Optional[str] = None, | ||
| 261 | + ) -> List[Optional[str]]: | ||
| 262 | + self._validate_languages(source_lang, target_lang) | ||
| 263 | + source_tokens = self._encode_source_tokens(texts, source_lang, target_lang) | ||
| 264 | + target_prefix = self._target_prefixes(len(source_tokens), source_lang, target_lang) | ||
| 265 | + results = self.translator.translate_batch( | ||
| 266 | + source_tokens, | ||
| 267 | + target_prefix=target_prefix, | ||
| 268 | + max_batch_size=self.batch_size, | ||
| 269 | + batch_type=self.ct2_batch_type, | ||
| 270 | + beam_size=self.num_beams, | ||
| 271 | + max_input_length=self.max_input_length, | ||
| 272 | + max_decoding_length=self.max_new_tokens, | ||
| 273 | + ) | ||
| 274 | + outputs: List[Optional[str]] = [] | ||
| 275 | + for result in results: | ||
| 276 | + hypothesis = result.hypotheses[0] if result.hypotheses else [] | ||
| 277 | + processed = self._postprocess_hypothesis(hypothesis, source_lang, target_lang) | ||
| 278 | + outputs.append(self._decode_tokens(processed)) | ||
| 279 | + return outputs | ||
| 280 | + | ||
| 281 | + def translate( | ||
| 282 | + self, | ||
| 283 | + text: Union[str, Sequence[str]], | ||
| 284 | + target_lang: str, | ||
| 285 | + source_lang: Optional[str] = None, | ||
| 286 | + scene: Optional[str] = None, | ||
| 287 | + ) -> Union[Optional[str], List[Optional[str]]]: | ||
| 288 | + del scene | ||
| 289 | + is_single = isinstance(text, str) | ||
| 290 | + texts = self._normalize_texts(text) | ||
| 291 | + outputs: List[Optional[str]] = [] | ||
| 292 | + for start in range(0, len(texts), self.batch_size): | ||
| 293 | + chunk = texts[start:start + self.batch_size] | ||
| 294 | + if not any(item.strip() for item in chunk): | ||
| 295 | + outputs.extend([None if not item.strip() else item for item in chunk]) # type: ignore[list-item] | ||
| 296 | + continue | ||
| 297 | + outputs.extend(self._translate_batch(chunk, target_lang=target_lang, source_lang=source_lang)) | ||
| 298 | + return outputs[0] if is_single else outputs | ||
| 299 | + | ||
| 300 | + | ||
| 301 | +class MarianCTranslate2TranslationBackend(LocalCTranslate2TranslationBackend): | ||
| 302 | + """Local backend for Marian/OPUS MT models on CTranslate2.""" | ||
| 303 | + | ||
| 304 | + def __init__( | ||
| 305 | + self, | ||
| 306 | + *, | ||
| 307 | + name: str, | ||
| 308 | + model_id: str, | ||
| 309 | + model_dir: str, | ||
| 310 | + device: str, | ||
| 311 | + torch_dtype: str, | ||
| 312 | + batch_size: int, | ||
| 313 | + max_input_length: int, | ||
| 314 | + max_new_tokens: int, | ||
| 315 | + num_beams: int, | ||
| 316 | + source_langs: Sequence[str], | ||
| 317 | + target_langs: Sequence[str], | ||
| 318 | + ct2_model_dir: Optional[str] = None, | ||
| 319 | + ct2_compute_type: Optional[str] = None, | ||
| 320 | + ct2_auto_convert: bool = True, | ||
| 321 | + ct2_conversion_quantization: Optional[str] = None, | ||
| 322 | + ct2_inter_threads: int = 1, | ||
| 323 | + ct2_intra_threads: int = 0, | ||
| 324 | + ct2_max_queued_batches: int = 0, | ||
| 325 | + ct2_batch_type: str = "examples", | ||
| 326 | + ) -> None: | ||
| 327 | + self.source_langs = {str(lang).strip().lower() for lang in source_langs if str(lang).strip()} | ||
| 328 | + self.target_langs = {str(lang).strip().lower() for lang in target_langs if str(lang).strip()} | ||
| 329 | + super().__init__( | ||
| 330 | + name=name, | ||
| 331 | + model_id=model_id, | ||
| 332 | + model_dir=model_dir, | ||
| 333 | + device=device, | ||
| 334 | + torch_dtype=torch_dtype, | ||
| 335 | + batch_size=batch_size, | ||
| 336 | + max_input_length=max_input_length, | ||
| 337 | + max_new_tokens=max_new_tokens, | ||
| 338 | + num_beams=num_beams, | ||
| 339 | + ct2_model_dir=ct2_model_dir, | ||
| 340 | + ct2_compute_type=ct2_compute_type, | ||
| 341 | + ct2_auto_convert=ct2_auto_convert, | ||
| 342 | + ct2_conversion_quantization=ct2_conversion_quantization, | ||
| 343 | + ct2_inter_threads=ct2_inter_threads, | ||
| 344 | + ct2_intra_threads=ct2_intra_threads, | ||
| 345 | + ct2_max_queued_batches=ct2_max_queued_batches, | ||
| 346 | + ct2_batch_type=ct2_batch_type, | ||
| 347 | + ) | ||
| 348 | + | ||
| 349 | + def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: | ||
| 350 | + src = str(source_lang or "").strip().lower() | ||
| 351 | + tgt = str(target_lang or "").strip().lower() | ||
| 352 | + if self.source_langs and src not in self.source_langs: | ||
| 353 | + raise ValueError( | ||
| 354 | + f"Model '{self.model}' only supports source languages: {sorted(self.source_langs)}" | ||
| 355 | + ) | ||
| 356 | + if self.target_langs and tgt not in self.target_langs: | ||
| 357 | + raise ValueError( | ||
| 358 | + f"Model '{self.model}' only supports target languages: {sorted(self.target_langs)}" | ||
| 359 | + ) | ||
| 360 | + | ||
| 361 | + | ||
| 362 | +class NLLBCTranslate2TranslationBackend(LocalCTranslate2TranslationBackend): | ||
| 363 | + """Local backend for NLLB models on CTranslate2.""" | ||
| 364 | + | ||
| 365 | + def __init__( | ||
| 366 | + self, | ||
| 367 | + *, | ||
| 368 | + name: str, | ||
| 369 | + model_id: str, | ||
| 370 | + model_dir: str, | ||
| 371 | + device: str, | ||
| 372 | + torch_dtype: str, | ||
| 373 | + batch_size: int, | ||
| 374 | + max_input_length: int, | ||
| 375 | + max_new_tokens: int, | ||
| 376 | + num_beams: int, | ||
| 377 | + language_codes: Optional[Dict[str, str]] = None, | ||
| 378 | + ct2_model_dir: Optional[str] = None, | ||
| 379 | + ct2_compute_type: Optional[str] = None, | ||
| 380 | + ct2_auto_convert: bool = True, | ||
| 381 | + ct2_conversion_quantization: Optional[str] = None, | ||
| 382 | + ct2_inter_threads: int = 1, | ||
| 383 | + ct2_intra_threads: int = 0, | ||
| 384 | + ct2_max_queued_batches: int = 0, | ||
| 385 | + ct2_batch_type: str = "examples", | ||
| 386 | + ) -> None: | ||
| 387 | + overrides = language_codes or {} | ||
| 388 | + self.language_codes = { | ||
| 389 | + **NLLB_LANGUAGE_CODES, | ||
| 390 | + **{str(k).strip().lower(): str(v).strip() for k, v in overrides.items() if str(k).strip()}, | ||
| 391 | + } | ||
| 392 | + self._tokenizers_by_source: Dict[str, object] = {} | ||
| 393 | + super().__init__( | ||
| 394 | + name=name, | ||
| 395 | + model_id=model_id, | ||
| 396 | + model_dir=model_dir, | ||
| 397 | + device=device, | ||
| 398 | + torch_dtype=torch_dtype, | ||
| 399 | + batch_size=batch_size, | ||
| 400 | + max_input_length=max_input_length, | ||
| 401 | + max_new_tokens=max_new_tokens, | ||
| 402 | + num_beams=num_beams, | ||
| 403 | + ct2_model_dir=ct2_model_dir, | ||
| 404 | + ct2_compute_type=ct2_compute_type, | ||
| 405 | + ct2_auto_convert=ct2_auto_convert, | ||
| 406 | + ct2_conversion_quantization=ct2_conversion_quantization, | ||
| 407 | + ct2_inter_threads=ct2_inter_threads, | ||
| 408 | + ct2_intra_threads=ct2_intra_threads, | ||
| 409 | + ct2_max_queued_batches=ct2_max_queued_batches, | ||
| 410 | + ct2_batch_type=ct2_batch_type, | ||
| 411 | + ) | ||
| 412 | + | ||
| 413 | + def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: | ||
| 414 | + src = str(source_lang or "").strip().lower() | ||
| 415 | + tgt = str(target_lang or "").strip().lower() | ||
| 416 | + if not src: | ||
| 417 | + raise ValueError(f"Model '{self.model}' requires source_lang") | ||
| 418 | + if src not in self.language_codes: | ||
| 419 | + raise ValueError(f"Unsupported NLLB source language: {source_lang}") | ||
| 420 | + if tgt not in self.language_codes: | ||
| 421 | + raise ValueError(f"Unsupported NLLB target language: {target_lang}") | ||
| 422 | + | ||
| 423 | + def _get_tokenizer_for_source(self, source_lang: str): | ||
| 424 | + src_code = self.language_codes[source_lang] | ||
| 425 | + with self._tokenizer_lock: | ||
| 426 | + tokenizer = self._tokenizers_by_source.get(src_code) | ||
| 427 | + if tokenizer is None: | ||
| 428 | + tokenizer = AutoTokenizer.from_pretrained(self._tokenizer_source(), src_lang=src_code) | ||
| 429 | + if tokenizer.pad_token is None and tokenizer.eos_token is not None: | ||
| 430 | + tokenizer.pad_token = tokenizer.eos_token | ||
| 431 | + self._tokenizers_by_source[src_code] = tokenizer | ||
| 432 | + return tokenizer | ||
| 433 | + | ||
| 434 | + def _encode_source_tokens( | ||
| 435 | + self, | ||
| 436 | + texts: List[str], | ||
| 437 | + source_lang: Optional[str], | ||
| 438 | + target_lang: str, | ||
| 439 | + ) -> List[List[str]]: | ||
| 440 | + del target_lang | ||
| 441 | + source_key = str(source_lang or "").strip().lower() | ||
| 442 | + tokenizer = self._get_tokenizer_for_source(source_key) | ||
| 443 | + encoded = tokenizer( | ||
| 444 | + texts, | ||
| 445 | + truncation=True, | ||
| 446 | + max_length=self.max_input_length, | ||
| 447 | + padding=False, | ||
| 448 | + ) | ||
| 449 | + input_ids = encoded["input_ids"] | ||
| 450 | + return [tokenizer.convert_ids_to_tokens(ids) for ids in input_ids] | ||
| 451 | + | ||
| 452 | + def _target_prefixes( | ||
| 453 | + self, | ||
| 454 | + count: int, | ||
| 455 | + source_lang: Optional[str], | ||
| 456 | + target_lang: str, | ||
| 457 | + ) -> Optional[List[Optional[List[str]]]]: | ||
| 458 | + del source_lang | ||
| 459 | + tgt_code = self.language_codes[str(target_lang).strip().lower()] | ||
| 460 | + return [[tgt_code] for _ in range(count)] | ||
| 461 | + | ||
| 462 | + def _postprocess_hypothesis( | ||
| 463 | + self, | ||
| 464 | + tokens: List[str], | ||
| 465 | + source_lang: Optional[str], | ||
| 466 | + target_lang: str, | ||
| 467 | + ) -> List[str]: | ||
| 468 | + del source_lang | ||
| 469 | + tgt_code = self.language_codes[str(target_lang).strip().lower()] | ||
| 470 | + if tokens and tokens[0] == tgt_code: | ||
| 471 | + return tokens[1:] | ||
| 472 | + return tokens | ||
| 473 | + | ||
| 474 | + | ||
| 475 | +def get_marian_language_direction(model_name: str) -> tuple[str, str]: | ||
| 476 | + direction = MARIAN_LANGUAGE_DIRECTIONS.get(model_name) | ||
| 477 | + if direction is None: | ||
| 478 | + raise ValueError(f"Translation capability '{model_name}' is not registered with Marian language directions") | ||
| 479 | + return direction |
translation/service.py
| @@ -111,9 +111,9 @@ class TranslationService: | @@ -111,9 +111,9 @@ class TranslationService: | ||
| 111 | ) | 111 | ) |
| 112 | 112 | ||
| 113 | def _create_local_nllb_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol: | 113 | def _create_local_nllb_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol: |
| 114 | - from translation.backends.local_seq2seq import NLLBTranslationBackend | 114 | + from translation.backends.local_ctranslate2 import NLLBCTranslate2TranslationBackend |
| 115 | 115 | ||
| 116 | - return NLLBTranslationBackend( | 116 | + return NLLBCTranslate2TranslationBackend( |
| 117 | name=name, | 117 | name=name, |
| 118 | model_id=str(cfg["model_id"]).strip(), | 118 | model_id=str(cfg["model_id"]).strip(), |
| 119 | model_dir=str(cfg["model_dir"]).strip(), | 119 | model_dir=str(cfg["model_dir"]).strip(), |
| @@ -123,15 +123,22 @@ class TranslationService: | @@ -123,15 +123,22 @@ class TranslationService: | ||
| 123 | max_input_length=int(cfg["max_input_length"]), | 123 | max_input_length=int(cfg["max_input_length"]), |
| 124 | max_new_tokens=int(cfg["max_new_tokens"]), | 124 | max_new_tokens=int(cfg["max_new_tokens"]), |
| 125 | num_beams=int(cfg["num_beams"]), | 125 | num_beams=int(cfg["num_beams"]), |
| 126 | - attn_implementation=cfg.get("attn_implementation"), | 126 | + ct2_model_dir=cfg.get("ct2_model_dir"), |
| 127 | + ct2_compute_type=cfg.get("ct2_compute_type"), | ||
| 128 | + ct2_auto_convert=bool(cfg.get("ct2_auto_convert", True)), | ||
| 129 | + ct2_conversion_quantization=cfg.get("ct2_conversion_quantization"), | ||
| 130 | + ct2_inter_threads=int(cfg.get("ct2_inter_threads", 1)), | ||
| 131 | + ct2_intra_threads=int(cfg.get("ct2_intra_threads", 0)), | ||
| 132 | + ct2_max_queued_batches=int(cfg.get("ct2_max_queued_batches", 0)), | ||
| 133 | + ct2_batch_type=str(cfg.get("ct2_batch_type", "examples")), | ||
| 127 | ) | 134 | ) |
| 128 | 135 | ||
| 129 | def _create_local_marian_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol: | 136 | def _create_local_marian_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol: |
| 130 | - from translation.backends.local_seq2seq import MarianMTTranslationBackend, get_marian_language_direction | 137 | + from translation.backends.local_ctranslate2 import MarianCTranslate2TranslationBackend, get_marian_language_direction |
| 131 | 138 | ||
| 132 | source_lang, target_lang = get_marian_language_direction(name) | 139 | source_lang, target_lang = get_marian_language_direction(name) |
| 133 | 140 | ||
| 134 | - return MarianMTTranslationBackend( | 141 | + return MarianCTranslate2TranslationBackend( |
| 135 | name=name, | 142 | name=name, |
| 136 | model_id=str(cfg["model_id"]).strip(), | 143 | model_id=str(cfg["model_id"]).strip(), |
| 137 | model_dir=str(cfg["model_dir"]).strip(), | 144 | model_dir=str(cfg["model_dir"]).strip(), |
| @@ -143,7 +150,14 @@ class TranslationService: | @@ -143,7 +150,14 @@ class TranslationService: | ||
| 143 | num_beams=int(cfg["num_beams"]), | 150 | num_beams=int(cfg["num_beams"]), |
| 144 | source_langs=[source_lang], | 151 | source_langs=[source_lang], |
| 145 | target_langs=[target_lang], | 152 | target_langs=[target_lang], |
| 146 | - attn_implementation=cfg.get("attn_implementation"), | 153 | + ct2_model_dir=cfg.get("ct2_model_dir"), |
| 154 | + ct2_compute_type=cfg.get("ct2_compute_type"), | ||
| 155 | + ct2_auto_convert=bool(cfg.get("ct2_auto_convert", True)), | ||
| 156 | + ct2_conversion_quantization=cfg.get("ct2_conversion_quantization"), | ||
| 157 | + ct2_inter_threads=int(cfg.get("ct2_inter_threads", 1)), | ||
| 158 | + ct2_intra_threads=int(cfg.get("ct2_intra_threads", 0)), | ||
| 159 | + ct2_max_queued_batches=int(cfg.get("ct2_max_queued_batches", 0)), | ||
| 160 | + ct2_batch_type=str(cfg.get("ct2_batch_type", "examples")), | ||
| 147 | ) | 161 | ) |
| 148 | 162 | ||
| 149 | @property | 163 | @property |