from __future__ import annotations import sys import types import pytest import translation.ct2_conversion as ct2_conversion class _FakeTransformersConverter: def __init__(self, model_name_or_path): self.model_name_or_path = model_name_or_path self.load_calls = [] def load_model(self, model_class, resolved_model_name_or_path, **kwargs): self.load_calls.append( { "model_class": model_class, "resolved_model_name_or_path": resolved_model_name_or_path, "kwargs": dict(kwargs), } ) if "dtype" in kwargs or "torch_dtype" in kwargs: raise TypeError("M2M100ForConditionalGeneration.__init__() got an unexpected keyword argument 'dtype'") return {"loaded": True, "path": resolved_model_name_or_path} def convert(self, output_dir, quantization=None, force=False): loaded = self.load_model("FakeModel", self.model_name_or_path, dtype="float32") return { "loaded": loaded, "output_dir": output_dir, "quantization": quantization, "force": force, "load_calls": list(self.load_calls), } def _install_fake_ctranslate2(monkeypatch, base_converter): converters_module = types.ModuleType("ctranslate2.converters") converters_module.TransformersConverter = base_converter ctranslate2_module = types.ModuleType("ctranslate2") ctranslate2_module.converters = converters_module monkeypatch.setitem(sys.modules, "ctranslate2", ctranslate2_module) monkeypatch.setitem(sys.modules, "ctranslate2.converters", converters_module) def test_convert_transformers_model_retries_without_torch_dtype(monkeypatch): _install_fake_ctranslate2(monkeypatch, _FakeTransformersConverter) fake_transformers = types.ModuleType("transformers") fake_transformers.AutoConfig = types.SimpleNamespace( from_pretrained=lambda path: types.SimpleNamespace(torch_dtype="float32", path=path) ) monkeypatch.setitem(sys.modules, "transformers", fake_transformers) result = ct2_conversion.convert_transformers_model("fake-model", "/tmp/out", "float16") assert result["loaded"] == {"loaded": True, "path": "fake-model"} assert result["output_dir"] == "/tmp/out" assert result["quantization"] == "float16" assert result["force"] is False assert len(result["load_calls"]) == 2 assert result["load_calls"][0] == { "model_class": "FakeModel", "resolved_model_name_or_path": "fake-model", "kwargs": {"dtype": "float32"}, } assert result["load_calls"][1]["model_class"] == "FakeModel" assert result["load_calls"][1]["resolved_model_name_or_path"] == "fake-model" assert getattr(result["load_calls"][1]["kwargs"]["config"], "torch_dtype", "missing") is None def test_convert_transformers_model_preserves_unrelated_type_errors(monkeypatch): class _AlwaysFailingConverter(_FakeTransformersConverter): def load_model(self, model_class, resolved_model_name_or_path, **kwargs): raise TypeError("different constructor error") _install_fake_ctranslate2(monkeypatch, _AlwaysFailingConverter) fake_transformers = types.ModuleType("transformers") fake_transformers.AutoConfig = types.SimpleNamespace(from_pretrained=lambda path: types.SimpleNamespace(path=path)) monkeypatch.setitem(sys.modules, "transformers", fake_transformers) with pytest.raises(TypeError, match="different constructor error"): ct2_conversion.convert_transformers_model("fake-model", "/tmp/out", "float16")