"""Local translation backends powered by CTranslate2.""" from __future__ import annotations import logging import os import json import threading from pathlib import Path from typing import Dict, List, Optional, Sequence, Union from transformers import AutoTokenizer from translation.languages import ( MARIAN_LANGUAGE_DIRECTIONS, build_nllb_language_catalog, normalize_language_key, resolve_nllb_language_code, ) from translation.text_splitter import ( compute_safe_input_token_limit, join_translated_segments, split_text_for_translation, ) from translation.ct2_conversion import convert_transformers_model logger = logging.getLogger(__name__) def _text_preview(text: Optional[str], limit: int = 32) -> str: return str(text or "").replace("\n", "\\n")[:limit] def _summarize_lengths(values: Sequence[int]) -> str: if not values: return "[]" total = sum(values) return f"min={min(values)} max={max(values)} avg={total / len(values):.1f}" def _resolve_device(device: Optional[str]) -> str: value = str(device or "auto").strip().lower() if value not in {"auto", "cpu", "cuda"}: raise ValueError(f"Unsupported CTranslate2 device: {device}") return value def _resolve_compute_type( torch_dtype: Optional[str], compute_type: Optional[str], device: str, ) -> str: value = str(compute_type or torch_dtype or "default").strip().lower() if value in {"auto", "default"}: return "float16" if device == "cuda" else "default" if value in {"float16", "fp16", "half"}: return "float16" if value in {"bfloat16", "bf16"}: return "bfloat16" if value in {"float32", "fp32"}: return "float32" if value in { "int8", "int8_float32", "int8_float16", "int8_bfloat16", "int16", }: return value raise ValueError(f"Unsupported CTranslate2 compute type: {compute_type or torch_dtype}") def _derive_ct2_model_dir(model_dir: str, compute_type: str) -> str: normalized = compute_type.replace("_", "-") return str(Path(model_dir).expanduser() / f"ctranslate2-{normalized}") def _detect_local_model_type(model_dir: str) -> Optional[str]: config_path = Path(model_dir).expanduser() / "config.json" if not config_path.exists(): return None try: with open(config_path, "r", encoding="utf-8") as handle: payload = json.load(handle) or {} except Exception as exc: logger.warning("Failed to inspect local translation config %s: %s", config_path, exc) return None model_type = str(payload.get("model_type") or "").strip().lower() return model_type or None class LocalCTranslate2TranslationBackend: """Base backend for local CTranslate2 translation models.""" def __init__( self, *, name: str, model_id: str, model_dir: str, device: str, torch_dtype: str, batch_size: int, max_input_length: int, max_new_tokens: int, num_beams: int, ct2_model_dir: Optional[str] = None, ct2_compute_type: Optional[str] = None, ct2_auto_convert: bool = True, ct2_conversion_quantization: Optional[str] = None, ct2_inter_threads: int = 1, ct2_intra_threads: int = 0, ct2_max_queued_batches: int = 0, ct2_batch_type: str = "examples", ct2_decoding_length_mode: str = "fixed", ct2_decoding_length_extra: int = 0, ct2_decoding_length_min: int = 1, ) -> None: self.model = name self.model_id = model_id self.model_dir = model_dir self.device = _resolve_device(device) self.compute_type = _resolve_compute_type(torch_dtype, ct2_compute_type, self.device) self.batch_size = int(batch_size) self.max_input_length = int(max_input_length) self.max_new_tokens = int(max_new_tokens) self.num_beams = int(num_beams) self.ct2_model_dir = str(ct2_model_dir or _derive_ct2_model_dir(model_dir, self.compute_type)) self.ct2_auto_convert = bool(ct2_auto_convert) self.ct2_conversion_quantization = _resolve_compute_type( torch_dtype, ct2_conversion_quantization or self.compute_type, self.device, ) self.ct2_inter_threads = int(ct2_inter_threads) self.ct2_intra_threads = int(ct2_intra_threads) self.ct2_max_queued_batches = int(ct2_max_queued_batches) self.ct2_batch_type = str(ct2_batch_type or "examples").strip().lower() if self.ct2_batch_type not in {"examples", "tokens"}: raise ValueError(f"Unsupported CTranslate2 batch type: {ct2_batch_type}") self.ct2_decoding_length_mode = str(ct2_decoding_length_mode or "fixed").strip().lower() if self.ct2_decoding_length_mode not in {"fixed", "source"}: raise ValueError(f"Unsupported CTranslate2 decoding length mode: {ct2_decoding_length_mode}") self.ct2_decoding_length_extra = int(ct2_decoding_length_extra) self.ct2_decoding_length_min = max(1, int(ct2_decoding_length_min)) self._tokenizer_lock = threading.Lock() self._local_model_source = self._resolve_local_model_source() self._load_runtime() @property def supports_batch(self) -> bool: return True def _tokenizer_source(self) -> str: return self._local_model_source or self.model_id def _model_source(self) -> str: return self._local_model_source or self.model_id def _expected_local_model_types(self) -> Optional[set[str]]: return None def _resolve_local_model_source(self) -> Optional[str]: model_path = Path(self.model_dir).expanduser() if not model_path.exists(): return None if not (model_path / "config.json").exists(): logger.warning( "Local translation model_dir is incomplete | model=%s model_dir=%s missing=config.json fallback=model_id", self.model, model_path, ) return None expected_types = self._expected_local_model_types() if not expected_types: return str(model_path) detected_type = _detect_local_model_type(str(model_path)) if detected_type is None: return str(model_path) if detected_type in expected_types: return str(model_path) logger.warning( "Local translation model_dir has unexpected model_type | model=%s model_dir=%s detected=%s expected=%s fallback=model_id", self.model, model_path, detected_type, sorted(expected_types), ) return None def _tokenizer_kwargs(self) -> Dict[str, object]: return {} def _translator_kwargs(self) -> Dict[str, object]: return { "device": self.device, "compute_type": self.compute_type, "inter_threads": self.ct2_inter_threads, "intra_threads": self.ct2_intra_threads, "max_queued_batches": self.ct2_max_queued_batches, } def _load_runtime(self) -> None: try: import ctranslate2 except ImportError as exc: raise RuntimeError( "CTranslate2 is required for local Marian/NLLB translation. " "Install the translator service dependencies again after adding ctranslate2." ) from exc tokenizer_source = self._tokenizer_source() model_source = self._model_source() self._ensure_converted_model(model_source) logger.info( "Loading CTranslate2 translation model | name=%s ct2_model_dir=%s tokenizer=%s device=%s compute_type=%s", self.model, self.ct2_model_dir, tokenizer_source, self.device, self.compute_type, ) self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_source, **self._tokenizer_kwargs()) self.translator = ctranslate2.Translator(self.ct2_model_dir, **self._translator_kwargs()) if self.tokenizer.pad_token is None and self.tokenizer.eos_token is not None: self.tokenizer.pad_token = self.tokenizer.eos_token def _ensure_converted_model(self, model_source: str) -> None: ct2_path = Path(self.ct2_model_dir).expanduser() if (ct2_path / "model.bin").exists(): return if not self.ct2_auto_convert: raise FileNotFoundError( f"CTranslate2 model not found for '{self.model}': {ct2_path}. " "Enable ct2_auto_convert or pre-convert the model." ) ct2_path.parent.mkdir(parents=True, exist_ok=True) logger.info( "Converting translation model to CTranslate2 | name=%s source=%s output=%s quantization=%s", self.model, model_source, ct2_path, self.ct2_conversion_quantization, ) try: convert_transformers_model( model_source, str(ct2_path), self.ct2_conversion_quantization, ) except Exception as exc: raise RuntimeError( f"Failed to convert model '{self.model}' to CTranslate2: {exc}" ) from exc def _normalize_texts(self, text: Union[str, Sequence[str]]) -> List[str]: if isinstance(text, str): return [text] return ["" if item is None else str(item) for item in text] def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: del source_lang, target_lang def _encode_source_tokens( self, texts: List[str], source_lang: Optional[str], target_lang: str, ) -> List[List[str]]: del source_lang, target_lang with self._tokenizer_lock: encoded = self.tokenizer( texts, truncation=True, max_length=self.max_input_length, padding=False, ) input_ids = encoded["input_ids"] return [self.tokenizer.convert_ids_to_tokens(ids) for ids in input_ids] def _target_prefixes( self, count: int, source_lang: Optional[str], target_lang: str, ) -> Optional[List[Optional[List[str]]]]: del count, source_lang, target_lang return None def _resolve_max_decoding_length(self, source_tokens: Sequence[Sequence[str]]) -> int: if self.ct2_decoding_length_mode != "source": return self.max_new_tokens if not source_tokens: return self.max_new_tokens max_source_length = max(len(tokens) for tokens in source_tokens) dynamic_length = max(self.ct2_decoding_length_min, max_source_length + self.ct2_decoding_length_extra) return min(self.max_new_tokens, dynamic_length) def _postprocess_hypothesis( self, tokens: List[str], source_lang: Optional[str], target_lang: str, ) -> List[str]: del source_lang, target_lang return tokens def _decode_tokens(self, tokens: List[str]) -> Optional[str]: token_ids = self.tokenizer.convert_tokens_to_ids(tokens) text = self.tokenizer.decode(token_ids, skip_special_tokens=True).strip() return text or None def _translate_batch( self, texts: List[str], target_lang: str, source_lang: Optional[str] = None, ) -> List[Optional[str]]: self._validate_languages(source_lang, target_lang) source_tokens = self._encode_source_tokens(texts, source_lang, target_lang) target_prefix = self._target_prefixes(len(source_tokens), source_lang, target_lang) max_decoding_length = self._resolve_max_decoding_length(source_tokens) logger.info( "Translation model batch detail | model=%s segment_count=%s token_lengths=%s max_decoding_length=%s batch_type=%s beam_size=%s target_lang=%s source_lang=%s", self.model, len(source_tokens), _summarize_lengths([len(tokens) for tokens in source_tokens]), max_decoding_length, self.ct2_batch_type, self.num_beams, target_lang, source_lang or "auto", ) results = self.translator.translate_batch( source_tokens, target_prefix=target_prefix, max_batch_size=self.batch_size, batch_type=self.ct2_batch_type, beam_size=self.num_beams, max_input_length=self.max_input_length, max_decoding_length=max_decoding_length, ) outputs: List[Optional[str]] = [] for result in results: hypothesis = result.hypotheses[0] if result.hypotheses else [] processed = self._postprocess_hypothesis(hypothesis, source_lang, target_lang) outputs.append(self._decode_tokens(processed)) return outputs def _token_count( self, text: str, target_lang: str, source_lang: Optional[str] = None, ) -> int: encoded = self._encode_source_tokens([text], source_lang, target_lang) return len(encoded[0]) if encoded else 0 def _effective_input_token_limit(self, target_lang: str, source_lang: Optional[str] = None) -> int: del target_lang, source_lang return compute_safe_input_token_limit( max_input_length=self.max_input_length, max_new_tokens=self.max_new_tokens, decoding_length_mode=self.ct2_decoding_length_mode, decoding_length_extra=self.ct2_decoding_length_extra, ) def _split_text_if_needed( self, text: str, target_lang: str, source_lang: Optional[str] = None, ) -> List[str]: limit = self._effective_input_token_limit(target_lang, source_lang) token_count_cache: Dict[str, int] = {} def _cached_token_count(value: str) -> int: cached = token_count_cache.get(value) if cached is not None: return cached count = self._token_count( value, target_lang=target_lang, source_lang=source_lang, ) token_count_cache[value] = count return count return split_text_for_translation( text, max_tokens=limit, token_length_fn=_cached_token_count, ) def _log_segmentation_summary( self, *, texts: Sequence[str], segment_plans: Sequence[Sequence[str]], target_lang: str, source_lang: Optional[str], ) -> None: non_empty_count = sum(1 for text in texts if text.strip()) segment_counts = [len(segments) for segments in segment_plans if segments] total_segments = sum(segment_counts) segmented_inputs = sum(1 for count in segment_counts if count > 1) logger.info( "Translation segmentation summary | model=%s inputs=%s non_empty_inputs=%s segmented_inputs=%s total_segments=%s batch_size=%s target_lang=%s source_lang=%s segments_per_input=%s", self.model, len(texts), non_empty_count, segmented_inputs, total_segments, self.batch_size, target_lang, source_lang or "auto", _summarize_lengths(segment_counts), ) def _translate_segment_batches( self, segments: List[str], target_lang: str, source_lang: Optional[str] = None, ) -> List[Optional[str]]: if not segments: return [] outputs: List[Optional[str]] = [] total_batches = (len(segments) + self.batch_size - 1) // self.batch_size for batch_index, start in enumerate(range(0, len(segments), self.batch_size), start=1): batch = segments[start:start + self.batch_size] logger.info( "Translation inference batch | model=%s batch_index=%s total_batches=%s segment_count=%s char_lengths=%s first_preview=%s target_lang=%s source_lang=%s", self.model, batch_index, total_batches, len(batch), _summarize_lengths([len(segment) for segment in batch]), _text_preview(batch[0] if batch else ""), target_lang, source_lang or "auto", ) outputs.extend( self._translate_batch(batch, target_lang=target_lang, source_lang=source_lang) ) return outputs def _translate_with_segmentation( self, texts: List[str], target_lang: str, source_lang: Optional[str] = None, ) -> List[Optional[str]]: segment_plans: List[List[str]] = [] flat_segments: List[str] = [] for text in texts: if not text.strip(): segment_plans.append([]) continue segments = self._split_text_if_needed(text, target_lang=target_lang, source_lang=source_lang) segment_plans.append(segments) flat_segments.extend(segments) self._log_segmentation_summary( texts=texts, segment_plans=segment_plans, target_lang=target_lang, source_lang=source_lang, ) translated_segments = ( self._translate_segment_batches(flat_segments, target_lang=target_lang, source_lang=source_lang) if flat_segments else [] ) outputs: List[Optional[str]] = [] offset = 0 for original_text, segments in zip(texts, segment_plans): if not segments: outputs.append(None if not original_text.strip() else original_text) continue current = translated_segments[offset:offset + len(segments)] offset += len(segments) if len(segments) == 1: outputs.append(current[0]) continue outputs.append( join_translated_segments( current, target_lang=target_lang, original_text=original_text, ) ) return outputs def translate( self, text: Union[str, Sequence[str]], target_lang: str, source_lang: Optional[str] = None, scene: Optional[str] = None, ) -> Union[Optional[str], List[Optional[str]]]: del scene is_single = isinstance(text, str) texts = self._normalize_texts(text) if not any(item.strip() for item in texts): outputs = [None if not item.strip() else item for item in texts] # type: ignore[list-item] return outputs[0] if is_single else outputs outputs = self._translate_with_segmentation(texts, target_lang=target_lang, source_lang=source_lang) return outputs[0] if is_single else outputs class MarianCTranslate2TranslationBackend(LocalCTranslate2TranslationBackend): """Local backend for Marian/OPUS MT models on CTranslate2.""" def __init__( self, *, name: str, model_id: str, model_dir: str, device: str, torch_dtype: str, batch_size: int, max_input_length: int, max_new_tokens: int, num_beams: int, source_langs: Sequence[str], target_langs: Sequence[str], ct2_model_dir: Optional[str] = None, ct2_compute_type: Optional[str] = None, ct2_auto_convert: bool = True, ct2_conversion_quantization: Optional[str] = None, ct2_inter_threads: int = 1, ct2_intra_threads: int = 0, ct2_max_queued_batches: int = 0, ct2_batch_type: str = "examples", ct2_decoding_length_mode: str = "fixed", ct2_decoding_length_extra: int = 0, ct2_decoding_length_min: int = 1, ) -> None: self.source_langs = {str(lang).strip().lower() for lang in source_langs if str(lang).strip()} self.target_langs = {str(lang).strip().lower() for lang in target_langs if str(lang).strip()} super().__init__( name=name, model_id=model_id, model_dir=model_dir, device=device, torch_dtype=torch_dtype, batch_size=batch_size, max_input_length=max_input_length, max_new_tokens=max_new_tokens, num_beams=num_beams, ct2_model_dir=ct2_model_dir, ct2_compute_type=ct2_compute_type, ct2_auto_convert=ct2_auto_convert, ct2_conversion_quantization=ct2_conversion_quantization, ct2_inter_threads=ct2_inter_threads, ct2_intra_threads=ct2_intra_threads, ct2_max_queued_batches=ct2_max_queued_batches, ct2_batch_type=ct2_batch_type, ct2_decoding_length_mode=ct2_decoding_length_mode, ct2_decoding_length_extra=ct2_decoding_length_extra, ct2_decoding_length_min=ct2_decoding_length_min, ) def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: src = str(source_lang or "").strip().lower() tgt = str(target_lang or "").strip().lower() if self.source_langs and src not in self.source_langs: raise ValueError( f"Model '{self.model}' only supports source languages: {sorted(self.source_langs)}" ) if self.target_langs and tgt not in self.target_langs: raise ValueError( f"Model '{self.model}' only supports target languages: {sorted(self.target_langs)}" ) def _expected_local_model_types(self) -> Optional[set[str]]: return {"marian"} class NLLBCTranslate2TranslationBackend(LocalCTranslate2TranslationBackend): """Local backend for NLLB models on CTranslate2.""" def __init__( self, *, name: str, model_id: str, model_dir: str, device: str, torch_dtype: str, batch_size: int, max_input_length: int, max_new_tokens: int, num_beams: int, language_codes: Optional[Dict[str, str]] = None, ct2_model_dir: Optional[str] = None, ct2_compute_type: Optional[str] = None, ct2_auto_convert: bool = True, ct2_conversion_quantization: Optional[str] = None, ct2_inter_threads: int = 1, ct2_intra_threads: int = 0, ct2_max_queued_batches: int = 0, ct2_batch_type: str = "examples", ct2_decoding_length_mode: str = "fixed", ct2_decoding_length_extra: int = 0, ct2_decoding_length_min: int = 1, ) -> None: self.language_codes = build_nllb_language_catalog(language_codes) self._tokenizers_by_source: Dict[str, object] = {} super().__init__( name=name, model_id=model_id, model_dir=model_dir, device=device, torch_dtype=torch_dtype, batch_size=batch_size, max_input_length=max_input_length, max_new_tokens=max_new_tokens, num_beams=num_beams, ct2_model_dir=ct2_model_dir, ct2_compute_type=ct2_compute_type, ct2_auto_convert=ct2_auto_convert, ct2_conversion_quantization=ct2_conversion_quantization, ct2_inter_threads=ct2_inter_threads, ct2_intra_threads=ct2_intra_threads, ct2_max_queued_batches=ct2_max_queued_batches, ct2_batch_type=ct2_batch_type, ct2_decoding_length_mode=ct2_decoding_length_mode, ct2_decoding_length_extra=ct2_decoding_length_extra, ct2_decoding_length_min=ct2_decoding_length_min, ) def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: if not str(source_lang or "").strip(): raise ValueError(f"Model '{self.model}' requires source_lang") if resolve_nllb_language_code(source_lang, self.language_codes) is None: raise ValueError(f"Unsupported NLLB source language: {source_lang}") if resolve_nllb_language_code(target_lang, self.language_codes) is None: raise ValueError(f"Unsupported NLLB target language: {target_lang}") def _expected_local_model_types(self) -> Optional[set[str]]: return {"m2m_100", "nllb_moe"} def _get_tokenizer_for_source(self, source_lang: str): src_code = resolve_nllb_language_code(source_lang, self.language_codes) if src_code is None: raise ValueError(f"Unsupported NLLB source language: {source_lang}") with self._tokenizer_lock: tokenizer = self._tokenizers_by_source.get(src_code) if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(self._tokenizer_source(), src_lang=src_code) if tokenizer.pad_token is None and tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token self._tokenizers_by_source[src_code] = tokenizer return tokenizer def _encode_source_tokens( self, texts: List[str], source_lang: Optional[str], target_lang: str, ) -> List[List[str]]: del target_lang source_key = normalize_language_key(source_lang) tokenizer = self._get_tokenizer_for_source(source_key) encoded = tokenizer( texts, truncation=True, max_length=self.max_input_length, padding=False, ) input_ids = encoded["input_ids"] return [tokenizer.convert_ids_to_tokens(ids) for ids in input_ids] def _target_prefixes( self, count: int, source_lang: Optional[str], target_lang: str, ) -> Optional[List[Optional[List[str]]]]: del source_lang tgt_code = resolve_nllb_language_code(target_lang, self.language_codes) if tgt_code is None: raise ValueError(f"Unsupported NLLB target language: {target_lang}") return [[tgt_code] for _ in range(count)] def _postprocess_hypothesis( self, tokens: List[str], source_lang: Optional[str], target_lang: str, ) -> List[str]: del source_lang tgt_code = resolve_nllb_language_code(target_lang, self.language_codes) if tgt_code is None: raise ValueError(f"Unsupported NLLB target language: {target_lang}") if tokens and tokens[0] == tgt_code: return tokens[1:] return tokens def get_marian_language_direction(model_name: str) -> tuple[str, str]: direction = MARIAN_LANGUAGE_DIRECTIONS.get(model_name) if direction is None: raise ValueError(f"Translation capability '{model_name}' is not registered with Marian language directions") return direction