From dc403578b9f167d06cff131dda36fd488340d99f Mon Sep 17 00:00:00 2001 From: tangwang Date: Fri, 27 Mar 2026 08:11:35 +0800 Subject: [PATCH] 多模态搜索 --- query/query_parser.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-- search/es_query_builder.py | 117 ++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------------------------------------------------- search/rerank_client.py | 9 ++++++++- search/searcher.py | 31 +++++++++++++++++++++++++------ tests/test_embedding_pipeline.py | 34 ++++++++++++++++++++++++++++++++-- tests/test_es_query_builder.py | 98 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------------------- tests/test_rerank_client.py | 19 +++++++++++++++++++ 7 files changed, 276 insertions(+), 107 deletions(-) diff --git a/query/query_parser.py b/query/query_parser.py index 4b7fffc..43cf343 100644 --- a/query/query_parser.py +++ b/query/query_parser.py @@ -14,6 +14,7 @@ import numpy as np import logging from concurrent.futures import ThreadPoolExecutor, wait +from embeddings.image_encoder import CLIPImageEncoder from embeddings.text_encoder import TextEmbeddingEncoder from config import SearchConfig from translation import create_translation_client @@ -66,6 +67,7 @@ class ParsedQuery: detected_language: Optional[str] = None translations: Dict[str, str] = field(default_factory=dict) query_vector: Optional[np.ndarray] = None + image_query_vector: Optional[np.ndarray] = None query_tokens: List[str] = field(default_factory=list) style_intent_profile: Optional[StyleIntentProfile] = None product_title_exclusion_profile: Optional[ProductTitleExclusionProfile] = None @@ -86,6 +88,8 @@ class ParsedQuery: "rewritten_query": self.rewritten_query, "detected_language": self.detected_language, "translations": self.translations, + "has_query_vector": self.query_vector is not None, + "has_image_query_vector": self.image_query_vector is not None, "query_tokens": self.query_tokens, "style_intent_profile": ( self.style_intent_profile.to_dict() if self.style_intent_profile is not None else None @@ -112,6 +116,7 @@ class QueryParser: self, config: SearchConfig, text_encoder: Optional[TextEmbeddingEncoder] = None, + image_encoder: Optional[CLIPImageEncoder] = None, translator: Optional[Any] = None, tokenizer: Optional[Callable[[str], Any]] = None, ): @@ -125,6 +130,7 @@ class QueryParser: """ self.config = config self._text_encoder = text_encoder + self._image_encoder = image_encoder self._translator = translator # Initialize components @@ -149,6 +155,9 @@ class QueryParser: if self.config.query_config.enable_text_embedding and self._text_encoder is None: logger.info("Initializing text encoder at QueryParser construction...") self._text_encoder = TextEmbeddingEncoder() + if self.config.query_config.image_embedding_field and self._image_encoder is None: + logger.info("Initializing image encoder at QueryParser construction...") + self._image_encoder = CLIPImageEncoder() if self._translator is None: from config.services_config import get_translation_config cfg = get_translation_config() @@ -169,6 +178,11 @@ class QueryParser: """Return pre-initialized translator.""" return self._translator + @property + def image_encoder(self) -> Optional[CLIPImageEncoder]: + """Return pre-initialized image encoder for CLIP text embeddings.""" + return self._image_encoder + def _build_tokenizer(self) -> Callable[[str], Any]: """Build the tokenizer used by query parsing. No fallback path by design.""" if hanlp is None: @@ -311,12 +325,21 @@ class QueryParser: # Stage 6: Text embedding - async execution query_vector = None + image_query_vector = None should_generate_embedding = ( generate_vector and self.config.query_config.enable_text_embedding ) + should_generate_image_embedding = ( + generate_vector and + bool(self.config.query_config.image_embedding_field) + ) - task_count = len(translation_targets) + (1 if should_generate_embedding else 0) + task_count = ( + len(translation_targets) + + (1 if should_generate_embedding else 0) + + (1 if should_generate_image_embedding else 0) + ) if task_count > 0: async_executor = ThreadPoolExecutor( max_workers=max(1, min(task_count, 4)), @@ -366,6 +389,28 @@ class QueryParser: future = async_executor.submit(_encode_query_vector) future_to_task[future] = ("embedding", None) + + if should_generate_image_embedding: + if self.image_encoder is None: + raise RuntimeError( + "Image embedding field is configured but image encoder is not initialized" + ) + log_debug("Submitting CLIP text query vector generation") + + def _encode_image_query_vector() -> Optional[np.ndarray]: + vec = self.image_encoder.encode_clip_text( + query_text, + normalize_embeddings=True, + priority=1, + request_id=(context.reqid if context else None), + user_id=(context.uid if context else None), + ) + if vec is None: + return None + return np.asarray(vec, dtype=np.float32) + + future = async_executor.submit(_encode_image_query_vector) + future_to_task[future] = ("image_embedding", None) except Exception as e: error_msg = f"Async query enrichment submission failed | Error: {str(e)}" log_info(error_msg) @@ -424,9 +469,27 @@ class QueryParser: log_info( "Query vector generation completed but result is None, will process without vector" ) + elif task_type == "image_embedding": + image_query_vector = result + if image_query_vector is not None: + log_debug( + f"CLIP text query vector generation completed | Shape: {image_query_vector.shape}" + ) + if context: + context.store_intermediate_result( + "image_query_vector_shape", + image_query_vector.shape, + ) + else: + log_info( + "CLIP text query vector generation completed but result is None, " + "will process without image vector" + ) except Exception as e: if task_type == "translation": error_msg = f"Translation failed | Language: {lang} | Error: {str(e)}" + elif task_type == "image_embedding": + error_msg = f"CLIP text query vector generation failed | Error: {str(e)}" else: error_msg = f"Query vector generation failed | Error: {str(e)}" log_info(error_msg) @@ -441,6 +504,11 @@ class QueryParser: f"Translation timeout (>{budget_ms}ms) | Language: {lang} | " f"Query text: '{query_text}'" ) + elif task_type == "image_embedding": + timeout_msg = ( + f"CLIP text query vector generation timeout (>{budget_ms}ms), " + "proceeding without image embedding result" + ) else: timeout_msg = ( f"Query vector generation timeout (>{budget_ms}ms), proceeding without embedding result" @@ -463,6 +531,7 @@ class QueryParser: detected_language=detected_lang, translations=translations, query_vector=query_vector, + image_query_vector=image_query_vector, query_tokens=query_tokens, ) style_intent_profile = self.style_intent_detector.detect(base_result) @@ -484,6 +553,7 @@ class QueryParser: detected_language=detected_lang, translations=translations, query_vector=query_vector, + image_query_vector=image_query_vector, query_tokens=query_tokens, style_intent_profile=style_intent_profile, product_title_exclusion_profile=product_title_exclusion_profile, @@ -492,7 +562,8 @@ class QueryParser: if context and hasattr(context, 'logger'): context.logger.info( f"Query parsing completed | Original query: '{query}' | Final query: '{rewritten or query_text}' | " - f"Translation count: {len(translations)} | Vector: {'yes' if query_vector is not None else 'no'}", + f"Translation count: {len(translations)} | Vector: {'yes' if query_vector is not None else 'no'} | " + f"Image vector: {'yes' if image_query_vector is not None else 'no'}", extra={'reqid': context.reqid, 'uid': context.uid} ) else: diff --git a/search/es_query_builder.py b/search/es_query_builder.py index e5ffc2c..25778b3 100644 --- a/search/es_query_builder.py +++ b/search/es_query_builder.py @@ -164,6 +164,7 @@ class ESQueryBuilder: self, query_text: str, query_vector: Optional[np.ndarray] = None, + image_query_vector: Optional[np.ndarray] = None, filters: Optional[Dict[str, Any]] = None, range_filters: Optional[Dict[str, Any]] = None, facet_configs: Optional[List[Any]] = None, @@ -212,15 +213,14 @@ class ESQueryBuilder: # 1. Build recall queries (text or embedding) recall_clauses = [] - + # Text recall (always include if query_text exists) if query_text: - # Unified text query strategy - text_query = self._build_advanced_text_query(query_text, parsed_query) - recall_clauses.append(text_query) - - # Embedding recall (KNN - separate from query, handled below) + recall_clauses.extend(self._build_advanced_text_query(query_text, parsed_query)) + + # Embedding recall has_embedding = enable_knn and query_vector is not None and self.text_embedding_field + has_image_embedding = enable_knn and image_query_vector is not None and self.image_embedding_field # 2. Split filters for multi-select faceting conjunctive_filters, disjunctive_filters = self._split_filters_for_faceting( @@ -233,9 +233,48 @@ class ESQueryBuilder: if product_title_exclusion_filter: filter_clauses.append(product_title_exclusion_filter) - # 3. Build main query structure: filters and recall + # 3. Add KNN search clauses alongside lexical clauses under the same bool.should + # Adjust KNN k, num_candidates, boost by query_tokens (short query: less KNN; long: more) + final_knn_k, final_knn_num_candidates = knn_k, knn_num_candidates + if has_embedding: + knn_boost = self.knn_boost + if parsed_query: + query_tokens = getattr(parsed_query, 'query_tokens', None) or [] + token_count = len(query_tokens) + if token_count >= 5: + final_knn_k, final_knn_num_candidates = 160, 500 + knn_boost = self.knn_boost * 1.4 # Higher weight for long queries + else: + final_knn_k, final_knn_num_candidates = 120, 400 + else: + final_knn_k, final_knn_num_candidates = 120, 400 + recall_clauses.append({ + "knn": { + "field": self.text_embedding_field, + "query_vector": query_vector.tolist(), + "k": final_knn_k, + "num_candidates": final_knn_num_candidates, + "boost": knn_boost, + "_name": "knn_query", + } + }) + + if has_image_embedding: + image_knn_k = max(final_knn_k, 120) + image_knn_num_candidates = max(final_knn_num_candidates, 400) + recall_clauses.append({ + "knn": { + "field": self.image_embedding_field, + "query_vector": image_query_vector.tolist(), + "k": image_knn_k, + "num_candidates": image_knn_num_candidates, + "boost": self.knn_boost, + "_name": "image_knn_query", + } + }) + + # 4. Build main query structure: filters and recall if recall_clauses: - # Combine text recalls with OR logic (if multiple) if len(recall_clauses) == 1: recall_query = recall_clauses[0] else: @@ -245,11 +284,9 @@ class ESQueryBuilder: "minimum_should_match": 1 } } - - # Wrap recall with function_score for boosting + recall_query = self._wrap_with_function_score(recall_query) - - # Combine filters and recall + if filter_clauses: es_query["query"] = { "bool": { @@ -260,7 +297,6 @@ class ESQueryBuilder: else: es_query["query"] = recall_query else: - # No recall queries, only filters (match_all filtered) if filter_clauses: es_query["query"] = { "bool": { @@ -271,41 +307,6 @@ class ESQueryBuilder: else: es_query["query"] = {"match_all": {}} - # 4. Add KNN search if enabled (separate from query, ES will combine) - # Adjust KNN k, num_candidates, boost by query_tokens (short query: less KNN; long: more) - if has_embedding: - knn_boost = self.knn_boost - if parsed_query: - query_tokens = getattr(parsed_query, 'query_tokens', None) or [] - token_count = len(query_tokens) - if token_count >= 5: - knn_k, knn_num_candidates = 160, 500 - knn_boost = self.knn_boost * 1.4 # Higher weight for long queries - else: - knn_k, knn_num_candidates = 120, 400 - else: - knn_k, knn_num_candidates = 120, 400 - knn_clause = { - "field": self.text_embedding_field, - "query_vector": query_vector.tolist(), - "k": knn_k, - "num_candidates": knn_num_candidates, - "boost": knn_boost, - "_name": "knn_query", - } - # Top-level knn does not inherit query.bool.filter automatically. - # Apply conjunctive + range filters here so vector recall respects hard filters. - if filter_clauses: - if len(filter_clauses) == 1: - knn_clause["filter"] = filter_clauses[0] - else: - knn_clause["filter"] = { - "bool": { - "filter": filter_clauses - } - } - es_query["knn"] = knn_clause - # 5. Add post_filter for disjunctive (multi-select) filters if disjunctive_filters: post_filter_clauses = self._build_filters(disjunctive_filters, None) @@ -536,21 +537,20 @@ class ESQueryBuilder: self, query_text: str, parsed_query: Optional[Any] = None, - ) -> Dict[str, Any]: + ) -> List[Dict[str, Any]]: """ Build advanced text query using base and translated lexical clauses. Unified implementation: - base_query: source-language clause - translation queries: target-language clauses from translations - - KNN query: added separately in build_query - + Args: query_text: Query text parsed_query: ParsedQuery object with analysis results Returns: - ES bool query with should clauses + Flat recall clauses to be merged with KNN clauses under query.bool.should """ should_clauses = [] source_lang = self.default_language @@ -603,18 +603,9 @@ class ESQueryBuilder: "minimum_should_match": self.base_minimum_should_match, } } - return fallback_lexical + return [fallback_lexical] - # Return bool query with should clauses - if len(should_clauses) == 1: - return should_clauses[0] - - return { - "bool": { - "should": should_clauses, - "minimum_should_match": 1 - } - } + return should_clauses def _build_filters( self, diff --git a/search/rerank_client.py b/search/rerank_client.py index f8433d9..8d49b58 100644 --- a/search/rerank_client.py +++ b/search/rerank_client.py @@ -151,6 +151,13 @@ def _extract_named_query_score(matched_queries: Any, name: str) -> float: return 1.0 if name in matched_queries else 0.0 return 0.0 + +def _extract_combined_knn_score(matched_queries: Any) -> float: + return max( + _extract_named_query_score(matched_queries, "knn_query"), + _extract_named_query_score(matched_queries, "image_knn_query"), + ) + """ 原始变量: ES总分 @@ -272,7 +279,7 @@ def fuse_scores_and_resort( es_score = _to_score(hit.get("_score")) rerank_score = _to_score(rerank_scores[idx]) matched_queries = hit.get("matched_queries") - knn_score = _extract_named_query_score(matched_queries, "knn_query") + knn_score = _extract_combined_knn_score(matched_queries) text_components = _collect_text_score_components(matched_queries, es_score) text_score = text_components["text_score"] rerank_factor, text_factor, knn_factor, fused = _multiply_fusion_factors( diff --git a/search/searcher.py b/search/searcher.py index 718ede2..75d6620 100644 --- a/search/searcher.py +++ b/search/searcher.py @@ -106,14 +106,14 @@ class Searcher: """ self.es_client = es_client self.config = config - # Index name is now generated dynamically per tenant, no longer stored here - self.query_parser = query_parser or QueryParser(config) self.text_embedding_field = config.query_config.text_embedding_field or "title_embedding" self.image_embedding_field = config.query_config.image_embedding_field if self.image_embedding_field and image_encoder is None: self.image_encoder = CLIPImageEncoder() else: self.image_encoder = image_encoder + # Index name is now generated dynamically per tenant, no longer stored here + self.query_parser = query_parser or QueryParser(config, image_encoder=self.image_encoder) self.source_fields = config.query_config.source_fields self.style_intent_registry = StyleIntentRegistry.from_query_config(self.config.query_config) self.style_sku_selector = StyleSkuSelector( @@ -403,7 +403,8 @@ class Searcher: f"查询解析完成 | 原查询: '{parsed_query.original_query}' | " f"重写后: '{parsed_query.rewritten_query}' | " f"语言: {parsed_query.detected_language} | " - f"向量: {'是' if parsed_query.query_vector is not None else '否'}", + f"文本向量: {'是' if parsed_query.query_vector is not None else '否'} | " + f"图片向量: {'是' if getattr(parsed_query, 'image_query_vector', None) is not None else '否'}", extra={'reqid': context.reqid, 'uid': context.uid} ) except Exception as e: @@ -428,12 +429,20 @@ class Searcher: es_query = self.query_builder.build_query( query_text=parsed_query.rewritten_query or parsed_query.query_normalized, query_vector=parsed_query.query_vector if enable_embedding else None, + image_query_vector=( + getattr(parsed_query, "image_query_vector", None) + if enable_embedding + else None + ), filters=filters, range_filters=range_filters, facet_configs=facets, size=es_fetch_size, from_=es_fetch_from, - enable_knn=enable_embedding and parsed_query.query_vector is not None, + enable_knn=enable_embedding and ( + parsed_query.query_vector is not None + or getattr(parsed_query, "image_query_vector", None) is not None + ), min_score=min_score, parsed_query=parsed_query, ) @@ -475,15 +484,24 @@ class Searcher: # Serialize ES query to compute a compact size + stable digest for correlation es_query_compact = json.dumps(es_query_for_fetch, ensure_ascii=False, separators=(",", ":")) es_query_digest = hashlib.sha256(es_query_compact.encode("utf-8")).hexdigest()[:16] - knn_enabled = bool(enable_embedding and parsed_query.query_vector is not None) + knn_enabled = bool(enable_embedding and ( + parsed_query.query_vector is not None + or getattr(parsed_query, "image_query_vector", None) is not None + )) vector_dims = int(len(parsed_query.query_vector)) if parsed_query.query_vector is not None else 0 + image_vector_dims = ( + int(len(parsed_query.image_query_vector)) + if getattr(parsed_query, "image_query_vector", None) is not None + else 0 + ) context.logger.info( - "ES query built | size: %s chars | digest: %s | KNN: %s | vector_dims: %s | facets: %s | rerank_prefetch_source: %s", + "ES query built | size: %s chars | digest: %s | KNN: %s | vector_dims: %s | image_vector_dims: %s | facets: %s | rerank_prefetch_source: %s", len(es_query_compact), es_query_digest, "yes" if knn_enabled else "no", vector_dims, + image_vector_dims, "yes" if facets else "no", rerank_prefetch_source, extra={'reqid': context.reqid, 'uid': context.uid} @@ -497,6 +515,7 @@ class Searcher: "sha256_16": es_query_digest, "knn_enabled": knn_enabled, "vector_dims": vector_dims, + "image_vector_dims": image_vector_dims, "has_facets": bool(facets), "query": es_query_for_fetch, }) diff --git a/tests/test_embedding_pipeline.py b/tests/test_embedding_pipeline.py index bd1734c..48dc0ff 100644 --- a/tests/test_embedding_pipeline.py +++ b/tests/test_embedding_pipeline.py @@ -75,6 +75,15 @@ class _FakeQueryEncoder: return np.array([np.array([0.11, 0.22, 0.33], dtype=np.float32) for _ in sentences], dtype=object) +class _FakeClipTextEncoder: + def __init__(self): + self.calls = [] + + def encode_clip_text(self, text, **kwargs): + self.calls.append({"text": text, "kwargs": dict(kwargs)}) + return np.array([0.44, 0.55, 0.66], dtype=np.float32) + + def _tokenizer(text): return str(text).split() @@ -91,7 +100,7 @@ class _FakeEmbeddingCache: return True -def _build_test_config() -> SearchConfig: +def _build_test_config(*, image_embedding_field: Optional[str] = None) -> SearchConfig: return SearchConfig( field_boosts={"title.en": 3.0}, indexes=[IndexConfig(name="default", label="default", fields=["title.en"], boost=1.0)], @@ -102,7 +111,7 @@ def _build_test_config() -> SearchConfig: enable_query_rewrite=False, rewrite_dictionary={}, text_embedding_field="title_embedding", - image_embedding_field=None, + image_embedding_field=image_embedding_field, ), function_score=FunctionScoreConfig(), rerank=RerankConfig(), @@ -250,6 +259,26 @@ def test_query_parser_generates_query_vector_with_encoder(): assert encoder.calls[0]["kwargs"]["priority"] == 1 +def test_query_parser_generates_image_query_vector_with_clip_text_encoder(): + text_encoder = _FakeQueryEncoder() + image_encoder = _FakeClipTextEncoder() + parser = QueryParser( + config=_build_test_config(image_embedding_field="image_embedding.vector"), + text_encoder=text_encoder, + image_encoder=image_encoder, + translator=_FakeTranslator(), + tokenizer=_tokenizer, + ) + + parsed = parser.parse("red dress", tenant_id="162", generate_vector=True) + assert parsed.query_vector is not None + assert parsed.image_query_vector is not None + assert parsed.image_query_vector.shape == (3,) + assert image_encoder.calls + assert image_encoder.calls[0]["text"] == "red dress" + assert image_encoder.calls[0]["kwargs"]["priority"] == 1 + + def test_query_parser_skips_query_vector_when_disabled(): parser = QueryParser( config=_build_test_config(), @@ -260,6 +289,7 @@ def test_query_parser_skips_query_vector_when_disabled(): parsed = parser.parse("red dress", tenant_id="162", generate_vector=False) assert parsed.query_vector is None + assert parsed.image_query_vector is None def test_tei_text_model_splits_batches_over_client_limit(monkeypatch): diff --git a/tests/test_es_query_builder.py b/tests/test_es_query_builder.py index f4a06bd..1e5789f 100644 --- a/tests/test_es_query_builder.py +++ b/tests/test_es_query_builder.py @@ -13,22 +13,29 @@ def _builder() -> ESQueryBuilder: core_multilingual_fields=["title", "brief"], shared_fields=[], text_embedding_field="title_embedding", + image_embedding_field="image_embedding.vector", default_language="en", ) -def _lexical_clause(query_root: Dict[str, Any]) -> Dict[str, Any]: - """Return the first named lexical bool clause from query_root.""" - if "bool" in query_root and query_root["bool"].get("_name"): - return query_root["bool"] - for clause in query_root.get("bool", {}).get("should", []): - clause_bool = clause.get("bool") or {} - if clause_bool.get("_name"): - return clause_bool - raise AssertionError("no lexical bool clause in query_root") +def _recall_root(es_body: Dict[str, Any]) -> Dict[str, Any]: + query_root = es_body["query"] + if "bool" in query_root and query_root["bool"].get("must"): + query_root = query_root["bool"]["must"][0] + if "function_score" in query_root: + query_root = query_root["function_score"]["query"] + return query_root -def test_knn_prefilter_includes_range_filters(): +def _recall_should_clauses(es_body: Dict[str, Any]) -> list[Dict[str, Any]]: + root = _recall_root(es_body) + should = root.get("bool", {}).get("should") + if should: + return should + return [root] + + +def test_knn_clause_moves_under_query_should_and_uses_outer_filters(): qb = _builder() q = qb.build_query( query_text="bags", @@ -37,11 +44,13 @@ def test_knn_prefilter_includes_range_filters(): enable_knn=True, ) - assert "knn" in q - assert q["knn"]["filter"] == {"range": {"min_price": {"gte": 50, "lt": 100}}} + assert "knn" not in q + should = _recall_should_clauses(q) + assert any(clause.get("knn", {}).get("_name") == "knn_query" for clause in should) + assert q["query"]["bool"]["filter"] == [{"range": {"min_price": {"gte": 50, "lt": 100}}}] -def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present(): +def test_knn_clause_uses_outer_query_filter_when_disjunctive_filters_present(): qb = _builder() facets = [SimpleNamespace(field="category_name", disjunctive=True)] q = qb.build_query( @@ -53,21 +62,15 @@ def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present(): enable_knn=True, ) - assert "knn" in q - assert "filter" in q["knn"] - knn_filter = q["knn"]["filter"] - assert knn_filter == { - "bool": { - "filter": [ - {"term": {"vendor": "Nike"}}, - {"range": {"min_price": {"gte": 50, "lt": 100}}}, - ] - } - } + assert "knn" not in q + assert q["query"]["bool"]["filter"] == [ + {"term": {"vendor": "Nike"}}, + {"range": {"min_price": {"gte": 50, "lt": 100}}}, + ] assert q["post_filter"] == {"terms": {"category_name": ["A", "B"]}} -def test_knn_prefilter_not_added_without_filters(): +def test_knn_clause_has_name_and_no_embedded_filter(): qb = _builder() q = qb.build_query( query_text="bags", @@ -75,9 +78,10 @@ def test_knn_prefilter_not_added_without_filters(): enable_knn=True, ) - assert "knn" in q - assert "filter" not in q["knn"] - assert q["knn"]["_name"] == "knn_query" + should = _recall_should_clauses(q) + knn_clause = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "knn_query") + assert "filter" not in knn_clause + assert knn_clause["_name"] == "knn_query" def test_text_query_contains_only_base_and_translation_named_queries(): @@ -93,11 +97,11 @@ def test_text_query_contains_only_base_and_translation_named_queries(): parsed_query=parsed_query, enable_knn=False, ) - should = q["query"]["bool"]["should"] + should = _recall_should_clauses(q) names = [clause["bool"]["_name"] for clause in should] assert names == ["base_query", "base_query_trans_zh"] - base_should = q["query"]["bool"]["should"][0]["bool"]["should"] + base_should = should[0]["bool"]["should"] assert [clause["multi_match"]["type"] for clause in base_should] == ["best_fields", "phrase"] @@ -115,12 +119,12 @@ def test_text_query_skips_duplicate_translation_same_as_base(): enable_knn=False, ) - root = q["query"] + root = _recall_root(q) assert root["bool"]["_name"] == "base_query" assert [clause["multi_match"]["type"] for clause in root["bool"]["should"]] == ["best_fields", "phrase"] -def test_product_title_exclusion_filter_is_applied_to_query_and_knn(): +def test_product_title_exclusion_filter_is_applied_once_on_outer_query(): qb = _builder() parsed_query = SimpleNamespace( rewritten_query="fitted dress", @@ -158,4 +162,32 @@ def test_product_title_exclusion_filter_is_applied_to_query_and_knn(): } assert expected_filter in q["query"]["bool"]["filter"] - assert q["knn"]["filter"] == expected_filter + should = _recall_should_clauses(q) + knn_clause = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "knn_query") + assert "filter" not in knn_clause + + +def test_image_knn_clause_is_added_alongside_base_translation_and_text_knn(): + qb = _builder() + parsed_query = SimpleNamespace( + rewritten_query="street tee", + detected_language="en", + translations={"zh": "街头短袖"}, + ) + + q = qb.build_query( + query_text="street tee", + query_vector=np.array([0.1, 0.2, 0.3]), + image_query_vector=np.array([0.4, 0.5, 0.6]), + parsed_query=parsed_query, + enable_knn=True, + ) + + should = _recall_should_clauses(q) + names = [ + clause["bool"]["_name"] if "bool" in clause else clause["knn"]["_name"] + for clause in should + ] + assert names == ["base_query", "base_query_trans_zh", "knn_query", "image_knn_query"] + image_knn = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "image_knn_query") + assert image_knn["field"] == "image_embedding.vector" diff --git a/tests/test_rerank_client.py b/tests/test_rerank_client.py index 736336e..e0fccf3 100644 --- a/tests/test_rerank_client.py +++ b/tests/test_rerank_client.py @@ -149,3 +149,22 @@ def test_fuse_scores_and_resort_boosts_hits_with_selected_sku(): assert [h["_id"] for h in hits] == ["style-selected", "plain"] assert debug[0]["style_intent_selected_sku"] is True assert debug[0]["style_intent_selected_sku_boost"] == 1.2 + + +def test_fuse_scores_and_resort_uses_max_of_text_and_image_knn_scores(): + hits = [ + { + "_id": "mm-hit", + "_score": 1.0, + "matched_queries": { + "base_query": 1.5, + "knn_query": 0.2, + "image_knn_query": 0.7, + }, + } + ] + + debug = fuse_scores_and_resort(hits, [0.8], debug=True) + + assert isclose(hits[0]["_knn_score"], 0.7, rel_tol=1e-9) + assert isclose(debug[0]["knn_score"], 0.7, rel_tol=1e-9) -- libgit2 0.21.2