Commit dc403578b9f167d06cff131dda36fd488340d99f
1 parent
74116f05
多模态搜索
Showing
7 changed files
with
276 additions
and
107 deletions
Show diff stats
query/query_parser.py
| @@ -14,6 +14,7 @@ import numpy as np | @@ -14,6 +14,7 @@ import numpy as np | ||
| 14 | import logging | 14 | import logging |
| 15 | from concurrent.futures import ThreadPoolExecutor, wait | 15 | from concurrent.futures import ThreadPoolExecutor, wait |
| 16 | 16 | ||
| 17 | +from embeddings.image_encoder import CLIPImageEncoder | ||
| 17 | from embeddings.text_encoder import TextEmbeddingEncoder | 18 | from embeddings.text_encoder import TextEmbeddingEncoder |
| 18 | from config import SearchConfig | 19 | from config import SearchConfig |
| 19 | from translation import create_translation_client | 20 | from translation import create_translation_client |
| @@ -66,6 +67,7 @@ class ParsedQuery: | @@ -66,6 +67,7 @@ class ParsedQuery: | ||
| 66 | detected_language: Optional[str] = None | 67 | detected_language: Optional[str] = None |
| 67 | translations: Dict[str, str] = field(default_factory=dict) | 68 | translations: Dict[str, str] = field(default_factory=dict) |
| 68 | query_vector: Optional[np.ndarray] = None | 69 | query_vector: Optional[np.ndarray] = None |
| 70 | + image_query_vector: Optional[np.ndarray] = None | ||
| 69 | query_tokens: List[str] = field(default_factory=list) | 71 | query_tokens: List[str] = field(default_factory=list) |
| 70 | style_intent_profile: Optional[StyleIntentProfile] = None | 72 | style_intent_profile: Optional[StyleIntentProfile] = None |
| 71 | product_title_exclusion_profile: Optional[ProductTitleExclusionProfile] = None | 73 | product_title_exclusion_profile: Optional[ProductTitleExclusionProfile] = None |
| @@ -86,6 +88,8 @@ class ParsedQuery: | @@ -86,6 +88,8 @@ class ParsedQuery: | ||
| 86 | "rewritten_query": self.rewritten_query, | 88 | "rewritten_query": self.rewritten_query, |
| 87 | "detected_language": self.detected_language, | 89 | "detected_language": self.detected_language, |
| 88 | "translations": self.translations, | 90 | "translations": self.translations, |
| 91 | + "has_query_vector": self.query_vector is not None, | ||
| 92 | + "has_image_query_vector": self.image_query_vector is not None, | ||
| 89 | "query_tokens": self.query_tokens, | 93 | "query_tokens": self.query_tokens, |
| 90 | "style_intent_profile": ( | 94 | "style_intent_profile": ( |
| 91 | self.style_intent_profile.to_dict() if self.style_intent_profile is not None else None | 95 | self.style_intent_profile.to_dict() if self.style_intent_profile is not None else None |
| @@ -112,6 +116,7 @@ class QueryParser: | @@ -112,6 +116,7 @@ class QueryParser: | ||
| 112 | self, | 116 | self, |
| 113 | config: SearchConfig, | 117 | config: SearchConfig, |
| 114 | text_encoder: Optional[TextEmbeddingEncoder] = None, | 118 | text_encoder: Optional[TextEmbeddingEncoder] = None, |
| 119 | + image_encoder: Optional[CLIPImageEncoder] = None, | ||
| 115 | translator: Optional[Any] = None, | 120 | translator: Optional[Any] = None, |
| 116 | tokenizer: Optional[Callable[[str], Any]] = None, | 121 | tokenizer: Optional[Callable[[str], Any]] = None, |
| 117 | ): | 122 | ): |
| @@ -125,6 +130,7 @@ class QueryParser: | @@ -125,6 +130,7 @@ class QueryParser: | ||
| 125 | """ | 130 | """ |
| 126 | self.config = config | 131 | self.config = config |
| 127 | self._text_encoder = text_encoder | 132 | self._text_encoder = text_encoder |
| 133 | + self._image_encoder = image_encoder | ||
| 128 | self._translator = translator | 134 | self._translator = translator |
| 129 | 135 | ||
| 130 | # Initialize components | 136 | # Initialize components |
| @@ -149,6 +155,9 @@ class QueryParser: | @@ -149,6 +155,9 @@ class QueryParser: | ||
| 149 | if self.config.query_config.enable_text_embedding and self._text_encoder is None: | 155 | if self.config.query_config.enable_text_embedding and self._text_encoder is None: |
| 150 | logger.info("Initializing text encoder at QueryParser construction...") | 156 | logger.info("Initializing text encoder at QueryParser construction...") |
| 151 | self._text_encoder = TextEmbeddingEncoder() | 157 | self._text_encoder = TextEmbeddingEncoder() |
| 158 | + if self.config.query_config.image_embedding_field and self._image_encoder is None: | ||
| 159 | + logger.info("Initializing image encoder at QueryParser construction...") | ||
| 160 | + self._image_encoder = CLIPImageEncoder() | ||
| 152 | if self._translator is None: | 161 | if self._translator is None: |
| 153 | from config.services_config import get_translation_config | 162 | from config.services_config import get_translation_config |
| 154 | cfg = get_translation_config() | 163 | cfg = get_translation_config() |
| @@ -169,6 +178,11 @@ class QueryParser: | @@ -169,6 +178,11 @@ class QueryParser: | ||
| 169 | """Return pre-initialized translator.""" | 178 | """Return pre-initialized translator.""" |
| 170 | return self._translator | 179 | return self._translator |
| 171 | 180 | ||
| 181 | + @property | ||
| 182 | + def image_encoder(self) -> Optional[CLIPImageEncoder]: | ||
| 183 | + """Return pre-initialized image encoder for CLIP text embeddings.""" | ||
| 184 | + return self._image_encoder | ||
| 185 | + | ||
| 172 | def _build_tokenizer(self) -> Callable[[str], Any]: | 186 | def _build_tokenizer(self) -> Callable[[str], Any]: |
| 173 | """Build the tokenizer used by query parsing. No fallback path by design.""" | 187 | """Build the tokenizer used by query parsing. No fallback path by design.""" |
| 174 | if hanlp is None: | 188 | if hanlp is None: |
| @@ -311,12 +325,21 @@ class QueryParser: | @@ -311,12 +325,21 @@ class QueryParser: | ||
| 311 | 325 | ||
| 312 | # Stage 6: Text embedding - async execution | 326 | # Stage 6: Text embedding - async execution |
| 313 | query_vector = None | 327 | query_vector = None |
| 328 | + image_query_vector = None | ||
| 314 | should_generate_embedding = ( | 329 | should_generate_embedding = ( |
| 315 | generate_vector and | 330 | generate_vector and |
| 316 | self.config.query_config.enable_text_embedding | 331 | self.config.query_config.enable_text_embedding |
| 317 | ) | 332 | ) |
| 333 | + should_generate_image_embedding = ( | ||
| 334 | + generate_vector and | ||
| 335 | + bool(self.config.query_config.image_embedding_field) | ||
| 336 | + ) | ||
| 318 | 337 | ||
| 319 | - task_count = len(translation_targets) + (1 if should_generate_embedding else 0) | 338 | + task_count = ( |
| 339 | + len(translation_targets) | ||
| 340 | + + (1 if should_generate_embedding else 0) | ||
| 341 | + + (1 if should_generate_image_embedding else 0) | ||
| 342 | + ) | ||
| 320 | if task_count > 0: | 343 | if task_count > 0: |
| 321 | async_executor = ThreadPoolExecutor( | 344 | async_executor = ThreadPoolExecutor( |
| 322 | max_workers=max(1, min(task_count, 4)), | 345 | max_workers=max(1, min(task_count, 4)), |
| @@ -366,6 +389,28 @@ class QueryParser: | @@ -366,6 +389,28 @@ class QueryParser: | ||
| 366 | 389 | ||
| 367 | future = async_executor.submit(_encode_query_vector) | 390 | future = async_executor.submit(_encode_query_vector) |
| 368 | future_to_task[future] = ("embedding", None) | 391 | future_to_task[future] = ("embedding", None) |
| 392 | + | ||
| 393 | + if should_generate_image_embedding: | ||
| 394 | + if self.image_encoder is None: | ||
| 395 | + raise RuntimeError( | ||
| 396 | + "Image embedding field is configured but image encoder is not initialized" | ||
| 397 | + ) | ||
| 398 | + log_debug("Submitting CLIP text query vector generation") | ||
| 399 | + | ||
| 400 | + def _encode_image_query_vector() -> Optional[np.ndarray]: | ||
| 401 | + vec = self.image_encoder.encode_clip_text( | ||
| 402 | + query_text, | ||
| 403 | + normalize_embeddings=True, | ||
| 404 | + priority=1, | ||
| 405 | + request_id=(context.reqid if context else None), | ||
| 406 | + user_id=(context.uid if context else None), | ||
| 407 | + ) | ||
| 408 | + if vec is None: | ||
| 409 | + return None | ||
| 410 | + return np.asarray(vec, dtype=np.float32) | ||
| 411 | + | ||
| 412 | + future = async_executor.submit(_encode_image_query_vector) | ||
| 413 | + future_to_task[future] = ("image_embedding", None) | ||
| 369 | except Exception as e: | 414 | except Exception as e: |
| 370 | error_msg = f"Async query enrichment submission failed | Error: {str(e)}" | 415 | error_msg = f"Async query enrichment submission failed | Error: {str(e)}" |
| 371 | log_info(error_msg) | 416 | log_info(error_msg) |
| @@ -424,9 +469,27 @@ class QueryParser: | @@ -424,9 +469,27 @@ class QueryParser: | ||
| 424 | log_info( | 469 | log_info( |
| 425 | "Query vector generation completed but result is None, will process without vector" | 470 | "Query vector generation completed but result is None, will process without vector" |
| 426 | ) | 471 | ) |
| 472 | + elif task_type == "image_embedding": | ||
| 473 | + image_query_vector = result | ||
| 474 | + if image_query_vector is not None: | ||
| 475 | + log_debug( | ||
| 476 | + f"CLIP text query vector generation completed | Shape: {image_query_vector.shape}" | ||
| 477 | + ) | ||
| 478 | + if context: | ||
| 479 | + context.store_intermediate_result( | ||
| 480 | + "image_query_vector_shape", | ||
| 481 | + image_query_vector.shape, | ||
| 482 | + ) | ||
| 483 | + else: | ||
| 484 | + log_info( | ||
| 485 | + "CLIP text query vector generation completed but result is None, " | ||
| 486 | + "will process without image vector" | ||
| 487 | + ) | ||
| 427 | except Exception as e: | 488 | except Exception as e: |
| 428 | if task_type == "translation": | 489 | if task_type == "translation": |
| 429 | error_msg = f"Translation failed | Language: {lang} | Error: {str(e)}" | 490 | error_msg = f"Translation failed | Language: {lang} | Error: {str(e)}" |
| 491 | + elif task_type == "image_embedding": | ||
| 492 | + error_msg = f"CLIP text query vector generation failed | Error: {str(e)}" | ||
| 430 | else: | 493 | else: |
| 431 | error_msg = f"Query vector generation failed | Error: {str(e)}" | 494 | error_msg = f"Query vector generation failed | Error: {str(e)}" |
| 432 | log_info(error_msg) | 495 | log_info(error_msg) |
| @@ -441,6 +504,11 @@ class QueryParser: | @@ -441,6 +504,11 @@ class QueryParser: | ||
| 441 | f"Translation timeout (>{budget_ms}ms) | Language: {lang} | " | 504 | f"Translation timeout (>{budget_ms}ms) | Language: {lang} | " |
| 442 | f"Query text: '{query_text}'" | 505 | f"Query text: '{query_text}'" |
| 443 | ) | 506 | ) |
| 507 | + elif task_type == "image_embedding": | ||
| 508 | + timeout_msg = ( | ||
| 509 | + f"CLIP text query vector generation timeout (>{budget_ms}ms), " | ||
| 510 | + "proceeding without image embedding result" | ||
| 511 | + ) | ||
| 444 | else: | 512 | else: |
| 445 | timeout_msg = ( | 513 | timeout_msg = ( |
| 446 | f"Query vector generation timeout (>{budget_ms}ms), proceeding without embedding result" | 514 | f"Query vector generation timeout (>{budget_ms}ms), proceeding without embedding result" |
| @@ -463,6 +531,7 @@ class QueryParser: | @@ -463,6 +531,7 @@ class QueryParser: | ||
| 463 | detected_language=detected_lang, | 531 | detected_language=detected_lang, |
| 464 | translations=translations, | 532 | translations=translations, |
| 465 | query_vector=query_vector, | 533 | query_vector=query_vector, |
| 534 | + image_query_vector=image_query_vector, | ||
| 466 | query_tokens=query_tokens, | 535 | query_tokens=query_tokens, |
| 467 | ) | 536 | ) |
| 468 | style_intent_profile = self.style_intent_detector.detect(base_result) | 537 | style_intent_profile = self.style_intent_detector.detect(base_result) |
| @@ -484,6 +553,7 @@ class QueryParser: | @@ -484,6 +553,7 @@ class QueryParser: | ||
| 484 | detected_language=detected_lang, | 553 | detected_language=detected_lang, |
| 485 | translations=translations, | 554 | translations=translations, |
| 486 | query_vector=query_vector, | 555 | query_vector=query_vector, |
| 556 | + image_query_vector=image_query_vector, | ||
| 487 | query_tokens=query_tokens, | 557 | query_tokens=query_tokens, |
| 488 | style_intent_profile=style_intent_profile, | 558 | style_intent_profile=style_intent_profile, |
| 489 | product_title_exclusion_profile=product_title_exclusion_profile, | 559 | product_title_exclusion_profile=product_title_exclusion_profile, |
| @@ -492,7 +562,8 @@ class QueryParser: | @@ -492,7 +562,8 @@ class QueryParser: | ||
| 492 | if context and hasattr(context, 'logger'): | 562 | if context and hasattr(context, 'logger'): |
| 493 | context.logger.info( | 563 | context.logger.info( |
| 494 | f"Query parsing completed | Original query: '{query}' | Final query: '{rewritten or query_text}' | " | 564 | f"Query parsing completed | Original query: '{query}' | Final query: '{rewritten or query_text}' | " |
| 495 | - f"Translation count: {len(translations)} | Vector: {'yes' if query_vector is not None else 'no'}", | 565 | + f"Translation count: {len(translations)} | Vector: {'yes' if query_vector is not None else 'no'} | " |
| 566 | + f"Image vector: {'yes' if image_query_vector is not None else 'no'}", | ||
| 496 | extra={'reqid': context.reqid, 'uid': context.uid} | 567 | extra={'reqid': context.reqid, 'uid': context.uid} |
| 497 | ) | 568 | ) |
| 498 | else: | 569 | else: |
search/es_query_builder.py
| @@ -164,6 +164,7 @@ class ESQueryBuilder: | @@ -164,6 +164,7 @@ class ESQueryBuilder: | ||
| 164 | self, | 164 | self, |
| 165 | query_text: str, | 165 | query_text: str, |
| 166 | query_vector: Optional[np.ndarray] = None, | 166 | query_vector: Optional[np.ndarray] = None, |
| 167 | + image_query_vector: Optional[np.ndarray] = None, | ||
| 167 | filters: Optional[Dict[str, Any]] = None, | 168 | filters: Optional[Dict[str, Any]] = None, |
| 168 | range_filters: Optional[Dict[str, Any]] = None, | 169 | range_filters: Optional[Dict[str, Any]] = None, |
| 169 | facet_configs: Optional[List[Any]] = None, | 170 | facet_configs: Optional[List[Any]] = None, |
| @@ -212,15 +213,14 @@ class ESQueryBuilder: | @@ -212,15 +213,14 @@ class ESQueryBuilder: | ||
| 212 | 213 | ||
| 213 | # 1. Build recall queries (text or embedding) | 214 | # 1. Build recall queries (text or embedding) |
| 214 | recall_clauses = [] | 215 | recall_clauses = [] |
| 215 | - | 216 | + |
| 216 | # Text recall (always include if query_text exists) | 217 | # Text recall (always include if query_text exists) |
| 217 | if query_text: | 218 | if query_text: |
| 218 | - # Unified text query strategy | ||
| 219 | - text_query = self._build_advanced_text_query(query_text, parsed_query) | ||
| 220 | - recall_clauses.append(text_query) | ||
| 221 | - | ||
| 222 | - # Embedding recall (KNN - separate from query, handled below) | 219 | + recall_clauses.extend(self._build_advanced_text_query(query_text, parsed_query)) |
| 220 | + | ||
| 221 | + # Embedding recall | ||
| 223 | has_embedding = enable_knn and query_vector is not None and self.text_embedding_field | 222 | has_embedding = enable_knn and query_vector is not None and self.text_embedding_field |
| 223 | + has_image_embedding = enable_knn and image_query_vector is not None and self.image_embedding_field | ||
| 224 | 224 | ||
| 225 | # 2. Split filters for multi-select faceting | 225 | # 2. Split filters for multi-select faceting |
| 226 | conjunctive_filters, disjunctive_filters = self._split_filters_for_faceting( | 226 | conjunctive_filters, disjunctive_filters = self._split_filters_for_faceting( |
| @@ -233,9 +233,48 @@ class ESQueryBuilder: | @@ -233,9 +233,48 @@ class ESQueryBuilder: | ||
| 233 | if product_title_exclusion_filter: | 233 | if product_title_exclusion_filter: |
| 234 | filter_clauses.append(product_title_exclusion_filter) | 234 | filter_clauses.append(product_title_exclusion_filter) |
| 235 | 235 | ||
| 236 | - # 3. Build main query structure: filters and recall | 236 | + # 3. Add KNN search clauses alongside lexical clauses under the same bool.should |
| 237 | + # Adjust KNN k, num_candidates, boost by query_tokens (short query: less KNN; long: more) | ||
| 238 | + final_knn_k, final_knn_num_candidates = knn_k, knn_num_candidates | ||
| 239 | + if has_embedding: | ||
| 240 | + knn_boost = self.knn_boost | ||
| 241 | + if parsed_query: | ||
| 242 | + query_tokens = getattr(parsed_query, 'query_tokens', None) or [] | ||
| 243 | + token_count = len(query_tokens) | ||
| 244 | + if token_count >= 5: | ||
| 245 | + final_knn_k, final_knn_num_candidates = 160, 500 | ||
| 246 | + knn_boost = self.knn_boost * 1.4 # Higher weight for long queries | ||
| 247 | + else: | ||
| 248 | + final_knn_k, final_knn_num_candidates = 120, 400 | ||
| 249 | + else: | ||
| 250 | + final_knn_k, final_knn_num_candidates = 120, 400 | ||
| 251 | + recall_clauses.append({ | ||
| 252 | + "knn": { | ||
| 253 | + "field": self.text_embedding_field, | ||
| 254 | + "query_vector": query_vector.tolist(), | ||
| 255 | + "k": final_knn_k, | ||
| 256 | + "num_candidates": final_knn_num_candidates, | ||
| 257 | + "boost": knn_boost, | ||
| 258 | + "_name": "knn_query", | ||
| 259 | + } | ||
| 260 | + }) | ||
| 261 | + | ||
| 262 | + if has_image_embedding: | ||
| 263 | + image_knn_k = max(final_knn_k, 120) | ||
| 264 | + image_knn_num_candidates = max(final_knn_num_candidates, 400) | ||
| 265 | + recall_clauses.append({ | ||
| 266 | + "knn": { | ||
| 267 | + "field": self.image_embedding_field, | ||
| 268 | + "query_vector": image_query_vector.tolist(), | ||
| 269 | + "k": image_knn_k, | ||
| 270 | + "num_candidates": image_knn_num_candidates, | ||
| 271 | + "boost": self.knn_boost, | ||
| 272 | + "_name": "image_knn_query", | ||
| 273 | + } | ||
| 274 | + }) | ||
| 275 | + | ||
| 276 | + # 4. Build main query structure: filters and recall | ||
| 237 | if recall_clauses: | 277 | if recall_clauses: |
| 238 | - # Combine text recalls with OR logic (if multiple) | ||
| 239 | if len(recall_clauses) == 1: | 278 | if len(recall_clauses) == 1: |
| 240 | recall_query = recall_clauses[0] | 279 | recall_query = recall_clauses[0] |
| 241 | else: | 280 | else: |
| @@ -245,11 +284,9 @@ class ESQueryBuilder: | @@ -245,11 +284,9 @@ class ESQueryBuilder: | ||
| 245 | "minimum_should_match": 1 | 284 | "minimum_should_match": 1 |
| 246 | } | 285 | } |
| 247 | } | 286 | } |
| 248 | - | ||
| 249 | - # Wrap recall with function_score for boosting | 287 | + |
| 250 | recall_query = self._wrap_with_function_score(recall_query) | 288 | recall_query = self._wrap_with_function_score(recall_query) |
| 251 | - | ||
| 252 | - # Combine filters and recall | 289 | + |
| 253 | if filter_clauses: | 290 | if filter_clauses: |
| 254 | es_query["query"] = { | 291 | es_query["query"] = { |
| 255 | "bool": { | 292 | "bool": { |
| @@ -260,7 +297,6 @@ class ESQueryBuilder: | @@ -260,7 +297,6 @@ class ESQueryBuilder: | ||
| 260 | else: | 297 | else: |
| 261 | es_query["query"] = recall_query | 298 | es_query["query"] = recall_query |
| 262 | else: | 299 | else: |
| 263 | - # No recall queries, only filters (match_all filtered) | ||
| 264 | if filter_clauses: | 300 | if filter_clauses: |
| 265 | es_query["query"] = { | 301 | es_query["query"] = { |
| 266 | "bool": { | 302 | "bool": { |
| @@ -271,41 +307,6 @@ class ESQueryBuilder: | @@ -271,41 +307,6 @@ class ESQueryBuilder: | ||
| 271 | else: | 307 | else: |
| 272 | es_query["query"] = {"match_all": {}} | 308 | es_query["query"] = {"match_all": {}} |
| 273 | 309 | ||
| 274 | - # 4. Add KNN search if enabled (separate from query, ES will combine) | ||
| 275 | - # Adjust KNN k, num_candidates, boost by query_tokens (short query: less KNN; long: more) | ||
| 276 | - if has_embedding: | ||
| 277 | - knn_boost = self.knn_boost | ||
| 278 | - if parsed_query: | ||
| 279 | - query_tokens = getattr(parsed_query, 'query_tokens', None) or [] | ||
| 280 | - token_count = len(query_tokens) | ||
| 281 | - if token_count >= 5: | ||
| 282 | - knn_k, knn_num_candidates = 160, 500 | ||
| 283 | - knn_boost = self.knn_boost * 1.4 # Higher weight for long queries | ||
| 284 | - else: | ||
| 285 | - knn_k, knn_num_candidates = 120, 400 | ||
| 286 | - else: | ||
| 287 | - knn_k, knn_num_candidates = 120, 400 | ||
| 288 | - knn_clause = { | ||
| 289 | - "field": self.text_embedding_field, | ||
| 290 | - "query_vector": query_vector.tolist(), | ||
| 291 | - "k": knn_k, | ||
| 292 | - "num_candidates": knn_num_candidates, | ||
| 293 | - "boost": knn_boost, | ||
| 294 | - "_name": "knn_query", | ||
| 295 | - } | ||
| 296 | - # Top-level knn does not inherit query.bool.filter automatically. | ||
| 297 | - # Apply conjunctive + range filters here so vector recall respects hard filters. | ||
| 298 | - if filter_clauses: | ||
| 299 | - if len(filter_clauses) == 1: | ||
| 300 | - knn_clause["filter"] = filter_clauses[0] | ||
| 301 | - else: | ||
| 302 | - knn_clause["filter"] = { | ||
| 303 | - "bool": { | ||
| 304 | - "filter": filter_clauses | ||
| 305 | - } | ||
| 306 | - } | ||
| 307 | - es_query["knn"] = knn_clause | ||
| 308 | - | ||
| 309 | # 5. Add post_filter for disjunctive (multi-select) filters | 310 | # 5. Add post_filter for disjunctive (multi-select) filters |
| 310 | if disjunctive_filters: | 311 | if disjunctive_filters: |
| 311 | post_filter_clauses = self._build_filters(disjunctive_filters, None) | 312 | post_filter_clauses = self._build_filters(disjunctive_filters, None) |
| @@ -536,21 +537,20 @@ class ESQueryBuilder: | @@ -536,21 +537,20 @@ class ESQueryBuilder: | ||
| 536 | self, | 537 | self, |
| 537 | query_text: str, | 538 | query_text: str, |
| 538 | parsed_query: Optional[Any] = None, | 539 | parsed_query: Optional[Any] = None, |
| 539 | - ) -> Dict[str, Any]: | 540 | + ) -> List[Dict[str, Any]]: |
| 540 | """ | 541 | """ |
| 541 | Build advanced text query using base and translated lexical clauses. | 542 | Build advanced text query using base and translated lexical clauses. |
| 542 | 543 | ||
| 543 | Unified implementation: | 544 | Unified implementation: |
| 544 | - base_query: source-language clause | 545 | - base_query: source-language clause |
| 545 | - translation queries: target-language clauses from translations | 546 | - translation queries: target-language clauses from translations |
| 546 | - - KNN query: added separately in build_query | ||
| 547 | - | 547 | + |
| 548 | Args: | 548 | Args: |
| 549 | query_text: Query text | 549 | query_text: Query text |
| 550 | parsed_query: ParsedQuery object with analysis results | 550 | parsed_query: ParsedQuery object with analysis results |
| 551 | 551 | ||
| 552 | Returns: | 552 | Returns: |
| 553 | - ES bool query with should clauses | 553 | + Flat recall clauses to be merged with KNN clauses under query.bool.should |
| 554 | """ | 554 | """ |
| 555 | should_clauses = [] | 555 | should_clauses = [] |
| 556 | source_lang = self.default_language | 556 | source_lang = self.default_language |
| @@ -603,18 +603,9 @@ class ESQueryBuilder: | @@ -603,18 +603,9 @@ class ESQueryBuilder: | ||
| 603 | "minimum_should_match": self.base_minimum_should_match, | 603 | "minimum_should_match": self.base_minimum_should_match, |
| 604 | } | 604 | } |
| 605 | } | 605 | } |
| 606 | - return fallback_lexical | 606 | + return [fallback_lexical] |
| 607 | 607 | ||
| 608 | - # Return bool query with should clauses | ||
| 609 | - if len(should_clauses) == 1: | ||
| 610 | - return should_clauses[0] | ||
| 611 | - | ||
| 612 | - return { | ||
| 613 | - "bool": { | ||
| 614 | - "should": should_clauses, | ||
| 615 | - "minimum_should_match": 1 | ||
| 616 | - } | ||
| 617 | - } | 608 | + return should_clauses |
| 618 | 609 | ||
| 619 | def _build_filters( | 610 | def _build_filters( |
| 620 | self, | 611 | self, |
search/rerank_client.py
| @@ -151,6 +151,13 @@ def _extract_named_query_score(matched_queries: Any, name: str) -> float: | @@ -151,6 +151,13 @@ def _extract_named_query_score(matched_queries: Any, name: str) -> float: | ||
| 151 | return 1.0 if name in matched_queries else 0.0 | 151 | return 1.0 if name in matched_queries else 0.0 |
| 152 | return 0.0 | 152 | return 0.0 |
| 153 | 153 | ||
| 154 | + | ||
| 155 | +def _extract_combined_knn_score(matched_queries: Any) -> float: | ||
| 156 | + return max( | ||
| 157 | + _extract_named_query_score(matched_queries, "knn_query"), | ||
| 158 | + _extract_named_query_score(matched_queries, "image_knn_query"), | ||
| 159 | + ) | ||
| 160 | + | ||
| 154 | """ | 161 | """ |
| 155 | 原始变量: | 162 | 原始变量: |
| 156 | ES总分 | 163 | ES总分 |
| @@ -272,7 +279,7 @@ def fuse_scores_and_resort( | @@ -272,7 +279,7 @@ def fuse_scores_and_resort( | ||
| 272 | es_score = _to_score(hit.get("_score")) | 279 | es_score = _to_score(hit.get("_score")) |
| 273 | rerank_score = _to_score(rerank_scores[idx]) | 280 | rerank_score = _to_score(rerank_scores[idx]) |
| 274 | matched_queries = hit.get("matched_queries") | 281 | matched_queries = hit.get("matched_queries") |
| 275 | - knn_score = _extract_named_query_score(matched_queries, "knn_query") | 282 | + knn_score = _extract_combined_knn_score(matched_queries) |
| 276 | text_components = _collect_text_score_components(matched_queries, es_score) | 283 | text_components = _collect_text_score_components(matched_queries, es_score) |
| 277 | text_score = text_components["text_score"] | 284 | text_score = text_components["text_score"] |
| 278 | rerank_factor, text_factor, knn_factor, fused = _multiply_fusion_factors( | 285 | rerank_factor, text_factor, knn_factor, fused = _multiply_fusion_factors( |
search/searcher.py
| @@ -106,14 +106,14 @@ class Searcher: | @@ -106,14 +106,14 @@ class Searcher: | ||
| 106 | """ | 106 | """ |
| 107 | self.es_client = es_client | 107 | self.es_client = es_client |
| 108 | self.config = config | 108 | self.config = config |
| 109 | - # Index name is now generated dynamically per tenant, no longer stored here | ||
| 110 | - self.query_parser = query_parser or QueryParser(config) | ||
| 111 | self.text_embedding_field = config.query_config.text_embedding_field or "title_embedding" | 109 | self.text_embedding_field = config.query_config.text_embedding_field or "title_embedding" |
| 112 | self.image_embedding_field = config.query_config.image_embedding_field | 110 | self.image_embedding_field = config.query_config.image_embedding_field |
| 113 | if self.image_embedding_field and image_encoder is None: | 111 | if self.image_embedding_field and image_encoder is None: |
| 114 | self.image_encoder = CLIPImageEncoder() | 112 | self.image_encoder = CLIPImageEncoder() |
| 115 | else: | 113 | else: |
| 116 | self.image_encoder = image_encoder | 114 | self.image_encoder = image_encoder |
| 115 | + # Index name is now generated dynamically per tenant, no longer stored here | ||
| 116 | + self.query_parser = query_parser or QueryParser(config, image_encoder=self.image_encoder) | ||
| 117 | self.source_fields = config.query_config.source_fields | 117 | self.source_fields = config.query_config.source_fields |
| 118 | self.style_intent_registry = StyleIntentRegistry.from_query_config(self.config.query_config) | 118 | self.style_intent_registry = StyleIntentRegistry.from_query_config(self.config.query_config) |
| 119 | self.style_sku_selector = StyleSkuSelector( | 119 | self.style_sku_selector = StyleSkuSelector( |
| @@ -403,7 +403,8 @@ class Searcher: | @@ -403,7 +403,8 @@ class Searcher: | ||
| 403 | f"查询解析完成 | 原查询: '{parsed_query.original_query}' | " | 403 | f"查询解析完成 | 原查询: '{parsed_query.original_query}' | " |
| 404 | f"重写后: '{parsed_query.rewritten_query}' | " | 404 | f"重写后: '{parsed_query.rewritten_query}' | " |
| 405 | f"语言: {parsed_query.detected_language} | " | 405 | f"语言: {parsed_query.detected_language} | " |
| 406 | - f"向量: {'是' if parsed_query.query_vector is not None else '否'}", | 406 | + f"文本向量: {'是' if parsed_query.query_vector is not None else '否'} | " |
| 407 | + f"图片向量: {'是' if getattr(parsed_query, 'image_query_vector', None) is not None else '否'}", | ||
| 407 | extra={'reqid': context.reqid, 'uid': context.uid} | 408 | extra={'reqid': context.reqid, 'uid': context.uid} |
| 408 | ) | 409 | ) |
| 409 | except Exception as e: | 410 | except Exception as e: |
| @@ -428,12 +429,20 @@ class Searcher: | @@ -428,12 +429,20 @@ class Searcher: | ||
| 428 | es_query = self.query_builder.build_query( | 429 | es_query = self.query_builder.build_query( |
| 429 | query_text=parsed_query.rewritten_query or parsed_query.query_normalized, | 430 | query_text=parsed_query.rewritten_query or parsed_query.query_normalized, |
| 430 | query_vector=parsed_query.query_vector if enable_embedding else None, | 431 | query_vector=parsed_query.query_vector if enable_embedding else None, |
| 432 | + image_query_vector=( | ||
| 433 | + getattr(parsed_query, "image_query_vector", None) | ||
| 434 | + if enable_embedding | ||
| 435 | + else None | ||
| 436 | + ), | ||
| 431 | filters=filters, | 437 | filters=filters, |
| 432 | range_filters=range_filters, | 438 | range_filters=range_filters, |
| 433 | facet_configs=facets, | 439 | facet_configs=facets, |
| 434 | size=es_fetch_size, | 440 | size=es_fetch_size, |
| 435 | from_=es_fetch_from, | 441 | from_=es_fetch_from, |
| 436 | - enable_knn=enable_embedding and parsed_query.query_vector is not None, | 442 | + enable_knn=enable_embedding and ( |
| 443 | + parsed_query.query_vector is not None | ||
| 444 | + or getattr(parsed_query, "image_query_vector", None) is not None | ||
| 445 | + ), | ||
| 437 | min_score=min_score, | 446 | min_score=min_score, |
| 438 | parsed_query=parsed_query, | 447 | parsed_query=parsed_query, |
| 439 | ) | 448 | ) |
| @@ -475,15 +484,24 @@ class Searcher: | @@ -475,15 +484,24 @@ class Searcher: | ||
| 475 | # Serialize ES query to compute a compact size + stable digest for correlation | 484 | # Serialize ES query to compute a compact size + stable digest for correlation |
| 476 | es_query_compact = json.dumps(es_query_for_fetch, ensure_ascii=False, separators=(",", ":")) | 485 | es_query_compact = json.dumps(es_query_for_fetch, ensure_ascii=False, separators=(",", ":")) |
| 477 | es_query_digest = hashlib.sha256(es_query_compact.encode("utf-8")).hexdigest()[:16] | 486 | es_query_digest = hashlib.sha256(es_query_compact.encode("utf-8")).hexdigest()[:16] |
| 478 | - knn_enabled = bool(enable_embedding and parsed_query.query_vector is not None) | 487 | + knn_enabled = bool(enable_embedding and ( |
| 488 | + parsed_query.query_vector is not None | ||
| 489 | + or getattr(parsed_query, "image_query_vector", None) is not None | ||
| 490 | + )) | ||
| 479 | vector_dims = int(len(parsed_query.query_vector)) if parsed_query.query_vector is not None else 0 | 491 | vector_dims = int(len(parsed_query.query_vector)) if parsed_query.query_vector is not None else 0 |
| 492 | + image_vector_dims = ( | ||
| 493 | + int(len(parsed_query.image_query_vector)) | ||
| 494 | + if getattr(parsed_query, "image_query_vector", None) is not None | ||
| 495 | + else 0 | ||
| 496 | + ) | ||
| 480 | 497 | ||
| 481 | context.logger.info( | 498 | context.logger.info( |
| 482 | - "ES query built | size: %s chars | digest: %s | KNN: %s | vector_dims: %s | facets: %s | rerank_prefetch_source: %s", | 499 | + "ES query built | size: %s chars | digest: %s | KNN: %s | vector_dims: %s | image_vector_dims: %s | facets: %s | rerank_prefetch_source: %s", |
| 483 | len(es_query_compact), | 500 | len(es_query_compact), |
| 484 | es_query_digest, | 501 | es_query_digest, |
| 485 | "yes" if knn_enabled else "no", | 502 | "yes" if knn_enabled else "no", |
| 486 | vector_dims, | 503 | vector_dims, |
| 504 | + image_vector_dims, | ||
| 487 | "yes" if facets else "no", | 505 | "yes" if facets else "no", |
| 488 | rerank_prefetch_source, | 506 | rerank_prefetch_source, |
| 489 | extra={'reqid': context.reqid, 'uid': context.uid} | 507 | extra={'reqid': context.reqid, 'uid': context.uid} |
| @@ -497,6 +515,7 @@ class Searcher: | @@ -497,6 +515,7 @@ class Searcher: | ||
| 497 | "sha256_16": es_query_digest, | 515 | "sha256_16": es_query_digest, |
| 498 | "knn_enabled": knn_enabled, | 516 | "knn_enabled": knn_enabled, |
| 499 | "vector_dims": vector_dims, | 517 | "vector_dims": vector_dims, |
| 518 | + "image_vector_dims": image_vector_dims, | ||
| 500 | "has_facets": bool(facets), | 519 | "has_facets": bool(facets), |
| 501 | "query": es_query_for_fetch, | 520 | "query": es_query_for_fetch, |
| 502 | }) | 521 | }) |
tests/test_embedding_pipeline.py
| @@ -75,6 +75,15 @@ class _FakeQueryEncoder: | @@ -75,6 +75,15 @@ class _FakeQueryEncoder: | ||
| 75 | return np.array([np.array([0.11, 0.22, 0.33], dtype=np.float32) for _ in sentences], dtype=object) | 75 | return np.array([np.array([0.11, 0.22, 0.33], dtype=np.float32) for _ in sentences], dtype=object) |
| 76 | 76 | ||
| 77 | 77 | ||
| 78 | +class _FakeClipTextEncoder: | ||
| 79 | + def __init__(self): | ||
| 80 | + self.calls = [] | ||
| 81 | + | ||
| 82 | + def encode_clip_text(self, text, **kwargs): | ||
| 83 | + self.calls.append({"text": text, "kwargs": dict(kwargs)}) | ||
| 84 | + return np.array([0.44, 0.55, 0.66], dtype=np.float32) | ||
| 85 | + | ||
| 86 | + | ||
| 78 | def _tokenizer(text): | 87 | def _tokenizer(text): |
| 79 | return str(text).split() | 88 | return str(text).split() |
| 80 | 89 | ||
| @@ -91,7 +100,7 @@ class _FakeEmbeddingCache: | @@ -91,7 +100,7 @@ class _FakeEmbeddingCache: | ||
| 91 | return True | 100 | return True |
| 92 | 101 | ||
| 93 | 102 | ||
| 94 | -def _build_test_config() -> SearchConfig: | 103 | +def _build_test_config(*, image_embedding_field: Optional[str] = None) -> SearchConfig: |
| 95 | return SearchConfig( | 104 | return SearchConfig( |
| 96 | field_boosts={"title.en": 3.0}, | 105 | field_boosts={"title.en": 3.0}, |
| 97 | indexes=[IndexConfig(name="default", label="default", fields=["title.en"], boost=1.0)], | 106 | indexes=[IndexConfig(name="default", label="default", fields=["title.en"], boost=1.0)], |
| @@ -102,7 +111,7 @@ def _build_test_config() -> SearchConfig: | @@ -102,7 +111,7 @@ def _build_test_config() -> SearchConfig: | ||
| 102 | enable_query_rewrite=False, | 111 | enable_query_rewrite=False, |
| 103 | rewrite_dictionary={}, | 112 | rewrite_dictionary={}, |
| 104 | text_embedding_field="title_embedding", | 113 | text_embedding_field="title_embedding", |
| 105 | - image_embedding_field=None, | 114 | + image_embedding_field=image_embedding_field, |
| 106 | ), | 115 | ), |
| 107 | function_score=FunctionScoreConfig(), | 116 | function_score=FunctionScoreConfig(), |
| 108 | rerank=RerankConfig(), | 117 | rerank=RerankConfig(), |
| @@ -250,6 +259,26 @@ def test_query_parser_generates_query_vector_with_encoder(): | @@ -250,6 +259,26 @@ def test_query_parser_generates_query_vector_with_encoder(): | ||
| 250 | assert encoder.calls[0]["kwargs"]["priority"] == 1 | 259 | assert encoder.calls[0]["kwargs"]["priority"] == 1 |
| 251 | 260 | ||
| 252 | 261 | ||
| 262 | +def test_query_parser_generates_image_query_vector_with_clip_text_encoder(): | ||
| 263 | + text_encoder = _FakeQueryEncoder() | ||
| 264 | + image_encoder = _FakeClipTextEncoder() | ||
| 265 | + parser = QueryParser( | ||
| 266 | + config=_build_test_config(image_embedding_field="image_embedding.vector"), | ||
| 267 | + text_encoder=text_encoder, | ||
| 268 | + image_encoder=image_encoder, | ||
| 269 | + translator=_FakeTranslator(), | ||
| 270 | + tokenizer=_tokenizer, | ||
| 271 | + ) | ||
| 272 | + | ||
| 273 | + parsed = parser.parse("red dress", tenant_id="162", generate_vector=True) | ||
| 274 | + assert parsed.query_vector is not None | ||
| 275 | + assert parsed.image_query_vector is not None | ||
| 276 | + assert parsed.image_query_vector.shape == (3,) | ||
| 277 | + assert image_encoder.calls | ||
| 278 | + assert image_encoder.calls[0]["text"] == "red dress" | ||
| 279 | + assert image_encoder.calls[0]["kwargs"]["priority"] == 1 | ||
| 280 | + | ||
| 281 | + | ||
| 253 | def test_query_parser_skips_query_vector_when_disabled(): | 282 | def test_query_parser_skips_query_vector_when_disabled(): |
| 254 | parser = QueryParser( | 283 | parser = QueryParser( |
| 255 | config=_build_test_config(), | 284 | config=_build_test_config(), |
| @@ -260,6 +289,7 @@ def test_query_parser_skips_query_vector_when_disabled(): | @@ -260,6 +289,7 @@ def test_query_parser_skips_query_vector_when_disabled(): | ||
| 260 | 289 | ||
| 261 | parsed = parser.parse("red dress", tenant_id="162", generate_vector=False) | 290 | parsed = parser.parse("red dress", tenant_id="162", generate_vector=False) |
| 262 | assert parsed.query_vector is None | 291 | assert parsed.query_vector is None |
| 292 | + assert parsed.image_query_vector is None | ||
| 263 | 293 | ||
| 264 | 294 | ||
| 265 | def test_tei_text_model_splits_batches_over_client_limit(monkeypatch): | 295 | def test_tei_text_model_splits_batches_over_client_limit(monkeypatch): |
tests/test_es_query_builder.py
| @@ -13,22 +13,29 @@ def _builder() -> ESQueryBuilder: | @@ -13,22 +13,29 @@ def _builder() -> ESQueryBuilder: | ||
| 13 | core_multilingual_fields=["title", "brief"], | 13 | core_multilingual_fields=["title", "brief"], |
| 14 | shared_fields=[], | 14 | shared_fields=[], |
| 15 | text_embedding_field="title_embedding", | 15 | text_embedding_field="title_embedding", |
| 16 | + image_embedding_field="image_embedding.vector", | ||
| 16 | default_language="en", | 17 | default_language="en", |
| 17 | ) | 18 | ) |
| 18 | 19 | ||
| 19 | 20 | ||
| 20 | -def _lexical_clause(query_root: Dict[str, Any]) -> Dict[str, Any]: | ||
| 21 | - """Return the first named lexical bool clause from query_root.""" | ||
| 22 | - if "bool" in query_root and query_root["bool"].get("_name"): | ||
| 23 | - return query_root["bool"] | ||
| 24 | - for clause in query_root.get("bool", {}).get("should", []): | ||
| 25 | - clause_bool = clause.get("bool") or {} | ||
| 26 | - if clause_bool.get("_name"): | ||
| 27 | - return clause_bool | ||
| 28 | - raise AssertionError("no lexical bool clause in query_root") | 21 | +def _recall_root(es_body: Dict[str, Any]) -> Dict[str, Any]: |
| 22 | + query_root = es_body["query"] | ||
| 23 | + if "bool" in query_root and query_root["bool"].get("must"): | ||
| 24 | + query_root = query_root["bool"]["must"][0] | ||
| 25 | + if "function_score" in query_root: | ||
| 26 | + query_root = query_root["function_score"]["query"] | ||
| 27 | + return query_root | ||
| 29 | 28 | ||
| 30 | 29 | ||
| 31 | -def test_knn_prefilter_includes_range_filters(): | 30 | +def _recall_should_clauses(es_body: Dict[str, Any]) -> list[Dict[str, Any]]: |
| 31 | + root = _recall_root(es_body) | ||
| 32 | + should = root.get("bool", {}).get("should") | ||
| 33 | + if should: | ||
| 34 | + return should | ||
| 35 | + return [root] | ||
| 36 | + | ||
| 37 | + | ||
| 38 | +def test_knn_clause_moves_under_query_should_and_uses_outer_filters(): | ||
| 32 | qb = _builder() | 39 | qb = _builder() |
| 33 | q = qb.build_query( | 40 | q = qb.build_query( |
| 34 | query_text="bags", | 41 | query_text="bags", |
| @@ -37,11 +44,13 @@ def test_knn_prefilter_includes_range_filters(): | @@ -37,11 +44,13 @@ def test_knn_prefilter_includes_range_filters(): | ||
| 37 | enable_knn=True, | 44 | enable_knn=True, |
| 38 | ) | 45 | ) |
| 39 | 46 | ||
| 40 | - assert "knn" in q | ||
| 41 | - assert q["knn"]["filter"] == {"range": {"min_price": {"gte": 50, "lt": 100}}} | 47 | + assert "knn" not in q |
| 48 | + should = _recall_should_clauses(q) | ||
| 49 | + assert any(clause.get("knn", {}).get("_name") == "knn_query" for clause in should) | ||
| 50 | + assert q["query"]["bool"]["filter"] == [{"range": {"min_price": {"gte": 50, "lt": 100}}}] | ||
| 42 | 51 | ||
| 43 | 52 | ||
| 44 | -def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present(): | 53 | +def test_knn_clause_uses_outer_query_filter_when_disjunctive_filters_present(): |
| 45 | qb = _builder() | 54 | qb = _builder() |
| 46 | facets = [SimpleNamespace(field="category_name", disjunctive=True)] | 55 | facets = [SimpleNamespace(field="category_name", disjunctive=True)] |
| 47 | q = qb.build_query( | 56 | q = qb.build_query( |
| @@ -53,21 +62,15 @@ def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present(): | @@ -53,21 +62,15 @@ def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present(): | ||
| 53 | enable_knn=True, | 62 | enable_knn=True, |
| 54 | ) | 63 | ) |
| 55 | 64 | ||
| 56 | - assert "knn" in q | ||
| 57 | - assert "filter" in q["knn"] | ||
| 58 | - knn_filter = q["knn"]["filter"] | ||
| 59 | - assert knn_filter == { | ||
| 60 | - "bool": { | ||
| 61 | - "filter": [ | ||
| 62 | - {"term": {"vendor": "Nike"}}, | ||
| 63 | - {"range": {"min_price": {"gte": 50, "lt": 100}}}, | ||
| 64 | - ] | ||
| 65 | - } | ||
| 66 | - } | 65 | + assert "knn" not in q |
| 66 | + assert q["query"]["bool"]["filter"] == [ | ||
| 67 | + {"term": {"vendor": "Nike"}}, | ||
| 68 | + {"range": {"min_price": {"gte": 50, "lt": 100}}}, | ||
| 69 | + ] | ||
| 67 | assert q["post_filter"] == {"terms": {"category_name": ["A", "B"]}} | 70 | assert q["post_filter"] == {"terms": {"category_name": ["A", "B"]}} |
| 68 | 71 | ||
| 69 | 72 | ||
| 70 | -def test_knn_prefilter_not_added_without_filters(): | 73 | +def test_knn_clause_has_name_and_no_embedded_filter(): |
| 71 | qb = _builder() | 74 | qb = _builder() |
| 72 | q = qb.build_query( | 75 | q = qb.build_query( |
| 73 | query_text="bags", | 76 | query_text="bags", |
| @@ -75,9 +78,10 @@ def test_knn_prefilter_not_added_without_filters(): | @@ -75,9 +78,10 @@ def test_knn_prefilter_not_added_without_filters(): | ||
| 75 | enable_knn=True, | 78 | enable_knn=True, |
| 76 | ) | 79 | ) |
| 77 | 80 | ||
| 78 | - assert "knn" in q | ||
| 79 | - assert "filter" not in q["knn"] | ||
| 80 | - assert q["knn"]["_name"] == "knn_query" | 81 | + should = _recall_should_clauses(q) |
| 82 | + knn_clause = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "knn_query") | ||
| 83 | + assert "filter" not in knn_clause | ||
| 84 | + assert knn_clause["_name"] == "knn_query" | ||
| 81 | 85 | ||
| 82 | 86 | ||
| 83 | def test_text_query_contains_only_base_and_translation_named_queries(): | 87 | def test_text_query_contains_only_base_and_translation_named_queries(): |
| @@ -93,11 +97,11 @@ def test_text_query_contains_only_base_and_translation_named_queries(): | @@ -93,11 +97,11 @@ def test_text_query_contains_only_base_and_translation_named_queries(): | ||
| 93 | parsed_query=parsed_query, | 97 | parsed_query=parsed_query, |
| 94 | enable_knn=False, | 98 | enable_knn=False, |
| 95 | ) | 99 | ) |
| 96 | - should = q["query"]["bool"]["should"] | 100 | + should = _recall_should_clauses(q) |
| 97 | names = [clause["bool"]["_name"] for clause in should] | 101 | names = [clause["bool"]["_name"] for clause in should] |
| 98 | 102 | ||
| 99 | assert names == ["base_query", "base_query_trans_zh"] | 103 | assert names == ["base_query", "base_query_trans_zh"] |
| 100 | - base_should = q["query"]["bool"]["should"][0]["bool"]["should"] | 104 | + base_should = should[0]["bool"]["should"] |
| 101 | assert [clause["multi_match"]["type"] for clause in base_should] == ["best_fields", "phrase"] | 105 | assert [clause["multi_match"]["type"] for clause in base_should] == ["best_fields", "phrase"] |
| 102 | 106 | ||
| 103 | 107 | ||
| @@ -115,12 +119,12 @@ def test_text_query_skips_duplicate_translation_same_as_base(): | @@ -115,12 +119,12 @@ def test_text_query_skips_duplicate_translation_same_as_base(): | ||
| 115 | enable_knn=False, | 119 | enable_knn=False, |
| 116 | ) | 120 | ) |
| 117 | 121 | ||
| 118 | - root = q["query"] | 122 | + root = _recall_root(q) |
| 119 | assert root["bool"]["_name"] == "base_query" | 123 | assert root["bool"]["_name"] == "base_query" |
| 120 | assert [clause["multi_match"]["type"] for clause in root["bool"]["should"]] == ["best_fields", "phrase"] | 124 | assert [clause["multi_match"]["type"] for clause in root["bool"]["should"]] == ["best_fields", "phrase"] |
| 121 | 125 | ||
| 122 | 126 | ||
| 123 | -def test_product_title_exclusion_filter_is_applied_to_query_and_knn(): | 127 | +def test_product_title_exclusion_filter_is_applied_once_on_outer_query(): |
| 124 | qb = _builder() | 128 | qb = _builder() |
| 125 | parsed_query = SimpleNamespace( | 129 | parsed_query = SimpleNamespace( |
| 126 | rewritten_query="fitted dress", | 130 | rewritten_query="fitted dress", |
| @@ -158,4 +162,32 @@ def test_product_title_exclusion_filter_is_applied_to_query_and_knn(): | @@ -158,4 +162,32 @@ def test_product_title_exclusion_filter_is_applied_to_query_and_knn(): | ||
| 158 | } | 162 | } |
| 159 | 163 | ||
| 160 | assert expected_filter in q["query"]["bool"]["filter"] | 164 | assert expected_filter in q["query"]["bool"]["filter"] |
| 161 | - assert q["knn"]["filter"] == expected_filter | 165 | + should = _recall_should_clauses(q) |
| 166 | + knn_clause = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "knn_query") | ||
| 167 | + assert "filter" not in knn_clause | ||
| 168 | + | ||
| 169 | + | ||
| 170 | +def test_image_knn_clause_is_added_alongside_base_translation_and_text_knn(): | ||
| 171 | + qb = _builder() | ||
| 172 | + parsed_query = SimpleNamespace( | ||
| 173 | + rewritten_query="street tee", | ||
| 174 | + detected_language="en", | ||
| 175 | + translations={"zh": "街头短袖"}, | ||
| 176 | + ) | ||
| 177 | + | ||
| 178 | + q = qb.build_query( | ||
| 179 | + query_text="street tee", | ||
| 180 | + query_vector=np.array([0.1, 0.2, 0.3]), | ||
| 181 | + image_query_vector=np.array([0.4, 0.5, 0.6]), | ||
| 182 | + parsed_query=parsed_query, | ||
| 183 | + enable_knn=True, | ||
| 184 | + ) | ||
| 185 | + | ||
| 186 | + should = _recall_should_clauses(q) | ||
| 187 | + names = [ | ||
| 188 | + clause["bool"]["_name"] if "bool" in clause else clause["knn"]["_name"] | ||
| 189 | + for clause in should | ||
| 190 | + ] | ||
| 191 | + assert names == ["base_query", "base_query_trans_zh", "knn_query", "image_knn_query"] | ||
| 192 | + image_knn = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "image_knn_query") | ||
| 193 | + assert image_knn["field"] == "image_embedding.vector" |
tests/test_rerank_client.py
| @@ -149,3 +149,22 @@ def test_fuse_scores_and_resort_boosts_hits_with_selected_sku(): | @@ -149,3 +149,22 @@ def test_fuse_scores_and_resort_boosts_hits_with_selected_sku(): | ||
| 149 | assert [h["_id"] for h in hits] == ["style-selected", "plain"] | 149 | assert [h["_id"] for h in hits] == ["style-selected", "plain"] |
| 150 | assert debug[0]["style_intent_selected_sku"] is True | 150 | assert debug[0]["style_intent_selected_sku"] is True |
| 151 | assert debug[0]["style_intent_selected_sku_boost"] == 1.2 | 151 | assert debug[0]["style_intent_selected_sku_boost"] == 1.2 |
| 152 | + | ||
| 153 | + | ||
| 154 | +def test_fuse_scores_and_resort_uses_max_of_text_and_image_knn_scores(): | ||
| 155 | + hits = [ | ||
| 156 | + { | ||
| 157 | + "_id": "mm-hit", | ||
| 158 | + "_score": 1.0, | ||
| 159 | + "matched_queries": { | ||
| 160 | + "base_query": 1.5, | ||
| 161 | + "knn_query": 0.2, | ||
| 162 | + "image_knn_query": 0.7, | ||
| 163 | + }, | ||
| 164 | + } | ||
| 165 | + ] | ||
| 166 | + | ||
| 167 | + debug = fuse_scores_and_resort(hits, [0.8], debug=True) | ||
| 168 | + | ||
| 169 | + assert isclose(hits[0]["_knn_score"], 0.7, rel_tol=1e-9) | ||
| 170 | + assert isclose(debug[0]["knn_score"], 0.7, rel_tol=1e-9) |