test_translation_converter_resolution.py 3.54 KB
from __future__ import annotations

import sys
import types

import pytest

import translation.ct2_conversion as ct2_conversion


class _FakeTransformersConverter:
    def __init__(self, model_name_or_path):
        self.model_name_or_path = model_name_or_path
        self.load_calls = []

    def load_model(self, model_class, resolved_model_name_or_path, **kwargs):
        self.load_calls.append(
            {
                "model_class": model_class,
                "resolved_model_name_or_path": resolved_model_name_or_path,
                "kwargs": dict(kwargs),
            }
        )
        if "dtype" in kwargs or "torch_dtype" in kwargs:
            raise TypeError("M2M100ForConditionalGeneration.__init__() got an unexpected keyword argument 'dtype'")
        return {"loaded": True, "path": resolved_model_name_or_path}

    def convert(self, output_dir, quantization=None, force=False):
        loaded = self.load_model("FakeModel", self.model_name_or_path, dtype="float32")
        return {
            "loaded": loaded,
            "output_dir": output_dir,
            "quantization": quantization,
            "force": force,
            "load_calls": list(self.load_calls),
        }


def _install_fake_ctranslate2(monkeypatch, base_converter):
    converters_module = types.ModuleType("ctranslate2.converters")
    converters_module.TransformersConverter = base_converter
    ctranslate2_module = types.ModuleType("ctranslate2")
    ctranslate2_module.converters = converters_module

    monkeypatch.setitem(sys.modules, "ctranslate2", ctranslate2_module)
    monkeypatch.setitem(sys.modules, "ctranslate2.converters", converters_module)


def test_convert_transformers_model_retries_without_torch_dtype(monkeypatch):
    _install_fake_ctranslate2(monkeypatch, _FakeTransformersConverter)
    fake_transformers = types.ModuleType("transformers")
    fake_transformers.AutoConfig = types.SimpleNamespace(
        from_pretrained=lambda path: types.SimpleNamespace(torch_dtype="float32", path=path)
    )
    monkeypatch.setitem(sys.modules, "transformers", fake_transformers)

    result = ct2_conversion.convert_transformers_model("fake-model", "/tmp/out", "float16")

    assert result["loaded"] == {"loaded": True, "path": "fake-model"}
    assert result["output_dir"] == "/tmp/out"
    assert result["quantization"] == "float16"
    assert result["force"] is False
    assert len(result["load_calls"]) == 2
    assert result["load_calls"][0] == {
        "model_class": "FakeModel",
        "resolved_model_name_or_path": "fake-model",
        "kwargs": {"dtype": "float32"},
    }
    assert result["load_calls"][1]["model_class"] == "FakeModel"
    assert result["load_calls"][1]["resolved_model_name_or_path"] == "fake-model"
    assert getattr(result["load_calls"][1]["kwargs"]["config"], "torch_dtype", "missing") is None


def test_convert_transformers_model_preserves_unrelated_type_errors(monkeypatch):
    class _AlwaysFailingConverter(_FakeTransformersConverter):
        def load_model(self, model_class, resolved_model_name_or_path, **kwargs):
            raise TypeError("different constructor error")

    _install_fake_ctranslate2(monkeypatch, _AlwaysFailingConverter)
    fake_transformers = types.ModuleType("transformers")
    fake_transformers.AutoConfig = types.SimpleNamespace(from_pretrained=lambda path: types.SimpleNamespace(path=path))
    monkeypatch.setitem(sys.modules, "transformers", fake_transformers)

    with pytest.raises(TypeError, match="different constructor error"):
        ct2_conversion.convert_transformers_model("fake-model", "/tmp/out", "float16")