"""Local translation backends powered by CTranslate2.""" from __future__ import annotations import logging import os import shutil import subprocess import sys import threading from pathlib import Path from typing import Dict, List, Optional, Sequence, Union from transformers import AutoTokenizer from translation.languages import MARIAN_LANGUAGE_DIRECTIONS, NLLB_LANGUAGE_CODES logger = logging.getLogger(__name__) def _resolve_device(device: Optional[str]) -> str: value = str(device or "auto").strip().lower() if value not in {"auto", "cpu", "cuda"}: raise ValueError(f"Unsupported CTranslate2 device: {device}") return value def _resolve_compute_type( torch_dtype: Optional[str], compute_type: Optional[str], device: str, ) -> str: value = str(compute_type or torch_dtype or "default").strip().lower() if value in {"auto", "default"}: return "float16" if device == "cuda" else "default" if value in {"float16", "fp16", "half"}: return "float16" if value in {"bfloat16", "bf16"}: return "bfloat16" if value in {"float32", "fp32"}: return "float32" if value in { "int8", "int8_float32", "int8_float16", "int8_bfloat16", "int16", }: return value raise ValueError(f"Unsupported CTranslate2 compute type: {compute_type or torch_dtype}") def _derive_ct2_model_dir(model_dir: str, compute_type: str) -> str: normalized = compute_type.replace("_", "-") return str(Path(model_dir).expanduser() / f"ctranslate2-{normalized}") def _resolve_converter_binary() -> str: candidate = shutil.which("ct2-transformers-converter") if candidate: return candidate venv_candidate = Path(sys.executable).absolute().parent / "ct2-transformers-converter" if venv_candidate.exists(): return str(venv_candidate) raise RuntimeError( "ct2-transformers-converter was not found. " "Ensure ctranslate2 is installed in the active translator environment." ) class LocalCTranslate2TranslationBackend: """Base backend for local CTranslate2 translation models.""" def __init__( self, *, name: str, model_id: str, model_dir: str, device: str, torch_dtype: str, batch_size: int, max_input_length: int, max_new_tokens: int, num_beams: int, ct2_model_dir: Optional[str] = None, ct2_compute_type: Optional[str] = None, ct2_auto_convert: bool = True, ct2_conversion_quantization: Optional[str] = None, ct2_inter_threads: int = 1, ct2_intra_threads: int = 0, ct2_max_queued_batches: int = 0, ct2_batch_type: str = "examples", ) -> None: self.model = name self.model_id = model_id self.model_dir = model_dir self.device = _resolve_device(device) self.compute_type = _resolve_compute_type(torch_dtype, ct2_compute_type, self.device) self.batch_size = int(batch_size) self.max_input_length = int(max_input_length) self.max_new_tokens = int(max_new_tokens) self.num_beams = int(num_beams) self.ct2_model_dir = str(ct2_model_dir or _derive_ct2_model_dir(model_dir, self.compute_type)) self.ct2_auto_convert = bool(ct2_auto_convert) self.ct2_conversion_quantization = _resolve_compute_type( torch_dtype, ct2_conversion_quantization or self.compute_type, self.device, ) self.ct2_inter_threads = int(ct2_inter_threads) self.ct2_intra_threads = int(ct2_intra_threads) self.ct2_max_queued_batches = int(ct2_max_queued_batches) self.ct2_batch_type = str(ct2_batch_type or "examples").strip().lower() if self.ct2_batch_type not in {"examples", "tokens"}: raise ValueError(f"Unsupported CTranslate2 batch type: {ct2_batch_type}") self._tokenizer_lock = threading.Lock() self._load_runtime() @property def supports_batch(self) -> bool: return True def _tokenizer_source(self) -> str: return self.model_dir if os.path.exists(self.model_dir) else self.model_id def _model_source(self) -> str: return self.model_dir if os.path.exists(self.model_dir) else self.model_id def _tokenizer_kwargs(self) -> Dict[str, object]: return {} def _translator_kwargs(self) -> Dict[str, object]: return { "device": self.device, "compute_type": self.compute_type, "inter_threads": self.ct2_inter_threads, "intra_threads": self.ct2_intra_threads, "max_queued_batches": self.ct2_max_queued_batches, } def _load_runtime(self) -> None: try: import ctranslate2 except ImportError as exc: raise RuntimeError( "CTranslate2 is required for local Marian/NLLB translation. " "Install the translator service dependencies again after adding ctranslate2." ) from exc tokenizer_source = self._tokenizer_source() model_source = self._model_source() self._ensure_converted_model(model_source) logger.info( "Loading CTranslate2 translation model | name=%s ct2_model_dir=%s tokenizer=%s device=%s compute_type=%s", self.model, self.ct2_model_dir, tokenizer_source, self.device, self.compute_type, ) self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_source, **self._tokenizer_kwargs()) self.translator = ctranslate2.Translator(self.ct2_model_dir, **self._translator_kwargs()) if self.tokenizer.pad_token is None and self.tokenizer.eos_token is not None: self.tokenizer.pad_token = self.tokenizer.eos_token def _ensure_converted_model(self, model_source: str) -> None: ct2_path = Path(self.ct2_model_dir).expanduser() if (ct2_path / "model.bin").exists(): return if not self.ct2_auto_convert: raise FileNotFoundError( f"CTranslate2 model not found for '{self.model}': {ct2_path}. " "Enable ct2_auto_convert or pre-convert the model." ) ct2_path.parent.mkdir(parents=True, exist_ok=True) converter = _resolve_converter_binary() logger.info( "Converting translation model to CTranslate2 | name=%s source=%s output=%s quantization=%s", self.model, model_source, ct2_path, self.ct2_conversion_quantization, ) try: subprocess.run( [ converter, "--model", model_source, "--output_dir", str(ct2_path), "--quantization", self.ct2_conversion_quantization, ], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, ) except subprocess.CalledProcessError as exc: stderr = exc.stderr.strip() raise RuntimeError( f"Failed to convert model '{self.model}' to CTranslate2: {stderr or exc}" ) from exc def _normalize_texts(self, text: Union[str, Sequence[str]]) -> List[str]: if isinstance(text, str): return [text] return ["" if item is None else str(item) for item in text] def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: del source_lang, target_lang def _encode_source_tokens( self, texts: List[str], source_lang: Optional[str], target_lang: str, ) -> List[List[str]]: del source_lang, target_lang with self._tokenizer_lock: encoded = self.tokenizer( texts, truncation=True, max_length=self.max_input_length, padding=False, ) input_ids = encoded["input_ids"] return [self.tokenizer.convert_ids_to_tokens(ids) for ids in input_ids] def _target_prefixes( self, count: int, source_lang: Optional[str], target_lang: str, ) -> Optional[List[Optional[List[str]]]]: del count, source_lang, target_lang return None def _postprocess_hypothesis( self, tokens: List[str], source_lang: Optional[str], target_lang: str, ) -> List[str]: del source_lang, target_lang return tokens def _decode_tokens(self, tokens: List[str]) -> Optional[str]: token_ids = self.tokenizer.convert_tokens_to_ids(tokens) text = self.tokenizer.decode(token_ids, skip_special_tokens=True).strip() return text or None def _translate_batch( self, texts: List[str], target_lang: str, source_lang: Optional[str] = None, ) -> List[Optional[str]]: self._validate_languages(source_lang, target_lang) source_tokens = self._encode_source_tokens(texts, source_lang, target_lang) target_prefix = self._target_prefixes(len(source_tokens), source_lang, target_lang) results = self.translator.translate_batch( source_tokens, target_prefix=target_prefix, max_batch_size=self.batch_size, batch_type=self.ct2_batch_type, beam_size=self.num_beams, max_input_length=self.max_input_length, max_decoding_length=self.max_new_tokens, ) outputs: List[Optional[str]] = [] for result in results: hypothesis = result.hypotheses[0] if result.hypotheses else [] processed = self._postprocess_hypothesis(hypothesis, source_lang, target_lang) outputs.append(self._decode_tokens(processed)) return outputs def translate( self, text: Union[str, Sequence[str]], target_lang: str, source_lang: Optional[str] = None, scene: Optional[str] = None, ) -> Union[Optional[str], List[Optional[str]]]: del scene is_single = isinstance(text, str) texts = self._normalize_texts(text) outputs: List[Optional[str]] = [] for start in range(0, len(texts), self.batch_size): chunk = texts[start:start + self.batch_size] if not any(item.strip() for item in chunk): outputs.extend([None if not item.strip() else item for item in chunk]) # type: ignore[list-item] continue outputs.extend(self._translate_batch(chunk, target_lang=target_lang, source_lang=source_lang)) return outputs[0] if is_single else outputs class MarianCTranslate2TranslationBackend(LocalCTranslate2TranslationBackend): """Local backend for Marian/OPUS MT models on CTranslate2.""" def __init__( self, *, name: str, model_id: str, model_dir: str, device: str, torch_dtype: str, batch_size: int, max_input_length: int, max_new_tokens: int, num_beams: int, source_langs: Sequence[str], target_langs: Sequence[str], ct2_model_dir: Optional[str] = None, ct2_compute_type: Optional[str] = None, ct2_auto_convert: bool = True, ct2_conversion_quantization: Optional[str] = None, ct2_inter_threads: int = 1, ct2_intra_threads: int = 0, ct2_max_queued_batches: int = 0, ct2_batch_type: str = "examples", ) -> None: self.source_langs = {str(lang).strip().lower() for lang in source_langs if str(lang).strip()} self.target_langs = {str(lang).strip().lower() for lang in target_langs if str(lang).strip()} super().__init__( name=name, model_id=model_id, model_dir=model_dir, device=device, torch_dtype=torch_dtype, batch_size=batch_size, max_input_length=max_input_length, max_new_tokens=max_new_tokens, num_beams=num_beams, ct2_model_dir=ct2_model_dir, ct2_compute_type=ct2_compute_type, ct2_auto_convert=ct2_auto_convert, ct2_conversion_quantization=ct2_conversion_quantization, ct2_inter_threads=ct2_inter_threads, ct2_intra_threads=ct2_intra_threads, ct2_max_queued_batches=ct2_max_queued_batches, ct2_batch_type=ct2_batch_type, ) def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: src = str(source_lang or "").strip().lower() tgt = str(target_lang or "").strip().lower() if self.source_langs and src not in self.source_langs: raise ValueError( f"Model '{self.model}' only supports source languages: {sorted(self.source_langs)}" ) if self.target_langs and tgt not in self.target_langs: raise ValueError( f"Model '{self.model}' only supports target languages: {sorted(self.target_langs)}" ) class NLLBCTranslate2TranslationBackend(LocalCTranslate2TranslationBackend): """Local backend for NLLB models on CTranslate2.""" def __init__( self, *, name: str, model_id: str, model_dir: str, device: str, torch_dtype: str, batch_size: int, max_input_length: int, max_new_tokens: int, num_beams: int, language_codes: Optional[Dict[str, str]] = None, ct2_model_dir: Optional[str] = None, ct2_compute_type: Optional[str] = None, ct2_auto_convert: bool = True, ct2_conversion_quantization: Optional[str] = None, ct2_inter_threads: int = 1, ct2_intra_threads: int = 0, ct2_max_queued_batches: int = 0, ct2_batch_type: str = "examples", ) -> None: overrides = language_codes or {} self.language_codes = { **NLLB_LANGUAGE_CODES, **{str(k).strip().lower(): str(v).strip() for k, v in overrides.items() if str(k).strip()}, } self._tokenizers_by_source: Dict[str, object] = {} super().__init__( name=name, model_id=model_id, model_dir=model_dir, device=device, torch_dtype=torch_dtype, batch_size=batch_size, max_input_length=max_input_length, max_new_tokens=max_new_tokens, num_beams=num_beams, ct2_model_dir=ct2_model_dir, ct2_compute_type=ct2_compute_type, ct2_auto_convert=ct2_auto_convert, ct2_conversion_quantization=ct2_conversion_quantization, ct2_inter_threads=ct2_inter_threads, ct2_intra_threads=ct2_intra_threads, ct2_max_queued_batches=ct2_max_queued_batches, ct2_batch_type=ct2_batch_type, ) def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: src = str(source_lang or "").strip().lower() tgt = str(target_lang or "").strip().lower() if not src: raise ValueError(f"Model '{self.model}' requires source_lang") if src not in self.language_codes: raise ValueError(f"Unsupported NLLB source language: {source_lang}") if tgt not in self.language_codes: raise ValueError(f"Unsupported NLLB target language: {target_lang}") def _get_tokenizer_for_source(self, source_lang: str): src_code = self.language_codes[source_lang] with self._tokenizer_lock: tokenizer = self._tokenizers_by_source.get(src_code) if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(self._tokenizer_source(), src_lang=src_code) if tokenizer.pad_token is None and tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token self._tokenizers_by_source[src_code] = tokenizer return tokenizer def _encode_source_tokens( self, texts: List[str], source_lang: Optional[str], target_lang: str, ) -> List[List[str]]: del target_lang source_key = str(source_lang or "").strip().lower() tokenizer = self._get_tokenizer_for_source(source_key) encoded = tokenizer( texts, truncation=True, max_length=self.max_input_length, padding=False, ) input_ids = encoded["input_ids"] return [tokenizer.convert_ids_to_tokens(ids) for ids in input_ids] def _target_prefixes( self, count: int, source_lang: Optional[str], target_lang: str, ) -> Optional[List[Optional[List[str]]]]: del source_lang tgt_code = self.language_codes[str(target_lang).strip().lower()] return [[tgt_code] for _ in range(count)] def _postprocess_hypothesis( self, tokens: List[str], source_lang: Optional[str], target_lang: str, ) -> List[str]: del source_lang tgt_code = self.language_codes[str(target_lang).strip().lower()] if tokens and tokens[0] == tgt_code: return tokens[1:] return tokens def get_marian_language_direction(model_name: str) -> tuple[str, str]: direction = MARIAN_LANGUAGE_DIRECTIONS.get(model_name) if direction is None: raise ValueError(f"Translation capability '{model_name}' is not registered with Marian language directions") return direction