translator_app.py 9.75 KB
"""Translator service HTTP app."""

import argparse
import logging
from contextlib import asynccontextmanager
from functools import lru_cache
from typing import List, Optional, Union

import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, ConfigDict, Field

from config.services_config import get_translation_config
from translation.service import TranslationService
from translation.settings import (
    get_enabled_translation_models,
    normalize_translation_model,
    normalize_translation_scene,
)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


@lru_cache(maxsize=1)
def get_translation_service() -> TranslationService:
    return TranslationService(get_translation_config())


# Request/Response models
class TranslationRequest(BaseModel):
    """Translation request model."""

    model_config = ConfigDict(
        json_schema_extra={
            "example": {
                "text": "商品名称",
                "target_lang": "en",
                "source_lang": "zh",
                "model": "llm",
                "scene": "sku_name",
            }
        }
    )

    text: Union[str, List[str]] = Field(..., description="Text to translate (string or list of strings)")
    target_lang: str = Field(..., description="Target language code (zh, en, ru, etc.)")
    source_lang: Optional[str] = Field(None, description="Source language code (optional, auto-detect if not provided)")
    model: Optional[str] = Field(None, description="Enabled translation capability name")
    scene: Optional[str] = Field(None, description="Translation scene, paired with model routing")


class TranslationResponse(BaseModel):
    """Translation response model."""
    text: Union[str, List[str]] = Field(..., description="Original text (string or list)")
    target_lang: str = Field(..., description="Target language code")
    source_lang: Optional[str] = Field(None, description="Source language code (detected or provided)")
    translated_text: Union[str, List[Optional[str]]] = Field(
        ...,
        description="Translated text (string or list; list elements may be null on failure)",
    )
    status: str = Field(..., description="Translation status")
    model: str = Field(..., description="Translation model used")
    scene: str = Field(..., description="Translation scene used")


def _normalize_scene(service: TranslationService, scene: Optional[str]) -> str:
    return normalize_translation_scene(service.config, scene)


def _normalize_model(service: TranslationService, model: Optional[str]) -> str:
    return normalize_translation_model(service.config, model or service.config["default_model"])


def _ensure_valid_text(text: Union[str, List[str]]) -> None:
    if isinstance(text, list):
        if not text:
            raise HTTPException(status_code=400, detail="Text list cannot be empty")
        return
    if not text or not text.strip():
        raise HTTPException(status_code=400, detail="Text cannot be empty")


def _normalize_batch_result(
    original: List[str],
    translated: Union[str, List[Optional[str]], None],
) -> List[Optional[str]]:
    if translated is None:
        return [None for _ in original]
    if not isinstance(translated, list):
        raise HTTPException(status_code=500, detail="Batch translation provider returned non-list result")
    return [translated[idx] if idx < len(translated) else None for idx, _ in enumerate(original)]


def _translate_batch(
    service: TranslationService,
    raw_text: List[str],
    *,
    target_lang: str,
    source_lang: Optional[str],
    model: str,
    scene: str,
) -> List[Optional[str]]:
    backend = service.get_backend(model)
    if getattr(backend, "supports_batch", False):
        try:
            translated = service.translate(
                text=raw_text,
                target_lang=target_lang,
                source_lang=source_lang,
                model=model,
                scene=scene,
            )
            return _normalize_batch_result(raw_text, translated)
        except ValueError:
            raise
        except Exception as exc:
            logger.error("Batch translation failed: %s", exc, exc_info=True)

    results: List[Optional[str]] = []
    for item in raw_text:
        if item is None or not str(item).strip():
            results.append(item)  # type: ignore[arg-type]
            continue
        try:
            out = service.translate(
                text=str(item),
                target_lang=target_lang,
                source_lang=source_lang,
                model=model,
                scene=scene,
            )
        except ValueError:
            raise
        except Exception as exc:
            logger.warning("Per-item translation failed: %s", exc, exc_info=True)
            out = None
        results.append(out)
    return results


@asynccontextmanager
async def lifespan(_: FastAPI):
    """Warm the default backend on process startup."""
    logger.info("Starting Translation Service API")
    service = get_translation_service()
    default_backend = service.get_backend(service.config["default_model"])
    logger.info(
        "Translation service ready | default_model=%s available_models=%s loaded_models=%s",
        service.config["default_model"],
        service.available_models,
        service.loaded_models,
    )
    logger.info(
        "Default translation backend warmed up | model=%s",
        getattr(default_backend, "model", service.config["default_model"]),
    )
    yield


# Create FastAPI app
app = FastAPI(
    title="Translation Service API",
    description="Translation service with pluggable capabilities and scene routing",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc",
    lifespan=lifespan,
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/health")
async def health_check():
    """Health check endpoint."""
    try:
        service = get_translation_service()
        return {
            "status": "healthy",
            "service": "translation",
            "default_model": service.config["default_model"],
            "default_scene": service.config["default_scene"],
            "available_models": service.available_models,
            "enabled_capabilities": get_enabled_translation_models(service.config),
            "loaded_models": service.loaded_models,
        }
    except Exception as e:
        logger.error(f"Health check failed: {e}")
        return JSONResponse(
            status_code=503,
            content={
                "status": "unhealthy",
                "error": str(e)
            }
        )


@app.post("/translate", response_model=TranslationResponse)
async def translate(request: TranslationRequest):
    _ensure_valid_text(request.text)

    if not request.target_lang:
        raise HTTPException(status_code=400, detail="target_lang is required")

    try:
        service = get_translation_service()
        scene = _normalize_scene(service, request.scene)
        model = _normalize_model(service, request.model)
        translator = service.get_backend(model)
        raw_text = request.text

        if isinstance(raw_text, list):
            results = _translate_batch(
                service,
                raw_text,
                target_lang=request.target_lang,
                source_lang=request.source_lang,
                model=model,
                scene=scene,
            )
            return TranslationResponse(
                text=raw_text,
                target_lang=request.target_lang,
                source_lang=request.source_lang,
                translated_text=results,
                status="success",
                model=str(getattr(translator, "model", model)),
                scene=scene,
            )

        translated_text = service.translate(
            text=raw_text,
            target_lang=request.target_lang,
            source_lang=request.source_lang,
            model=model,
            scene=scene,
        )

        if translated_text is None:
            raise HTTPException(status_code=500, detail="Translation failed")

        return TranslationResponse(
            text=raw_text,
            target_lang=request.target_lang,
            source_lang=request.source_lang,
            translated_text=translated_text,
            status="success",
            model=str(getattr(translator, "model", model)),
            scene=scene,
        )
    
    except HTTPException:
        raise
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e)) from e
    except Exception as e:
        logger.error(f"Translation error: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Translation error: {str(e)}")


@app.get("/")
async def root():
    """Root endpoint with API information."""
    return {
        "service": "Translation Service API",
        "version": "1.0.0",
        "status": "running",
        "endpoints": {
            "translate": "POST /translate",
            "health": "GET /health",
            "docs": "GET /docs"
        }
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Start translation API service')
    parser.add_argument('--host', default='0.0.0.0', help='Host to bind to')
    parser.add_argument('--port', type=int, default=6006, help='Port to bind to')
    parser.add_argument('--reload', action='store_true', help='Enable auto-reload')
    args = parser.parse_args()

    # Run server
    uvicorn.run(
        "api.translator_app:app",
        host=args.host,
        port=args.port,
        reload=args.reload
    )