llm.py 12.2 KB
"""LLM-based translation backend."""

from __future__ import annotations

import logging
import re
import time
from typing import List, Optional, Sequence, Union

from openai import OpenAI

from translation.languages import LANGUAGE_LABELS
from translation.prompts import BATCH_TRANSLATION_PROMPTS, TRANSLATION_PROMPTS
from translation.scenes import normalize_scene_name

logger = logging.getLogger(__name__)
_NUMBERED_LINE_RE = re.compile(r"^\s*(\d+)[\.\uFF0E]\s*(.*)\s*$")


def _resolve_prompt_template(
    prompt_groups: dict[str, dict[str, str]],
    *,
    target_lang: str,
    scene: Optional[str],
) -> tuple[str, str, str]:
    tgt = str(target_lang or "").strip().lower()
    normalized_scene = normalize_scene_name(scene)
    group = prompt_groups[normalized_scene]
    template = group.get(tgt) or group.get("en")
    if template is None:
        raise ValueError(f"Missing llm translation prompt for scene='{normalized_scene}' target_lang='{tgt}'")
    return tgt, normalized_scene, template


def _build_prompt(
    text: str,
    *,
    source_lang: Optional[str],
    target_lang: str,
    scene: Optional[str],
) -> str:
    src = str(source_lang or "auto").strip().lower() or "auto"
    tgt, _normalized_scene, template = _resolve_prompt_template(
        TRANSLATION_PROMPTS,
        target_lang=target_lang,
        scene=scene,
    )
    source_lang_label = LANGUAGE_LABELS.get(src, src)
    target_lang_label = LANGUAGE_LABELS.get(tgt, tgt)

    return template.format(
        source_lang=source_lang_label,
        src_lang_code=src,
        target_lang=target_lang_label,
        tgt_lang_code=tgt,
        text=text,
    )


def _build_batch_prompt(
    texts: Sequence[str],
    *,
    source_lang: Optional[str],
    target_lang: str,
    scene: Optional[str],
) -> str:
    src = str(source_lang or "auto").strip().lower() or "auto"
    tgt, _normalized_scene, template = _resolve_prompt_template(
        BATCH_TRANSLATION_PROMPTS,
        target_lang=target_lang,
        scene=scene,
    )
    source_lang_label = LANGUAGE_LABELS.get(src, src)
    target_lang_label = LANGUAGE_LABELS.get(tgt, tgt)
    numbered_input = "\n".join(f"{idx}. {item}" for idx, item in enumerate(texts, start=1))
    format_example = "\n".join(f"{idx}. translation" for idx in range(1, len(texts) + 1))

    return template.format(
        source_lang=source_lang_label,
        src_lang_code=src,
        target_lang=target_lang_label,
        tgt_lang_code=tgt,
        item_count=len(texts),
        format_example=format_example,
        text=numbered_input,
    )


def _parse_batch_translation_output(content: str, *, expected_count: int) -> Optional[List[str]]:
    numbered_lines: dict[int, str] = {}
    for raw_line in content.splitlines():
        stripped = raw_line.strip()
        if not stripped or stripped.startswith("```"):
            continue
        match = _NUMBERED_LINE_RE.match(stripped)
        if match is None:
            logger.warning("[llm] Invalid batch line format | line=%s", raw_line)
            return None
        index = int(match.group(1))
        if index in numbered_lines:
            logger.warning("[llm] Duplicate batch line index | index=%s", index)
            return None
        numbered_lines[index] = match.group(2).strip()

    expected_indices = set(range(1, expected_count + 1))
    actual_indices = set(numbered_lines.keys())
    if actual_indices != expected_indices:
        logger.warning(
            "[llm] Batch line indices mismatch | expected=%s actual=%s",
            sorted(expected_indices),
            sorted(actual_indices),
        )
        return None
    return [numbered_lines[idx] for idx in range(1, expected_count + 1)]


class LLMTranslationBackend:
    def __init__(
        self,
        *,
        capability_name: str,
        model: str,
        timeout_sec: float,
        base_url: str,
        api_key: Optional[str],
    ) -> None:
        self.capability_name = capability_name
        self.model = model
        self.timeout_sec = float(timeout_sec)
        self.base_url = base_url
        self.api_key = api_key
        self.client = self._create_client()

    @property
    def supports_batch(self) -> bool:
        return True

    def _create_client(self) -> Optional[OpenAI]:
        if not self.api_key:
            logger.warning("DASHSCOPE_API_KEY not set; llm translation unavailable")
            return None
        try:
            return OpenAI(api_key=self.api_key, base_url=self.base_url)
        except Exception as exc:
            logger.error("Failed to initialize llm translation client: %s", exc, exc_info=True)
            return None

    def _translate_single(
        self,
        text: str,
        target_lang: str,
        source_lang: Optional[str] = None,
        scene: Optional[str] = None,
    ) -> Optional[str]:
        if not text or not str(text).strip():
            return text
        if not self.client:
            return None

        tgt = str(target_lang or "").strip().lower()
        src = str(source_lang or "auto").strip().lower() or "auto"
        if scene is None:
            raise ValueError("llm translation scene is required")
        normalized_scene = normalize_scene_name(scene)
        user_prompt = _build_prompt(
            text=text,
            source_lang=src,
            target_lang=tgt,
            scene=normalized_scene,
        )
        start = time.time()
        try:
            logger.info(
                "[llm] Request | src=%s tgt=%s model=%s prompt=%s",
                src,
                tgt,
                self.model,
                user_prompt,
            )
            completion = self.client.chat.completions.create(
                model=self.model,
                messages=[{"role": "user", "content": user_prompt}],
                timeout=self.timeout_sec,
            )
            content = (completion.choices[0].message.content or "").strip()
            latency_ms = (time.time() - start) * 1000
            if not content:
                logger.warning("[llm] Empty result | src=%s tgt=%s latency=%.1fms", src, tgt, latency_ms)
                return None
            logger.info(
                "[llm] Success | src=%s tgt=%s src_text=%s response=%s latency=%.1fms",
                src,
                tgt,
                text,
                content,
                latency_ms,
            )
            return content
        except Exception as exc:
            latency_ms = (time.time() - start) * 1000
            logger.warning(
                "[llm] Failed | src=%s tgt=%s latency=%.1fms error=%s",
                src,
                tgt,
                latency_ms,
                exc,
                exc_info=True,
            )
            return None

    def _translate_batch_serial_fallback(
        self,
        texts: Sequence[Optional[str]],
        target_lang: str,
        source_lang: Optional[str] = None,
        scene: Optional[str] = None,
    ) -> List[Optional[str]]:
        results: List[Optional[str]] = []
        for item in texts:
            if item is None:
                results.append(None)
                continue
            normalized = str(item)
            if not normalized.strip():
                results.append(normalized)
                continue
            results.append(
                self._translate_single(
                    text=normalized,
                    target_lang=target_lang,
                    source_lang=source_lang,
                    scene=scene,
                )
            )
        return results

    def _translate_batch(
        self,
        texts: Sequence[Optional[str]],
        target_lang: str,
        source_lang: Optional[str] = None,
        scene: Optional[str] = None,
    ) -> List[Optional[str]]:
        results: List[Optional[str]] = [None] * len(texts)
        prompt_texts: List[str] = []
        prompt_positions: List[int] = []

        for idx, item in enumerate(texts):
            if item is None:
                continue
            normalized = str(item)
            if not normalized.strip():
                results[idx] = normalized
                continue
            if "\n" in normalized or "\r" in normalized:
                logger.info("[llm] Batch fallback to serial | reason=multiline_input item_index=%s", idx)
                return self._translate_batch_serial_fallback(
                    texts=texts,
                    target_lang=target_lang,
                    source_lang=source_lang,
                    scene=scene,
                )
            prompt_texts.append(normalized)
            prompt_positions.append(idx)

        if not prompt_texts:
            return results
        if not self.client:
            return results

        tgt = str(target_lang or "").strip().lower()
        src = str(source_lang or "auto").strip().lower() or "auto"
        if scene is None:
            raise ValueError("llm translation scene is required")
        normalized_scene = normalize_scene_name(scene)
        user_prompt = _build_batch_prompt(
            texts=prompt_texts,
            source_lang=src,
            target_lang=tgt,
            scene=normalized_scene,
        )

        start = time.time()
        try:
            logger.info(
                "[llm] Batch request | src=%s tgt=%s model=%s item_count=%s prompt=%s",
                src,
                tgt,
                self.model,
                len(prompt_texts),
                user_prompt,
            )
            completion = self.client.chat.completions.create(
                model=self.model,
                messages=[{"role": "user", "content": user_prompt}],
                timeout=self.timeout_sec,
            )
            content = (completion.choices[0].message.content or "").strip()
            latency_ms = (time.time() - start) * 1000
            if not content:
                logger.warning(
                    "[llm] Empty batch result | src=%s tgt=%s item_count=%s latency=%.1fms",
                    src,
                    tgt,
                    len(prompt_texts),
                    latency_ms,
                )
                return self._translate_batch_serial_fallback(
                    texts=texts,
                    target_lang=target_lang,
                    source_lang=source_lang,
                    scene=scene,
                )

            parsed = _parse_batch_translation_output(content, expected_count=len(prompt_texts))
            if parsed is None:
                logger.warning(
                    "[llm] Batch parse failed, fallback to serial | src=%s tgt=%s item_count=%s response=%s",
                    src,
                    tgt,
                    len(prompt_texts),
                    content,
                )
                return self._translate_batch_serial_fallback(
                    texts=texts,
                    target_lang=target_lang,
                    source_lang=source_lang,
                    scene=scene,
                )

            for position, translated in zip(prompt_positions, parsed):
                results[position] = translated
            logger.info(
                "[llm] Batch success | src=%s tgt=%s item_count=%s response=%s latency=%.1fms",
                src,
                tgt,
                len(prompt_texts),
                content,
                latency_ms,
            )
            return results
        except Exception as exc:
            latency_ms = (time.time() - start) * 1000
            logger.warning(
                "[llm] Batch failed | src=%s tgt=%s item_count=%s latency=%.1fms error=%s",
                src,
                tgt,
                len(prompt_texts),
                latency_ms,
                exc,
                exc_info=True,
            )
            return results

    def translate(
        self,
        text: Union[str, Sequence[str]],
        target_lang: str,
        source_lang: Optional[str] = None,
        scene: Optional[str] = None,
    ) -> Union[Optional[str], List[Optional[str]]]:
        if isinstance(text, (list, tuple)):
            return self._translate_batch(
                text,
                target_lang=target_lang,
                source_lang=source_lang,
                scene=scene,
            )

        return self._translate_single(
            text=str(text),
            target_lang=target_lang,
            source_lang=source_lang,
            scene=scene,
        )