From de98daa3248e2ec397d44865c6d4b866ea8a6adb Mon Sep 17 00:00:00 2001 From: tangwang Date: Mon, 30 Mar 2026 20:59:37 +0800 Subject: [PATCH] 多模态召回优化 --- config/config.yaml | 18 ++++++++++++------ config/loader.py | 6 ++++++ config/schema.py | 4 ++++ mappings/generate_search_products_mapping.py | 2 +- mappings/search_products.json | 2 +- search/es_query_builder.py | 34 ++++++++++++++++++++++++---------- search/rerank_client.py | 17 +++++++++++++---- tests/test_es_query_builder.py | 18 +++++++++++++++--- 8 files changed, 76 insertions(+), 25 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 7551686..43aeb0a 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -213,14 +213,15 @@ query_config: knn_text_boost: 4 knn_image_boost: 4 - knn_text_k: 150 - knn_text_num_candidates: 400 + # knn_text_num_candidates = k * 3.4 + knn_text_k: 160 + knn_text_num_candidates: 560 - knn_text_k_long: 300 - knn_text_num_candidates_long: 720 + knn_text_k_long: 400 + knn_text_num_candidates_long: 1200 - knn_image_k: 3000 - knn_image_num_candidates: 7200 + knn_image_k: 400 + knn_image_num_candidates: 1200 # Function Score配置(ES层打分规则) function_score: @@ -236,6 +237,9 @@ coarse_rank: fusion: text_bias: 0.1 text_exponent: 0.35 + # base_query_trans_* 相对 base_query 的权重(见 search/rerank_client 中文本 dismax 融合) + # 因为es的打分已经给了trans进行了折扣,所以这里不再继续折扣 + text_translation_weight: 1.0 knn_text_weight: 1.0 knn_image_weight: 1.0 knn_tie_breaker: 0.1 @@ -273,6 +277,8 @@ rerank: fine_exponent: 1.0 text_bias: 0.1 text_exponent: 0.35 + # base_query_trans_* 相对 base_query 的权重(见 search/rerank_client 中文本 dismax 融合) + text_translation_weight: 1.0 knn_text_weight: 1.0 knn_image_weight: 1.0 knn_tie_breaker: 0.1 diff --git a/config/loader.py b/config/loader.py index 316cf0a..c084b3d 100644 --- a/config/loader.py +++ b/config/loader.py @@ -498,6 +498,9 @@ class AppConfigLoader: knn_tie_breaker=float(coarse_fusion_raw.get("knn_tie_breaker", 0.0)), knn_bias=float(coarse_fusion_raw.get("knn_bias", 0.6)), knn_exponent=float(coarse_fusion_raw.get("knn_exponent", 0.2)), + text_translation_weight=float( + coarse_fusion_raw.get("text_translation_weight", 0.8) + ), ), ), fine_rank=FineRankConfig( @@ -538,6 +541,9 @@ class AppConfigLoader: knn_exponent=float(fusion_raw.get("knn_exponent", 0.2)), fine_bias=float(fusion_raw.get("fine_bias", 0.00001)), fine_exponent=float(fusion_raw.get("fine_exponent", 1.0)), + text_translation_weight=float( + fusion_raw.get("text_translation_weight", 0.8) + ), ), ), spu_config=SPUConfig( diff --git a/config/schema.py b/config/schema.py index 6e5f61b..0554aab 100644 --- a/config/schema.py +++ b/config/schema.py @@ -119,6 +119,8 @@ class RerankFusionConfig: knn_exponent: float = 0.2 fine_bias: float = 0.00001 fine_exponent: float = 1.0 + #: 翻译子句 named query 分数相对原文 base_query 的权重(加权后再与原文做 dismax 融合) + text_translation_weight: float = 0.8 @dataclass(frozen=True) @@ -136,6 +138,8 @@ class CoarseRankFusionConfig: knn_tie_breaker: float = 0.0 knn_bias: float = 0.6 knn_exponent: float = 0.2 + #: 翻译子句 named query 分数相对原文 base_query 的权重(加权后再与原文做 dismax 融合) + text_translation_weight: float = 0.8 @dataclass(frozen=True) diff --git a/mappings/generate_search_products_mapping.py b/mappings/generate_search_products_mapping.py index 0579825..5ed7e98 100644 --- a/mappings/generate_search_products_mapping.py +++ b/mappings/generate_search_products_mapping.py @@ -80,7 +80,7 @@ ANALYZERS = { } SETTINGS = { - "number_of_shards": 1, + "number_of_shards": 4, "number_of_replicas": 0, "refresh_interval": "30s", "analysis": { diff --git a/mappings/search_products.json b/mappings/search_products.json index 0fefc62..07173fa 100644 --- a/mappings/search_products.json +++ b/mappings/search_products.json @@ -1,6 +1,6 @@ { "settings": { - "number_of_shards": 1, + "number_of_shards": 4, "number_of_replicas": 0, "refresh_interval": "30s", "analysis": { diff --git a/search/es_query_builder.py b/search/es_query_builder.py index 59d8543..2ad5fac 100644 --- a/search/es_query_builder.py +++ b/search/es_query_builder.py @@ -272,16 +272,30 @@ class ESQueryBuilder: }) if has_image_embedding: - recall_clauses.append({ - "knn": { - "field": self.image_embedding_field, - "query_vector": image_query_vector.tolist(), - "k": self.knn_image_k, - "num_candidates": self.knn_image_num_candidates, - "boost": self.knn_image_boost, - "_name": "image_knn_query", - } - }) + nested_path, _, _ = str(self.image_embedding_field).rpartition(".") + image_knn_query = { + "field": self.image_embedding_field, + "query_vector": image_query_vector.tolist(), + "k": self.knn_image_k, + "num_candidates": self.knn_image_num_candidates, + "boost": self.knn_image_boost, + } + if nested_path: + recall_clauses.append({ + "nested": { + "path": nested_path, + "_name": "image_knn_query", + "query": {"knn": image_knn_query}, + "score_mode": "max", + } + }) + else: + recall_clauses.append({ + "knn": { + **image_knn_query, + "_name": "image_knn_query", + } + }) # 4. Build main query structure: filters and recall if recall_clauses: diff --git a/search/rerank_client.py b/search/rerank_client.py index 560ce1c..4384e50 100644 --- a/search/rerank_client.py +++ b/search/rerank_client.py @@ -186,7 +186,7 @@ translation_score:所有名字以 base_query_trans_ 开头的 named query 的 中间变量:计算原始query得分和翻译query得分 weighted_source : -weighted_translation : 0.8 * translation_score +weighted_translation : text_translation_weight * translation_score(由 fusion.text_translation_weight 配置) 区分主信号和辅助信号: 合成primary_text_score和support_text_score,取 更强 的那一路(原文检索 vs 翻译检索)作为主信号 @@ -197,7 +197,12 @@ support_text_score : weighted_source + weighted_translation - primary_text_score 最终text_score:主信号 + 0.25 * 辅助信号 text_score : primary_text_score + 0.25 * support_text_score """ -def _collect_text_score_components(matched_queries: Any, fallback_es_score: float) -> Dict[str, float]: +def _collect_text_score_components( + matched_queries: Any, + fallback_es_score: float, + *, + translation_weight: float, +) -> Dict[str, float]: source_score = _extract_named_query_score(matched_queries, "base_query") translation_score = 0.0 @@ -216,7 +221,7 @@ def _collect_text_score_components(matched_queries: Any, fallback_es_score: floa translation_score = 1.0 weighted_source = source_score - weighted_translation = 0.8 * translation_score + weighted_translation = float(translation_weight) * translation_score weighted_components = [weighted_source, weighted_translation] primary_text_score = max(weighted_components) support_text_score = sum(weighted_components) - primary_text_score @@ -249,7 +254,11 @@ def _build_hit_signal_bundle( ) -> Dict[str, Any]: es_score = _to_score(hit.get("_score")) matched_queries = hit.get("matched_queries") - text_components = _collect_text_score_components(matched_queries, es_score) + text_components = _collect_text_score_components( + matched_queries, + es_score, + translation_weight=fusion.text_translation_weight, + ) knn_components = _collect_knn_score_components(matched_queries, fusion) return { "doc_id": hit.get("_id"), diff --git a/tests/test_es_query_builder.py b/tests/test_es_query_builder.py index a6cca7a..03c8448 100644 --- a/tests/test_es_query_builder.py +++ b/tests/test_es_query_builder.py @@ -35,6 +35,16 @@ def _recall_should_clauses(es_body: Dict[str, Any]) -> list[Dict[str, Any]]: return [root] +def _recall_clause_name(clause: Dict[str, Any]) -> str | None: + if "bool" in clause: + return clause["bool"].get("_name") + if "knn" in clause: + return clause["knn"].get("_name") + if "nested" in clause: + return clause["nested"].get("_name") + return None + + def test_knn_clause_moves_under_query_should_and_uses_outer_filters(): qb = _builder() q = qb.build_query( @@ -188,9 +198,11 @@ def test_image_knn_clause_is_added_alongside_base_translation_and_text_knn(): should = _recall_should_clauses(q) names = [ - clause["bool"]["_name"] if "bool" in clause else clause["knn"]["_name"] + _recall_clause_name(clause) for clause in should ] assert names == ["base_query", "base_query_trans_zh", "knn_query", "image_knn_query"] - image_knn = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "image_knn_query") - assert image_knn["field"] == "image_embedding.vector" + image_knn = next(clause["nested"] for clause in should if clause.get("nested", {}).get("_name") == "image_knn_query") + assert image_knn["path"] == "image_embedding" + assert image_knn["score_mode"] == "max" + assert image_knn["query"]["knn"]["field"] == "image_embedding.vector" -- libgit2 0.21.2