Commit 24edc20863244080ed60692f15cdb241d79a06ea
1 parent
dc403578
修改_extract_combined_knn_score相关的代码以及配置,
重排融合:之前有knn的配置bias和exponential。现在,文本和图片的embedding相似需要融合,融合方式是dis_max,因此需要配置: 1)各自的权重和tie_breaker 2)整个向量方面的权重(bias和exponential)
Showing
6 changed files
with
87 additions
and
8 deletions
Show diff stats
config/config.yaml
| ... | ... | @@ -169,7 +169,7 @@ query_config: |
| 169 | 169 | |
| 170 | 170 | # Embedding字段名称 |
| 171 | 171 | text_embedding_field: "title_embedding" |
| 172 | - image_embedding_field: null | |
| 172 | + image_embedding_field: "image_embedding.vector" | |
| 173 | 173 | |
| 174 | 174 | # 返回字段配置(_source includes) |
| 175 | 175 | # null表示返回所有字段,[]表示不返回任何字段,列表表示只返回指定字段 |
| ... | ... | @@ -225,13 +225,19 @@ rerank: |
| 225 | 225 | rerank_query_template: "{query}" |
| 226 | 226 | rerank_doc_template: "{title}" |
| 227 | 227 | # 乘法融合:fused = Π (max(score,0) + bias) ** exponent(rerank / text / knn 三项) |
| 228 | + # 其中 knn_score 先做一层 dis_max: | |
| 229 | + # max(knn_text_weight * text_knn, knn_image_weight * image_knn) | |
| 230 | + # + knn_tie_breaker * 另一侧较弱信号 | |
| 228 | 231 | fusion: |
| 229 | 232 | rerank_bias: 0.00001 |
| 230 | 233 | rerank_exponent: 1.0 |
| 231 | 234 | text_bias: 0.1 |
| 232 | 235 | text_exponent: 0.35 |
| 236 | + knn_text_weight: 1.0 | |
| 237 | + knn_image_weight: 1.0 | |
| 238 | + knn_tie_breaker: 0.1 | |
| 233 | 239 | knn_bias: 0.6 |
| 234 | - knn_exponent: 0.0 | |
| 240 | + knn_exponent: 0.2 | |
| 235 | 241 | |
| 236 | 242 | # 可扩展服务/provider 注册表(单一配置源) |
| 237 | 243 | services: | ... | ... |
config/loader.py
| ... | ... | @@ -477,6 +477,9 @@ class AppConfigLoader: |
| 477 | 477 | rerank_exponent=float(fusion_raw.get("rerank_exponent", 1.0)), |
| 478 | 478 | text_bias=float(fusion_raw.get("text_bias", 0.1)), |
| 479 | 479 | text_exponent=float(fusion_raw.get("text_exponent", 0.35)), |
| 480 | + knn_text_weight=float(fusion_raw.get("knn_text_weight", 1.0)), | |
| 481 | + knn_image_weight=float(fusion_raw.get("knn_image_weight", 1.0)), | |
| 482 | + knn_tie_breaker=float(fusion_raw.get("knn_tie_breaker", 0.0)), | |
| 480 | 483 | knn_bias=float(fusion_raw.get("knn_bias", 0.6)), |
| 481 | 484 | knn_exponent=float(fusion_raw.get("knn_exponent", 0.2)), |
| 482 | 485 | ), | ... | ... |
config/schema.py
| ... | ... | @@ -104,6 +104,9 @@ class RerankFusionConfig: |
| 104 | 104 | rerank_exponent: float = 1.0 |
| 105 | 105 | text_bias: float = 0.1 |
| 106 | 106 | text_exponent: float = 0.35 |
| 107 | + knn_text_weight: float = 1.0 | |
| 108 | + knn_image_weight: float = 1.0 | |
| 109 | + knn_tie_breaker: float = 0.0 | |
| 107 | 110 | knn_bias: float = 0.6 |
| 108 | 111 | knn_exponent: float = 0.2 |
| 109 | 112 | ... | ... |
search/rerank_client.py
| ... | ... | @@ -152,11 +152,30 @@ def _extract_named_query_score(matched_queries: Any, name: str) -> float: |
| 152 | 152 | return 0.0 |
| 153 | 153 | |
| 154 | 154 | |
| 155 | -def _extract_combined_knn_score(matched_queries: Any) -> float: | |
| 156 | - return max( | |
| 157 | - _extract_named_query_score(matched_queries, "knn_query"), | |
| 158 | - _extract_named_query_score(matched_queries, "image_knn_query"), | |
| 159 | - ) | |
| 155 | +def _collect_knn_score_components( | |
| 156 | + matched_queries: Any, | |
| 157 | + fusion: RerankFusionConfig, | |
| 158 | +) -> Dict[str, float]: | |
| 159 | + text_knn_score = _extract_named_query_score(matched_queries, "knn_query") | |
| 160 | + image_knn_score = _extract_named_query_score(matched_queries, "image_knn_query") | |
| 161 | + | |
| 162 | + weighted_text_knn_score = text_knn_score * float(fusion.knn_text_weight) | |
| 163 | + weighted_image_knn_score = image_knn_score * float(fusion.knn_image_weight) | |
| 164 | + weighted_components = [weighted_text_knn_score, weighted_image_knn_score] | |
| 165 | + | |
| 166 | + primary_knn_score = max(weighted_components) | |
| 167 | + support_knn_score = sum(weighted_components) - primary_knn_score | |
| 168 | + knn_score = primary_knn_score + float(fusion.knn_tie_breaker) * support_knn_score | |
| 169 | + | |
| 170 | + return { | |
| 171 | + "text_knn_score": text_knn_score, | |
| 172 | + "image_knn_score": image_knn_score, | |
| 173 | + "weighted_text_knn_score": weighted_text_knn_score, | |
| 174 | + "weighted_image_knn_score": weighted_image_knn_score, | |
| 175 | + "primary_knn_score": primary_knn_score, | |
| 176 | + "support_knn_score": support_knn_score, | |
| 177 | + "knn_score": knn_score, | |
| 178 | + } | |
| 160 | 179 | |
| 161 | 180 | """ |
| 162 | 181 | 原始变量: |
| ... | ... | @@ -279,7 +298,8 @@ def fuse_scores_and_resort( |
| 279 | 298 | es_score = _to_score(hit.get("_score")) |
| 280 | 299 | rerank_score = _to_score(rerank_scores[idx]) |
| 281 | 300 | matched_queries = hit.get("matched_queries") |
| 282 | - knn_score = _extract_combined_knn_score(matched_queries) | |
| 301 | + knn_components = _collect_knn_score_components(matched_queries, f) | |
| 302 | + knn_score = knn_components["knn_score"] | |
| 283 | 303 | text_components = _collect_text_score_components(matched_queries, es_score) |
| 284 | 304 | text_score = text_components["text_score"] |
| 285 | 305 | rerank_factor, text_factor, knn_factor, fused = _multiply_fusion_factors( |
| ... | ... | @@ -293,6 +313,8 @@ def fuse_scores_and_resort( |
| 293 | 313 | hit["_rerank_score"] = rerank_score |
| 294 | 314 | hit["_text_score"] = text_score |
| 295 | 315 | hit["_knn_score"] = knn_score |
| 316 | + hit["_text_knn_score"] = knn_components["text_knn_score"] | |
| 317 | + hit["_image_knn_score"] = knn_components["image_knn_score"] | |
| 296 | 318 | hit["_fused_score"] = fused |
| 297 | 319 | hit["_style_intent_selected_sku_boost"] = style_boost |
| 298 | 320 | if debug: |
| ... | ... | @@ -300,6 +322,8 @@ def fuse_scores_and_resort( |
| 300 | 322 | hit["_text_translation_score"] = text_components["translation_score"] |
| 301 | 323 | hit["_text_primary_score"] = text_components["primary_text_score"] |
| 302 | 324 | hit["_text_support_score"] = text_components["support_text_score"] |
| 325 | + hit["_knn_primary_score"] = knn_components["primary_knn_score"] | |
| 326 | + hit["_knn_support_score"] = knn_components["support_knn_score"] | |
| 303 | 327 | |
| 304 | 328 | if debug: |
| 305 | 329 | debug_entry = { |
| ... | ... | @@ -318,6 +342,12 @@ def fuse_scores_and_resort( |
| 318 | 342 | and text_components["source_score"] <= 0.0 |
| 319 | 343 | and text_components["translation_score"] <= 0.0 |
| 320 | 344 | ), |
| 345 | + "text_knn_score": knn_components["text_knn_score"], | |
| 346 | + "image_knn_score": knn_components["image_knn_score"], | |
| 347 | + "weighted_text_knn_score": knn_components["weighted_text_knn_score"], | |
| 348 | + "weighted_image_knn_score": knn_components["weighted_image_knn_score"], | |
| 349 | + "knn_primary_score": knn_components["primary_knn_score"], | |
| 350 | + "knn_support_score": knn_components["support_knn_score"], | |
| 321 | 351 | "knn_score": knn_score, |
| 322 | 352 | "rerank_factor": rerank_factor, |
| 323 | 353 | "text_factor": text_factor, | ... | ... |
search/searcher.py
| ... | ... | @@ -882,6 +882,7 @@ class Searcher: |
| 882 | 882 | "index_languages": index_langs, |
| 883 | 883 | "translations": context.query_analysis.translations, |
| 884 | 884 | "has_vector": context.query_analysis.query_vector is not None, |
| 885 | + "has_image_vector": getattr(parsed_query, "image_query_vector", None) is not None, | |
| 885 | 886 | "query_tokens": getattr(parsed_query, "query_tokens", []), |
| 886 | 887 | "intent_detection": context.get_intermediate_result("style_intent_profile"), |
| 887 | 888 | }, | ... | ... |
tests/test_rerank_client.py
| ... | ... | @@ -168,3 +168,39 @@ def test_fuse_scores_and_resort_uses_max_of_text_and_image_knn_scores(): |
| 168 | 168 | |
| 169 | 169 | assert isclose(hits[0]["_knn_score"], 0.7, rel_tol=1e-9) |
| 170 | 170 | assert isclose(debug[0]["knn_score"], 0.7, rel_tol=1e-9) |
| 171 | + assert isclose(debug[0]["text_knn_score"], 0.2, rel_tol=1e-9) | |
| 172 | + assert isclose(debug[0]["image_knn_score"], 0.7, rel_tol=1e-9) | |
| 173 | + | |
| 174 | + | |
| 175 | +def test_fuse_scores_and_resort_applies_knn_dismax_weights_and_tie_breaker(): | |
| 176 | + hits = [ | |
| 177 | + { | |
| 178 | + "_id": "mm-hit", | |
| 179 | + "_score": 1.0, | |
| 180 | + "matched_queries": { | |
| 181 | + "base_query": 1.5, | |
| 182 | + "knn_query": 0.4, | |
| 183 | + "image_knn_query": 0.5, | |
| 184 | + }, | |
| 185 | + } | |
| 186 | + ] | |
| 187 | + fusion = RerankFusionConfig( | |
| 188 | + rerank_bias=0.00001, | |
| 189 | + rerank_exponent=1.0, | |
| 190 | + text_bias=0.1, | |
| 191 | + text_exponent=0.35, | |
| 192 | + knn_text_weight=2.0, | |
| 193 | + knn_image_weight=1.0, | |
| 194 | + knn_tie_breaker=0.25, | |
| 195 | + knn_bias=0.0, | |
| 196 | + knn_exponent=1.0, | |
| 197 | + ) | |
| 198 | + | |
| 199 | + debug = fuse_scores_and_resort(hits, [0.8], fusion=fusion, debug=True) | |
| 200 | + | |
| 201 | + expected_knn = 0.8 + 0.25 * 0.5 | |
| 202 | + assert isclose(hits[0]["_knn_score"], expected_knn, rel_tol=1e-9) | |
| 203 | + assert isclose(debug[0]["weighted_text_knn_score"], 0.8, rel_tol=1e-9) | |
| 204 | + assert isclose(debug[0]["weighted_image_knn_score"], 0.5, rel_tol=1e-9) | |
| 205 | + assert isclose(debug[0]["knn_primary_score"], 0.8, rel_tol=1e-9) | |
| 206 | + assert isclose(debug[0]["knn_support_score"], 0.5, rel_tol=1e-9) | ... | ... |