Commit dc403578b9f167d06cff131dda36fd488340d99f

Authored by tangwang
1 parent 74116f05

多模态搜索

query/query_parser.py
... ... @@ -14,6 +14,7 @@ import numpy as np
14 14 import logging
15 15 from concurrent.futures import ThreadPoolExecutor, wait
16 16  
  17 +from embeddings.image_encoder import CLIPImageEncoder
17 18 from embeddings.text_encoder import TextEmbeddingEncoder
18 19 from config import SearchConfig
19 20 from translation import create_translation_client
... ... @@ -66,6 +67,7 @@ class ParsedQuery:
66 67 detected_language: Optional[str] = None
67 68 translations: Dict[str, str] = field(default_factory=dict)
68 69 query_vector: Optional[np.ndarray] = None
  70 + image_query_vector: Optional[np.ndarray] = None
69 71 query_tokens: List[str] = field(default_factory=list)
70 72 style_intent_profile: Optional[StyleIntentProfile] = None
71 73 product_title_exclusion_profile: Optional[ProductTitleExclusionProfile] = None
... ... @@ -86,6 +88,8 @@ class ParsedQuery:
86 88 "rewritten_query": self.rewritten_query,
87 89 "detected_language": self.detected_language,
88 90 "translations": self.translations,
  91 + "has_query_vector": self.query_vector is not None,
  92 + "has_image_query_vector": self.image_query_vector is not None,
89 93 "query_tokens": self.query_tokens,
90 94 "style_intent_profile": (
91 95 self.style_intent_profile.to_dict() if self.style_intent_profile is not None else None
... ... @@ -112,6 +116,7 @@ class QueryParser:
112 116 self,
113 117 config: SearchConfig,
114 118 text_encoder: Optional[TextEmbeddingEncoder] = None,
  119 + image_encoder: Optional[CLIPImageEncoder] = None,
115 120 translator: Optional[Any] = None,
116 121 tokenizer: Optional[Callable[[str], Any]] = None,
117 122 ):
... ... @@ -125,6 +130,7 @@ class QueryParser:
125 130 """
126 131 self.config = config
127 132 self._text_encoder = text_encoder
  133 + self._image_encoder = image_encoder
128 134 self._translator = translator
129 135  
130 136 # Initialize components
... ... @@ -149,6 +155,9 @@ class QueryParser:
149 155 if self.config.query_config.enable_text_embedding and self._text_encoder is None:
150 156 logger.info("Initializing text encoder at QueryParser construction...")
151 157 self._text_encoder = TextEmbeddingEncoder()
  158 + if self.config.query_config.image_embedding_field and self._image_encoder is None:
  159 + logger.info("Initializing image encoder at QueryParser construction...")
  160 + self._image_encoder = CLIPImageEncoder()
152 161 if self._translator is None:
153 162 from config.services_config import get_translation_config
154 163 cfg = get_translation_config()
... ... @@ -169,6 +178,11 @@ class QueryParser:
169 178 """Return pre-initialized translator."""
170 179 return self._translator
171 180  
  181 + @property
  182 + def image_encoder(self) -> Optional[CLIPImageEncoder]:
  183 + """Return pre-initialized image encoder for CLIP text embeddings."""
  184 + return self._image_encoder
  185 +
172 186 def _build_tokenizer(self) -> Callable[[str], Any]:
173 187 """Build the tokenizer used by query parsing. No fallback path by design."""
174 188 if hanlp is None:
... ... @@ -311,12 +325,21 @@ class QueryParser:
311 325  
312 326 # Stage 6: Text embedding - async execution
313 327 query_vector = None
  328 + image_query_vector = None
314 329 should_generate_embedding = (
315 330 generate_vector and
316 331 self.config.query_config.enable_text_embedding
317 332 )
  333 + should_generate_image_embedding = (
  334 + generate_vector and
  335 + bool(self.config.query_config.image_embedding_field)
  336 + )
318 337  
319   - task_count = len(translation_targets) + (1 if should_generate_embedding else 0)
  338 + task_count = (
  339 + len(translation_targets)
  340 + + (1 if should_generate_embedding else 0)
  341 + + (1 if should_generate_image_embedding else 0)
  342 + )
320 343 if task_count > 0:
321 344 async_executor = ThreadPoolExecutor(
322 345 max_workers=max(1, min(task_count, 4)),
... ... @@ -366,6 +389,28 @@ class QueryParser:
366 389  
367 390 future = async_executor.submit(_encode_query_vector)
368 391 future_to_task[future] = ("embedding", None)
  392 +
  393 + if should_generate_image_embedding:
  394 + if self.image_encoder is None:
  395 + raise RuntimeError(
  396 + "Image embedding field is configured but image encoder is not initialized"
  397 + )
  398 + log_debug("Submitting CLIP text query vector generation")
  399 +
  400 + def _encode_image_query_vector() -> Optional[np.ndarray]:
  401 + vec = self.image_encoder.encode_clip_text(
  402 + query_text,
  403 + normalize_embeddings=True,
  404 + priority=1,
  405 + request_id=(context.reqid if context else None),
  406 + user_id=(context.uid if context else None),
  407 + )
  408 + if vec is None:
  409 + return None
  410 + return np.asarray(vec, dtype=np.float32)
  411 +
  412 + future = async_executor.submit(_encode_image_query_vector)
  413 + future_to_task[future] = ("image_embedding", None)
369 414 except Exception as e:
370 415 error_msg = f"Async query enrichment submission failed | Error: {str(e)}"
371 416 log_info(error_msg)
... ... @@ -424,9 +469,27 @@ class QueryParser:
424 469 log_info(
425 470 "Query vector generation completed but result is None, will process without vector"
426 471 )
  472 + elif task_type == "image_embedding":
  473 + image_query_vector = result
  474 + if image_query_vector is not None:
  475 + log_debug(
  476 + f"CLIP text query vector generation completed | Shape: {image_query_vector.shape}"
  477 + )
  478 + if context:
  479 + context.store_intermediate_result(
  480 + "image_query_vector_shape",
  481 + image_query_vector.shape,
  482 + )
  483 + else:
  484 + log_info(
  485 + "CLIP text query vector generation completed but result is None, "
  486 + "will process without image vector"
  487 + )
427 488 except Exception as e:
428 489 if task_type == "translation":
429 490 error_msg = f"Translation failed | Language: {lang} | Error: {str(e)}"
  491 + elif task_type == "image_embedding":
  492 + error_msg = f"CLIP text query vector generation failed | Error: {str(e)}"
430 493 else:
431 494 error_msg = f"Query vector generation failed | Error: {str(e)}"
432 495 log_info(error_msg)
... ... @@ -441,6 +504,11 @@ class QueryParser:
441 504 f"Translation timeout (>{budget_ms}ms) | Language: {lang} | "
442 505 f"Query text: '{query_text}'"
443 506 )
  507 + elif task_type == "image_embedding":
  508 + timeout_msg = (
  509 + f"CLIP text query vector generation timeout (>{budget_ms}ms), "
  510 + "proceeding without image embedding result"
  511 + )
444 512 else:
445 513 timeout_msg = (
446 514 f"Query vector generation timeout (>{budget_ms}ms), proceeding without embedding result"
... ... @@ -463,6 +531,7 @@ class QueryParser:
463 531 detected_language=detected_lang,
464 532 translations=translations,
465 533 query_vector=query_vector,
  534 + image_query_vector=image_query_vector,
466 535 query_tokens=query_tokens,
467 536 )
468 537 style_intent_profile = self.style_intent_detector.detect(base_result)
... ... @@ -484,6 +553,7 @@ class QueryParser:
484 553 detected_language=detected_lang,
485 554 translations=translations,
486 555 query_vector=query_vector,
  556 + image_query_vector=image_query_vector,
487 557 query_tokens=query_tokens,
488 558 style_intent_profile=style_intent_profile,
489 559 product_title_exclusion_profile=product_title_exclusion_profile,
... ... @@ -492,7 +562,8 @@ class QueryParser:
492 562 if context and hasattr(context, 'logger'):
493 563 context.logger.info(
494 564 f"Query parsing completed | Original query: '{query}' | Final query: '{rewritten or query_text}' | "
495   - f"Translation count: {len(translations)} | Vector: {'yes' if query_vector is not None else 'no'}",
  565 + f"Translation count: {len(translations)} | Vector: {'yes' if query_vector is not None else 'no'} | "
  566 + f"Image vector: {'yes' if image_query_vector is not None else 'no'}",
496 567 extra={'reqid': context.reqid, 'uid': context.uid}
497 568 )
498 569 else:
... ...
search/es_query_builder.py
... ... @@ -164,6 +164,7 @@ class ESQueryBuilder:
164 164 self,
165 165 query_text: str,
166 166 query_vector: Optional[np.ndarray] = None,
  167 + image_query_vector: Optional[np.ndarray] = None,
167 168 filters: Optional[Dict[str, Any]] = None,
168 169 range_filters: Optional[Dict[str, Any]] = None,
169 170 facet_configs: Optional[List[Any]] = None,
... ... @@ -212,15 +213,14 @@ class ESQueryBuilder:
212 213  
213 214 # 1. Build recall queries (text or embedding)
214 215 recall_clauses = []
215   -
  216 +
216 217 # Text recall (always include if query_text exists)
217 218 if query_text:
218   - # Unified text query strategy
219   - text_query = self._build_advanced_text_query(query_text, parsed_query)
220   - recall_clauses.append(text_query)
221   -
222   - # Embedding recall (KNN - separate from query, handled below)
  219 + recall_clauses.extend(self._build_advanced_text_query(query_text, parsed_query))
  220 +
  221 + # Embedding recall
223 222 has_embedding = enable_knn and query_vector is not None and self.text_embedding_field
  223 + has_image_embedding = enable_knn and image_query_vector is not None and self.image_embedding_field
224 224  
225 225 # 2. Split filters for multi-select faceting
226 226 conjunctive_filters, disjunctive_filters = self._split_filters_for_faceting(
... ... @@ -233,9 +233,48 @@ class ESQueryBuilder:
233 233 if product_title_exclusion_filter:
234 234 filter_clauses.append(product_title_exclusion_filter)
235 235  
236   - # 3. Build main query structure: filters and recall
  236 + # 3. Add KNN search clauses alongside lexical clauses under the same bool.should
  237 + # Adjust KNN k, num_candidates, boost by query_tokens (short query: less KNN; long: more)
  238 + final_knn_k, final_knn_num_candidates = knn_k, knn_num_candidates
  239 + if has_embedding:
  240 + knn_boost = self.knn_boost
  241 + if parsed_query:
  242 + query_tokens = getattr(parsed_query, 'query_tokens', None) or []
  243 + token_count = len(query_tokens)
  244 + if token_count >= 5:
  245 + final_knn_k, final_knn_num_candidates = 160, 500
  246 + knn_boost = self.knn_boost * 1.4 # Higher weight for long queries
  247 + else:
  248 + final_knn_k, final_knn_num_candidates = 120, 400
  249 + else:
  250 + final_knn_k, final_knn_num_candidates = 120, 400
  251 + recall_clauses.append({
  252 + "knn": {
  253 + "field": self.text_embedding_field,
  254 + "query_vector": query_vector.tolist(),
  255 + "k": final_knn_k,
  256 + "num_candidates": final_knn_num_candidates,
  257 + "boost": knn_boost,
  258 + "_name": "knn_query",
  259 + }
  260 + })
  261 +
  262 + if has_image_embedding:
  263 + image_knn_k = max(final_knn_k, 120)
  264 + image_knn_num_candidates = max(final_knn_num_candidates, 400)
  265 + recall_clauses.append({
  266 + "knn": {
  267 + "field": self.image_embedding_field,
  268 + "query_vector": image_query_vector.tolist(),
  269 + "k": image_knn_k,
  270 + "num_candidates": image_knn_num_candidates,
  271 + "boost": self.knn_boost,
  272 + "_name": "image_knn_query",
  273 + }
  274 + })
  275 +
  276 + # 4. Build main query structure: filters and recall
237 277 if recall_clauses:
238   - # Combine text recalls with OR logic (if multiple)
239 278 if len(recall_clauses) == 1:
240 279 recall_query = recall_clauses[0]
241 280 else:
... ... @@ -245,11 +284,9 @@ class ESQueryBuilder:
245 284 "minimum_should_match": 1
246 285 }
247 286 }
248   -
249   - # Wrap recall with function_score for boosting
  287 +
250 288 recall_query = self._wrap_with_function_score(recall_query)
251   -
252   - # Combine filters and recall
  289 +
253 290 if filter_clauses:
254 291 es_query["query"] = {
255 292 "bool": {
... ... @@ -260,7 +297,6 @@ class ESQueryBuilder:
260 297 else:
261 298 es_query["query"] = recall_query
262 299 else:
263   - # No recall queries, only filters (match_all filtered)
264 300 if filter_clauses:
265 301 es_query["query"] = {
266 302 "bool": {
... ... @@ -271,41 +307,6 @@ class ESQueryBuilder:
271 307 else:
272 308 es_query["query"] = {"match_all": {}}
273 309  
274   - # 4. Add KNN search if enabled (separate from query, ES will combine)
275   - # Adjust KNN k, num_candidates, boost by query_tokens (short query: less KNN; long: more)
276   - if has_embedding:
277   - knn_boost = self.knn_boost
278   - if parsed_query:
279   - query_tokens = getattr(parsed_query, 'query_tokens', None) or []
280   - token_count = len(query_tokens)
281   - if token_count >= 5:
282   - knn_k, knn_num_candidates = 160, 500
283   - knn_boost = self.knn_boost * 1.4 # Higher weight for long queries
284   - else:
285   - knn_k, knn_num_candidates = 120, 400
286   - else:
287   - knn_k, knn_num_candidates = 120, 400
288   - knn_clause = {
289   - "field": self.text_embedding_field,
290   - "query_vector": query_vector.tolist(),
291   - "k": knn_k,
292   - "num_candidates": knn_num_candidates,
293   - "boost": knn_boost,
294   - "_name": "knn_query",
295   - }
296   - # Top-level knn does not inherit query.bool.filter automatically.
297   - # Apply conjunctive + range filters here so vector recall respects hard filters.
298   - if filter_clauses:
299   - if len(filter_clauses) == 1:
300   - knn_clause["filter"] = filter_clauses[0]
301   - else:
302   - knn_clause["filter"] = {
303   - "bool": {
304   - "filter": filter_clauses
305   - }
306   - }
307   - es_query["knn"] = knn_clause
308   -
309 310 # 5. Add post_filter for disjunctive (multi-select) filters
310 311 if disjunctive_filters:
311 312 post_filter_clauses = self._build_filters(disjunctive_filters, None)
... ... @@ -536,21 +537,20 @@ class ESQueryBuilder:
536 537 self,
537 538 query_text: str,
538 539 parsed_query: Optional[Any] = None,
539   - ) -> Dict[str, Any]:
  540 + ) -> List[Dict[str, Any]]:
540 541 """
541 542 Build advanced text query using base and translated lexical clauses.
542 543  
543 544 Unified implementation:
544 545 - base_query: source-language clause
545 546 - translation queries: target-language clauses from translations
546   - - KNN query: added separately in build_query
547   -
  547 +
548 548 Args:
549 549 query_text: Query text
550 550 parsed_query: ParsedQuery object with analysis results
551 551  
552 552 Returns:
553   - ES bool query with should clauses
  553 + Flat recall clauses to be merged with KNN clauses under query.bool.should
554 554 """
555 555 should_clauses = []
556 556 source_lang = self.default_language
... ... @@ -603,18 +603,9 @@ class ESQueryBuilder:
603 603 "minimum_should_match": self.base_minimum_should_match,
604 604 }
605 605 }
606   - return fallback_lexical
  606 + return [fallback_lexical]
607 607  
608   - # Return bool query with should clauses
609   - if len(should_clauses) == 1:
610   - return should_clauses[0]
611   -
612   - return {
613   - "bool": {
614   - "should": should_clauses,
615   - "minimum_should_match": 1
616   - }
617   - }
  608 + return should_clauses
618 609  
619 610 def _build_filters(
620 611 self,
... ...
search/rerank_client.py
... ... @@ -151,6 +151,13 @@ def _extract_named_query_score(matched_queries: Any, name: str) -> float:
151 151 return 1.0 if name in matched_queries else 0.0
152 152 return 0.0
153 153  
  154 +
  155 +def _extract_combined_knn_score(matched_queries: Any) -> float:
  156 + return max(
  157 + _extract_named_query_score(matched_queries, "knn_query"),
  158 + _extract_named_query_score(matched_queries, "image_knn_query"),
  159 + )
  160 +
154 161 """
155 162 原始变量:
156 163 ES总分
... ... @@ -272,7 +279,7 @@ def fuse_scores_and_resort(
272 279 es_score = _to_score(hit.get("_score"))
273 280 rerank_score = _to_score(rerank_scores[idx])
274 281 matched_queries = hit.get("matched_queries")
275   - knn_score = _extract_named_query_score(matched_queries, "knn_query")
  282 + knn_score = _extract_combined_knn_score(matched_queries)
276 283 text_components = _collect_text_score_components(matched_queries, es_score)
277 284 text_score = text_components["text_score"]
278 285 rerank_factor, text_factor, knn_factor, fused = _multiply_fusion_factors(
... ...
search/searcher.py
... ... @@ -106,14 +106,14 @@ class Searcher:
106 106 """
107 107 self.es_client = es_client
108 108 self.config = config
109   - # Index name is now generated dynamically per tenant, no longer stored here
110   - self.query_parser = query_parser or QueryParser(config)
111 109 self.text_embedding_field = config.query_config.text_embedding_field or "title_embedding"
112 110 self.image_embedding_field = config.query_config.image_embedding_field
113 111 if self.image_embedding_field and image_encoder is None:
114 112 self.image_encoder = CLIPImageEncoder()
115 113 else:
116 114 self.image_encoder = image_encoder
  115 + # Index name is now generated dynamically per tenant, no longer stored here
  116 + self.query_parser = query_parser or QueryParser(config, image_encoder=self.image_encoder)
117 117 self.source_fields = config.query_config.source_fields
118 118 self.style_intent_registry = StyleIntentRegistry.from_query_config(self.config.query_config)
119 119 self.style_sku_selector = StyleSkuSelector(
... ... @@ -403,7 +403,8 @@ class Searcher:
403 403 f"查询解析完成 | 原查询: '{parsed_query.original_query}' | "
404 404 f"重写后: '{parsed_query.rewritten_query}' | "
405 405 f"语言: {parsed_query.detected_language} | "
406   - f"向量: {'是' if parsed_query.query_vector is not None else '否'}",
  406 + f"文本向量: {'是' if parsed_query.query_vector is not None else '否'} | "
  407 + f"图片向量: {'是' if getattr(parsed_query, 'image_query_vector', None) is not None else '否'}",
407 408 extra={'reqid': context.reqid, 'uid': context.uid}
408 409 )
409 410 except Exception as e:
... ... @@ -428,12 +429,20 @@ class Searcher:
428 429 es_query = self.query_builder.build_query(
429 430 query_text=parsed_query.rewritten_query or parsed_query.query_normalized,
430 431 query_vector=parsed_query.query_vector if enable_embedding else None,
  432 + image_query_vector=(
  433 + getattr(parsed_query, "image_query_vector", None)
  434 + if enable_embedding
  435 + else None
  436 + ),
431 437 filters=filters,
432 438 range_filters=range_filters,
433 439 facet_configs=facets,
434 440 size=es_fetch_size,
435 441 from_=es_fetch_from,
436   - enable_knn=enable_embedding and parsed_query.query_vector is not None,
  442 + enable_knn=enable_embedding and (
  443 + parsed_query.query_vector is not None
  444 + or getattr(parsed_query, "image_query_vector", None) is not None
  445 + ),
437 446 min_score=min_score,
438 447 parsed_query=parsed_query,
439 448 )
... ... @@ -475,15 +484,24 @@ class Searcher:
475 484 # Serialize ES query to compute a compact size + stable digest for correlation
476 485 es_query_compact = json.dumps(es_query_for_fetch, ensure_ascii=False, separators=(",", ":"))
477 486 es_query_digest = hashlib.sha256(es_query_compact.encode("utf-8")).hexdigest()[:16]
478   - knn_enabled = bool(enable_embedding and parsed_query.query_vector is not None)
  487 + knn_enabled = bool(enable_embedding and (
  488 + parsed_query.query_vector is not None
  489 + or getattr(parsed_query, "image_query_vector", None) is not None
  490 + ))
479 491 vector_dims = int(len(parsed_query.query_vector)) if parsed_query.query_vector is not None else 0
  492 + image_vector_dims = (
  493 + int(len(parsed_query.image_query_vector))
  494 + if getattr(parsed_query, "image_query_vector", None) is not None
  495 + else 0
  496 + )
480 497  
481 498 context.logger.info(
482   - "ES query built | size: %s chars | digest: %s | KNN: %s | vector_dims: %s | facets: %s | rerank_prefetch_source: %s",
  499 + "ES query built | size: %s chars | digest: %s | KNN: %s | vector_dims: %s | image_vector_dims: %s | facets: %s | rerank_prefetch_source: %s",
483 500 len(es_query_compact),
484 501 es_query_digest,
485 502 "yes" if knn_enabled else "no",
486 503 vector_dims,
  504 + image_vector_dims,
487 505 "yes" if facets else "no",
488 506 rerank_prefetch_source,
489 507 extra={'reqid': context.reqid, 'uid': context.uid}
... ... @@ -497,6 +515,7 @@ class Searcher:
497 515 "sha256_16": es_query_digest,
498 516 "knn_enabled": knn_enabled,
499 517 "vector_dims": vector_dims,
  518 + "image_vector_dims": image_vector_dims,
500 519 "has_facets": bool(facets),
501 520 "query": es_query_for_fetch,
502 521 })
... ...
tests/test_embedding_pipeline.py
... ... @@ -75,6 +75,15 @@ class _FakeQueryEncoder:
75 75 return np.array([np.array([0.11, 0.22, 0.33], dtype=np.float32) for _ in sentences], dtype=object)
76 76  
77 77  
  78 +class _FakeClipTextEncoder:
  79 + def __init__(self):
  80 + self.calls = []
  81 +
  82 + def encode_clip_text(self, text, **kwargs):
  83 + self.calls.append({"text": text, "kwargs": dict(kwargs)})
  84 + return np.array([0.44, 0.55, 0.66], dtype=np.float32)
  85 +
  86 +
78 87 def _tokenizer(text):
79 88 return str(text).split()
80 89  
... ... @@ -91,7 +100,7 @@ class _FakeEmbeddingCache:
91 100 return True
92 101  
93 102  
94   -def _build_test_config() -> SearchConfig:
  103 +def _build_test_config(*, image_embedding_field: Optional[str] = None) -> SearchConfig:
95 104 return SearchConfig(
96 105 field_boosts={"title.en": 3.0},
97 106 indexes=[IndexConfig(name="default", label="default", fields=["title.en"], boost=1.0)],
... ... @@ -102,7 +111,7 @@ def _build_test_config() -> SearchConfig:
102 111 enable_query_rewrite=False,
103 112 rewrite_dictionary={},
104 113 text_embedding_field="title_embedding",
105   - image_embedding_field=None,
  114 + image_embedding_field=image_embedding_field,
106 115 ),
107 116 function_score=FunctionScoreConfig(),
108 117 rerank=RerankConfig(),
... ... @@ -250,6 +259,26 @@ def test_query_parser_generates_query_vector_with_encoder():
250 259 assert encoder.calls[0]["kwargs"]["priority"] == 1
251 260  
252 261  
  262 +def test_query_parser_generates_image_query_vector_with_clip_text_encoder():
  263 + text_encoder = _FakeQueryEncoder()
  264 + image_encoder = _FakeClipTextEncoder()
  265 + parser = QueryParser(
  266 + config=_build_test_config(image_embedding_field="image_embedding.vector"),
  267 + text_encoder=text_encoder,
  268 + image_encoder=image_encoder,
  269 + translator=_FakeTranslator(),
  270 + tokenizer=_tokenizer,
  271 + )
  272 +
  273 + parsed = parser.parse("red dress", tenant_id="162", generate_vector=True)
  274 + assert parsed.query_vector is not None
  275 + assert parsed.image_query_vector is not None
  276 + assert parsed.image_query_vector.shape == (3,)
  277 + assert image_encoder.calls
  278 + assert image_encoder.calls[0]["text"] == "red dress"
  279 + assert image_encoder.calls[0]["kwargs"]["priority"] == 1
  280 +
  281 +
253 282 def test_query_parser_skips_query_vector_when_disabled():
254 283 parser = QueryParser(
255 284 config=_build_test_config(),
... ... @@ -260,6 +289,7 @@ def test_query_parser_skips_query_vector_when_disabled():
260 289  
261 290 parsed = parser.parse("red dress", tenant_id="162", generate_vector=False)
262 291 assert parsed.query_vector is None
  292 + assert parsed.image_query_vector is None
263 293  
264 294  
265 295 def test_tei_text_model_splits_batches_over_client_limit(monkeypatch):
... ...
tests/test_es_query_builder.py
... ... @@ -13,22 +13,29 @@ def _builder() -> ESQueryBuilder:
13 13 core_multilingual_fields=["title", "brief"],
14 14 shared_fields=[],
15 15 text_embedding_field="title_embedding",
  16 + image_embedding_field="image_embedding.vector",
16 17 default_language="en",
17 18 )
18 19  
19 20  
20   -def _lexical_clause(query_root: Dict[str, Any]) -> Dict[str, Any]:
21   - """Return the first named lexical bool clause from query_root."""
22   - if "bool" in query_root and query_root["bool"].get("_name"):
23   - return query_root["bool"]
24   - for clause in query_root.get("bool", {}).get("should", []):
25   - clause_bool = clause.get("bool") or {}
26   - if clause_bool.get("_name"):
27   - return clause_bool
28   - raise AssertionError("no lexical bool clause in query_root")
  21 +def _recall_root(es_body: Dict[str, Any]) -> Dict[str, Any]:
  22 + query_root = es_body["query"]
  23 + if "bool" in query_root and query_root["bool"].get("must"):
  24 + query_root = query_root["bool"]["must"][0]
  25 + if "function_score" in query_root:
  26 + query_root = query_root["function_score"]["query"]
  27 + return query_root
29 28  
30 29  
31   -def test_knn_prefilter_includes_range_filters():
  30 +def _recall_should_clauses(es_body: Dict[str, Any]) -> list[Dict[str, Any]]:
  31 + root = _recall_root(es_body)
  32 + should = root.get("bool", {}).get("should")
  33 + if should:
  34 + return should
  35 + return [root]
  36 +
  37 +
  38 +def test_knn_clause_moves_under_query_should_and_uses_outer_filters():
32 39 qb = _builder()
33 40 q = qb.build_query(
34 41 query_text="bags",
... ... @@ -37,11 +44,13 @@ def test_knn_prefilter_includes_range_filters():
37 44 enable_knn=True,
38 45 )
39 46  
40   - assert "knn" in q
41   - assert q["knn"]["filter"] == {"range": {"min_price": {"gte": 50, "lt": 100}}}
  47 + assert "knn" not in q
  48 + should = _recall_should_clauses(q)
  49 + assert any(clause.get("knn", {}).get("_name") == "knn_query" for clause in should)
  50 + assert q["query"]["bool"]["filter"] == [{"range": {"min_price": {"gte": 50, "lt": 100}}}]
42 51  
43 52  
44   -def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present():
  53 +def test_knn_clause_uses_outer_query_filter_when_disjunctive_filters_present():
45 54 qb = _builder()
46 55 facets = [SimpleNamespace(field="category_name", disjunctive=True)]
47 56 q = qb.build_query(
... ... @@ -53,21 +62,15 @@ def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present():
53 62 enable_knn=True,
54 63 )
55 64  
56   - assert "knn" in q
57   - assert "filter" in q["knn"]
58   - knn_filter = q["knn"]["filter"]
59   - assert knn_filter == {
60   - "bool": {
61   - "filter": [
62   - {"term": {"vendor": "Nike"}},
63   - {"range": {"min_price": {"gte": 50, "lt": 100}}},
64   - ]
65   - }
66   - }
  65 + assert "knn" not in q
  66 + assert q["query"]["bool"]["filter"] == [
  67 + {"term": {"vendor": "Nike"}},
  68 + {"range": {"min_price": {"gte": 50, "lt": 100}}},
  69 + ]
67 70 assert q["post_filter"] == {"terms": {"category_name": ["A", "B"]}}
68 71  
69 72  
70   -def test_knn_prefilter_not_added_without_filters():
  73 +def test_knn_clause_has_name_and_no_embedded_filter():
71 74 qb = _builder()
72 75 q = qb.build_query(
73 76 query_text="bags",
... ... @@ -75,9 +78,10 @@ def test_knn_prefilter_not_added_without_filters():
75 78 enable_knn=True,
76 79 )
77 80  
78   - assert "knn" in q
79   - assert "filter" not in q["knn"]
80   - assert q["knn"]["_name"] == "knn_query"
  81 + should = _recall_should_clauses(q)
  82 + knn_clause = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "knn_query")
  83 + assert "filter" not in knn_clause
  84 + assert knn_clause["_name"] == "knn_query"
81 85  
82 86  
83 87 def test_text_query_contains_only_base_and_translation_named_queries():
... ... @@ -93,11 +97,11 @@ def test_text_query_contains_only_base_and_translation_named_queries():
93 97 parsed_query=parsed_query,
94 98 enable_knn=False,
95 99 )
96   - should = q["query"]["bool"]["should"]
  100 + should = _recall_should_clauses(q)
97 101 names = [clause["bool"]["_name"] for clause in should]
98 102  
99 103 assert names == ["base_query", "base_query_trans_zh"]
100   - base_should = q["query"]["bool"]["should"][0]["bool"]["should"]
  104 + base_should = should[0]["bool"]["should"]
101 105 assert [clause["multi_match"]["type"] for clause in base_should] == ["best_fields", "phrase"]
102 106  
103 107  
... ... @@ -115,12 +119,12 @@ def test_text_query_skips_duplicate_translation_same_as_base():
115 119 enable_knn=False,
116 120 )
117 121  
118   - root = q["query"]
  122 + root = _recall_root(q)
119 123 assert root["bool"]["_name"] == "base_query"
120 124 assert [clause["multi_match"]["type"] for clause in root["bool"]["should"]] == ["best_fields", "phrase"]
121 125  
122 126  
123   -def test_product_title_exclusion_filter_is_applied_to_query_and_knn():
  127 +def test_product_title_exclusion_filter_is_applied_once_on_outer_query():
124 128 qb = _builder()
125 129 parsed_query = SimpleNamespace(
126 130 rewritten_query="fitted dress",
... ... @@ -158,4 +162,32 @@ def test_product_title_exclusion_filter_is_applied_to_query_and_knn():
158 162 }
159 163  
160 164 assert expected_filter in q["query"]["bool"]["filter"]
161   - assert q["knn"]["filter"] == expected_filter
  165 + should = _recall_should_clauses(q)
  166 + knn_clause = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "knn_query")
  167 + assert "filter" not in knn_clause
  168 +
  169 +
  170 +def test_image_knn_clause_is_added_alongside_base_translation_and_text_knn():
  171 + qb = _builder()
  172 + parsed_query = SimpleNamespace(
  173 + rewritten_query="street tee",
  174 + detected_language="en",
  175 + translations={"zh": "街头短袖"},
  176 + )
  177 +
  178 + q = qb.build_query(
  179 + query_text="street tee",
  180 + query_vector=np.array([0.1, 0.2, 0.3]),
  181 + image_query_vector=np.array([0.4, 0.5, 0.6]),
  182 + parsed_query=parsed_query,
  183 + enable_knn=True,
  184 + )
  185 +
  186 + should = _recall_should_clauses(q)
  187 + names = [
  188 + clause["bool"]["_name"] if "bool" in clause else clause["knn"]["_name"]
  189 + for clause in should
  190 + ]
  191 + assert names == ["base_query", "base_query_trans_zh", "knn_query", "image_knn_query"]
  192 + image_knn = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "image_knn_query")
  193 + assert image_knn["field"] == "image_embedding.vector"
... ...
tests/test_rerank_client.py
... ... @@ -149,3 +149,22 @@ def test_fuse_scores_and_resort_boosts_hits_with_selected_sku():
149 149 assert [h["_id"] for h in hits] == ["style-selected", "plain"]
150 150 assert debug[0]["style_intent_selected_sku"] is True
151 151 assert debug[0]["style_intent_selected_sku_boost"] == 1.2
  152 +
  153 +
  154 +def test_fuse_scores_and_resort_uses_max_of_text_and_image_knn_scores():
  155 + hits = [
  156 + {
  157 + "_id": "mm-hit",
  158 + "_score": 1.0,
  159 + "matched_queries": {
  160 + "base_query": 1.5,
  161 + "knn_query": 0.2,
  162 + "image_knn_query": 0.7,
  163 + },
  164 + }
  165 + ]
  166 +
  167 + debug = fuse_scores_and_resort(hits, [0.8], debug=True)
  168 +
  169 + assert isclose(hits[0]["_knn_score"], 0.7, rel_tol=1e-9)
  170 + assert isclose(debug[0]["knn_score"], 0.7, rel_tol=1e-9)
... ...