Commit dc403578b9f167d06cff131dda36fd488340d99f
1 parent
74116f05
多模态搜索
Showing
7 changed files
with
276 additions
and
107 deletions
Show diff stats
query/query_parser.py
| ... | ... | @@ -14,6 +14,7 @@ import numpy as np |
| 14 | 14 | import logging |
| 15 | 15 | from concurrent.futures import ThreadPoolExecutor, wait |
| 16 | 16 | |
| 17 | +from embeddings.image_encoder import CLIPImageEncoder | |
| 17 | 18 | from embeddings.text_encoder import TextEmbeddingEncoder |
| 18 | 19 | from config import SearchConfig |
| 19 | 20 | from translation import create_translation_client |
| ... | ... | @@ -66,6 +67,7 @@ class ParsedQuery: |
| 66 | 67 | detected_language: Optional[str] = None |
| 67 | 68 | translations: Dict[str, str] = field(default_factory=dict) |
| 68 | 69 | query_vector: Optional[np.ndarray] = None |
| 70 | + image_query_vector: Optional[np.ndarray] = None | |
| 69 | 71 | query_tokens: List[str] = field(default_factory=list) |
| 70 | 72 | style_intent_profile: Optional[StyleIntentProfile] = None |
| 71 | 73 | product_title_exclusion_profile: Optional[ProductTitleExclusionProfile] = None |
| ... | ... | @@ -86,6 +88,8 @@ class ParsedQuery: |
| 86 | 88 | "rewritten_query": self.rewritten_query, |
| 87 | 89 | "detected_language": self.detected_language, |
| 88 | 90 | "translations": self.translations, |
| 91 | + "has_query_vector": self.query_vector is not None, | |
| 92 | + "has_image_query_vector": self.image_query_vector is not None, | |
| 89 | 93 | "query_tokens": self.query_tokens, |
| 90 | 94 | "style_intent_profile": ( |
| 91 | 95 | self.style_intent_profile.to_dict() if self.style_intent_profile is not None else None |
| ... | ... | @@ -112,6 +116,7 @@ class QueryParser: |
| 112 | 116 | self, |
| 113 | 117 | config: SearchConfig, |
| 114 | 118 | text_encoder: Optional[TextEmbeddingEncoder] = None, |
| 119 | + image_encoder: Optional[CLIPImageEncoder] = None, | |
| 115 | 120 | translator: Optional[Any] = None, |
| 116 | 121 | tokenizer: Optional[Callable[[str], Any]] = None, |
| 117 | 122 | ): |
| ... | ... | @@ -125,6 +130,7 @@ class QueryParser: |
| 125 | 130 | """ |
| 126 | 131 | self.config = config |
| 127 | 132 | self._text_encoder = text_encoder |
| 133 | + self._image_encoder = image_encoder | |
| 128 | 134 | self._translator = translator |
| 129 | 135 | |
| 130 | 136 | # Initialize components |
| ... | ... | @@ -149,6 +155,9 @@ class QueryParser: |
| 149 | 155 | if self.config.query_config.enable_text_embedding and self._text_encoder is None: |
| 150 | 156 | logger.info("Initializing text encoder at QueryParser construction...") |
| 151 | 157 | self._text_encoder = TextEmbeddingEncoder() |
| 158 | + if self.config.query_config.image_embedding_field and self._image_encoder is None: | |
| 159 | + logger.info("Initializing image encoder at QueryParser construction...") | |
| 160 | + self._image_encoder = CLIPImageEncoder() | |
| 152 | 161 | if self._translator is None: |
| 153 | 162 | from config.services_config import get_translation_config |
| 154 | 163 | cfg = get_translation_config() |
| ... | ... | @@ -169,6 +178,11 @@ class QueryParser: |
| 169 | 178 | """Return pre-initialized translator.""" |
| 170 | 179 | return self._translator |
| 171 | 180 | |
| 181 | + @property | |
| 182 | + def image_encoder(self) -> Optional[CLIPImageEncoder]: | |
| 183 | + """Return pre-initialized image encoder for CLIP text embeddings.""" | |
| 184 | + return self._image_encoder | |
| 185 | + | |
| 172 | 186 | def _build_tokenizer(self) -> Callable[[str], Any]: |
| 173 | 187 | """Build the tokenizer used by query parsing. No fallback path by design.""" |
| 174 | 188 | if hanlp is None: |
| ... | ... | @@ -311,12 +325,21 @@ class QueryParser: |
| 311 | 325 | |
| 312 | 326 | # Stage 6: Text embedding - async execution |
| 313 | 327 | query_vector = None |
| 328 | + image_query_vector = None | |
| 314 | 329 | should_generate_embedding = ( |
| 315 | 330 | generate_vector and |
| 316 | 331 | self.config.query_config.enable_text_embedding |
| 317 | 332 | ) |
| 333 | + should_generate_image_embedding = ( | |
| 334 | + generate_vector and | |
| 335 | + bool(self.config.query_config.image_embedding_field) | |
| 336 | + ) | |
| 318 | 337 | |
| 319 | - task_count = len(translation_targets) + (1 if should_generate_embedding else 0) | |
| 338 | + task_count = ( | |
| 339 | + len(translation_targets) | |
| 340 | + + (1 if should_generate_embedding else 0) | |
| 341 | + + (1 if should_generate_image_embedding else 0) | |
| 342 | + ) | |
| 320 | 343 | if task_count > 0: |
| 321 | 344 | async_executor = ThreadPoolExecutor( |
| 322 | 345 | max_workers=max(1, min(task_count, 4)), |
| ... | ... | @@ -366,6 +389,28 @@ class QueryParser: |
| 366 | 389 | |
| 367 | 390 | future = async_executor.submit(_encode_query_vector) |
| 368 | 391 | future_to_task[future] = ("embedding", None) |
| 392 | + | |
| 393 | + if should_generate_image_embedding: | |
| 394 | + if self.image_encoder is None: | |
| 395 | + raise RuntimeError( | |
| 396 | + "Image embedding field is configured but image encoder is not initialized" | |
| 397 | + ) | |
| 398 | + log_debug("Submitting CLIP text query vector generation") | |
| 399 | + | |
| 400 | + def _encode_image_query_vector() -> Optional[np.ndarray]: | |
| 401 | + vec = self.image_encoder.encode_clip_text( | |
| 402 | + query_text, | |
| 403 | + normalize_embeddings=True, | |
| 404 | + priority=1, | |
| 405 | + request_id=(context.reqid if context else None), | |
| 406 | + user_id=(context.uid if context else None), | |
| 407 | + ) | |
| 408 | + if vec is None: | |
| 409 | + return None | |
| 410 | + return np.asarray(vec, dtype=np.float32) | |
| 411 | + | |
| 412 | + future = async_executor.submit(_encode_image_query_vector) | |
| 413 | + future_to_task[future] = ("image_embedding", None) | |
| 369 | 414 | except Exception as e: |
| 370 | 415 | error_msg = f"Async query enrichment submission failed | Error: {str(e)}" |
| 371 | 416 | log_info(error_msg) |
| ... | ... | @@ -424,9 +469,27 @@ class QueryParser: |
| 424 | 469 | log_info( |
| 425 | 470 | "Query vector generation completed but result is None, will process without vector" |
| 426 | 471 | ) |
| 472 | + elif task_type == "image_embedding": | |
| 473 | + image_query_vector = result | |
| 474 | + if image_query_vector is not None: | |
| 475 | + log_debug( | |
| 476 | + f"CLIP text query vector generation completed | Shape: {image_query_vector.shape}" | |
| 477 | + ) | |
| 478 | + if context: | |
| 479 | + context.store_intermediate_result( | |
| 480 | + "image_query_vector_shape", | |
| 481 | + image_query_vector.shape, | |
| 482 | + ) | |
| 483 | + else: | |
| 484 | + log_info( | |
| 485 | + "CLIP text query vector generation completed but result is None, " | |
| 486 | + "will process without image vector" | |
| 487 | + ) | |
| 427 | 488 | except Exception as e: |
| 428 | 489 | if task_type == "translation": |
| 429 | 490 | error_msg = f"Translation failed | Language: {lang} | Error: {str(e)}" |
| 491 | + elif task_type == "image_embedding": | |
| 492 | + error_msg = f"CLIP text query vector generation failed | Error: {str(e)}" | |
| 430 | 493 | else: |
| 431 | 494 | error_msg = f"Query vector generation failed | Error: {str(e)}" |
| 432 | 495 | log_info(error_msg) |
| ... | ... | @@ -441,6 +504,11 @@ class QueryParser: |
| 441 | 504 | f"Translation timeout (>{budget_ms}ms) | Language: {lang} | " |
| 442 | 505 | f"Query text: '{query_text}'" |
| 443 | 506 | ) |
| 507 | + elif task_type == "image_embedding": | |
| 508 | + timeout_msg = ( | |
| 509 | + f"CLIP text query vector generation timeout (>{budget_ms}ms), " | |
| 510 | + "proceeding without image embedding result" | |
| 511 | + ) | |
| 444 | 512 | else: |
| 445 | 513 | timeout_msg = ( |
| 446 | 514 | f"Query vector generation timeout (>{budget_ms}ms), proceeding without embedding result" |
| ... | ... | @@ -463,6 +531,7 @@ class QueryParser: |
| 463 | 531 | detected_language=detected_lang, |
| 464 | 532 | translations=translations, |
| 465 | 533 | query_vector=query_vector, |
| 534 | + image_query_vector=image_query_vector, | |
| 466 | 535 | query_tokens=query_tokens, |
| 467 | 536 | ) |
| 468 | 537 | style_intent_profile = self.style_intent_detector.detect(base_result) |
| ... | ... | @@ -484,6 +553,7 @@ class QueryParser: |
| 484 | 553 | detected_language=detected_lang, |
| 485 | 554 | translations=translations, |
| 486 | 555 | query_vector=query_vector, |
| 556 | + image_query_vector=image_query_vector, | |
| 487 | 557 | query_tokens=query_tokens, |
| 488 | 558 | style_intent_profile=style_intent_profile, |
| 489 | 559 | product_title_exclusion_profile=product_title_exclusion_profile, |
| ... | ... | @@ -492,7 +562,8 @@ class QueryParser: |
| 492 | 562 | if context and hasattr(context, 'logger'): |
| 493 | 563 | context.logger.info( |
| 494 | 564 | f"Query parsing completed | Original query: '{query}' | Final query: '{rewritten or query_text}' | " |
| 495 | - f"Translation count: {len(translations)} | Vector: {'yes' if query_vector is not None else 'no'}", | |
| 565 | + f"Translation count: {len(translations)} | Vector: {'yes' if query_vector is not None else 'no'} | " | |
| 566 | + f"Image vector: {'yes' if image_query_vector is not None else 'no'}", | |
| 496 | 567 | extra={'reqid': context.reqid, 'uid': context.uid} |
| 497 | 568 | ) |
| 498 | 569 | else: | ... | ... |
search/es_query_builder.py
| ... | ... | @@ -164,6 +164,7 @@ class ESQueryBuilder: |
| 164 | 164 | self, |
| 165 | 165 | query_text: str, |
| 166 | 166 | query_vector: Optional[np.ndarray] = None, |
| 167 | + image_query_vector: Optional[np.ndarray] = None, | |
| 167 | 168 | filters: Optional[Dict[str, Any]] = None, |
| 168 | 169 | range_filters: Optional[Dict[str, Any]] = None, |
| 169 | 170 | facet_configs: Optional[List[Any]] = None, |
| ... | ... | @@ -212,15 +213,14 @@ class ESQueryBuilder: |
| 212 | 213 | |
| 213 | 214 | # 1. Build recall queries (text or embedding) |
| 214 | 215 | recall_clauses = [] |
| 215 | - | |
| 216 | + | |
| 216 | 217 | # Text recall (always include if query_text exists) |
| 217 | 218 | if query_text: |
| 218 | - # Unified text query strategy | |
| 219 | - text_query = self._build_advanced_text_query(query_text, parsed_query) | |
| 220 | - recall_clauses.append(text_query) | |
| 221 | - | |
| 222 | - # Embedding recall (KNN - separate from query, handled below) | |
| 219 | + recall_clauses.extend(self._build_advanced_text_query(query_text, parsed_query)) | |
| 220 | + | |
| 221 | + # Embedding recall | |
| 223 | 222 | has_embedding = enable_knn and query_vector is not None and self.text_embedding_field |
| 223 | + has_image_embedding = enable_knn and image_query_vector is not None and self.image_embedding_field | |
| 224 | 224 | |
| 225 | 225 | # 2. Split filters for multi-select faceting |
| 226 | 226 | conjunctive_filters, disjunctive_filters = self._split_filters_for_faceting( |
| ... | ... | @@ -233,9 +233,48 @@ class ESQueryBuilder: |
| 233 | 233 | if product_title_exclusion_filter: |
| 234 | 234 | filter_clauses.append(product_title_exclusion_filter) |
| 235 | 235 | |
| 236 | - # 3. Build main query structure: filters and recall | |
| 236 | + # 3. Add KNN search clauses alongside lexical clauses under the same bool.should | |
| 237 | + # Adjust KNN k, num_candidates, boost by query_tokens (short query: less KNN; long: more) | |
| 238 | + final_knn_k, final_knn_num_candidates = knn_k, knn_num_candidates | |
| 239 | + if has_embedding: | |
| 240 | + knn_boost = self.knn_boost | |
| 241 | + if parsed_query: | |
| 242 | + query_tokens = getattr(parsed_query, 'query_tokens', None) or [] | |
| 243 | + token_count = len(query_tokens) | |
| 244 | + if token_count >= 5: | |
| 245 | + final_knn_k, final_knn_num_candidates = 160, 500 | |
| 246 | + knn_boost = self.knn_boost * 1.4 # Higher weight for long queries | |
| 247 | + else: | |
| 248 | + final_knn_k, final_knn_num_candidates = 120, 400 | |
| 249 | + else: | |
| 250 | + final_knn_k, final_knn_num_candidates = 120, 400 | |
| 251 | + recall_clauses.append({ | |
| 252 | + "knn": { | |
| 253 | + "field": self.text_embedding_field, | |
| 254 | + "query_vector": query_vector.tolist(), | |
| 255 | + "k": final_knn_k, | |
| 256 | + "num_candidates": final_knn_num_candidates, | |
| 257 | + "boost": knn_boost, | |
| 258 | + "_name": "knn_query", | |
| 259 | + } | |
| 260 | + }) | |
| 261 | + | |
| 262 | + if has_image_embedding: | |
| 263 | + image_knn_k = max(final_knn_k, 120) | |
| 264 | + image_knn_num_candidates = max(final_knn_num_candidates, 400) | |
| 265 | + recall_clauses.append({ | |
| 266 | + "knn": { | |
| 267 | + "field": self.image_embedding_field, | |
| 268 | + "query_vector": image_query_vector.tolist(), | |
| 269 | + "k": image_knn_k, | |
| 270 | + "num_candidates": image_knn_num_candidates, | |
| 271 | + "boost": self.knn_boost, | |
| 272 | + "_name": "image_knn_query", | |
| 273 | + } | |
| 274 | + }) | |
| 275 | + | |
| 276 | + # 4. Build main query structure: filters and recall | |
| 237 | 277 | if recall_clauses: |
| 238 | - # Combine text recalls with OR logic (if multiple) | |
| 239 | 278 | if len(recall_clauses) == 1: |
| 240 | 279 | recall_query = recall_clauses[0] |
| 241 | 280 | else: |
| ... | ... | @@ -245,11 +284,9 @@ class ESQueryBuilder: |
| 245 | 284 | "minimum_should_match": 1 |
| 246 | 285 | } |
| 247 | 286 | } |
| 248 | - | |
| 249 | - # Wrap recall with function_score for boosting | |
| 287 | + | |
| 250 | 288 | recall_query = self._wrap_with_function_score(recall_query) |
| 251 | - | |
| 252 | - # Combine filters and recall | |
| 289 | + | |
| 253 | 290 | if filter_clauses: |
| 254 | 291 | es_query["query"] = { |
| 255 | 292 | "bool": { |
| ... | ... | @@ -260,7 +297,6 @@ class ESQueryBuilder: |
| 260 | 297 | else: |
| 261 | 298 | es_query["query"] = recall_query |
| 262 | 299 | else: |
| 263 | - # No recall queries, only filters (match_all filtered) | |
| 264 | 300 | if filter_clauses: |
| 265 | 301 | es_query["query"] = { |
| 266 | 302 | "bool": { |
| ... | ... | @@ -271,41 +307,6 @@ class ESQueryBuilder: |
| 271 | 307 | else: |
| 272 | 308 | es_query["query"] = {"match_all": {}} |
| 273 | 309 | |
| 274 | - # 4. Add KNN search if enabled (separate from query, ES will combine) | |
| 275 | - # Adjust KNN k, num_candidates, boost by query_tokens (short query: less KNN; long: more) | |
| 276 | - if has_embedding: | |
| 277 | - knn_boost = self.knn_boost | |
| 278 | - if parsed_query: | |
| 279 | - query_tokens = getattr(parsed_query, 'query_tokens', None) or [] | |
| 280 | - token_count = len(query_tokens) | |
| 281 | - if token_count >= 5: | |
| 282 | - knn_k, knn_num_candidates = 160, 500 | |
| 283 | - knn_boost = self.knn_boost * 1.4 # Higher weight for long queries | |
| 284 | - else: | |
| 285 | - knn_k, knn_num_candidates = 120, 400 | |
| 286 | - else: | |
| 287 | - knn_k, knn_num_candidates = 120, 400 | |
| 288 | - knn_clause = { | |
| 289 | - "field": self.text_embedding_field, | |
| 290 | - "query_vector": query_vector.tolist(), | |
| 291 | - "k": knn_k, | |
| 292 | - "num_candidates": knn_num_candidates, | |
| 293 | - "boost": knn_boost, | |
| 294 | - "_name": "knn_query", | |
| 295 | - } | |
| 296 | - # Top-level knn does not inherit query.bool.filter automatically. | |
| 297 | - # Apply conjunctive + range filters here so vector recall respects hard filters. | |
| 298 | - if filter_clauses: | |
| 299 | - if len(filter_clauses) == 1: | |
| 300 | - knn_clause["filter"] = filter_clauses[0] | |
| 301 | - else: | |
| 302 | - knn_clause["filter"] = { | |
| 303 | - "bool": { | |
| 304 | - "filter": filter_clauses | |
| 305 | - } | |
| 306 | - } | |
| 307 | - es_query["knn"] = knn_clause | |
| 308 | - | |
| 309 | 310 | # 5. Add post_filter for disjunctive (multi-select) filters |
| 310 | 311 | if disjunctive_filters: |
| 311 | 312 | post_filter_clauses = self._build_filters(disjunctive_filters, None) |
| ... | ... | @@ -536,21 +537,20 @@ class ESQueryBuilder: |
| 536 | 537 | self, |
| 537 | 538 | query_text: str, |
| 538 | 539 | parsed_query: Optional[Any] = None, |
| 539 | - ) -> Dict[str, Any]: | |
| 540 | + ) -> List[Dict[str, Any]]: | |
| 540 | 541 | """ |
| 541 | 542 | Build advanced text query using base and translated lexical clauses. |
| 542 | 543 | |
| 543 | 544 | Unified implementation: |
| 544 | 545 | - base_query: source-language clause |
| 545 | 546 | - translation queries: target-language clauses from translations |
| 546 | - - KNN query: added separately in build_query | |
| 547 | - | |
| 547 | + | |
| 548 | 548 | Args: |
| 549 | 549 | query_text: Query text |
| 550 | 550 | parsed_query: ParsedQuery object with analysis results |
| 551 | 551 | |
| 552 | 552 | Returns: |
| 553 | - ES bool query with should clauses | |
| 553 | + Flat recall clauses to be merged with KNN clauses under query.bool.should | |
| 554 | 554 | """ |
| 555 | 555 | should_clauses = [] |
| 556 | 556 | source_lang = self.default_language |
| ... | ... | @@ -603,18 +603,9 @@ class ESQueryBuilder: |
| 603 | 603 | "minimum_should_match": self.base_minimum_should_match, |
| 604 | 604 | } |
| 605 | 605 | } |
| 606 | - return fallback_lexical | |
| 606 | + return [fallback_lexical] | |
| 607 | 607 | |
| 608 | - # Return bool query with should clauses | |
| 609 | - if len(should_clauses) == 1: | |
| 610 | - return should_clauses[0] | |
| 611 | - | |
| 612 | - return { | |
| 613 | - "bool": { | |
| 614 | - "should": should_clauses, | |
| 615 | - "minimum_should_match": 1 | |
| 616 | - } | |
| 617 | - } | |
| 608 | + return should_clauses | |
| 618 | 609 | |
| 619 | 610 | def _build_filters( |
| 620 | 611 | self, | ... | ... |
search/rerank_client.py
| ... | ... | @@ -151,6 +151,13 @@ def _extract_named_query_score(matched_queries: Any, name: str) -> float: |
| 151 | 151 | return 1.0 if name in matched_queries else 0.0 |
| 152 | 152 | return 0.0 |
| 153 | 153 | |
| 154 | + | |
| 155 | +def _extract_combined_knn_score(matched_queries: Any) -> float: | |
| 156 | + return max( | |
| 157 | + _extract_named_query_score(matched_queries, "knn_query"), | |
| 158 | + _extract_named_query_score(matched_queries, "image_knn_query"), | |
| 159 | + ) | |
| 160 | + | |
| 154 | 161 | """ |
| 155 | 162 | 原始变量: |
| 156 | 163 | ES总分 |
| ... | ... | @@ -272,7 +279,7 @@ def fuse_scores_and_resort( |
| 272 | 279 | es_score = _to_score(hit.get("_score")) |
| 273 | 280 | rerank_score = _to_score(rerank_scores[idx]) |
| 274 | 281 | matched_queries = hit.get("matched_queries") |
| 275 | - knn_score = _extract_named_query_score(matched_queries, "knn_query") | |
| 282 | + knn_score = _extract_combined_knn_score(matched_queries) | |
| 276 | 283 | text_components = _collect_text_score_components(matched_queries, es_score) |
| 277 | 284 | text_score = text_components["text_score"] |
| 278 | 285 | rerank_factor, text_factor, knn_factor, fused = _multiply_fusion_factors( | ... | ... |
search/searcher.py
| ... | ... | @@ -106,14 +106,14 @@ class Searcher: |
| 106 | 106 | """ |
| 107 | 107 | self.es_client = es_client |
| 108 | 108 | self.config = config |
| 109 | - # Index name is now generated dynamically per tenant, no longer stored here | |
| 110 | - self.query_parser = query_parser or QueryParser(config) | |
| 111 | 109 | self.text_embedding_field = config.query_config.text_embedding_field or "title_embedding" |
| 112 | 110 | self.image_embedding_field = config.query_config.image_embedding_field |
| 113 | 111 | if self.image_embedding_field and image_encoder is None: |
| 114 | 112 | self.image_encoder = CLIPImageEncoder() |
| 115 | 113 | else: |
| 116 | 114 | self.image_encoder = image_encoder |
| 115 | + # Index name is now generated dynamically per tenant, no longer stored here | |
| 116 | + self.query_parser = query_parser or QueryParser(config, image_encoder=self.image_encoder) | |
| 117 | 117 | self.source_fields = config.query_config.source_fields |
| 118 | 118 | self.style_intent_registry = StyleIntentRegistry.from_query_config(self.config.query_config) |
| 119 | 119 | self.style_sku_selector = StyleSkuSelector( |
| ... | ... | @@ -403,7 +403,8 @@ class Searcher: |
| 403 | 403 | f"查询解析完成 | 原查询: '{parsed_query.original_query}' | " |
| 404 | 404 | f"重写后: '{parsed_query.rewritten_query}' | " |
| 405 | 405 | f"语言: {parsed_query.detected_language} | " |
| 406 | - f"向量: {'是' if parsed_query.query_vector is not None else '否'}", | |
| 406 | + f"文本向量: {'是' if parsed_query.query_vector is not None else '否'} | " | |
| 407 | + f"图片向量: {'是' if getattr(parsed_query, 'image_query_vector', None) is not None else '否'}", | |
| 407 | 408 | extra={'reqid': context.reqid, 'uid': context.uid} |
| 408 | 409 | ) |
| 409 | 410 | except Exception as e: |
| ... | ... | @@ -428,12 +429,20 @@ class Searcher: |
| 428 | 429 | es_query = self.query_builder.build_query( |
| 429 | 430 | query_text=parsed_query.rewritten_query or parsed_query.query_normalized, |
| 430 | 431 | query_vector=parsed_query.query_vector if enable_embedding else None, |
| 432 | + image_query_vector=( | |
| 433 | + getattr(parsed_query, "image_query_vector", None) | |
| 434 | + if enable_embedding | |
| 435 | + else None | |
| 436 | + ), | |
| 431 | 437 | filters=filters, |
| 432 | 438 | range_filters=range_filters, |
| 433 | 439 | facet_configs=facets, |
| 434 | 440 | size=es_fetch_size, |
| 435 | 441 | from_=es_fetch_from, |
| 436 | - enable_knn=enable_embedding and parsed_query.query_vector is not None, | |
| 442 | + enable_knn=enable_embedding and ( | |
| 443 | + parsed_query.query_vector is not None | |
| 444 | + or getattr(parsed_query, "image_query_vector", None) is not None | |
| 445 | + ), | |
| 437 | 446 | min_score=min_score, |
| 438 | 447 | parsed_query=parsed_query, |
| 439 | 448 | ) |
| ... | ... | @@ -475,15 +484,24 @@ class Searcher: |
| 475 | 484 | # Serialize ES query to compute a compact size + stable digest for correlation |
| 476 | 485 | es_query_compact = json.dumps(es_query_for_fetch, ensure_ascii=False, separators=(",", ":")) |
| 477 | 486 | es_query_digest = hashlib.sha256(es_query_compact.encode("utf-8")).hexdigest()[:16] |
| 478 | - knn_enabled = bool(enable_embedding and parsed_query.query_vector is not None) | |
| 487 | + knn_enabled = bool(enable_embedding and ( | |
| 488 | + parsed_query.query_vector is not None | |
| 489 | + or getattr(parsed_query, "image_query_vector", None) is not None | |
| 490 | + )) | |
| 479 | 491 | vector_dims = int(len(parsed_query.query_vector)) if parsed_query.query_vector is not None else 0 |
| 492 | + image_vector_dims = ( | |
| 493 | + int(len(parsed_query.image_query_vector)) | |
| 494 | + if getattr(parsed_query, "image_query_vector", None) is not None | |
| 495 | + else 0 | |
| 496 | + ) | |
| 480 | 497 | |
| 481 | 498 | context.logger.info( |
| 482 | - "ES query built | size: %s chars | digest: %s | KNN: %s | vector_dims: %s | facets: %s | rerank_prefetch_source: %s", | |
| 499 | + "ES query built | size: %s chars | digest: %s | KNN: %s | vector_dims: %s | image_vector_dims: %s | facets: %s | rerank_prefetch_source: %s", | |
| 483 | 500 | len(es_query_compact), |
| 484 | 501 | es_query_digest, |
| 485 | 502 | "yes" if knn_enabled else "no", |
| 486 | 503 | vector_dims, |
| 504 | + image_vector_dims, | |
| 487 | 505 | "yes" if facets else "no", |
| 488 | 506 | rerank_prefetch_source, |
| 489 | 507 | extra={'reqid': context.reqid, 'uid': context.uid} |
| ... | ... | @@ -497,6 +515,7 @@ class Searcher: |
| 497 | 515 | "sha256_16": es_query_digest, |
| 498 | 516 | "knn_enabled": knn_enabled, |
| 499 | 517 | "vector_dims": vector_dims, |
| 518 | + "image_vector_dims": image_vector_dims, | |
| 500 | 519 | "has_facets": bool(facets), |
| 501 | 520 | "query": es_query_for_fetch, |
| 502 | 521 | }) | ... | ... |
tests/test_embedding_pipeline.py
| ... | ... | @@ -75,6 +75,15 @@ class _FakeQueryEncoder: |
| 75 | 75 | return np.array([np.array([0.11, 0.22, 0.33], dtype=np.float32) for _ in sentences], dtype=object) |
| 76 | 76 | |
| 77 | 77 | |
| 78 | +class _FakeClipTextEncoder: | |
| 79 | + def __init__(self): | |
| 80 | + self.calls = [] | |
| 81 | + | |
| 82 | + def encode_clip_text(self, text, **kwargs): | |
| 83 | + self.calls.append({"text": text, "kwargs": dict(kwargs)}) | |
| 84 | + return np.array([0.44, 0.55, 0.66], dtype=np.float32) | |
| 85 | + | |
| 86 | + | |
| 78 | 87 | def _tokenizer(text): |
| 79 | 88 | return str(text).split() |
| 80 | 89 | |
| ... | ... | @@ -91,7 +100,7 @@ class _FakeEmbeddingCache: |
| 91 | 100 | return True |
| 92 | 101 | |
| 93 | 102 | |
| 94 | -def _build_test_config() -> SearchConfig: | |
| 103 | +def _build_test_config(*, image_embedding_field: Optional[str] = None) -> SearchConfig: | |
| 95 | 104 | return SearchConfig( |
| 96 | 105 | field_boosts={"title.en": 3.0}, |
| 97 | 106 | indexes=[IndexConfig(name="default", label="default", fields=["title.en"], boost=1.0)], |
| ... | ... | @@ -102,7 +111,7 @@ def _build_test_config() -> SearchConfig: |
| 102 | 111 | enable_query_rewrite=False, |
| 103 | 112 | rewrite_dictionary={}, |
| 104 | 113 | text_embedding_field="title_embedding", |
| 105 | - image_embedding_field=None, | |
| 114 | + image_embedding_field=image_embedding_field, | |
| 106 | 115 | ), |
| 107 | 116 | function_score=FunctionScoreConfig(), |
| 108 | 117 | rerank=RerankConfig(), |
| ... | ... | @@ -250,6 +259,26 @@ def test_query_parser_generates_query_vector_with_encoder(): |
| 250 | 259 | assert encoder.calls[0]["kwargs"]["priority"] == 1 |
| 251 | 260 | |
| 252 | 261 | |
| 262 | +def test_query_parser_generates_image_query_vector_with_clip_text_encoder(): | |
| 263 | + text_encoder = _FakeQueryEncoder() | |
| 264 | + image_encoder = _FakeClipTextEncoder() | |
| 265 | + parser = QueryParser( | |
| 266 | + config=_build_test_config(image_embedding_field="image_embedding.vector"), | |
| 267 | + text_encoder=text_encoder, | |
| 268 | + image_encoder=image_encoder, | |
| 269 | + translator=_FakeTranslator(), | |
| 270 | + tokenizer=_tokenizer, | |
| 271 | + ) | |
| 272 | + | |
| 273 | + parsed = parser.parse("red dress", tenant_id="162", generate_vector=True) | |
| 274 | + assert parsed.query_vector is not None | |
| 275 | + assert parsed.image_query_vector is not None | |
| 276 | + assert parsed.image_query_vector.shape == (3,) | |
| 277 | + assert image_encoder.calls | |
| 278 | + assert image_encoder.calls[0]["text"] == "red dress" | |
| 279 | + assert image_encoder.calls[0]["kwargs"]["priority"] == 1 | |
| 280 | + | |
| 281 | + | |
| 253 | 282 | def test_query_parser_skips_query_vector_when_disabled(): |
| 254 | 283 | parser = QueryParser( |
| 255 | 284 | config=_build_test_config(), |
| ... | ... | @@ -260,6 +289,7 @@ def test_query_parser_skips_query_vector_when_disabled(): |
| 260 | 289 | |
| 261 | 290 | parsed = parser.parse("red dress", tenant_id="162", generate_vector=False) |
| 262 | 291 | assert parsed.query_vector is None |
| 292 | + assert parsed.image_query_vector is None | |
| 263 | 293 | |
| 264 | 294 | |
| 265 | 295 | def test_tei_text_model_splits_batches_over_client_limit(monkeypatch): | ... | ... |
tests/test_es_query_builder.py
| ... | ... | @@ -13,22 +13,29 @@ def _builder() -> ESQueryBuilder: |
| 13 | 13 | core_multilingual_fields=["title", "brief"], |
| 14 | 14 | shared_fields=[], |
| 15 | 15 | text_embedding_field="title_embedding", |
| 16 | + image_embedding_field="image_embedding.vector", | |
| 16 | 17 | default_language="en", |
| 17 | 18 | ) |
| 18 | 19 | |
| 19 | 20 | |
| 20 | -def _lexical_clause(query_root: Dict[str, Any]) -> Dict[str, Any]: | |
| 21 | - """Return the first named lexical bool clause from query_root.""" | |
| 22 | - if "bool" in query_root and query_root["bool"].get("_name"): | |
| 23 | - return query_root["bool"] | |
| 24 | - for clause in query_root.get("bool", {}).get("should", []): | |
| 25 | - clause_bool = clause.get("bool") or {} | |
| 26 | - if clause_bool.get("_name"): | |
| 27 | - return clause_bool | |
| 28 | - raise AssertionError("no lexical bool clause in query_root") | |
| 21 | +def _recall_root(es_body: Dict[str, Any]) -> Dict[str, Any]: | |
| 22 | + query_root = es_body["query"] | |
| 23 | + if "bool" in query_root and query_root["bool"].get("must"): | |
| 24 | + query_root = query_root["bool"]["must"][0] | |
| 25 | + if "function_score" in query_root: | |
| 26 | + query_root = query_root["function_score"]["query"] | |
| 27 | + return query_root | |
| 29 | 28 | |
| 30 | 29 | |
| 31 | -def test_knn_prefilter_includes_range_filters(): | |
| 30 | +def _recall_should_clauses(es_body: Dict[str, Any]) -> list[Dict[str, Any]]: | |
| 31 | + root = _recall_root(es_body) | |
| 32 | + should = root.get("bool", {}).get("should") | |
| 33 | + if should: | |
| 34 | + return should | |
| 35 | + return [root] | |
| 36 | + | |
| 37 | + | |
| 38 | +def test_knn_clause_moves_under_query_should_and_uses_outer_filters(): | |
| 32 | 39 | qb = _builder() |
| 33 | 40 | q = qb.build_query( |
| 34 | 41 | query_text="bags", |
| ... | ... | @@ -37,11 +44,13 @@ def test_knn_prefilter_includes_range_filters(): |
| 37 | 44 | enable_knn=True, |
| 38 | 45 | ) |
| 39 | 46 | |
| 40 | - assert "knn" in q | |
| 41 | - assert q["knn"]["filter"] == {"range": {"min_price": {"gte": 50, "lt": 100}}} | |
| 47 | + assert "knn" not in q | |
| 48 | + should = _recall_should_clauses(q) | |
| 49 | + assert any(clause.get("knn", {}).get("_name") == "knn_query" for clause in should) | |
| 50 | + assert q["query"]["bool"]["filter"] == [{"range": {"min_price": {"gte": 50, "lt": 100}}}] | |
| 42 | 51 | |
| 43 | 52 | |
| 44 | -def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present(): | |
| 53 | +def test_knn_clause_uses_outer_query_filter_when_disjunctive_filters_present(): | |
| 45 | 54 | qb = _builder() |
| 46 | 55 | facets = [SimpleNamespace(field="category_name", disjunctive=True)] |
| 47 | 56 | q = qb.build_query( |
| ... | ... | @@ -53,21 +62,15 @@ def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present(): |
| 53 | 62 | enable_knn=True, |
| 54 | 63 | ) |
| 55 | 64 | |
| 56 | - assert "knn" in q | |
| 57 | - assert "filter" in q["knn"] | |
| 58 | - knn_filter = q["knn"]["filter"] | |
| 59 | - assert knn_filter == { | |
| 60 | - "bool": { | |
| 61 | - "filter": [ | |
| 62 | - {"term": {"vendor": "Nike"}}, | |
| 63 | - {"range": {"min_price": {"gte": 50, "lt": 100}}}, | |
| 64 | - ] | |
| 65 | - } | |
| 66 | - } | |
| 65 | + assert "knn" not in q | |
| 66 | + assert q["query"]["bool"]["filter"] == [ | |
| 67 | + {"term": {"vendor": "Nike"}}, | |
| 68 | + {"range": {"min_price": {"gte": 50, "lt": 100}}}, | |
| 69 | + ] | |
| 67 | 70 | assert q["post_filter"] == {"terms": {"category_name": ["A", "B"]}} |
| 68 | 71 | |
| 69 | 72 | |
| 70 | -def test_knn_prefilter_not_added_without_filters(): | |
| 73 | +def test_knn_clause_has_name_and_no_embedded_filter(): | |
| 71 | 74 | qb = _builder() |
| 72 | 75 | q = qb.build_query( |
| 73 | 76 | query_text="bags", |
| ... | ... | @@ -75,9 +78,10 @@ def test_knn_prefilter_not_added_without_filters(): |
| 75 | 78 | enable_knn=True, |
| 76 | 79 | ) |
| 77 | 80 | |
| 78 | - assert "knn" in q | |
| 79 | - assert "filter" not in q["knn"] | |
| 80 | - assert q["knn"]["_name"] == "knn_query" | |
| 81 | + should = _recall_should_clauses(q) | |
| 82 | + knn_clause = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "knn_query") | |
| 83 | + assert "filter" not in knn_clause | |
| 84 | + assert knn_clause["_name"] == "knn_query" | |
| 81 | 85 | |
| 82 | 86 | |
| 83 | 87 | def test_text_query_contains_only_base_and_translation_named_queries(): |
| ... | ... | @@ -93,11 +97,11 @@ def test_text_query_contains_only_base_and_translation_named_queries(): |
| 93 | 97 | parsed_query=parsed_query, |
| 94 | 98 | enable_knn=False, |
| 95 | 99 | ) |
| 96 | - should = q["query"]["bool"]["should"] | |
| 100 | + should = _recall_should_clauses(q) | |
| 97 | 101 | names = [clause["bool"]["_name"] for clause in should] |
| 98 | 102 | |
| 99 | 103 | assert names == ["base_query", "base_query_trans_zh"] |
| 100 | - base_should = q["query"]["bool"]["should"][0]["bool"]["should"] | |
| 104 | + base_should = should[0]["bool"]["should"] | |
| 101 | 105 | assert [clause["multi_match"]["type"] for clause in base_should] == ["best_fields", "phrase"] |
| 102 | 106 | |
| 103 | 107 | |
| ... | ... | @@ -115,12 +119,12 @@ def test_text_query_skips_duplicate_translation_same_as_base(): |
| 115 | 119 | enable_knn=False, |
| 116 | 120 | ) |
| 117 | 121 | |
| 118 | - root = q["query"] | |
| 122 | + root = _recall_root(q) | |
| 119 | 123 | assert root["bool"]["_name"] == "base_query" |
| 120 | 124 | assert [clause["multi_match"]["type"] for clause in root["bool"]["should"]] == ["best_fields", "phrase"] |
| 121 | 125 | |
| 122 | 126 | |
| 123 | -def test_product_title_exclusion_filter_is_applied_to_query_and_knn(): | |
| 127 | +def test_product_title_exclusion_filter_is_applied_once_on_outer_query(): | |
| 124 | 128 | qb = _builder() |
| 125 | 129 | parsed_query = SimpleNamespace( |
| 126 | 130 | rewritten_query="fitted dress", |
| ... | ... | @@ -158,4 +162,32 @@ def test_product_title_exclusion_filter_is_applied_to_query_and_knn(): |
| 158 | 162 | } |
| 159 | 163 | |
| 160 | 164 | assert expected_filter in q["query"]["bool"]["filter"] |
| 161 | - assert q["knn"]["filter"] == expected_filter | |
| 165 | + should = _recall_should_clauses(q) | |
| 166 | + knn_clause = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "knn_query") | |
| 167 | + assert "filter" not in knn_clause | |
| 168 | + | |
| 169 | + | |
| 170 | +def test_image_knn_clause_is_added_alongside_base_translation_and_text_knn(): | |
| 171 | + qb = _builder() | |
| 172 | + parsed_query = SimpleNamespace( | |
| 173 | + rewritten_query="street tee", | |
| 174 | + detected_language="en", | |
| 175 | + translations={"zh": "街头短袖"}, | |
| 176 | + ) | |
| 177 | + | |
| 178 | + q = qb.build_query( | |
| 179 | + query_text="street tee", | |
| 180 | + query_vector=np.array([0.1, 0.2, 0.3]), | |
| 181 | + image_query_vector=np.array([0.4, 0.5, 0.6]), | |
| 182 | + parsed_query=parsed_query, | |
| 183 | + enable_knn=True, | |
| 184 | + ) | |
| 185 | + | |
| 186 | + should = _recall_should_clauses(q) | |
| 187 | + names = [ | |
| 188 | + clause["bool"]["_name"] if "bool" in clause else clause["knn"]["_name"] | |
| 189 | + for clause in should | |
| 190 | + ] | |
| 191 | + assert names == ["base_query", "base_query_trans_zh", "knn_query", "image_knn_query"] | |
| 192 | + image_knn = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "image_knn_query") | |
| 193 | + assert image_knn["field"] == "image_embedding.vector" | ... | ... |
tests/test_rerank_client.py
| ... | ... | @@ -149,3 +149,22 @@ def test_fuse_scores_and_resort_boosts_hits_with_selected_sku(): |
| 149 | 149 | assert [h["_id"] for h in hits] == ["style-selected", "plain"] |
| 150 | 150 | assert debug[0]["style_intent_selected_sku"] is True |
| 151 | 151 | assert debug[0]["style_intent_selected_sku_boost"] == 1.2 |
| 152 | + | |
| 153 | + | |
| 154 | +def test_fuse_scores_and_resort_uses_max_of_text_and_image_knn_scores(): | |
| 155 | + hits = [ | |
| 156 | + { | |
| 157 | + "_id": "mm-hit", | |
| 158 | + "_score": 1.0, | |
| 159 | + "matched_queries": { | |
| 160 | + "base_query": 1.5, | |
| 161 | + "knn_query": 0.2, | |
| 162 | + "image_knn_query": 0.7, | |
| 163 | + }, | |
| 164 | + } | |
| 165 | + ] | |
| 166 | + | |
| 167 | + debug = fuse_scores_and_resort(hits, [0.8], debug=True) | |
| 168 | + | |
| 169 | + assert isclose(hits[0]["_knn_score"], 0.7, rel_tol=1e-9) | |
| 170 | + assert isclose(debug[0]["knn_score"], 0.7, rel_tol=1e-9) | ... | ... |