diff --git a/api/translator_app.py b/api/translator_app.py index b4e4f87..50d8927 100644 --- a/api/translator_app.py +++ b/api/translator_app.py @@ -271,16 +271,20 @@ async def lifespan(_: FastAPI): """Initialize all enabled translation backends on process startup.""" logger.info("Starting Translation Service API") service = get_translation_service() + failed_models = list(getattr(service, "failed_models", [])) + backend_errors = dict(getattr(service, "backend_errors", {})) logger.info( - "Translation service ready | default_model=%s default_scene=%s available_models=%s loaded_models=%s", + "Translation service ready | default_model=%s default_scene=%s available_models=%s loaded_models=%s failed_models=%s", service.config["default_model"], service.config["default_scene"], service.available_models, service.loaded_models, + failed_models, ) logger.info( - "Translation backends initialized on startup | models=%s", + "Translation backends initialized on startup | loaded=%s failed=%s", service.loaded_models, + backend_errors, ) verbose_logger.info( "Translation startup detail | capabilities=%s cache_ttl_seconds=%s cache_sliding_expiration=%s", @@ -316,11 +320,14 @@ async def health_check(): """Health check endpoint.""" try: service = get_translation_service() + failed_models = list(getattr(service, "failed_models", [])) + backend_errors = dict(getattr(service, "backend_errors", {})) logger.info( - "Health check | default_model=%s default_scene=%s loaded_models=%s", + "Health check | default_model=%s default_scene=%s loaded_models=%s failed_models=%s", service.config["default_model"], service.config["default_scene"], service.loaded_models, + failed_models, ) return { "status": "healthy", @@ -330,6 +337,8 @@ async def health_check(): "available_models": service.available_models, "enabled_capabilities": get_enabled_translation_models(service.config), "loaded_models": service.loaded_models, + "failed_models": failed_models, + "backend_errors": backend_errors, } except Exception as e: logger.error(f"Health check failed: {e}") @@ -463,6 +472,10 @@ async def translate(request: TranslationRequest, http_request: Request): latency_ms = (time.perf_counter() - request_started) * 1000 logger.warning("Translation validation error | error=%s latency_ms=%.2f", e, latency_ms) raise HTTPException(status_code=400, detail=str(e)) from e + except RuntimeError as e: + latency_ms = (time.perf_counter() - request_started) * 1000 + logger.warning("Translation backend unavailable | error=%s latency_ms=%.2f", e, latency_ms) + raise HTTPException(status_code=503, detail=str(e)) from e except Exception as e: latency_ms = (time.perf_counter() - request_started) * 1000 logger.error("Translation error | error=%s latency_ms=%.2f", e, latency_ms, exc_info=True) diff --git a/config/loader.py b/config/loader.py index 5306f8c..cbe635d 100644 --- a/config/loader.py +++ b/config/loader.py @@ -655,6 +655,14 @@ class AppConfigLoader: translation_raw = raw.get("translation") if isinstance(raw.get("translation"), dict) else {} normalized_translation = build_translation_config(translation_raw) + local_translation_backends = {"local_nllb", "local_marian"} + for capability_name, capability_cfg in normalized_translation["capabilities"].items(): + backend_name = str(capability_cfg.get("backend") or "").strip().lower() + if backend_name not in local_translation_backends: + continue + for path_key in ("model_dir", "ct2_model_dir"): + if capability_cfg.get(path_key) not in (None, ""): + capability_cfg[path_key] = str(self._resolve_project_path_value(capability_cfg[path_key]).resolve()) translation_config = TranslationServiceConfig( endpoint=str(normalized_translation["service_url"]).rstrip("/"), timeout_sec=float(normalized_translation["timeout_sec"]), @@ -749,7 +757,7 @@ class AppConfigLoader: port=port, backend=backend_name, runtime_dir=( - str(v) + str(self._resolve_project_path_value(v).resolve()) if (v := instance_raw.get("runtime_dir")) not in (None, "") else None ), @@ -787,6 +795,12 @@ class AppConfigLoader: rerank=rerank_config, ) + def _resolve_project_path_value(self, value: Any) -> Path: + candidate = Path(str(value)).expanduser() + if candidate.is_absolute(): + return candidate + return self.project_root / candidate + def _build_tenants_config(self, raw: Dict[str, Any]) -> TenantCatalogConfig: if not isinstance(raw, dict): raise ConfigurationError("tenant_config must be a mapping") diff --git a/frontend/static/js/app.js b/frontend/static/js/app.js index 435780b..ea036ba 100644 --- a/frontend/static/js/app.js +++ b/frontend/static/js/app.js @@ -316,7 +316,10 @@ async function performSearch(page = 1) { document.getElementById('productGrid').innerHTML = ''; try { - const response = await fetch(`${API_BASE_URL}/search/`, { + const searchUrl = new URL(`${API_BASE_URL}/search/`, window.location.origin); + searchUrl.searchParams.set('tenant_id', tenantId); + + const response = await fetch(searchUrl.toString(), { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/requirements_translator_service.txt b/requirements_translator_service.txt index e8b8f18..d944e6c 100644 --- a/requirements_translator_service.txt +++ b/requirements_translator_service.txt @@ -13,7 +13,8 @@ httpx>=0.24.0 tqdm>=4.65.0 torch>=2.0.0 -transformers>=4.30.0 +# Keep translator conversions on the last verified NLLB-compatible release line. +transformers>=4.51.0,<4.52.0 ctranslate2>=4.7.0 sentencepiece>=0.2.0 sacremoses>=0.1.1 diff --git a/scripts/download_translation_models.py b/scripts/download_translation_models.py new file mode 100644 index 0000000..0b67f40 --- /dev/null +++ b/scripts/download_translation_models.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 +"""Backward-compatible entrypoint for translation model downloads.""" + +from __future__ import annotations + +import runpy +from pathlib import Path + + +if __name__ == "__main__": + target = Path(__file__).resolve().parent / "translation" / "download_translation_models.py" + runpy.run_path(str(target), run_name="__main__") diff --git a/scripts/frontend/frontend_server.py b/scripts/frontend/frontend_server.py index 15231ca..0d30342 100755 --- a/scripts/frontend/frontend_server.py +++ b/scripts/frontend/frontend_server.py @@ -60,6 +60,8 @@ class RateLimitingMixin: class MyHTTPRequestHandler(http.server.SimpleHTTPRequestHandler, RateLimitingMixin): """Custom request handler with CORS support and robust error handling.""" + _ALLOWED_CORS_HEADERS = "Content-Type, X-Tenant-ID, X-Request-ID, Referer" + def _is_proxy_path(self, path: str) -> bool: """Return True for API paths that should be forwarded to backend service.""" return path.startswith('/search/') or path.startswith('/admin/') or path.startswith('/indexer/') @@ -220,7 +222,7 @@ class MyHTTPRequestHandler(http.server.SimpleHTTPRequestHandler, RateLimitingMix # Add CORS headers self.send_header('Access-Control-Allow-Origin', '*') self.send_header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS') - self.send_header('Access-Control-Allow-Headers', 'Content-Type') + self.send_header('Access-Control-Allow-Headers', self._ALLOWED_CORS_HEADERS) # Add security headers self.send_header('X-Content-Type-Options', 'nosniff') self.send_header('X-Frame-Options', 'DENY') diff --git a/scripts/setup_translator_venv.sh b/scripts/setup_translator_venv.sh index a17abe0..9cf32d6 100755 --- a/scripts/setup_translator_venv.sh +++ b/scripts/setup_translator_venv.sh @@ -8,8 +8,47 @@ PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)" cd "${PROJECT_ROOT}" VENV_DIR="${PROJECT_ROOT}/.venv-translator" -PYTHON_BIN="${PYTHON_BIN:-python3}" TMP_DIR="${TRANSLATOR_PIP_TMPDIR:-${PROJECT_ROOT}/.tmp/translator-pip}" +MIN_PYTHON_MAJOR=3 +MIN_PYTHON_MINOR=10 + +python_meets_minimum() { + local bin="$1" + "${bin}" - <<'PY' "${MIN_PYTHON_MAJOR}" "${MIN_PYTHON_MINOR}" +import sys + +required = tuple(int(value) for value in sys.argv[1:]) +sys.exit(0 if sys.version_info[:2] >= required else 1) +PY +} + +discover_python_bin() { + local candidates=() + + if [[ -n "${PYTHON_BIN:-}" ]]; then + candidates+=("${PYTHON_BIN}") + fi + candidates+=("python3.12" "python3.11" "python3.10" "python3") + + local candidate + for candidate in "${candidates[@]}"; do + if ! command -v "${candidate}" >/dev/null 2>&1; then + continue + fi + if python_meets_minimum "${candidate}"; then + echo "${candidate}" + return 0 + fi + done + + return 1 +} + +if ! PYTHON_BIN="$(discover_python_bin)"; then + echo "ERROR: unable to find Python >= ${MIN_PYTHON_MAJOR}.${MIN_PYTHON_MINOR}." >&2 + echo "Set PYTHON_BIN to a compatible interpreter and rerun." >&2 + exit 1 +fi if ! command -v "${PYTHON_BIN}" >/dev/null 2>&1; then echo "ERROR: python not found: ${PYTHON_BIN}" >&2 @@ -32,6 +71,7 @@ mkdir -p "${TMP_DIR}" export TMPDIR="${TMP_DIR}" PIP_ARGS=(--no-cache-dir) +echo "Using Python=${PYTHON_BIN}" echo "Using TMPDIR=${TMPDIR}" "${VENV_DIR}/bin/python" -m pip install "${PIP_ARGS[@]}" --upgrade pip wheel "${VENV_DIR}/bin/python" -m pip install "${PIP_ARGS[@]}" -r requirements_translator_service.txt diff --git a/scripts/translation/download_translation_models.py b/scripts/translation/download_translation_models.py index 527159f..59f5602 100755 --- a/scripts/translation/download_translation_models.py +++ b/scripts/translation/download_translation_models.py @@ -6,8 +6,6 @@ from __future__ import annotations import argparse import os from pathlib import Path -import shutil -import subprocess import sys from typing import Iterable @@ -19,6 +17,7 @@ if str(PROJECT_ROOT) not in sys.path: os.environ.setdefault("HF_HUB_DISABLE_XET", "1") from config.services_config import get_translation_config +from translation.ct2_conversion import convert_transformers_model LOCAL_BACKENDS = {"local_nllb", "local_marian"} @@ -46,19 +45,6 @@ def _compute_ct2_output_dir(capability: dict) -> Path: return model_dir / f"ctranslate2-{normalized}" -def _resolve_converter_binary() -> str: - candidate = shutil.which("ct2-transformers-converter") - if candidate: - return candidate - venv_candidate = Path(sys.executable).absolute().parent / "ct2-transformers-converter" - if venv_candidate.exists(): - return str(venv_candidate) - raise RuntimeError( - "ct2-transformers-converter was not found. " - "Install ctranslate2 in the active Python environment first." - ) - - def convert_to_ctranslate2(name: str, capability: dict) -> None: model_id = str(capability.get("model_id") or "").strip() model_dir = Path(str(capability.get("model_dir") or "")).expanduser() @@ -75,18 +61,7 @@ def convert_to_ctranslate2(name: str, capability: dict) -> None: ).strip() output_dir.parent.mkdir(parents=True, exist_ok=True) print(f"[convert] {name} -> {output_dir} ({quantization})") - subprocess.run( - [ - _resolve_converter_binary(), - "--model", - model_source, - "--output_dir", - str(output_dir), - "--quantization", - quantization, - ], - check=True, - ) + convert_transformers_model(model_source, str(output_dir), quantization) print(f"[converted] {name}") diff --git a/tests/test_translation_converter_resolution.py b/tests/test_translation_converter_resolution.py new file mode 100644 index 0000000..4213e6a --- /dev/null +++ b/tests/test_translation_converter_resolution.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import sys +import types + +import pytest + +import translation.ct2_conversion as ct2_conversion + + +class _FakeTransformersConverter: + def __init__(self, model_name_or_path): + self.model_name_or_path = model_name_or_path + self.load_calls = [] + + def load_model(self, model_class, resolved_model_name_or_path, **kwargs): + self.load_calls.append( + { + "model_class": model_class, + "resolved_model_name_or_path": resolved_model_name_or_path, + "kwargs": dict(kwargs), + } + ) + if "dtype" in kwargs or "torch_dtype" in kwargs: + raise TypeError("M2M100ForConditionalGeneration.__init__() got an unexpected keyword argument 'dtype'") + return {"loaded": True, "path": resolved_model_name_or_path} + + def convert(self, output_dir, quantization=None, force=False): + loaded = self.load_model("FakeModel", self.model_name_or_path, dtype="float32") + return { + "loaded": loaded, + "output_dir": output_dir, + "quantization": quantization, + "force": force, + "load_calls": list(self.load_calls), + } + + +def _install_fake_ctranslate2(monkeypatch, base_converter): + converters_module = types.ModuleType("ctranslate2.converters") + converters_module.TransformersConverter = base_converter + ctranslate2_module = types.ModuleType("ctranslate2") + ctranslate2_module.converters = converters_module + + monkeypatch.setitem(sys.modules, "ctranslate2", ctranslate2_module) + monkeypatch.setitem(sys.modules, "ctranslate2.converters", converters_module) + + +def test_convert_transformers_model_retries_without_torch_dtype(monkeypatch): + _install_fake_ctranslate2(monkeypatch, _FakeTransformersConverter) + fake_transformers = types.ModuleType("transformers") + fake_transformers.AutoConfig = types.SimpleNamespace( + from_pretrained=lambda path: types.SimpleNamespace(torch_dtype="float32", path=path) + ) + monkeypatch.setitem(sys.modules, "transformers", fake_transformers) + + result = ct2_conversion.convert_transformers_model("fake-model", "/tmp/out", "float16") + + assert result["loaded"] == {"loaded": True, "path": "fake-model"} + assert result["output_dir"] == "/tmp/out" + assert result["quantization"] == "float16" + assert result["force"] is False + assert len(result["load_calls"]) == 2 + assert result["load_calls"][0] == { + "model_class": "FakeModel", + "resolved_model_name_or_path": "fake-model", + "kwargs": {"dtype": "float32"}, + } + assert result["load_calls"][1]["model_class"] == "FakeModel" + assert result["load_calls"][1]["resolved_model_name_or_path"] == "fake-model" + assert getattr(result["load_calls"][1]["kwargs"]["config"], "torch_dtype", "missing") is None + + +def test_convert_transformers_model_preserves_unrelated_type_errors(monkeypatch): + class _AlwaysFailingConverter(_FakeTransformersConverter): + def load_model(self, model_class, resolved_model_name_or_path, **kwargs): + raise TypeError("different constructor error") + + _install_fake_ctranslate2(monkeypatch, _AlwaysFailingConverter) + fake_transformers = types.ModuleType("transformers") + fake_transformers.AutoConfig = types.SimpleNamespace(from_pretrained=lambda path: types.SimpleNamespace(path=path)) + monkeypatch.setitem(sys.modules, "transformers", fake_transformers) + + with pytest.raises(TypeError, match="different constructor error"): + ct2_conversion.convert_transformers_model("fake-model", "/tmp/out", "float16") diff --git a/tests/test_translation_local_backends.py b/tests/test_translation_local_backends.py index d3999c2..1378ca2 100644 --- a/tests/test_translation_local_backends.py +++ b/tests/test_translation_local_backends.py @@ -201,6 +201,51 @@ def test_nllb_ctranslate2_accepts_finnish_short_code(monkeypatch): assert backend.translator.last_translate_batch_kwargs["target_prefix"] == [["zho_Hans"]] +def test_nllb_ctranslate2_falls_back_to_model_id_when_local_dir_is_wrong_type(tmp_path, monkeypatch): + wrong_dir = tmp_path / "wrong-nllb" + wrong_dir.mkdir() + (wrong_dir / "config.json").write_text('{"model_type":"led"}', encoding="utf-8") + + monkeypatch.setattr(NLLBCTranslate2TranslationBackend, "_load_runtime", _stub_load_ct2_runtime) + + backend = NLLBCTranslate2TranslationBackend( + name="nllb-200-distilled-600m", + model_id="facebook/nllb-200-distilled-600M", + model_dir=str(wrong_dir), + device="cpu", + torch_dtype="float32", + batch_size=1, + max_input_length=16, + max_new_tokens=16, + num_beams=1, + ) + + assert backend._model_source() == "facebook/nllb-200-distilled-600M" + assert backend._tokenizer_source() == "facebook/nllb-200-distilled-600M" + + +def test_nllb_ctranslate2_falls_back_to_model_id_when_local_dir_is_incomplete(tmp_path, monkeypatch): + incomplete_dir = tmp_path / "incomplete-nllb" + incomplete_dir.mkdir() + (incomplete_dir / "ctranslate2-float16").mkdir() + + monkeypatch.setattr(NLLBCTranslate2TranslationBackend, "_load_runtime", _stub_load_ct2_runtime) + + backend = NLLBCTranslate2TranslationBackend( + name="nllb-200-distilled-600m", + model_id="facebook/nllb-200-distilled-600M", + model_dir=str(incomplete_dir), + device="cpu", + torch_dtype="float32", + batch_size=1, + max_input_length=16, + max_new_tokens=16, + num_beams=1, + ) + + assert backend._model_source() == "facebook/nllb-200-distilled-600M" + + def test_nllb_resolves_flores_short_tags_and_iso_no(): cat = build_nllb_language_catalog(None) assert resolve_nllb_language_code("ca", cat) == "cat_Latn" diff --git a/tests/test_translator_failure_semantics.py b/tests/test_translator_failure_semantics.py index bfa924e..997a7e7 100644 --- a/tests/test_translator_failure_semantics.py +++ b/tests/test_translator_failure_semantics.py @@ -197,6 +197,73 @@ def test_translation_route_log_focuses_on_routing_decision(monkeypatch, caplog): ] +def test_service_skips_failed_backend_but_keeps_healthy_capabilities(monkeypatch): + monkeypatch.setattr(TranslationCache, "_init_redis_client", staticmethod(lambda: None)) + + def _fake_create_backend(self, *, name, backend_type, cfg): + del self, backend_type, cfg + if name == "broken-nllb": + raise RuntimeError("broken model dir") + + class _Backend: + model = name + + @property + def supports_batch(self): + return True + + def translate(self, text, target_lang, source_lang=None, scene=None): + del target_lang, source_lang, scene + return text + + return _Backend() + + monkeypatch.setattr(TranslationService, "_create_backend", _fake_create_backend) + service = TranslationService( + { + "service_url": "http://127.0.0.1:6006", + "timeout_sec": 10.0, + "default_model": "llm", + "default_scene": "general", + "capabilities": { + "llm": { + "enabled": True, + "backend": "llm", + "model": "dummy-llm", + "base_url": "https://example.com", + "timeout_sec": 10.0, + "use_cache": True, + }, + "broken-nllb": { + "enabled": True, + "backend": "local_nllb", + "model_id": "dummy", + "model_dir": "dummy", + "device": "cpu", + "torch_dtype": "float32", + "batch_size": 8, + "max_input_length": 16, + "max_new_tokens": 16, + "num_beams": 1, + "use_cache": True, + }, + }, + "cache": { + "ttl_seconds": 60, + "sliding_expiration": True, + }, + } + ) + + assert service.available_models == ["llm", "broken-nllb"] + assert service.loaded_models == ["llm"] + assert service.failed_models == ["broken-nllb"] + assert service.backend_errors["broken-nllb"] == "broken model dir" + + with pytest.raises(RuntimeError, match="failed to initialize"): + service.get_backend("broken-nllb") + + def test_translation_cache_probe_models_order(): cfg = {"cache": {"model_quality_tiers": {"low": 10, "high": 50, "mid": 30}}} assert translation_cache_probe_models(cfg, "low") == ["high", "mid", "low"] diff --git a/translation/backends/local_ctranslate2.py b/translation/backends/local_ctranslate2.py index 58de075..06bc61d 100644 --- a/translation/backends/local_ctranslate2.py +++ b/translation/backends/local_ctranslate2.py @@ -4,9 +4,7 @@ from __future__ import annotations import logging import os -import shutil -import subprocess -import sys +import json import threading from pathlib import Path from typing import Dict, List, Optional, Sequence, Union @@ -24,6 +22,7 @@ from translation.text_splitter import ( join_translated_segments, split_text_for_translation, ) +from translation.ct2_conversion import convert_transformers_model logger = logging.getLogger(__name__) @@ -76,17 +75,18 @@ def _derive_ct2_model_dir(model_dir: str, compute_type: str) -> str: return str(Path(model_dir).expanduser() / f"ctranslate2-{normalized}") -def _resolve_converter_binary() -> str: - candidate = shutil.which("ct2-transformers-converter") - if candidate: - return candidate - venv_candidate = Path(sys.executable).absolute().parent / "ct2-transformers-converter" - if venv_candidate.exists(): - return str(venv_candidate) - raise RuntimeError( - "ct2-transformers-converter was not found. " - "Ensure ctranslate2 is installed in the active translator environment." - ) +def _detect_local_model_type(model_dir: str) -> Optional[str]: + config_path = Path(model_dir).expanduser() / "config.json" + if not config_path.exists(): + return None + try: + with open(config_path, "r", encoding="utf-8") as handle: + payload = json.load(handle) or {} + except Exception as exc: + logger.warning("Failed to inspect local translation config %s: %s", config_path, exc) + return None + model_type = str(payload.get("model_type") or "").strip().lower() + return model_type or None class LocalCTranslate2TranslationBackend: @@ -144,6 +144,7 @@ class LocalCTranslate2TranslationBackend: self.ct2_decoding_length_extra = int(ct2_decoding_length_extra) self.ct2_decoding_length_min = max(1, int(ct2_decoding_length_min)) self._tokenizer_lock = threading.Lock() + self._local_model_source = self._resolve_local_model_source() self._load_runtime() @property @@ -151,10 +152,44 @@ class LocalCTranslate2TranslationBackend: return True def _tokenizer_source(self) -> str: - return self.model_dir if os.path.exists(self.model_dir) else self.model_id + return self._local_model_source or self.model_id def _model_source(self) -> str: - return self.model_dir if os.path.exists(self.model_dir) else self.model_id + return self._local_model_source or self.model_id + + def _expected_local_model_types(self) -> Optional[set[str]]: + return None + + def _resolve_local_model_source(self) -> Optional[str]: + model_path = Path(self.model_dir).expanduser() + if not model_path.exists(): + return None + if not (model_path / "config.json").exists(): + logger.warning( + "Local translation model_dir is incomplete | model=%s model_dir=%s missing=config.json fallback=model_id", + self.model, + model_path, + ) + return None + + expected_types = self._expected_local_model_types() + if not expected_types: + return str(model_path) + + detected_type = _detect_local_model_type(str(model_path)) + if detected_type is None: + return str(model_path) + if detected_type in expected_types: + return str(model_path) + + logger.warning( + "Local translation model_dir has unexpected model_type | model=%s model_dir=%s detected=%s expected=%s fallback=model_id", + self.model, + model_path, + detected_type, + sorted(expected_types), + ) + return None def _tokenizer_kwargs(self) -> Dict[str, object]: return {} @@ -204,7 +239,6 @@ class LocalCTranslate2TranslationBackend: ) ct2_path.parent.mkdir(parents=True, exist_ok=True) - converter = _resolve_converter_binary() logger.info( "Converting translation model to CTranslate2 | name=%s source=%s output=%s quantization=%s", self.model, @@ -213,25 +247,14 @@ class LocalCTranslate2TranslationBackend: self.ct2_conversion_quantization, ) try: - subprocess.run( - [ - converter, - "--model", - model_source, - "--output_dir", - str(ct2_path), - "--quantization", - self.ct2_conversion_quantization, - ], - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, + convert_transformers_model( + model_source, + str(ct2_path), + self.ct2_conversion_quantization, ) - except subprocess.CalledProcessError as exc: - stderr = exc.stderr.strip() + except Exception as exc: raise RuntimeError( - f"Failed to convert model '{self.model}' to CTranslate2: {stderr or exc}" + f"Failed to convert model '{self.model}' to CTranslate2: {exc}" ) from exc def _normalize_texts(self, text: Union[str, Sequence[str]]) -> List[str]: @@ -557,6 +580,9 @@ class MarianCTranslate2TranslationBackend(LocalCTranslate2TranslationBackend): f"Model '{self.model}' only supports target languages: {sorted(self.target_langs)}" ) + def _expected_local_model_types(self) -> Optional[set[str]]: + return {"marian"} + class NLLBCTranslate2TranslationBackend(LocalCTranslate2TranslationBackend): """Local backend for NLLB models on CTranslate2.""" @@ -619,6 +645,9 @@ class NLLBCTranslate2TranslationBackend(LocalCTranslate2TranslationBackend): if resolve_nllb_language_code(target_lang, self.language_codes) is None: raise ValueError(f"Unsupported NLLB target language: {target_lang}") + def _expected_local_model_types(self) -> Optional[set[str]]: + return {"m2m_100", "nllb_moe"} + def _get_tokenizer_for_source(self, source_lang: str): src_code = resolve_nllb_language_code(source_lang, self.language_codes) if src_code is None: diff --git a/translation/ct2_conversion.py b/translation/ct2_conversion.py new file mode 100644 index 0000000..63e728b --- /dev/null +++ b/translation/ct2_conversion.py @@ -0,0 +1,52 @@ +"""Helpers for converting Hugging Face translation models to CTranslate2.""" + +from __future__ import annotations + +import copy +import logging + +logger = logging.getLogger(__name__) + + +def convert_transformers_model( + model_name_or_path: str, + output_dir: str, + quantization: str, + *, + force: bool = False, +) -> str: + from ctranslate2.converters import TransformersConverter + from transformers import AutoConfig + + class _CompatibleTransformersConverter(TransformersConverter): + def load_model(self, model_class, resolved_model_name_or_path, **kwargs): + try: + return super().load_model(model_class, resolved_model_name_or_path, **kwargs) + except TypeError as exc: + if "unexpected keyword argument 'dtype'" not in str(exc): + raise + if kwargs.get("dtype") is None and kwargs.get("torch_dtype") is None: + raise + + logger.warning( + "Retrying CTranslate2 model load without dtype hints | model=%s class=%s", + resolved_model_name_or_path, + getattr(model_class, "__name__", model_class), + ) + retry_kwargs = dict(kwargs) + retry_kwargs.pop("dtype", None) + retry_kwargs.pop("torch_dtype", None) + config = retry_kwargs.get("config") + if config is None: + config = AutoConfig.from_pretrained(resolved_model_name_or_path) + else: + config = copy.deepcopy(config) + if hasattr(config, "dtype"): + config.dtype = None + if hasattr(config, "torch_dtype"): + config.torch_dtype = None + retry_kwargs["config"] = config + return super().load_model(model_class, resolved_model_name_or_path, **retry_kwargs) + + converter = _CompatibleTransformersConverter(model_name_or_path) + return converter.convert(output_dir=output_dir, quantization=quantization, force=force) diff --git a/translation/service.py b/translation/service.py index 354b558..e070aef 100644 --- a/translation/service.py +++ b/translation/service.py @@ -31,7 +31,12 @@ class TranslationService: if not self._enabled_capabilities: raise ValueError("No enabled translation backends found in services.translation.capabilities") self._translation_cache = TranslationCache(self.config["cache"]) - self._backends = self._initialize_backends() + self._backends: Dict[str, TranslationBackendProtocol] = {} + self._backend_errors: Dict[str, str] = {} + self._initialize_backends() + if not self._backends: + details = ", ".join(f"{name}: {err}" for name, err in sorted(self._backend_errors.items())) or "unknown error" + raise RuntimeError(f"No translation backends could be initialized: {details}") def _collect_enabled_capabilities(self) -> Dict[str, Dict[str, object]]: enabled: Dict[str, Dict[str, object]] = {} @@ -62,24 +67,47 @@ class TranslationService: raise ValueError(f"Unsupported translation backend '{backend_type}' for capability '{name}'") return factory(name=name, cfg=cfg) - def _initialize_backends(self) -> Dict[str, TranslationBackendProtocol]: - backends: Dict[str, TranslationBackendProtocol] = {} - for name, capability_cfg in self._enabled_capabilities.items(): - backend_type = str(capability_cfg["backend"]) - logger.info("Initializing translation backend | model=%s backend=%s", name, backend_type) - backends[name] = self._create_backend( + def _load_backend(self, name: str) -> Optional[TranslationBackendProtocol]: + capability_cfg = self._enabled_capabilities.get(name) + if capability_cfg is None: + return None + if name in self._backends: + return self._backends[name] + + backend_type = str(capability_cfg["backend"]) + logger.info("Initializing translation backend | model=%s backend=%s", name, backend_type) + try: + backend = self._create_backend( name=name, backend_type=backend_type, cfg=capability_cfg, ) - logger.info( - "Translation backend initialized | model=%s backend=%s use_cache=%s backend_model=%s", + except Exception as exc: + error_text = str(exc).strip() or exc.__class__.__name__ + self._backend_errors[name] = error_text + logger.error( + "Translation backend initialization failed | model=%s backend=%s error=%s", name, backend_type, - bool(capability_cfg.get("use_cache")), - getattr(backends[name], "model", name), + error_text, + exc_info=True, ) - return backends + return None + + self._backends[name] = backend + self._backend_errors.pop(name, None) + logger.info( + "Translation backend initialized | model=%s backend=%s use_cache=%s backend_model=%s", + name, + backend_type, + bool(capability_cfg.get("use_cache")), + getattr(backend, "model", name), + ) + return backend + + def _initialize_backends(self) -> None: + for name, capability_cfg in self._enabled_capabilities.items(): + self._load_backend(name) def _create_qwen_mt_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol: from translation.backends.qwen_mt import QwenMTTranslationBackend @@ -178,13 +206,27 @@ class TranslationService: def loaded_models(self) -> List[str]: return list(self._backends.keys()) + @property + def failed_models(self) -> List[str]: + return list(self._backend_errors.keys()) + + @property + def backend_errors(self) -> Dict[str, str]: + return dict(self._backend_errors) + def get_backend(self, model: Optional[str] = None) -> TranslationBackendProtocol: normalized = normalize_translation_model(self.config, model) - backend = self._backends.get(normalized) + backend = self._backends.get(normalized) or self._load_backend(normalized) if backend is None: - raise ValueError( - f"Translation model '{normalized}' is not enabled. " - f"Available models: {', '.join(self.available_models) or 'none'}" + if normalized not in self._enabled_capabilities: + raise ValueError( + f"Translation model '{normalized}' is not enabled. " + f"Available models: {', '.join(self.available_models) or 'none'}" + ) + error_text = self._backend_errors.get(normalized) or "unknown initialization error" + raise RuntimeError( + f"Translation model '{normalized}' failed to initialize: {error_text}. " + f"Loaded models: {', '.join(self.loaded_models) or 'none'}" ) return backend -- libgit2 0.21.2