diff --git a/tests/test_translation_local_backends.py b/tests/test_translation_local_backends.py index d48edca..d3999c2 100644 --- a/tests/test_translation_local_backends.py +++ b/tests/test_translation_local_backends.py @@ -5,6 +5,7 @@ import torch from translation.backends.local_seq2seq import MarianMTTranslationBackend, NLLBTranslationBackend from translation.backends.local_ctranslate2 import NLLBCTranslate2TranslationBackend +from translation.languages import build_nllb_language_catalog, resolve_nllb_language_code from translation.service import TranslationService from translation.text_splitter import compute_safe_input_token_limit, split_text_for_translation @@ -200,6 +201,22 @@ def test_nllb_ctranslate2_accepts_finnish_short_code(monkeypatch): assert backend.translator.last_translate_batch_kwargs["target_prefix"] == [["zho_Hans"]] +def test_nllb_resolves_flores_short_tags_and_iso_no(): + cat = build_nllb_language_catalog(None) + assert resolve_nllb_language_code("ca", cat) == "cat_Latn" + assert resolve_nllb_language_code("da", cat) == "dan_Latn" + assert resolve_nllb_language_code("eu", cat) == "eus_Latn" + assert resolve_nllb_language_code("gl", cat) == "glg_Latn" + assert resolve_nllb_language_code("hu", cat) == "hun_Latn" + assert resolve_nllb_language_code("id", cat) == "ind_Latn" + assert resolve_nllb_language_code("nl", cat) == "nld_Latn" + assert resolve_nllb_language_code("no", cat) == "nob_Latn" + assert resolve_nllb_language_code("ro", cat) == "ron_Latn" + assert resolve_nllb_language_code("SV", cat) == "swe_Latn" + assert resolve_nllb_language_code("tr", cat) == "tur_Latn" + assert resolve_nllb_language_code("deu_Latn", cat) == "deu_Latn" + + def test_translation_service_preloads_enabled_backends(monkeypatch): created = [] diff --git a/translation/languages.py b/translation/languages.py index 815aff0..b7256de 100644 --- a/translation/languages.py +++ b/translation/languages.py @@ -2,8 +2,14 @@ from __future__ import annotations +from functools import lru_cache from typing import Dict, Mapping, Optional, Tuple +from translation.nllb_flores_short_map import ( + NLLB_FLORES_SHORT_TO_CODE, + NLLB_TOKENIZER_LANGUAGE_CODES, +) + LANGUAGE_LABELS: Dict[str, str] = { "zh": "Chinese", @@ -48,6 +54,8 @@ DEEPL_LANGUAGE_CODES: Dict[str, str] = { } +# Sparse overrides on top of ``NLLB_FLORES_SHORT_TO_CODE`` (same keys win later in +# ``build_nllb_language_catalog``). Kept for backward compatibility and explicit defaults. NLLB_LANGUAGE_CODES: Dict[str, str] = { "en": "eng_Latn", "fi": "fin_Latn", @@ -82,14 +90,24 @@ def normalize_language_key(language: Optional[str]) -> str: return str(language or "").strip().lower().replace("-", "_") +@lru_cache(maxsize=1) +def _nllb_tokenizer_code_by_normalized_key() -> Dict[str, str]: + """Map lowercased ``deu_latn``-style keys to canonical tokenizer strings (e.g. ``deu_Latn``).""" + return {normalize_language_key(code): code for code in NLLB_TOKENIZER_LANGUAGE_CODES} + + def build_nllb_language_catalog( overrides: Optional[Mapping[str, str]] = None, ) -> Dict[str, str]: - catalog = { - normalize_language_key(key): str(value).strip() - for key, value in NLLB_LANGUAGE_CODES.items() - if str(key).strip() - } + catalog: Dict[str, str] = {} + for key, value in NLLB_FLORES_SHORT_TO_CODE.items(): + normalized_key = normalize_language_key(key) + if normalized_key: + catalog[normalized_key] = str(value).strip() + for key, value in NLLB_LANGUAGE_CODES.items(): + normalized_key = normalize_language_key(key) + if normalized_key: + catalog[normalized_key] = str(value).strip() for key, value in (overrides or {}).items(): normalized_key = normalize_language_key(key) if normalized_key: @@ -116,6 +134,10 @@ def resolve_nllb_language_code( if aliased is not None: return aliased + tokenizer_hit = _nllb_tokenizer_code_by_normalized_key().get(normalized) + if tokenizer_hit is not None: + return tokenizer_hit + for code in catalog.values(): if normalize_language_key(code) == normalized: return code diff --git a/translation/nllb_flores_short_map.py b/translation/nllb_flores_short_map.py new file mode 100644 index 0000000..356049b --- /dev/null +++ b/translation/nllb_flores_short_map.py @@ -0,0 +1,416 @@ +"""FLORES short language tags and canonical NLLB tokenizer codes. + +``NLLB_FLORES_SHORT_TO_CODE`` maps model-card short tags (ISO 639-1 / FLORES ids) +to NLLB ``src_lang`` tokens: ``_