From 24edc20863244080ed60692f15cdb241d79a06ea Mon Sep 17 00:00:00 2001 From: tangwang Date: Fri, 27 Mar 2026 08:33:16 +0800 Subject: [PATCH] 修改_extract_combined_knn_score相关的代码以及配置, 重排融合:之前有knn的配置bias和exponential。现在,文本和图片的embedding相似需要融合,融合方式是dis_max,因此需要配置: 1)各自的权重和tie_breaker 2)整个向量方面的权重(bias和exponential) --- config/config.yaml | 10 ++++++++-- config/loader.py | 3 +++ config/schema.py | 3 +++ search/rerank_client.py | 42 ++++++++++++++++++++++++++++++++++++------ search/searcher.py | 1 + tests/test_rerank_client.py | 36 ++++++++++++++++++++++++++++++++++++ 6 files changed, 87 insertions(+), 8 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index b569ff1..6d0cbd6 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -169,7 +169,7 @@ query_config: # Embedding字段名称 text_embedding_field: "title_embedding" - image_embedding_field: null + image_embedding_field: "image_embedding.vector" # 返回字段配置(_source includes) # null表示返回所有字段,[]表示不返回任何字段,列表表示只返回指定字段 @@ -225,13 +225,19 @@ rerank: rerank_query_template: "{query}" rerank_doc_template: "{title}" # 乘法融合:fused = Π (max(score,0) + bias) ** exponent(rerank / text / knn 三项) + # 其中 knn_score 先做一层 dis_max: + # max(knn_text_weight * text_knn, knn_image_weight * image_knn) + # + knn_tie_breaker * 另一侧较弱信号 fusion: rerank_bias: 0.00001 rerank_exponent: 1.0 text_bias: 0.1 text_exponent: 0.35 + knn_text_weight: 1.0 + knn_image_weight: 1.0 + knn_tie_breaker: 0.1 knn_bias: 0.6 - knn_exponent: 0.0 + knn_exponent: 0.2 # 可扩展服务/provider 注册表(单一配置源) services: diff --git a/config/loader.py b/config/loader.py index f01b495..c0cd5c2 100644 --- a/config/loader.py +++ b/config/loader.py @@ -477,6 +477,9 @@ class AppConfigLoader: rerank_exponent=float(fusion_raw.get("rerank_exponent", 1.0)), text_bias=float(fusion_raw.get("text_bias", 0.1)), text_exponent=float(fusion_raw.get("text_exponent", 0.35)), + knn_text_weight=float(fusion_raw.get("knn_text_weight", 1.0)), + knn_image_weight=float(fusion_raw.get("knn_image_weight", 1.0)), + knn_tie_breaker=float(fusion_raw.get("knn_tie_breaker", 0.0)), knn_bias=float(fusion_raw.get("knn_bias", 0.6)), knn_exponent=float(fusion_raw.get("knn_exponent", 0.2)), ), diff --git a/config/schema.py b/config/schema.py index 2e4fabb..fd2148b 100644 --- a/config/schema.py +++ b/config/schema.py @@ -104,6 +104,9 @@ class RerankFusionConfig: rerank_exponent: float = 1.0 text_bias: float = 0.1 text_exponent: float = 0.35 + knn_text_weight: float = 1.0 + knn_image_weight: float = 1.0 + knn_tie_breaker: float = 0.0 knn_bias: float = 0.6 knn_exponent: float = 0.2 diff --git a/search/rerank_client.py b/search/rerank_client.py index 8d49b58..f6a1217 100644 --- a/search/rerank_client.py +++ b/search/rerank_client.py @@ -152,11 +152,30 @@ def _extract_named_query_score(matched_queries: Any, name: str) -> float: return 0.0 -def _extract_combined_knn_score(matched_queries: Any) -> float: - return max( - _extract_named_query_score(matched_queries, "knn_query"), - _extract_named_query_score(matched_queries, "image_knn_query"), - ) +def _collect_knn_score_components( + matched_queries: Any, + fusion: RerankFusionConfig, +) -> Dict[str, float]: + text_knn_score = _extract_named_query_score(matched_queries, "knn_query") + image_knn_score = _extract_named_query_score(matched_queries, "image_knn_query") + + weighted_text_knn_score = text_knn_score * float(fusion.knn_text_weight) + weighted_image_knn_score = image_knn_score * float(fusion.knn_image_weight) + weighted_components = [weighted_text_knn_score, weighted_image_knn_score] + + primary_knn_score = max(weighted_components) + support_knn_score = sum(weighted_components) - primary_knn_score + knn_score = primary_knn_score + float(fusion.knn_tie_breaker) * support_knn_score + + return { + "text_knn_score": text_knn_score, + "image_knn_score": image_knn_score, + "weighted_text_knn_score": weighted_text_knn_score, + "weighted_image_knn_score": weighted_image_knn_score, + "primary_knn_score": primary_knn_score, + "support_knn_score": support_knn_score, + "knn_score": knn_score, + } """ 原始变量: @@ -279,7 +298,8 @@ def fuse_scores_and_resort( es_score = _to_score(hit.get("_score")) rerank_score = _to_score(rerank_scores[idx]) matched_queries = hit.get("matched_queries") - knn_score = _extract_combined_knn_score(matched_queries) + knn_components = _collect_knn_score_components(matched_queries, f) + knn_score = knn_components["knn_score"] text_components = _collect_text_score_components(matched_queries, es_score) text_score = text_components["text_score"] rerank_factor, text_factor, knn_factor, fused = _multiply_fusion_factors( @@ -293,6 +313,8 @@ def fuse_scores_and_resort( hit["_rerank_score"] = rerank_score hit["_text_score"] = text_score hit["_knn_score"] = knn_score + hit["_text_knn_score"] = knn_components["text_knn_score"] + hit["_image_knn_score"] = knn_components["image_knn_score"] hit["_fused_score"] = fused hit["_style_intent_selected_sku_boost"] = style_boost if debug: @@ -300,6 +322,8 @@ def fuse_scores_and_resort( hit["_text_translation_score"] = text_components["translation_score"] hit["_text_primary_score"] = text_components["primary_text_score"] hit["_text_support_score"] = text_components["support_text_score"] + hit["_knn_primary_score"] = knn_components["primary_knn_score"] + hit["_knn_support_score"] = knn_components["support_knn_score"] if debug: debug_entry = { @@ -318,6 +342,12 @@ def fuse_scores_and_resort( and text_components["source_score"] <= 0.0 and text_components["translation_score"] <= 0.0 ), + "text_knn_score": knn_components["text_knn_score"], + "image_knn_score": knn_components["image_knn_score"], + "weighted_text_knn_score": knn_components["weighted_text_knn_score"], + "weighted_image_knn_score": knn_components["weighted_image_knn_score"], + "knn_primary_score": knn_components["primary_knn_score"], + "knn_support_score": knn_components["support_knn_score"], "knn_score": knn_score, "rerank_factor": rerank_factor, "text_factor": text_factor, diff --git a/search/searcher.py b/search/searcher.py index 75d6620..66b78ab 100644 --- a/search/searcher.py +++ b/search/searcher.py @@ -882,6 +882,7 @@ class Searcher: "index_languages": index_langs, "translations": context.query_analysis.translations, "has_vector": context.query_analysis.query_vector is not None, + "has_image_vector": getattr(parsed_query, "image_query_vector", None) is not None, "query_tokens": getattr(parsed_query, "query_tokens", []), "intent_detection": context.get_intermediate_result("style_intent_profile"), }, diff --git a/tests/test_rerank_client.py b/tests/test_rerank_client.py index e0fccf3..459b3f8 100644 --- a/tests/test_rerank_client.py +++ b/tests/test_rerank_client.py @@ -168,3 +168,39 @@ def test_fuse_scores_and_resort_uses_max_of_text_and_image_knn_scores(): assert isclose(hits[0]["_knn_score"], 0.7, rel_tol=1e-9) assert isclose(debug[0]["knn_score"], 0.7, rel_tol=1e-9) + assert isclose(debug[0]["text_knn_score"], 0.2, rel_tol=1e-9) + assert isclose(debug[0]["image_knn_score"], 0.7, rel_tol=1e-9) + + +def test_fuse_scores_and_resort_applies_knn_dismax_weights_and_tie_breaker(): + hits = [ + { + "_id": "mm-hit", + "_score": 1.0, + "matched_queries": { + "base_query": 1.5, + "knn_query": 0.4, + "image_knn_query": 0.5, + }, + } + ] + fusion = RerankFusionConfig( + rerank_bias=0.00001, + rerank_exponent=1.0, + text_bias=0.1, + text_exponent=0.35, + knn_text_weight=2.0, + knn_image_weight=1.0, + knn_tie_breaker=0.25, + knn_bias=0.0, + knn_exponent=1.0, + ) + + debug = fuse_scores_and_resort(hits, [0.8], fusion=fusion, debug=True) + + expected_knn = 0.8 + 0.25 * 0.5 + assert isclose(hits[0]["_knn_score"], expected_knn, rel_tol=1e-9) + assert isclose(debug[0]["weighted_text_knn_score"], 0.8, rel_tol=1e-9) + assert isclose(debug[0]["weighted_image_knn_score"], 0.5, rel_tol=1e-9) + assert isclose(debug[0]["knn_primary_score"], 0.8, rel_tol=1e-9) + assert isclose(debug[0]["knn_support_score"], 0.5, rel_tol=1e-9) -- libgit2 0.21.2