"""Local translation backends powered by CTranslate2.""" from __future__ import annotations import logging import os import shutil import subprocess import sys import threading from pathlib import Path from typing import Dict, List, Optional, Sequence, Union from transformers import AutoTokenizer from translation.languages import ( MARIAN_LANGUAGE_DIRECTIONS, build_nllb_language_catalog, normalize_language_key, resolve_nllb_language_code, ) from translation.text_splitter import ( compute_safe_input_token_limit, join_translated_segments, split_text_for_translation, ) logger = logging.getLogger(__name__) def _text_preview(text: Optional[str], limit: int = 32) -> str: return str(text or "").replace("\n", "\\n")[:limit] def _summarize_lengths(values: Sequence[int]) -> str: if not values: return "[]" total = sum(values) return f"min={min(values)} max={max(values)} avg={total / len(values):.1f}" def _resolve_device(device: Optional[str]) -> str: value = str(device or "auto").strip().lower() if value not in {"auto", "cpu", "cuda"}: raise ValueError(f"Unsupported CTranslate2 device: {device}") return value def _resolve_compute_type( torch_dtype: Optional[str], compute_type: Optional[str], device: str, ) -> str: value = str(compute_type or torch_dtype or "default").strip().lower() if value in {"auto", "default"}: return "float16" if device == "cuda" else "default" if value in {"float16", "fp16", "half"}: return "float16" if value in {"bfloat16", "bf16"}: return "bfloat16" if value in {"float32", "fp32"}: return "float32" if value in { "int8", "int8_float32", "int8_float16", "int8_bfloat16", "int16", }: return value raise ValueError(f"Unsupported CTranslate2 compute type: {compute_type or torch_dtype}") def _derive_ct2_model_dir(model_dir: str, compute_type: str) -> str: normalized = compute_type.replace("_", "-") return str(Path(model_dir).expanduser() / f"ctranslate2-{normalized}") def _resolve_converter_binary() -> str: candidate = shutil.which("ct2-transformers-converter") if candidate: return candidate venv_candidate = Path(sys.executable).absolute().parent / "ct2-transformers-converter" if venv_candidate.exists(): return str(venv_candidate) raise RuntimeError( "ct2-transformers-converter was not found. " "Ensure ctranslate2 is installed in the active translator environment." ) class LocalCTranslate2TranslationBackend: """Base backend for local CTranslate2 translation models.""" def __init__( self, *, name: str, model_id: str, model_dir: str, device: str, torch_dtype: str, batch_size: int, max_input_length: int, max_new_tokens: int, num_beams: int, ct2_model_dir: Optional[str] = None, ct2_compute_type: Optional[str] = None, ct2_auto_convert: bool = True, ct2_conversion_quantization: Optional[str] = None, ct2_inter_threads: int = 1, ct2_intra_threads: int = 0, ct2_max_queued_batches: int = 0, ct2_batch_type: str = "examples", ct2_decoding_length_mode: str = "fixed", ct2_decoding_length_extra: int = 0, ct2_decoding_length_min: int = 1, ) -> None: self.model = name self.model_id = model_id self.model_dir = model_dir self.device = _resolve_device(device) self.compute_type = _resolve_compute_type(torch_dtype, ct2_compute_type, self.device) self.batch_size = int(batch_size) self.max_input_length = int(max_input_length) self.max_new_tokens = int(max_new_tokens) self.num_beams = int(num_beams) self.ct2_model_dir = str(ct2_model_dir or _derive_ct2_model_dir(model_dir, self.compute_type)) self.ct2_auto_convert = bool(ct2_auto_convert) self.ct2_conversion_quantization = _resolve_compute_type( torch_dtype, ct2_conversion_quantization or self.compute_type, self.device, ) self.ct2_inter_threads = int(ct2_inter_threads) self.ct2_intra_threads = int(ct2_intra_threads) self.ct2_max_queued_batches = int(ct2_max_queued_batches) self.ct2_batch_type = str(ct2_batch_type or "examples").strip().lower() if self.ct2_batch_type not in {"examples", "tokens"}: raise ValueError(f"Unsupported CTranslate2 batch type: {ct2_batch_type}") self.ct2_decoding_length_mode = str(ct2_decoding_length_mode or "fixed").strip().lower() if self.ct2_decoding_length_mode not in {"fixed", "source"}: raise ValueError(f"Unsupported CTranslate2 decoding length mode: {ct2_decoding_length_mode}") self.ct2_decoding_length_extra = int(ct2_decoding_length_extra) self.ct2_decoding_length_min = max(1, int(ct2_decoding_length_min)) self._tokenizer_lock = threading.Lock() self._load_runtime() @property def supports_batch(self) -> bool: return True def _tokenizer_source(self) -> str: return self.model_dir if os.path.exists(self.model_dir) else self.model_id def _model_source(self) -> str: return self.model_dir if os.path.exists(self.model_dir) else self.model_id def _tokenizer_kwargs(self) -> Dict[str, object]: return {} def _translator_kwargs(self) -> Dict[str, object]: return { "device": self.device, "compute_type": self.compute_type, "inter_threads": self.ct2_inter_threads, "intra_threads": self.ct2_intra_threads, "max_queued_batches": self.ct2_max_queued_batches, } def _load_runtime(self) -> None: try: import ctranslate2 except ImportError as exc: raise RuntimeError( "CTranslate2 is required for local Marian/NLLB translation. " "Install the translator service dependencies again after adding ctranslate2." ) from exc tokenizer_source = self._tokenizer_source() model_source = self._model_source() self._ensure_converted_model(model_source) logger.info( "Loading CTranslate2 translation model | name=%s ct2_model_dir=%s tokenizer=%s device=%s compute_type=%s", self.model, self.ct2_model_dir, tokenizer_source, self.device, self.compute_type, ) self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_source, **self._tokenizer_kwargs()) self.translator = ctranslate2.Translator(self.ct2_model_dir, **self._translator_kwargs()) if self.tokenizer.pad_token is None and self.tokenizer.eos_token is not None: self.tokenizer.pad_token = self.tokenizer.eos_token def _ensure_converted_model(self, model_source: str) -> None: ct2_path = Path(self.ct2_model_dir).expanduser() if (ct2_path / "model.bin").exists(): return if not self.ct2_auto_convert: raise FileNotFoundError( f"CTranslate2 model not found for '{self.model}': {ct2_path}. " "Enable ct2_auto_convert or pre-convert the model." ) ct2_path.parent.mkdir(parents=True, exist_ok=True) converter = _resolve_converter_binary() logger.info( "Converting translation model to CTranslate2 | name=%s source=%s output=%s quantization=%s", self.model, model_source, ct2_path, self.ct2_conversion_quantization, ) try: subprocess.run( [ converter, "--model", model_source, "--output_dir", str(ct2_path), "--quantization", self.ct2_conversion_quantization, ], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, ) except subprocess.CalledProcessError as exc: stderr = exc.stderr.strip() raise RuntimeError( f"Failed to convert model '{self.model}' to CTranslate2: {stderr or exc}" ) from exc def _normalize_texts(self, text: Union[str, Sequence[str]]) -> List[str]: if isinstance(text, str): return [text] return ["" if item is None else str(item) for item in text] def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: del source_lang, target_lang def _encode_source_tokens( self, texts: List[str], source_lang: Optional[str], target_lang: str, ) -> List[List[str]]: del source_lang, target_lang with self._tokenizer_lock: encoded = self.tokenizer( texts, truncation=True, max_length=self.max_input_length, padding=False, ) input_ids = encoded["input_ids"] return [self.tokenizer.convert_ids_to_tokens(ids) for ids in input_ids] def _target_prefixes( self, count: int, source_lang: Optional[str], target_lang: str, ) -> Optional[List[Optional[List[str]]]]: del count, source_lang, target_lang return None def _resolve_max_decoding_length(self, source_tokens: Sequence[Sequence[str]]) -> int: if self.ct2_decoding_length_mode != "source": return self.max_new_tokens if not source_tokens: return self.max_new_tokens max_source_length = max(len(tokens) for tokens in source_tokens) dynamic_length = max(self.ct2_decoding_length_min, max_source_length + self.ct2_decoding_length_extra) return min(self.max_new_tokens, dynamic_length) def _postprocess_hypothesis( self, tokens: List[str], source_lang: Optional[str], target_lang: str, ) -> List[str]: del source_lang, target_lang return tokens def _decode_tokens(self, tokens: List[str]) -> Optional[str]: token_ids = self.tokenizer.convert_tokens_to_ids(tokens) text = self.tokenizer.decode(token_ids, skip_special_tokens=True).strip() return text or None def _translate_batch( self, texts: List[str], target_lang: str, source_lang: Optional[str] = None, ) -> List[Optional[str]]: self._validate_languages(source_lang, target_lang) source_tokens = self._encode_source_tokens(texts, source_lang, target_lang) target_prefix = self._target_prefixes(len(source_tokens), source_lang, target_lang) max_decoding_length = self._resolve_max_decoding_length(source_tokens) logger.info( "Translation model batch detail | model=%s segment_count=%s token_lengths=%s max_decoding_length=%s batch_type=%s beam_size=%s target_lang=%s source_lang=%s", self.model, len(source_tokens), _summarize_lengths([len(tokens) for tokens in source_tokens]), max_decoding_length, self.ct2_batch_type, self.num_beams, target_lang, source_lang or "auto", ) results = self.translator.translate_batch( source_tokens, target_prefix=target_prefix, max_batch_size=self.batch_size, batch_type=self.ct2_batch_type, beam_size=self.num_beams, max_input_length=self.max_input_length, max_decoding_length=max_decoding_length, ) outputs: List[Optional[str]] = [] for result in results: hypothesis = result.hypotheses[0] if result.hypotheses else [] processed = self._postprocess_hypothesis(hypothesis, source_lang, target_lang) outputs.append(self._decode_tokens(processed)) return outputs def _token_count( self, text: str, target_lang: str, source_lang: Optional[str] = None, ) -> int: encoded = self._encode_source_tokens([text], source_lang, target_lang) return len(encoded[0]) if encoded else 0 def _effective_input_token_limit(self, target_lang: str, source_lang: Optional[str] = None) -> int: del target_lang, source_lang return compute_safe_input_token_limit( max_input_length=self.max_input_length, max_new_tokens=self.max_new_tokens, decoding_length_mode=self.ct2_decoding_length_mode, decoding_length_extra=self.ct2_decoding_length_extra, ) def _split_text_if_needed( self, text: str, target_lang: str, source_lang: Optional[str] = None, ) -> List[str]: limit = self._effective_input_token_limit(target_lang, source_lang) return split_text_for_translation( text, max_tokens=limit, token_length_fn=lambda value: self._token_count( value, target_lang=target_lang, source_lang=source_lang, ), ) def _log_segmentation_summary( self, *, texts: Sequence[str], segment_plans: Sequence[Sequence[str]], target_lang: str, source_lang: Optional[str], ) -> None: non_empty_count = sum(1 for text in texts if text.strip()) segment_counts = [len(segments) for segments in segment_plans if segments] total_segments = sum(segment_counts) segmented_inputs = sum(1 for count in segment_counts if count > 1) logger.info( "Translation segmentation summary | model=%s inputs=%s non_empty_inputs=%s segmented_inputs=%s total_segments=%s batch_size=%s target_lang=%s source_lang=%s segments_per_input=%s", self.model, len(texts), non_empty_count, segmented_inputs, total_segments, self.batch_size, target_lang, source_lang or "auto", _summarize_lengths(segment_counts), ) def _translate_segment_batches( self, segments: List[str], target_lang: str, source_lang: Optional[str] = None, ) -> List[Optional[str]]: if not segments: return [] outputs: List[Optional[str]] = [] total_batches = (len(segments) + self.batch_size - 1) // self.batch_size for batch_index, start in enumerate(range(0, len(segments), self.batch_size), start=1): batch = segments[start:start + self.batch_size] logger.info( "Translation inference batch | model=%s batch_index=%s total_batches=%s segment_count=%s char_lengths=%s first_preview=%s target_lang=%s source_lang=%s", self.model, batch_index, total_batches, len(batch), _summarize_lengths([len(segment) for segment in batch]), _text_preview(batch[0] if batch else ""), target_lang, source_lang or "auto", ) outputs.extend( self._translate_batch(batch, target_lang=target_lang, source_lang=source_lang) ) return outputs def _translate_with_segmentation( self, texts: List[str], target_lang: str, source_lang: Optional[str] = None, ) -> List[Optional[str]]: segment_plans: List[List[str]] = [] flat_segments: List[str] = [] for text in texts: if not text.strip(): segment_plans.append([]) continue segments = self._split_text_if_needed(text, target_lang=target_lang, source_lang=source_lang) segment_plans.append(segments) flat_segments.extend(segments) self._log_segmentation_summary( texts=texts, segment_plans=segment_plans, target_lang=target_lang, source_lang=source_lang, ) translated_segments = ( self._translate_segment_batches(flat_segments, target_lang=target_lang, source_lang=source_lang) if flat_segments else [] ) outputs: List[Optional[str]] = [] offset = 0 for original_text, segments in zip(texts, segment_plans): if not segments: outputs.append(None if not original_text.strip() else original_text) continue current = translated_segments[offset:offset + len(segments)] offset += len(segments) if len(segments) == 1: outputs.append(current[0]) continue outputs.append( join_translated_segments( current, target_lang=target_lang, original_text=original_text, ) ) return outputs def translate( self, text: Union[str, Sequence[str]], target_lang: str, source_lang: Optional[str] = None, scene: Optional[str] = None, ) -> Union[Optional[str], List[Optional[str]]]: del scene is_single = isinstance(text, str) texts = self._normalize_texts(text) if not any(item.strip() for item in texts): outputs = [None if not item.strip() else item for item in texts] # type: ignore[list-item] return outputs[0] if is_single else outputs outputs = self._translate_with_segmentation(texts, target_lang=target_lang, source_lang=source_lang) return outputs[0] if is_single else outputs class MarianCTranslate2TranslationBackend(LocalCTranslate2TranslationBackend): """Local backend for Marian/OPUS MT models on CTranslate2.""" def __init__( self, *, name: str, model_id: str, model_dir: str, device: str, torch_dtype: str, batch_size: int, max_input_length: int, max_new_tokens: int, num_beams: int, source_langs: Sequence[str], target_langs: Sequence[str], ct2_model_dir: Optional[str] = None, ct2_compute_type: Optional[str] = None, ct2_auto_convert: bool = True, ct2_conversion_quantization: Optional[str] = None, ct2_inter_threads: int = 1, ct2_intra_threads: int = 0, ct2_max_queued_batches: int = 0, ct2_batch_type: str = "examples", ct2_decoding_length_mode: str = "fixed", ct2_decoding_length_extra: int = 0, ct2_decoding_length_min: int = 1, ) -> None: self.source_langs = {str(lang).strip().lower() for lang in source_langs if str(lang).strip()} self.target_langs = {str(lang).strip().lower() for lang in target_langs if str(lang).strip()} super().__init__( name=name, model_id=model_id, model_dir=model_dir, device=device, torch_dtype=torch_dtype, batch_size=batch_size, max_input_length=max_input_length, max_new_tokens=max_new_tokens, num_beams=num_beams, ct2_model_dir=ct2_model_dir, ct2_compute_type=ct2_compute_type, ct2_auto_convert=ct2_auto_convert, ct2_conversion_quantization=ct2_conversion_quantization, ct2_inter_threads=ct2_inter_threads, ct2_intra_threads=ct2_intra_threads, ct2_max_queued_batches=ct2_max_queued_batches, ct2_batch_type=ct2_batch_type, ct2_decoding_length_mode=ct2_decoding_length_mode, ct2_decoding_length_extra=ct2_decoding_length_extra, ct2_decoding_length_min=ct2_decoding_length_min, ) def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: src = str(source_lang or "").strip().lower() tgt = str(target_lang or "").strip().lower() if self.source_langs and src not in self.source_langs: raise ValueError( f"Model '{self.model}' only supports source languages: {sorted(self.source_langs)}" ) if self.target_langs and tgt not in self.target_langs: raise ValueError( f"Model '{self.model}' only supports target languages: {sorted(self.target_langs)}" ) class NLLBCTranslate2TranslationBackend(LocalCTranslate2TranslationBackend): """Local backend for NLLB models on CTranslate2.""" def __init__( self, *, name: str, model_id: str, model_dir: str, device: str, torch_dtype: str, batch_size: int, max_input_length: int, max_new_tokens: int, num_beams: int, language_codes: Optional[Dict[str, str]] = None, ct2_model_dir: Optional[str] = None, ct2_compute_type: Optional[str] = None, ct2_auto_convert: bool = True, ct2_conversion_quantization: Optional[str] = None, ct2_inter_threads: int = 1, ct2_intra_threads: int = 0, ct2_max_queued_batches: int = 0, ct2_batch_type: str = "examples", ct2_decoding_length_mode: str = "fixed", ct2_decoding_length_extra: int = 0, ct2_decoding_length_min: int = 1, ) -> None: self.language_codes = build_nllb_language_catalog(language_codes) self._tokenizers_by_source: Dict[str, object] = {} super().__init__( name=name, model_id=model_id, model_dir=model_dir, device=device, torch_dtype=torch_dtype, batch_size=batch_size, max_input_length=max_input_length, max_new_tokens=max_new_tokens, num_beams=num_beams, ct2_model_dir=ct2_model_dir, ct2_compute_type=ct2_compute_type, ct2_auto_convert=ct2_auto_convert, ct2_conversion_quantization=ct2_conversion_quantization, ct2_inter_threads=ct2_inter_threads, ct2_intra_threads=ct2_intra_threads, ct2_max_queued_batches=ct2_max_queued_batches, ct2_batch_type=ct2_batch_type, ct2_decoding_length_mode=ct2_decoding_length_mode, ct2_decoding_length_extra=ct2_decoding_length_extra, ct2_decoding_length_min=ct2_decoding_length_min, ) def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: if not str(source_lang or "").strip(): raise ValueError(f"Model '{self.model}' requires source_lang") if resolve_nllb_language_code(source_lang, self.language_codes) is None: raise ValueError(f"Unsupported NLLB source language: {source_lang}") if resolve_nllb_language_code(target_lang, self.language_codes) is None: raise ValueError(f"Unsupported NLLB target language: {target_lang}") def _get_tokenizer_for_source(self, source_lang: str): src_code = resolve_nllb_language_code(source_lang, self.language_codes) if src_code is None: raise ValueError(f"Unsupported NLLB source language: {source_lang}") with self._tokenizer_lock: tokenizer = self._tokenizers_by_source.get(src_code) if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(self._tokenizer_source(), src_lang=src_code) if tokenizer.pad_token is None and tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token self._tokenizers_by_source[src_code] = tokenizer return tokenizer def _encode_source_tokens( self, texts: List[str], source_lang: Optional[str], target_lang: str, ) -> List[List[str]]: del target_lang source_key = normalize_language_key(source_lang) tokenizer = self._get_tokenizer_for_source(source_key) encoded = tokenizer( texts, truncation=True, max_length=self.max_input_length, padding=False, ) input_ids = encoded["input_ids"] return [tokenizer.convert_ids_to_tokens(ids) for ids in input_ids] def _target_prefixes( self, count: int, source_lang: Optional[str], target_lang: str, ) -> Optional[List[Optional[List[str]]]]: del source_lang tgt_code = resolve_nllb_language_code(target_lang, self.language_codes) if tgt_code is None: raise ValueError(f"Unsupported NLLB target language: {target_lang}") return [[tgt_code] for _ in range(count)] def _postprocess_hypothesis( self, tokens: List[str], source_lang: Optional[str], target_lang: str, ) -> List[str]: del source_lang tgt_code = resolve_nllb_language_code(target_lang, self.language_codes) if tgt_code is None: raise ValueError(f"Unsupported NLLB target language: {target_lang}") if tokens and tokens[0] == tgt_code: return tokens[1:] return tokens def get_marian_language_direction(model_name: str) -> tuple[str, str]: direction = MARIAN_LANGUAGE_DIRECTIONS.get(model_name) if direction is None: raise ValueError(f"Translation capability '{model_name}' is not registered with Marian language directions") return direction