"""Local seq2seq translation backends powered by Transformers.""" from __future__ import annotations import logging import os import threading from typing import Dict, List, Optional, Sequence, Union import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from translation.languages import ( MARIAN_LANGUAGE_DIRECTIONS, build_nllb_language_catalog, resolve_nllb_language_code, ) from translation.text_splitter import ( compute_safe_input_token_limit, join_translated_segments, split_text_for_translation, ) logger = logging.getLogger(__name__) def _text_preview(text: Optional[str], limit: int = 32) -> str: return str(text or "").replace("\n", "\\n")[:limit] def _summarize_lengths(values: Sequence[int]) -> str: if not values: return "[]" total = sum(values) return f"min={min(values)} max={max(values)} avg={total / len(values):.1f}" def _resolve_device(device: Optional[str]) -> str: value = str(device or "auto").strip().lower() if value == "auto": return "cuda" if torch.cuda.is_available() else "cpu" return value def _resolve_dtype(dtype: Optional[str], device: str) -> Optional[torch.dtype]: value = str(dtype or "auto").strip().lower() if value == "auto": return torch.float16 if device.startswith("cuda") else None if value in {"float16", "fp16", "half"}: return torch.float16 if device.startswith("cuda") else None if value in {"bfloat16", "bf16"}: return torch.bfloat16 if value in {"float32", "fp32"}: return torch.float32 raise ValueError(f"Unsupported torch dtype: {dtype}") class LocalSeq2SeqTranslationBackend: """Base backend for local Hugging Face seq2seq translation models.""" def __init__( self, *, name: str, model_id: str, model_dir: str, device: str, torch_dtype: str, batch_size: int, max_input_length: int, max_new_tokens: int, num_beams: int, attn_implementation: Optional[str] = None, ) -> None: self.model = name self.model_id = model_id self.model_dir = model_dir self.device = _resolve_device(device) self.torch_dtype = _resolve_dtype(torch_dtype, self.device) self.batch_size = int(batch_size) self.max_input_length = int(max_input_length) self.max_new_tokens = int(max_new_tokens) self.num_beams = int(num_beams) self.attn_implementation = str(attn_implementation or "").strip() or None self._lock = threading.Lock() self._load_model() @property def supports_batch(self) -> bool: return True def _load_model(self) -> None: model_path = self.model_dir if os.path.exists(self.model_dir) else self.model_id logger.info( "Loading local translation model | name=%s source=%s device=%s dtype=%s", self.model, model_path, self.device, self.torch_dtype, ) tokenizer_kwargs = self._tokenizer_kwargs() model_kwargs = self._model_kwargs() self.tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_kwargs) self.seq2seq_model = AutoModelForSeq2SeqLM.from_pretrained(model_path, **model_kwargs) self.seq2seq_model.to(self.device) self.seq2seq_model.eval() if self.tokenizer.pad_token is None and self.tokenizer.eos_token is not None: self.tokenizer.pad_token = self.tokenizer.eos_token def _tokenizer_kwargs(self) -> Dict[str, object]: return {} def _model_kwargs(self) -> Dict[str, object]: kwargs: Dict[str, object] = {} if self.torch_dtype is not None: kwargs["torch_dtype"] = self.torch_dtype kwargs["low_cpu_mem_usage"] = True if self.attn_implementation: kwargs["attn_implementation"] = self.attn_implementation return kwargs def _normalize_texts(self, text: Union[str, Sequence[str]]) -> List[str]: if isinstance(text, str): return [text] return ["" if item is None else str(item) for item in text] def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: del source_lang, target_lang def _prepare_tokenizer(self, source_lang: Optional[str], target_lang: str) -> Dict[str, object]: del source_lang, target_lang return {} def _build_generate_kwargs(self, source_lang: Optional[str], target_lang: str) -> Dict[str, object]: del source_lang, target_lang return { "num_beams": self.num_beams, } def _translate_batch( self, texts: List[str], target_lang: str, source_lang: Optional[str] = None, ) -> List[Optional[str]]: self._validate_languages(source_lang, target_lang) tokenizer_kwargs = self._prepare_tokenizer(source_lang, target_lang) with self._lock, torch.inference_mode(): encoded = self.tokenizer( texts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_input_length, **tokenizer_kwargs, ) encoded = { key: value.to(self.device, non_blocking=self.device.startswith("cuda")) for key, value in encoded.items() } generate_kwargs = self._build_generate_kwargs(source_lang, target_lang) input_ids = encoded.get("input_ids") if input_ids is not None and "max_length" not in generate_kwargs: generate_kwargs["max_length"] = int(input_ids.shape[-1]) + self.max_new_tokens generated = self.seq2seq_model.generate( **encoded, **generate_kwargs, ) outputs = self.tokenizer.batch_decode(generated, skip_special_tokens=True) return [item.strip() if item and item.strip() else None for item in outputs] def _token_count( self, text: str, target_lang: str, source_lang: Optional[str] = None, ) -> int: tokenizer_kwargs = self._prepare_tokenizer(source_lang, target_lang) with self._lock: encoded = self.tokenizer( [text], truncation=False, padding=False, **tokenizer_kwargs, ) input_ids = encoded["input_ids"] first_item = input_ids[0] if hasattr(first_item, "shape"): return int(first_item.shape[-1]) return len(first_item) def _effective_input_token_limit(self, target_lang: str, source_lang: Optional[str] = None) -> int: del target_lang, source_lang return compute_safe_input_token_limit( max_input_length=self.max_input_length, max_new_tokens=self.max_new_tokens, ) def _split_text_if_needed( self, text: str, target_lang: str, source_lang: Optional[str] = None, ) -> List[str]: limit = self._effective_input_token_limit(target_lang, source_lang) token_count_cache: Dict[str, int] = {} def _cached_token_count(value: str) -> int: cached = token_count_cache.get(value) if cached is not None: return cached count = self._token_count( value, target_lang=target_lang, source_lang=source_lang, ) token_count_cache[value] = count return count return split_text_for_translation( text, max_tokens=limit, token_length_fn=_cached_token_count, ) def _log_segmentation_summary( self, *, texts: Sequence[str], segment_plans: Sequence[Sequence[str]], target_lang: str, source_lang: Optional[str], ) -> None: non_empty_count = sum(1 for text in texts if text.strip()) segment_counts = [len(segments) for segments in segment_plans if segments] total_segments = sum(segment_counts) segmented_inputs = sum(1 for count in segment_counts if count > 1) logger.info( "Translation segmentation summary | model=%s inputs=%s non_empty_inputs=%s segmented_inputs=%s total_segments=%s batch_size=%s target_lang=%s source_lang=%s segments_per_input=%s", self.model, len(texts), non_empty_count, segmented_inputs, total_segments, self.batch_size, target_lang, source_lang or "auto", _summarize_lengths(segment_counts), ) def _translate_segment_batches( self, segments: List[str], target_lang: str, source_lang: Optional[str] = None, ) -> List[Optional[str]]: if not segments: return [] outputs: List[Optional[str]] = [] total_batches = (len(segments) + self.batch_size - 1) // self.batch_size for batch_index, start in enumerate(range(0, len(segments), self.batch_size), start=1): batch = segments[start:start + self.batch_size] logger.info( "Translation inference batch | model=%s batch_index=%s total_batches=%s segment_count=%s char_lengths=%s first_preview=%s target_lang=%s source_lang=%s", self.model, batch_index, total_batches, len(batch), _summarize_lengths([len(segment) for segment in batch]), _text_preview(batch[0] if batch else ""), target_lang, source_lang or "auto", ) outputs.extend( self._translate_batch(batch, target_lang=target_lang, source_lang=source_lang) ) return outputs def _translate_with_segmentation( self, texts: List[str], target_lang: str, source_lang: Optional[str] = None, ) -> List[Optional[str]]: segment_plans: List[List[str]] = [] flat_segments: List[str] = [] for text in texts: if not text.strip(): segment_plans.append([]) continue segments = self._split_text_if_needed(text, target_lang=target_lang, source_lang=source_lang) segment_plans.append(segments) flat_segments.extend(segments) self._log_segmentation_summary( texts=texts, segment_plans=segment_plans, target_lang=target_lang, source_lang=source_lang, ) translated_segments = ( self._translate_segment_batches(flat_segments, target_lang=target_lang, source_lang=source_lang) if flat_segments else [] ) outputs: List[Optional[str]] = [] offset = 0 for original_text, segments in zip(texts, segment_plans): if not segments: outputs.append(None if not original_text.strip() else original_text) continue current = translated_segments[offset:offset + len(segments)] offset += len(segments) if len(segments) == 1: outputs.append(current[0]) continue outputs.append( join_translated_segments( current, target_lang=target_lang, original_text=original_text, ) ) return outputs def translate( self, text: Union[str, Sequence[str]], target_lang: str, source_lang: Optional[str] = None, scene: Optional[str] = None, ) -> Union[Optional[str], List[Optional[str]]]: del scene is_single = isinstance(text, str) texts = self._normalize_texts(text) if not any(item.strip() for item in texts): outputs = [None if not item.strip() else item for item in texts] # type: ignore[list-item] return outputs[0] if is_single else outputs outputs = self._translate_with_segmentation(texts, target_lang=target_lang, source_lang=source_lang) return outputs[0] if is_single else outputs class MarianMTTranslationBackend(LocalSeq2SeqTranslationBackend): """Local backend for Marian/OPUS MT models.""" def __init__( self, *, name: str, model_id: str, model_dir: str, device: str, torch_dtype: str, batch_size: int, max_input_length: int, max_new_tokens: int, num_beams: int, source_langs: Sequence[str], target_langs: Sequence[str], attn_implementation: Optional[str] = None, ) -> None: self.source_langs = {str(lang).strip().lower() for lang in source_langs if str(lang).strip()} self.target_langs = {str(lang).strip().lower() for lang in target_langs if str(lang).strip()} super().__init__( name=name, model_id=model_id, model_dir=model_dir, device=device, torch_dtype=torch_dtype, batch_size=batch_size, max_input_length=max_input_length, max_new_tokens=max_new_tokens, num_beams=num_beams, attn_implementation=attn_implementation, ) def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: src = str(source_lang or "").strip().lower() tgt = str(target_lang or "").strip().lower() if self.source_langs and src not in self.source_langs: raise ValueError( f"Model '{self.model}' only supports source languages: {sorted(self.source_langs)}" ) if self.target_langs and tgt not in self.target_langs: raise ValueError( f"Model '{self.model}' only supports target languages: {sorted(self.target_langs)}" ) class NLLBTranslationBackend(LocalSeq2SeqTranslationBackend): """Local backend for NLLB translation models.""" def __init__( self, *, name: str, model_id: str, model_dir: str, device: str, torch_dtype: str, batch_size: int, max_input_length: int, max_new_tokens: int, num_beams: int, language_codes: Optional[Dict[str, str]] = None, attn_implementation: Optional[str] = None, ) -> None: self.language_codes = build_nllb_language_catalog(language_codes) super().__init__( name=name, model_id=model_id, model_dir=model_dir, device=device, torch_dtype=torch_dtype, batch_size=batch_size, max_input_length=max_input_length, max_new_tokens=max_new_tokens, num_beams=num_beams, attn_implementation=attn_implementation, ) def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: if not str(source_lang or "").strip(): raise ValueError(f"Model '{self.model}' requires source_lang") if resolve_nllb_language_code(source_lang, self.language_codes) is None: raise ValueError(f"Unsupported NLLB source language: {source_lang}") if resolve_nllb_language_code(target_lang, self.language_codes) is None: raise ValueError(f"Unsupported NLLB target language: {target_lang}") def _prepare_tokenizer(self, source_lang: Optional[str], target_lang: str) -> Dict[str, object]: del target_lang src_code = resolve_nllb_language_code(source_lang, self.language_codes) if src_code is None: raise ValueError(f"Unsupported NLLB source language: {source_lang}") self.tokenizer.src_lang = src_code return {} def _build_generate_kwargs(self, source_lang: Optional[str], target_lang: str) -> Dict[str, object]: del source_lang tgt_code = resolve_nllb_language_code(target_lang, self.language_codes) if tgt_code is None: raise ValueError(f"Unsupported NLLB target language: {target_lang}") forced_bos_token_id = None if hasattr(self.tokenizer, "lang_code_to_id"): forced_bos_token_id = self.tokenizer.lang_code_to_id.get(tgt_code) if forced_bos_token_id is None: forced_bos_token_id = self.tokenizer.convert_tokens_to_ids(tgt_code) return { "num_beams": self.num_beams, "forced_bos_token_id": forced_bos_token_id, } def get_marian_language_direction(model_name: str) -> tuple[str, str]: direction = MARIAN_LANGUAGE_DIRECTIONS.get(model_name) if direction is None: raise ValueError(f"Translation capability '{model_name}' is not registered with Marian language directions") return direction