local_seq2seq.py 10.7 KB
"""Local seq2seq translation backends powered by Transformers."""

from __future__ import annotations

import logging
import os
import threading
from typing import Dict, List, Optional, Sequence, Union

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from translation.languages import MARIAN_LANGUAGE_DIRECTIONS, NLLB_LANGUAGE_CODES

logger = logging.getLogger(__name__)


def _resolve_device(device: Optional[str]) -> str:
    value = str(device or "auto").strip().lower()
    if value == "auto":
        return "cuda" if torch.cuda.is_available() else "cpu"
    return value


def _resolve_dtype(dtype: Optional[str], device: str) -> Optional[torch.dtype]:
    value = str(dtype or "auto").strip().lower()
    if value == "auto":
        return torch.float16 if device.startswith("cuda") else None
    if value in {"float16", "fp16", "half"}:
        return torch.float16 if device.startswith("cuda") else None
    if value in {"bfloat16", "bf16"}:
        return torch.bfloat16
    if value in {"float32", "fp32"}:
        return torch.float32
    raise ValueError(f"Unsupported torch dtype: {dtype}")


class LocalSeq2SeqTranslationBackend:
    """Base backend for local Hugging Face seq2seq translation models."""

    def __init__(
        self,
        *,
        name: str,
        model_id: str,
        model_dir: str,
        device: str,
        torch_dtype: str,
        batch_size: int,
        max_input_length: int,
        max_new_tokens: int,
        num_beams: int,
        attn_implementation: Optional[str] = None,
    ) -> None:
        self.model = name
        self.model_id = model_id
        self.model_dir = model_dir
        self.device = _resolve_device(device)
        self.torch_dtype = _resolve_dtype(torch_dtype, self.device)
        self.batch_size = int(batch_size)
        self.max_input_length = int(max_input_length)
        self.max_new_tokens = int(max_new_tokens)
        self.num_beams = int(num_beams)
        self.attn_implementation = str(attn_implementation or "").strip() or None
        self._lock = threading.Lock()
        self._load_model()

    @property
    def supports_batch(self) -> bool:
        return True

    def _load_model(self) -> None:
        model_path = self.model_dir if os.path.exists(self.model_dir) else self.model_id
        logger.info(
            "Loading local translation model | name=%s source=%s device=%s dtype=%s",
            self.model,
            model_path,
            self.device,
            self.torch_dtype,
        )
        tokenizer_kwargs = self._tokenizer_kwargs()
        model_kwargs = self._model_kwargs()
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_kwargs)
        self.seq2seq_model = AutoModelForSeq2SeqLM.from_pretrained(model_path, **model_kwargs)
        self.seq2seq_model.to(self.device)
        self.seq2seq_model.eval()
        if self.tokenizer.pad_token is None and self.tokenizer.eos_token is not None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def _tokenizer_kwargs(self) -> Dict[str, object]:
        return {}

    def _model_kwargs(self) -> Dict[str, object]:
        kwargs: Dict[str, object] = {}
        if self.torch_dtype is not None:
            kwargs["torch_dtype"] = self.torch_dtype
        kwargs["low_cpu_mem_usage"] = True
        if self.attn_implementation:
            kwargs["attn_implementation"] = self.attn_implementation
        return kwargs

    def _normalize_texts(self, text: Union[str, Sequence[str]]) -> List[str]:
        if isinstance(text, str):
            return [text]
        return ["" if item is None else str(item) for item in text]

    def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None:
        del source_lang, target_lang

    def _prepare_tokenizer(self, source_lang: Optional[str], target_lang: str) -> Dict[str, object]:
        del source_lang, target_lang
        return {}

    def _build_generate_kwargs(self, source_lang: Optional[str], target_lang: str) -> Dict[str, object]:
        del source_lang, target_lang
        return {
            "num_beams": self.num_beams,
        }

    def _translate_batch(
        self,
        texts: List[str],
        target_lang: str,
        source_lang: Optional[str] = None,
    ) -> List[Optional[str]]:
        self._validate_languages(source_lang, target_lang)
        tokenizer_kwargs = self._prepare_tokenizer(source_lang, target_lang)
        with self._lock, torch.inference_mode():
            encoded = self.tokenizer(
                texts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.max_input_length,
                **tokenizer_kwargs,
            )
            encoded = {
                key: value.to(self.device, non_blocking=self.device.startswith("cuda"))
                for key, value in encoded.items()
            }
            generate_kwargs = self._build_generate_kwargs(source_lang, target_lang)
            input_ids = encoded.get("input_ids")
            if input_ids is not None and "max_length" not in generate_kwargs:
                generate_kwargs["max_length"] = int(input_ids.shape[-1]) + self.max_new_tokens
            generated = self.seq2seq_model.generate(
                **encoded,
                **generate_kwargs,
            )
            outputs = self.tokenizer.batch_decode(generated, skip_special_tokens=True)
        return [item.strip() if item and item.strip() else None for item in outputs]

    def translate(
        self,
        text: Union[str, Sequence[str]],
        target_lang: str,
        source_lang: Optional[str] = None,
        scene: Optional[str] = None,
    ) -> Union[Optional[str], List[Optional[str]]]:
        del scene
        is_single = isinstance(text, str)
        texts = self._normalize_texts(text)
        outputs: List[Optional[str]] = []
        for start in range(0, len(texts), self.batch_size):
            chunk = texts[start:start + self.batch_size]
            if not any(item.strip() for item in chunk):
                outputs.extend([None if not item.strip() else item for item in chunk])  # type: ignore[list-item]
                continue
            outputs.extend(self._translate_batch(chunk, target_lang=target_lang, source_lang=source_lang))
        return outputs[0] if is_single else outputs


class MarianMTTranslationBackend(LocalSeq2SeqTranslationBackend):
    """Local backend for Marian/OPUS MT models."""

    def __init__(
        self,
        *,
        name: str,
        model_id: str,
        model_dir: str,
        device: str,
        torch_dtype: str,
        batch_size: int,
        max_input_length: int,
        max_new_tokens: int,
        num_beams: int,
        source_langs: Sequence[str],
        target_langs: Sequence[str],
        attn_implementation: Optional[str] = None,
    ) -> None:
        self.source_langs = {str(lang).strip().lower() for lang in source_langs if str(lang).strip()}
        self.target_langs = {str(lang).strip().lower() for lang in target_langs if str(lang).strip()}
        super().__init__(
            name=name,
            model_id=model_id,
            model_dir=model_dir,
            device=device,
            torch_dtype=torch_dtype,
            batch_size=batch_size,
            max_input_length=max_input_length,
            max_new_tokens=max_new_tokens,
            num_beams=num_beams,
            attn_implementation=attn_implementation,
        )

    def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None:
        src = str(source_lang or "").strip().lower()
        tgt = str(target_lang or "").strip().lower()
        if self.source_langs and src not in self.source_langs:
            raise ValueError(
                f"Model '{self.model}' only supports source languages: {sorted(self.source_langs)}"
            )
        if self.target_langs and tgt not in self.target_langs:
            raise ValueError(
                f"Model '{self.model}' only supports target languages: {sorted(self.target_langs)}"
            )


class NLLBTranslationBackend(LocalSeq2SeqTranslationBackend):
    """Local backend for NLLB translation models."""

    def __init__(
        self,
        *,
        name: str,
        model_id: str,
        model_dir: str,
        device: str,
        torch_dtype: str,
        batch_size: int,
        max_input_length: int,
        max_new_tokens: int,
        num_beams: int,
        language_codes: Optional[Dict[str, str]] = None,
        attn_implementation: Optional[str] = None,
    ) -> None:
        overrides = language_codes or {}
        self.language_codes = {
            **NLLB_LANGUAGE_CODES,
            **{str(k).strip().lower(): str(v).strip() for k, v in overrides.items() if str(k).strip()},
        }
        super().__init__(
            name=name,
            model_id=model_id,
            model_dir=model_dir,
            device=device,
            torch_dtype=torch_dtype,
            batch_size=batch_size,
            max_input_length=max_input_length,
            max_new_tokens=max_new_tokens,
            num_beams=num_beams,
            attn_implementation=attn_implementation,
        )

    def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None:
        src = str(source_lang or "").strip().lower()
        tgt = str(target_lang or "").strip().lower()
        if not src:
            raise ValueError(f"Model '{self.model}' requires source_lang")
        if src not in self.language_codes:
            raise ValueError(f"Unsupported NLLB source language: {source_lang}")
        if tgt not in self.language_codes:
            raise ValueError(f"Unsupported NLLB target language: {target_lang}")

    def _prepare_tokenizer(self, source_lang: Optional[str], target_lang: str) -> Dict[str, object]:
        del target_lang
        src_code = self.language_codes[str(source_lang).strip().lower()]
        self.tokenizer.src_lang = src_code
        return {}

    def _build_generate_kwargs(self, source_lang: Optional[str], target_lang: str) -> Dict[str, object]:
        del source_lang
        tgt_code = self.language_codes[str(target_lang).strip().lower()]
        forced_bos_token_id = None
        if hasattr(self.tokenizer, "lang_code_to_id"):
            forced_bos_token_id = self.tokenizer.lang_code_to_id.get(tgt_code)
        if forced_bos_token_id is None:
            forced_bos_token_id = self.tokenizer.convert_tokens_to_ids(tgt_code)
        return {
            "num_beams": self.num_beams,
            "forced_bos_token_id": forced_bos_token_id,
        }


def get_marian_language_direction(model_name: str) -> tuple[str, str]:
    direction = MARIAN_LANGUAGE_DIRECTIONS.get(model_name)
    if direction is None:
        raise ValueError(f"Translation capability '{model_name}' is not registered with Marian language directions")
    return direction