950a640e
tangwang
embeddings
|
1
2
3
|
from typing import Any, Dict, List, Optional
import numpy as np
|
ed948666
tangwang
tidy
|
4
|
import pytest
|
950a640e
tangwang
embeddings
|
5
6
7
8
9
|
from config import (
FunctionScoreConfig,
IndexConfig,
QueryConfig,
|
950a640e
tangwang
embeddings
|
10
11
12
13
14
|
RerankConfig,
SPUConfig,
SearchConfig,
)
from embeddings.text_encoder import TextEmbeddingEncoder
|
7214c2e7
tangwang
mplemented**
|
15
|
from embeddings.image_encoder import CLIPImageEncoder
|
4a37d233
tangwang
1. embedding cach...
|
16
|
from embeddings.bf16 import encode_embedding_for_redis
|
7214c2e7
tangwang
mplemented**
|
17
|
from embeddings.cache_keys import build_image_cache_key, build_text_cache_key
|
950a640e
tangwang
embeddings
|
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
|
from query import QueryParser
class _FakeRedis:
def __init__(self):
self.store: Dict[str, bytes] = {}
def ping(self):
return True
def get(self, key: str):
return self.store.get(key)
def setex(self, key: str, _expire, value: bytes):
self.store[key] = value
return True
def expire(self, key: str, _expire):
return key in self.store
def delete(self, key: str):
self.store.pop(key, None)
return True
class _FakeResponse:
def __init__(self, payload: List[Optional[List[float]]]):
self._payload = payload
def raise_for_status(self):
return None
def json(self):
return self._payload
class _FakeTranslator:
def translate(
self,
text: str,
target_lang: str,
source_lang: Optional[str] = None,
prompt: Optional[str] = None,
) -> str:
return f"{text}-{target_lang}"
class _FakeQueryEncoder:
|
b754fd41
tangwang
图片向量化支持优先级参数
|
66
67
68
|
def __init__(self):
self.calls = []
|
950a640e
tangwang
embeddings
|
69
|
def encode(self, sentences, **kwargs):
|
b754fd41
tangwang
图片向量化支持优先级参数
|
70
|
self.calls.append({"sentences": sentences, "kwargs": dict(kwargs)})
|
950a640e
tangwang
embeddings
|
71
72
73
74
75
|
if isinstance(sentences, str):
sentences = [sentences]
return np.array([np.array([0.11, 0.22, 0.33], dtype=np.float32) for _ in sentences], dtype=object)
|
ef5baa86
tangwang
混杂语言处理
|
76
77
78
79
|
def _tokenizer(text):
return str(text).split()
|
7214c2e7
tangwang
mplemented**
|
80
81
82
83
84
85
86
87
88
89
90
91
|
class _FakeEmbeddingCache:
def __init__(self):
self.store: Dict[str, np.ndarray] = {}
def get(self, key: str):
return self.store.get(key)
def set(self, key: str, embedding: np.ndarray):
self.store[key] = np.asarray(embedding, dtype=np.float32)
return True
|
950a640e
tangwang
embeddings
|
92
93
94
95
96
97
98
99
100
|
def _build_test_config() -> SearchConfig:
return SearchConfig(
field_boosts={"title.en": 3.0},
indexes=[IndexConfig(name="default", label="default", fields=["title.en"], boost=1.0)],
query_config=QueryConfig(
supported_languages=["en", "zh"],
default_language="en",
enable_text_embedding=True,
enable_query_rewrite=False,
|
950a640e
tangwang
embeddings
|
101
|
rewrite_dictionary={},
|
950a640e
tangwang
embeddings
|
102
103
104
|
text_embedding_field="title_embedding",
image_embedding_field=None,
),
|
77ab67ad
tangwang
更新测试用例
|
105
|
function_score=FunctionScoreConfig(),
|
950a640e
tangwang
embeddings
|
106
107
108
|
rerank=RerankConfig(),
spu_config=SPUConfig(enabled=True, spu_field="spu_id", inner_hits_size=3),
es_index_name="test_products",
|
950a640e
tangwang
embeddings
|
109
|
es_settings={},
|
950a640e
tangwang
embeddings
|
110
111
112
113
|
)
def test_text_embedding_encoder_response_alignment(monkeypatch):
|
7214c2e7
tangwang
mplemented**
|
114
115
|
fake_cache = _FakeEmbeddingCache()
monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache)
|
950a640e
tangwang
embeddings
|
116
|
|
77ab67ad
tangwang
更新测试用例
|
117
|
def _fake_post(url, json, timeout, **kwargs):
|
950a640e
tangwang
embeddings
|
118
119
|
assert url.endswith("/embed/text")
assert json == ["hello", "world"]
|
b754fd41
tangwang
图片向量化支持优先级参数
|
120
|
assert kwargs["params"]["priority"] == 0
|
ed948666
tangwang
tidy
|
121
|
return _FakeResponse([[0.1, 0.2], [0.3, 0.4]])
|
950a640e
tangwang
embeddings
|
122
123
124
125
126
127
128
129
130
|
monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post)
encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005")
out = encoder.encode(["hello", "world"])
assert len(out) == 2
assert isinstance(out[0], np.ndarray)
assert out[0].shape == (2,)
|
ed948666
tangwang
tidy
|
131
132
133
134
135
|
assert isinstance(out[1], np.ndarray)
assert out[1].shape == (2,)
def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch):
|
7214c2e7
tangwang
mplemented**
|
136
137
|
fake_cache = _FakeEmbeddingCache()
monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache)
|
ed948666
tangwang
tidy
|
138
|
|
77ab67ad
tangwang
更新测试用例
|
139
|
def _fake_post(url, json, timeout, **kwargs):
|
ed948666
tangwang
tidy
|
140
141
142
143
144
145
146
|
return _FakeResponse([[0.1, 0.2], None])
monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post)
encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005")
with pytest.raises(ValueError):
encoder.encode(["hello", "world"])
|
950a640e
tangwang
embeddings
|
147
148
149
|
def test_text_embedding_encoder_cache_hit(monkeypatch):
|
7214c2e7
tangwang
mplemented**
|
150
|
fake_cache = _FakeEmbeddingCache()
|
950a640e
tangwang
embeddings
|
151
|
cached = np.array([0.9, 0.8], dtype=np.float32)
|
7214c2e7
tangwang
mplemented**
|
152
153
|
fake_cache.store[build_text_cache_key("cached-text", normalize=True)] = cached
monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache)
|
950a640e
tangwang
embeddings
|
154
155
156
|
calls = {"count": 0}
|
77ab67ad
tangwang
更新测试用例
|
157
|
def _fake_post(url, json, timeout, **kwargs):
|
950a640e
tangwang
embeddings
|
158
159
160
161
162
163
164
165
166
167
168
169
170
|
calls["count"] += 1
return _FakeResponse([[0.3, 0.4]])
monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post)
encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005")
out = encoder.encode(["cached-text", "new-text"])
assert calls["count"] == 1
assert np.allclose(out[0], cached)
assert np.allclose(out[1], np.array([0.3, 0.4], dtype=np.float32))
|
7214c2e7
tangwang
mplemented**
|
171
172
173
174
175
176
177
178
179
180
181
|
def test_image_embedding_encoder_cache_hit(monkeypatch):
fake_cache = _FakeEmbeddingCache()
cached = np.array([0.5, 0.6], dtype=np.float32)
url = "https://example.com/a.jpg"
fake_cache.store[build_image_cache_key(url, normalize=True)] = cached
monkeypatch.setattr("embeddings.image_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache)
calls = {"count": 0}
def _fake_post(url, params, json, timeout, **kwargs):
calls["count"] += 1
|
b754fd41
tangwang
图片向量化支持优先级参数
|
182
|
assert params["priority"] == 0
|
7214c2e7
tangwang
mplemented**
|
183
184
185
186
187
188
189
190
191
192
193
194
|
return _FakeResponse([[0.1, 0.2]])
monkeypatch.setattr("embeddings.image_encoder.requests.post", _fake_post)
encoder = CLIPImageEncoder(service_url="http://127.0.0.1:6008")
out = encoder.encode_batch(["https://example.com/a.jpg", "https://example.com/b.jpg"])
assert calls["count"] == 1
assert np.allclose(out[0], cached)
assert np.allclose(out[1], np.array([0.1, 0.2], dtype=np.float32))
|
b754fd41
tangwang
图片向量化支持优先级参数
|
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
|
def test_image_embedding_encoder_passes_priority(monkeypatch):
fake_cache = _FakeEmbeddingCache()
monkeypatch.setattr("embeddings.image_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache)
def _fake_post(url, params, json, timeout, **kwargs):
assert params["priority"] == 1
return _FakeResponse([[0.1, 0.2]])
monkeypatch.setattr("embeddings.image_encoder.requests.post", _fake_post)
encoder = CLIPImageEncoder(service_url="http://127.0.0.1:6008")
out = encoder.encode_batch(["https://example.com/a.jpg"], priority=1)
assert len(out) == 1
assert np.allclose(out[0], np.array([0.1, 0.2], dtype=np.float32))
|
950a640e
tangwang
embeddings
|
211
|
def test_query_parser_generates_query_vector_with_encoder():
|
b754fd41
tangwang
图片向量化支持优先级参数
|
212
|
encoder = _FakeQueryEncoder()
|
950a640e
tangwang
embeddings
|
213
214
|
parser = QueryParser(
config=_build_test_config(),
|
b754fd41
tangwang
图片向量化支持优先级参数
|
215
|
text_encoder=encoder,
|
950a640e
tangwang
embeddings
|
216
|
translator=_FakeTranslator(),
|
ef5baa86
tangwang
混杂语言处理
|
217
|
tokenizer=_tokenizer,
|
950a640e
tangwang
embeddings
|
218
219
220
221
222
|
)
parsed = parser.parse("red dress", tenant_id="162", generate_vector=True)
assert parsed.query_vector is not None
assert parsed.query_vector.shape == (3,)
|
b754fd41
tangwang
图片向量化支持优先级参数
|
223
224
|
assert encoder.calls
assert encoder.calls[0]["kwargs"]["priority"] == 1
|
950a640e
tangwang
embeddings
|
225
226
227
228
229
230
231
|
def test_query_parser_skips_query_vector_when_disabled():
parser = QueryParser(
config=_build_test_config(),
text_encoder=_FakeQueryEncoder(),
translator=_FakeTranslator(),
|
ef5baa86
tangwang
混杂语言处理
|
232
|
tokenizer=_tokenizer,
|
950a640e
tangwang
embeddings
|
233
234
235
236
|
)
parsed = parser.parse("red dress", tenant_id="162", generate_vector=False)
assert parsed.query_vector is None
|