diff --git a/config/config.yaml b/config/config.yaml index 8d026e3..bfa7eb5 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -135,19 +135,34 @@ services: backend: "local_nllb" model_id: "facebook/nllb-200-distilled-600M" model_dir: "./models/translation/facebook/nllb-200-distilled-600M" + ct2_model_dir: "./models/translation/facebook/nllb-200-distilled-600M/ctranslate2-float16" + ct2_compute_type: "float16" + ct2_conversion_quantization: "float16" + ct2_auto_convert: true + ct2_inter_threads: 1 + ct2_intra_threads: 0 + ct2_max_queued_batches: 0 + ct2_batch_type: "examples" device: "cuda" torch_dtype: "float16" batch_size: 16 max_input_length: 256 max_new_tokens: 64 num_beams: 1 - attn_implementation: "sdpa" use_cache: true opus-mt-zh-en: enabled: true backend: "local_marian" model_id: "Helsinki-NLP/opus-mt-zh-en" model_dir: "./models/translation/Helsinki-NLP/opus-mt-zh-en" + ct2_model_dir: "./models/translation/Helsinki-NLP/opus-mt-zh-en/ctranslate2-float16" + ct2_compute_type: "float16" + ct2_conversion_quantization: "float16" + ct2_auto_convert: true + ct2_inter_threads: 1 + ct2_intra_threads: 0 + ct2_max_queued_batches: 0 + ct2_batch_type: "examples" device: "cuda" torch_dtype: "float16" batch_size: 16 @@ -160,6 +175,14 @@ services: backend: "local_marian" model_id: "Helsinki-NLP/opus-mt-en-zh" model_dir: "./models/translation/Helsinki-NLP/opus-mt-en-zh" + ct2_model_dir: "./models/translation/Helsinki-NLP/opus-mt-en-zh/ctranslate2-float16" + ct2_compute_type: "float16" + ct2_conversion_quantization: "float16" + ct2_auto_convert: true + ct2_inter_threads: 1 + ct2_intra_threads: 0 + ct2_max_queued_batches: 0 + ct2_batch_type: "examples" device: "cuda" torch_dtype: "float16" batch_size: 16 diff --git a/requirements.txt b/requirements.txt index e6518e7..449fd35 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,6 +32,7 @@ anyio>=3.7.0 # Translation requests>=2.31.0 +ctranslate2>=4.7.0 # Utilities tqdm>=4.65.0 diff --git a/requirements_translator_service.txt b/requirements_translator_service.txt index a6d72df..e8b8f18 100644 --- a/requirements_translator_service.txt +++ b/requirements_translator_service.txt @@ -14,6 +14,7 @@ tqdm>=4.65.0 torch>=2.0.0 transformers>=4.30.0 +ctranslate2>=4.7.0 sentencepiece>=0.2.0 sacremoses>=0.1.1 safetensors>=0.4.0 diff --git a/scripts/download_translation_models.py b/scripts/download_translation_models.py index 7f6558f..a6fcba4 100755 --- a/scripts/download_translation_models.py +++ b/scripts/download_translation_models.py @@ -4,8 +4,10 @@ from __future__ import annotations import argparse -from pathlib import Path import os +from pathlib import Path +import shutil +import subprocess import sys from typing import Iterable @@ -24,7 +26,8 @@ LOCAL_BACKENDS = {"local_nllb", "local_marian"} def iter_local_capabilities(selected: set[str] | None = None) -> Iterable[tuple[str, dict]]: cfg = get_translation_config() - for name, capability in cfg.capabilities.items(): + capabilities = cfg.get("capabilities", {}) if isinstance(cfg, dict) else {} + for name, capability in capabilities.items(): backend = str(capability.get("backend") or "").strip().lower() if backend not in LOCAL_BACKENDS: continue @@ -33,10 +36,69 @@ def iter_local_capabilities(selected: set[str] | None = None) -> Iterable[tuple[ yield name, capability +def _compute_ct2_output_dir(capability: dict) -> Path: + custom = str(capability.get("ct2_model_dir") or "").strip() + if custom: + return Path(custom).expanduser() + model_dir = Path(str(capability.get("model_dir") or "")).expanduser() + compute_type = str(capability.get("ct2_compute_type") or capability.get("torch_dtype") or "default").strip().lower() + normalized = compute_type.replace("_", "-") + return model_dir / f"ctranslate2-{normalized}" + + +def _resolve_converter_binary() -> str: + candidate = shutil.which("ct2-transformers-converter") + if candidate: + return candidate + venv_candidate = Path(sys.executable).absolute().parent / "ct2-transformers-converter" + if venv_candidate.exists(): + return str(venv_candidate) + raise RuntimeError( + "ct2-transformers-converter was not found. " + "Install ctranslate2 in the active Python environment first." + ) + + +def convert_to_ctranslate2(name: str, capability: dict) -> None: + model_id = str(capability.get("model_id") or "").strip() + model_dir = Path(str(capability.get("model_dir") or "")).expanduser() + model_source = str(model_dir if model_dir.exists() else model_id) + output_dir = _compute_ct2_output_dir(capability) + if (output_dir / "model.bin").exists(): + print(f"[skip-convert] {name} -> {output_dir}") + return + quantization = str( + capability.get("ct2_conversion_quantization") + or capability.get("ct2_compute_type") + or capability.get("torch_dtype") + or "default" + ).strip() + output_dir.parent.mkdir(parents=True, exist_ok=True) + print(f"[convert] {name} -> {output_dir} ({quantization})") + subprocess.run( + [ + _resolve_converter_binary(), + "--model", + model_source, + "--output_dir", + str(output_dir), + "--quantization", + quantization, + ], + check=True, + ) + print(f"[converted] {name}") + + def main() -> None: parser = argparse.ArgumentParser(description="Download local translation models") parser.add_argument("--all-local", action="store_true", help="Download all configured local translation models") parser.add_argument("--models", nargs="*", default=[], help="Specific capability names to download") + parser.add_argument( + "--convert-ctranslate2", + action="store_true", + help="Also convert the downloaded Hugging Face models into CTranslate2 format", + ) args = parser.parse_args() selected = {item.strip().lower() for item in args.models if item.strip()} or None @@ -55,6 +117,8 @@ def main() -> None: local_dir=str(model_dir), ) print(f"[done] {name}") + if args.convert_ctranslate2: + convert_to_ctranslate2(name, capability) if __name__ == "__main__": diff --git a/translation/README.md b/translation/README.md index db780c8..47dd79d 100644 --- a/translation/README.md +++ b/translation/README.md @@ -56,8 +56,8 @@ 通用 LLM 翻译 - [`translation/backends/deepl.py`](/data/saas-search/translation/backends/deepl.py) DeepL 翻译 -- [`translation/backends/local_seq2seq.py`](/data/saas-search/translation/backends/local_seq2seq.py) - 本地 Hugging Face Seq2Seq 模型,包括 NLLB 和 Marian/OPUS MT +- [`translation/backends/local_ctranslate2.py`](/data/saas-search/translation/backends/local_ctranslate2.py) + 本地 CTranslate2 翻译模型,包括 NLLB 和 Marian/OPUS MT ## 3. 配置约定 @@ -103,19 +103,26 @@ services: backend: "local_nllb" model_id: "facebook/nllb-200-distilled-600M" model_dir: "./models/translation/facebook/nllb-200-distilled-600M" + ct2_model_dir: "./models/translation/facebook/nllb-200-distilled-600M/ctranslate2-float16" + ct2_compute_type: "float16" + ct2_conversion_quantization: "float16" + ct2_auto_convert: true device: "cuda" torch_dtype: "float16" batch_size: 16 max_input_length: 256 max_new_tokens: 64 num_beams: 1 - attn_implementation: "sdpa" use_cache: true opus-mt-zh-en: enabled: true backend: "local_marian" model_id: "Helsinki-NLP/opus-mt-zh-en" model_dir: "./models/translation/Helsinki-NLP/opus-mt-zh-en" + ct2_model_dir: "./models/translation/Helsinki-NLP/opus-mt-zh-en/ctranslate2-float16" + ct2_compute_type: "float16" + ct2_conversion_quantization: "float16" + ct2_auto_convert: true device: "cuda" torch_dtype: "float16" batch_size: 16 @@ -128,6 +135,10 @@ services: backend: "local_marian" model_id: "Helsinki-NLP/opus-mt-en-zh" model_dir: "./models/translation/Helsinki-NLP/opus-mt-en-zh" + ct2_model_dir: "./models/translation/Helsinki-NLP/opus-mt-en-zh/ctranslate2-float16" + ct2_compute_type: "float16" + ct2_conversion_quantization: "float16" + ct2_auto_convert: true device: "cuda" torch_dtype: "float16" batch_size: 16 @@ -148,6 +159,7 @@ services: - `service_url`、`default_model`、`default_scene` 只从 YAML 读取 - 不再通过环境变量静默覆盖翻译行为配置 - 密钥仍通过环境变量提供 +- `local_nllb` / `local_marian` 当前由 CTranslate2 运行;首次启动时若 `ct2_model_dir` 不存在,会从 `model_dir` 自动转换 ## 4. 环境变量 @@ -338,7 +350,7 @@ results = translator.translate( ### 8.4 `facebook/nllb-200-distilled-600M` 实现文件: -- [`translation/backends/local_seq2seq.py`](/data/saas-search/translation/backends/local_seq2seq.py) +- [`translation/backends/local_ctranslate2.py`](/data/saas-search/translation/backends/local_ctranslate2.py) 模型信息: - Hugging Face 名称:`facebook/nllb-200-distilled-600M` @@ -392,18 +404,17 @@ results = translator.translate( 当前实现特点: - backend 类型:`local_nllb` +- 运行时:`CTranslate2 Translator` - 支持多语 - 调用时必须显式传 `source_lang` - 语言码映射定义在 [`translation/languages.py`](/data/saas-search/translation/languages.py) -- 当前 T4 推荐配置:`device=cuda`、`torch_dtype=float16`、`batch_size=16`、`max_new_tokens=64`、`attn_implementation=sdpa` +- 当前 T4 推荐配置:`device=cuda`、`ct2_compute_type=float16`、`batch_size=16`、`max_new_tokens=64` 当前实现已经利用的优化: - 已做批量分块:`translate()` 会按 capability 的 `batch_size` 分批进入模型 -- 已做动态 padding:tokenizer 使用 `padding=True`、`truncation=True` -- 已传入 `attention_mask`:由 tokenizer 生成并随 `generate()` 一起送入模型 -- 已设置方向控制:NLLB 通过 `tokenizer.src_lang` 和 `forced_bos_token_id` 指定语言对 -- 已启用推理态:`torch.inference_mode()` + `model.eval()` -- 已启用半精度和更优注意力实现:当前配置为 `float16 + sdpa` +- 已切换到 CTranslate2 推理引擎:不再依赖 PyTorch `generate()` +- 已设置方向控制:NLLB 通过 target prefix 指定目标语言 +- 已启用半精度:当前配置为 `float16` - 已关闭高开销搜索:默认 `num_beams=1`,更接近线上低延迟设置 和你给出的批处理示例对照: @@ -414,12 +425,12 @@ results = translator.translate( 优化空间(按场景): - **线上 query**:优先补测 `batch_size=1` 的真实延迟与 tail latency,而不是继续拉大 batch。 - **离线批量**:可再尝试更激进的 batching / 长度分桶 / 独立批处理队列(吞吐更高,但会增加在线尾延迟风险)。 -- **进一步降显存 / 提速**:可评估 `ctranslate2` / int8;当前仓库尚未引入该运行栈。 +- **进一步降显存 / 提速**:可在当前 CT2 方案上继续评估 `int8_float16`。 ### 8.5 `opus-mt-zh-en` 实现文件: -- [`translation/backends/local_seq2seq.py`](/data/saas-search/translation/backends/local_seq2seq.py) +- [`translation/backends/local_ctranslate2.py`](/data/saas-search/translation/backends/local_ctranslate2.py) 模型信息: - Hugging Face 名称:`Helsinki-NLP/opus-mt-zh-en` @@ -437,7 +448,7 @@ results = translator.translate( ### 8.6 `opus-mt-en-zh` 实现文件: -- [`translation/backends/local_seq2seq.py`](/data/saas-search/translation/backends/local_seq2seq.py) +- [`translation/backends/local_ctranslate2.py`](/data/saas-search/translation/backends/local_ctranslate2.py) 模型信息: - Hugging Face 名称:`Helsinki-NLP/opus-mt-en-zh` @@ -498,7 +509,7 @@ models/translation/Helsinki-NLP/opus-mt-en-zh - 避免多 worker 重复加载模型 - GPU 机器上优先使用 `cuda + float16` - CPU 只建议用于功能验证或离线低频任务 -- 对 NLLB,T4 上优先采用 `batch_size=16 + max_new_tokens=64 + attn_implementation=sdpa` +- 对 NLLB,T4 上优先采用 `batch_size=16 + max_new_tokens=64 + ct2_compute_type=float16` ### 9.5 验证 @@ -524,6 +535,10 @@ curl -X POST http://127.0.0.1:6006/translate \ ## 10. 性能测试与复现 +说明: +- 本节现有数值是 `2026-03-18` 的 Hugging Face / PyTorch 基线结果。 +- 切换到 CTranslate2 后需要重新跑一轮基准,尤其关注 `nllb-200-distilled-600m` 的单条延迟、并发 tail latency 和 `opus-mt-*` 的 batch throughput。 + 性能脚本: - [`scripts/benchmark_translation_local_models.py`](/data/saas-search/scripts/benchmark_translation_local_models.py) diff --git a/translation/backends/local_ctranslate2.py b/translation/backends/local_ctranslate2.py new file mode 100644 index 0000000..fef3873 --- /dev/null +++ b/translation/backends/local_ctranslate2.py @@ -0,0 +1,479 @@ +"""Local translation backends powered by CTranslate2.""" + +from __future__ import annotations + +import logging +import os +import shutil +import subprocess +import sys +import threading +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Union + +from transformers import AutoTokenizer + +from translation.languages import MARIAN_LANGUAGE_DIRECTIONS, NLLB_LANGUAGE_CODES + +logger = logging.getLogger(__name__) + + +def _resolve_device(device: Optional[str]) -> str: + value = str(device or "auto").strip().lower() + if value not in {"auto", "cpu", "cuda"}: + raise ValueError(f"Unsupported CTranslate2 device: {device}") + return value + + +def _resolve_compute_type( + torch_dtype: Optional[str], + compute_type: Optional[str], + device: str, +) -> str: + value = str(compute_type or torch_dtype or "default").strip().lower() + if value in {"auto", "default"}: + return "float16" if device == "cuda" else "default" + if value in {"float16", "fp16", "half"}: + return "float16" + if value in {"bfloat16", "bf16"}: + return "bfloat16" + if value in {"float32", "fp32"}: + return "float32" + if value in { + "int8", + "int8_float32", + "int8_float16", + "int8_bfloat16", + "int16", + }: + return value + raise ValueError(f"Unsupported CTranslate2 compute type: {compute_type or torch_dtype}") + + +def _derive_ct2_model_dir(model_dir: str, compute_type: str) -> str: + normalized = compute_type.replace("_", "-") + return str(Path(model_dir).expanduser() / f"ctranslate2-{normalized}") + + +def _resolve_converter_binary() -> str: + candidate = shutil.which("ct2-transformers-converter") + if candidate: + return candidate + venv_candidate = Path(sys.executable).absolute().parent / "ct2-transformers-converter" + if venv_candidate.exists(): + return str(venv_candidate) + raise RuntimeError( + "ct2-transformers-converter was not found. " + "Ensure ctranslate2 is installed in the active translator environment." + ) + + +class LocalCTranslate2TranslationBackend: + """Base backend for local CTranslate2 translation models.""" + + def __init__( + self, + *, + name: str, + model_id: str, + model_dir: str, + device: str, + torch_dtype: str, + batch_size: int, + max_input_length: int, + max_new_tokens: int, + num_beams: int, + ct2_model_dir: Optional[str] = None, + ct2_compute_type: Optional[str] = None, + ct2_auto_convert: bool = True, + ct2_conversion_quantization: Optional[str] = None, + ct2_inter_threads: int = 1, + ct2_intra_threads: int = 0, + ct2_max_queued_batches: int = 0, + ct2_batch_type: str = "examples", + ) -> None: + self.model = name + self.model_id = model_id + self.model_dir = model_dir + self.device = _resolve_device(device) + self.compute_type = _resolve_compute_type(torch_dtype, ct2_compute_type, self.device) + self.batch_size = int(batch_size) + self.max_input_length = int(max_input_length) + self.max_new_tokens = int(max_new_tokens) + self.num_beams = int(num_beams) + self.ct2_model_dir = str(ct2_model_dir or _derive_ct2_model_dir(model_dir, self.compute_type)) + self.ct2_auto_convert = bool(ct2_auto_convert) + self.ct2_conversion_quantization = _resolve_compute_type( + torch_dtype, + ct2_conversion_quantization or self.compute_type, + self.device, + ) + self.ct2_inter_threads = int(ct2_inter_threads) + self.ct2_intra_threads = int(ct2_intra_threads) + self.ct2_max_queued_batches = int(ct2_max_queued_batches) + self.ct2_batch_type = str(ct2_batch_type or "examples").strip().lower() + if self.ct2_batch_type not in {"examples", "tokens"}: + raise ValueError(f"Unsupported CTranslate2 batch type: {ct2_batch_type}") + self._tokenizer_lock = threading.Lock() + self._load_runtime() + + @property + def supports_batch(self) -> bool: + return True + + def _tokenizer_source(self) -> str: + return self.model_dir if os.path.exists(self.model_dir) else self.model_id + + def _model_source(self) -> str: + return self.model_dir if os.path.exists(self.model_dir) else self.model_id + + def _tokenizer_kwargs(self) -> Dict[str, object]: + return {} + + def _translator_kwargs(self) -> Dict[str, object]: + return { + "device": self.device, + "compute_type": self.compute_type, + "inter_threads": self.ct2_inter_threads, + "intra_threads": self.ct2_intra_threads, + "max_queued_batches": self.ct2_max_queued_batches, + } + + def _load_runtime(self) -> None: + try: + import ctranslate2 + except ImportError as exc: + raise RuntimeError( + "CTranslate2 is required for local Marian/NLLB translation. " + "Install the translator service dependencies again after adding ctranslate2." + ) from exc + + tokenizer_source = self._tokenizer_source() + model_source = self._model_source() + self._ensure_converted_model(model_source) + logger.info( + "Loading CTranslate2 translation model | name=%s ct2_model_dir=%s tokenizer=%s device=%s compute_type=%s", + self.model, + self.ct2_model_dir, + tokenizer_source, + self.device, + self.compute_type, + ) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_source, **self._tokenizer_kwargs()) + self.translator = ctranslate2.Translator(self.ct2_model_dir, **self._translator_kwargs()) + if self.tokenizer.pad_token is None and self.tokenizer.eos_token is not None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + def _ensure_converted_model(self, model_source: str) -> None: + ct2_path = Path(self.ct2_model_dir).expanduser() + if (ct2_path / "model.bin").exists(): + return + if not self.ct2_auto_convert: + raise FileNotFoundError( + f"CTranslate2 model not found for '{self.model}': {ct2_path}. " + "Enable ct2_auto_convert or pre-convert the model." + ) + + ct2_path.parent.mkdir(parents=True, exist_ok=True) + converter = _resolve_converter_binary() + logger.info( + "Converting translation model to CTranslate2 | name=%s source=%s output=%s quantization=%s", + self.model, + model_source, + ct2_path, + self.ct2_conversion_quantization, + ) + try: + subprocess.run( + [ + converter, + "--model", + model_source, + "--output_dir", + str(ct2_path), + "--quantization", + self.ct2_conversion_quantization, + ], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + except subprocess.CalledProcessError as exc: + stderr = exc.stderr.strip() + raise RuntimeError( + f"Failed to convert model '{self.model}' to CTranslate2: {stderr or exc}" + ) from exc + + def _normalize_texts(self, text: Union[str, Sequence[str]]) -> List[str]: + if isinstance(text, str): + return [text] + return ["" if item is None else str(item) for item in text] + + def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: + del source_lang, target_lang + + def _encode_source_tokens( + self, + texts: List[str], + source_lang: Optional[str], + target_lang: str, + ) -> List[List[str]]: + del source_lang, target_lang + with self._tokenizer_lock: + encoded = self.tokenizer( + texts, + truncation=True, + max_length=self.max_input_length, + padding=False, + ) + input_ids = encoded["input_ids"] + return [self.tokenizer.convert_ids_to_tokens(ids) for ids in input_ids] + + def _target_prefixes( + self, + count: int, + source_lang: Optional[str], + target_lang: str, + ) -> Optional[List[Optional[List[str]]]]: + del count, source_lang, target_lang + return None + + def _postprocess_hypothesis( + self, + tokens: List[str], + source_lang: Optional[str], + target_lang: str, + ) -> List[str]: + del source_lang, target_lang + return tokens + + def _decode_tokens(self, tokens: List[str]) -> Optional[str]: + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + text = self.tokenizer.decode(token_ids, skip_special_tokens=True).strip() + return text or None + + def _translate_batch( + self, + texts: List[str], + target_lang: str, + source_lang: Optional[str] = None, + ) -> List[Optional[str]]: + self._validate_languages(source_lang, target_lang) + source_tokens = self._encode_source_tokens(texts, source_lang, target_lang) + target_prefix = self._target_prefixes(len(source_tokens), source_lang, target_lang) + results = self.translator.translate_batch( + source_tokens, + target_prefix=target_prefix, + max_batch_size=self.batch_size, + batch_type=self.ct2_batch_type, + beam_size=self.num_beams, + max_input_length=self.max_input_length, + max_decoding_length=self.max_new_tokens, + ) + outputs: List[Optional[str]] = [] + for result in results: + hypothesis = result.hypotheses[0] if result.hypotheses else [] + processed = self._postprocess_hypothesis(hypothesis, source_lang, target_lang) + outputs.append(self._decode_tokens(processed)) + return outputs + + def translate( + self, + text: Union[str, Sequence[str]], + target_lang: str, + source_lang: Optional[str] = None, + scene: Optional[str] = None, + ) -> Union[Optional[str], List[Optional[str]]]: + del scene + is_single = isinstance(text, str) + texts = self._normalize_texts(text) + outputs: List[Optional[str]] = [] + for start in range(0, len(texts), self.batch_size): + chunk = texts[start:start + self.batch_size] + if not any(item.strip() for item in chunk): + outputs.extend([None if not item.strip() else item for item in chunk]) # type: ignore[list-item] + continue + outputs.extend(self._translate_batch(chunk, target_lang=target_lang, source_lang=source_lang)) + return outputs[0] if is_single else outputs + + +class MarianCTranslate2TranslationBackend(LocalCTranslate2TranslationBackend): + """Local backend for Marian/OPUS MT models on CTranslate2.""" + + def __init__( + self, + *, + name: str, + model_id: str, + model_dir: str, + device: str, + torch_dtype: str, + batch_size: int, + max_input_length: int, + max_new_tokens: int, + num_beams: int, + source_langs: Sequence[str], + target_langs: Sequence[str], + ct2_model_dir: Optional[str] = None, + ct2_compute_type: Optional[str] = None, + ct2_auto_convert: bool = True, + ct2_conversion_quantization: Optional[str] = None, + ct2_inter_threads: int = 1, + ct2_intra_threads: int = 0, + ct2_max_queued_batches: int = 0, + ct2_batch_type: str = "examples", + ) -> None: + self.source_langs = {str(lang).strip().lower() for lang in source_langs if str(lang).strip()} + self.target_langs = {str(lang).strip().lower() for lang in target_langs if str(lang).strip()} + super().__init__( + name=name, + model_id=model_id, + model_dir=model_dir, + device=device, + torch_dtype=torch_dtype, + batch_size=batch_size, + max_input_length=max_input_length, + max_new_tokens=max_new_tokens, + num_beams=num_beams, + ct2_model_dir=ct2_model_dir, + ct2_compute_type=ct2_compute_type, + ct2_auto_convert=ct2_auto_convert, + ct2_conversion_quantization=ct2_conversion_quantization, + ct2_inter_threads=ct2_inter_threads, + ct2_intra_threads=ct2_intra_threads, + ct2_max_queued_batches=ct2_max_queued_batches, + ct2_batch_type=ct2_batch_type, + ) + + def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: + src = str(source_lang or "").strip().lower() + tgt = str(target_lang or "").strip().lower() + if self.source_langs and src not in self.source_langs: + raise ValueError( + f"Model '{self.model}' only supports source languages: {sorted(self.source_langs)}" + ) + if self.target_langs and tgt not in self.target_langs: + raise ValueError( + f"Model '{self.model}' only supports target languages: {sorted(self.target_langs)}" + ) + + +class NLLBCTranslate2TranslationBackend(LocalCTranslate2TranslationBackend): + """Local backend for NLLB models on CTranslate2.""" + + def __init__( + self, + *, + name: str, + model_id: str, + model_dir: str, + device: str, + torch_dtype: str, + batch_size: int, + max_input_length: int, + max_new_tokens: int, + num_beams: int, + language_codes: Optional[Dict[str, str]] = None, + ct2_model_dir: Optional[str] = None, + ct2_compute_type: Optional[str] = None, + ct2_auto_convert: bool = True, + ct2_conversion_quantization: Optional[str] = None, + ct2_inter_threads: int = 1, + ct2_intra_threads: int = 0, + ct2_max_queued_batches: int = 0, + ct2_batch_type: str = "examples", + ) -> None: + overrides = language_codes or {} + self.language_codes = { + **NLLB_LANGUAGE_CODES, + **{str(k).strip().lower(): str(v).strip() for k, v in overrides.items() if str(k).strip()}, + } + self._tokenizers_by_source: Dict[str, object] = {} + super().__init__( + name=name, + model_id=model_id, + model_dir=model_dir, + device=device, + torch_dtype=torch_dtype, + batch_size=batch_size, + max_input_length=max_input_length, + max_new_tokens=max_new_tokens, + num_beams=num_beams, + ct2_model_dir=ct2_model_dir, + ct2_compute_type=ct2_compute_type, + ct2_auto_convert=ct2_auto_convert, + ct2_conversion_quantization=ct2_conversion_quantization, + ct2_inter_threads=ct2_inter_threads, + ct2_intra_threads=ct2_intra_threads, + ct2_max_queued_batches=ct2_max_queued_batches, + ct2_batch_type=ct2_batch_type, + ) + + def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: + src = str(source_lang or "").strip().lower() + tgt = str(target_lang or "").strip().lower() + if not src: + raise ValueError(f"Model '{self.model}' requires source_lang") + if src not in self.language_codes: + raise ValueError(f"Unsupported NLLB source language: {source_lang}") + if tgt not in self.language_codes: + raise ValueError(f"Unsupported NLLB target language: {target_lang}") + + def _get_tokenizer_for_source(self, source_lang: str): + src_code = self.language_codes[source_lang] + with self._tokenizer_lock: + tokenizer = self._tokenizers_by_source.get(src_code) + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(self._tokenizer_source(), src_lang=src_code) + if tokenizer.pad_token is None and tokenizer.eos_token is not None: + tokenizer.pad_token = tokenizer.eos_token + self._tokenizers_by_source[src_code] = tokenizer + return tokenizer + + def _encode_source_tokens( + self, + texts: List[str], + source_lang: Optional[str], + target_lang: str, + ) -> List[List[str]]: + del target_lang + source_key = str(source_lang or "").strip().lower() + tokenizer = self._get_tokenizer_for_source(source_key) + encoded = tokenizer( + texts, + truncation=True, + max_length=self.max_input_length, + padding=False, + ) + input_ids = encoded["input_ids"] + return [tokenizer.convert_ids_to_tokens(ids) for ids in input_ids] + + def _target_prefixes( + self, + count: int, + source_lang: Optional[str], + target_lang: str, + ) -> Optional[List[Optional[List[str]]]]: + del source_lang + tgt_code = self.language_codes[str(target_lang).strip().lower()] + return [[tgt_code] for _ in range(count)] + + def _postprocess_hypothesis( + self, + tokens: List[str], + source_lang: Optional[str], + target_lang: str, + ) -> List[str]: + del source_lang + tgt_code = self.language_codes[str(target_lang).strip().lower()] + if tokens and tokens[0] == tgt_code: + return tokens[1:] + return tokens + + +def get_marian_language_direction(model_name: str) -> tuple[str, str]: + direction = MARIAN_LANGUAGE_DIRECTIONS.get(model_name) + if direction is None: + raise ValueError(f"Translation capability '{model_name}' is not registered with Marian language directions") + return direction diff --git a/translation/service.py b/translation/service.py index ff4349a..0afb312 100644 --- a/translation/service.py +++ b/translation/service.py @@ -111,9 +111,9 @@ class TranslationService: ) def _create_local_nllb_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol: - from translation.backends.local_seq2seq import NLLBTranslationBackend + from translation.backends.local_ctranslate2 import NLLBCTranslate2TranslationBackend - return NLLBTranslationBackend( + return NLLBCTranslate2TranslationBackend( name=name, model_id=str(cfg["model_id"]).strip(), model_dir=str(cfg["model_dir"]).strip(), @@ -123,15 +123,22 @@ class TranslationService: max_input_length=int(cfg["max_input_length"]), max_new_tokens=int(cfg["max_new_tokens"]), num_beams=int(cfg["num_beams"]), - attn_implementation=cfg.get("attn_implementation"), + ct2_model_dir=cfg.get("ct2_model_dir"), + ct2_compute_type=cfg.get("ct2_compute_type"), + ct2_auto_convert=bool(cfg.get("ct2_auto_convert", True)), + ct2_conversion_quantization=cfg.get("ct2_conversion_quantization"), + ct2_inter_threads=int(cfg.get("ct2_inter_threads", 1)), + ct2_intra_threads=int(cfg.get("ct2_intra_threads", 0)), + ct2_max_queued_batches=int(cfg.get("ct2_max_queued_batches", 0)), + ct2_batch_type=str(cfg.get("ct2_batch_type", "examples")), ) def _create_local_marian_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol: - from translation.backends.local_seq2seq import MarianMTTranslationBackend, get_marian_language_direction + from translation.backends.local_ctranslate2 import MarianCTranslate2TranslationBackend, get_marian_language_direction source_lang, target_lang = get_marian_language_direction(name) - return MarianMTTranslationBackend( + return MarianCTranslate2TranslationBackend( name=name, model_id=str(cfg["model_id"]).strip(), model_dir=str(cfg["model_dir"]).strip(), @@ -143,7 +150,14 @@ class TranslationService: num_beams=int(cfg["num_beams"]), source_langs=[source_lang], target_langs=[target_lang], - attn_implementation=cfg.get("attn_implementation"), + ct2_model_dir=cfg.get("ct2_model_dir"), + ct2_compute_type=cfg.get("ct2_compute_type"), + ct2_auto_convert=bool(cfg.get("ct2_auto_convert", True)), + ct2_conversion_quantization=cfg.get("ct2_conversion_quantization"), + ct2_inter_threads=int(cfg.get("ct2_inter_threads", 1)), + ct2_intra_threads=int(cfg.get("ct2_intra_threads", 0)), + ct2_max_queued_batches=int(cfg.get("ct2_max_queued_batches", 0)), + ct2_batch_type=str(cfg.get("ct2_batch_type", "examples")), ) @property -- libgit2 0.21.2