Commit de98daa3248e2ec397d44865c6d4b866ea8a6adb
1 parent
9b956985
多模态召回优化
Showing
8 changed files
with
76 additions
and
25 deletions
Show diff stats
config/config.yaml
| ... | ... | @@ -213,14 +213,15 @@ query_config: |
| 213 | 213 | knn_text_boost: 4 |
| 214 | 214 | knn_image_boost: 4 |
| 215 | 215 | |
| 216 | - knn_text_k: 150 | |
| 217 | - knn_text_num_candidates: 400 | |
| 216 | + # knn_text_num_candidates = k * 3.4 | |
| 217 | + knn_text_k: 160 | |
| 218 | + knn_text_num_candidates: 560 | |
| 218 | 219 | |
| 219 | - knn_text_k_long: 300 | |
| 220 | - knn_text_num_candidates_long: 720 | |
| 220 | + knn_text_k_long: 400 | |
| 221 | + knn_text_num_candidates_long: 1200 | |
| 221 | 222 | |
| 222 | - knn_image_k: 3000 | |
| 223 | - knn_image_num_candidates: 7200 | |
| 223 | + knn_image_k: 400 | |
| 224 | + knn_image_num_candidates: 1200 | |
| 224 | 225 | |
| 225 | 226 | # Function Score配置(ES层打分规则) |
| 226 | 227 | function_score: |
| ... | ... | @@ -236,6 +237,9 @@ coarse_rank: |
| 236 | 237 | fusion: |
| 237 | 238 | text_bias: 0.1 |
| 238 | 239 | text_exponent: 0.35 |
| 240 | + # base_query_trans_* 相对 base_query 的权重(见 search/rerank_client 中文本 dismax 融合) | |
| 241 | + # 因为es的打分已经给了trans进行了折扣,所以这里不再继续折扣 | |
| 242 | + text_translation_weight: 1.0 | |
| 239 | 243 | knn_text_weight: 1.0 |
| 240 | 244 | knn_image_weight: 1.0 |
| 241 | 245 | knn_tie_breaker: 0.1 |
| ... | ... | @@ -273,6 +277,8 @@ rerank: |
| 273 | 277 | fine_exponent: 1.0 |
| 274 | 278 | text_bias: 0.1 |
| 275 | 279 | text_exponent: 0.35 |
| 280 | + # base_query_trans_* 相对 base_query 的权重(见 search/rerank_client 中文本 dismax 融合) | |
| 281 | + text_translation_weight: 1.0 | |
| 276 | 282 | knn_text_weight: 1.0 |
| 277 | 283 | knn_image_weight: 1.0 |
| 278 | 284 | knn_tie_breaker: 0.1 | ... | ... |
config/loader.py
| ... | ... | @@ -498,6 +498,9 @@ class AppConfigLoader: |
| 498 | 498 | knn_tie_breaker=float(coarse_fusion_raw.get("knn_tie_breaker", 0.0)), |
| 499 | 499 | knn_bias=float(coarse_fusion_raw.get("knn_bias", 0.6)), |
| 500 | 500 | knn_exponent=float(coarse_fusion_raw.get("knn_exponent", 0.2)), |
| 501 | + text_translation_weight=float( | |
| 502 | + coarse_fusion_raw.get("text_translation_weight", 0.8) | |
| 503 | + ), | |
| 501 | 504 | ), |
| 502 | 505 | ), |
| 503 | 506 | fine_rank=FineRankConfig( |
| ... | ... | @@ -538,6 +541,9 @@ class AppConfigLoader: |
| 538 | 541 | knn_exponent=float(fusion_raw.get("knn_exponent", 0.2)), |
| 539 | 542 | fine_bias=float(fusion_raw.get("fine_bias", 0.00001)), |
| 540 | 543 | fine_exponent=float(fusion_raw.get("fine_exponent", 1.0)), |
| 544 | + text_translation_weight=float( | |
| 545 | + fusion_raw.get("text_translation_weight", 0.8) | |
| 546 | + ), | |
| 541 | 547 | ), |
| 542 | 548 | ), |
| 543 | 549 | spu_config=SPUConfig( | ... | ... |
config/schema.py
| ... | ... | @@ -119,6 +119,8 @@ class RerankFusionConfig: |
| 119 | 119 | knn_exponent: float = 0.2 |
| 120 | 120 | fine_bias: float = 0.00001 |
| 121 | 121 | fine_exponent: float = 1.0 |
| 122 | + #: 翻译子句 named query 分数相对原文 base_query 的权重(加权后再与原文做 dismax 融合) | |
| 123 | + text_translation_weight: float = 0.8 | |
| 122 | 124 | |
| 123 | 125 | |
| 124 | 126 | @dataclass(frozen=True) |
| ... | ... | @@ -136,6 +138,8 @@ class CoarseRankFusionConfig: |
| 136 | 138 | knn_tie_breaker: float = 0.0 |
| 137 | 139 | knn_bias: float = 0.6 |
| 138 | 140 | knn_exponent: float = 0.2 |
| 141 | + #: 翻译子句 named query 分数相对原文 base_query 的权重(加权后再与原文做 dismax 融合) | |
| 142 | + text_translation_weight: float = 0.8 | |
| 139 | 143 | |
| 140 | 144 | |
| 141 | 145 | @dataclass(frozen=True) | ... | ... |
mappings/generate_search_products_mapping.py
mappings/search_products.json
search/es_query_builder.py
| ... | ... | @@ -272,16 +272,30 @@ class ESQueryBuilder: |
| 272 | 272 | }) |
| 273 | 273 | |
| 274 | 274 | if has_image_embedding: |
| 275 | - recall_clauses.append({ | |
| 276 | - "knn": { | |
| 277 | - "field": self.image_embedding_field, | |
| 278 | - "query_vector": image_query_vector.tolist(), | |
| 279 | - "k": self.knn_image_k, | |
| 280 | - "num_candidates": self.knn_image_num_candidates, | |
| 281 | - "boost": self.knn_image_boost, | |
| 282 | - "_name": "image_knn_query", | |
| 283 | - } | |
| 284 | - }) | |
| 275 | + nested_path, _, _ = str(self.image_embedding_field).rpartition(".") | |
| 276 | + image_knn_query = { | |
| 277 | + "field": self.image_embedding_field, | |
| 278 | + "query_vector": image_query_vector.tolist(), | |
| 279 | + "k": self.knn_image_k, | |
| 280 | + "num_candidates": self.knn_image_num_candidates, | |
| 281 | + "boost": self.knn_image_boost, | |
| 282 | + } | |
| 283 | + if nested_path: | |
| 284 | + recall_clauses.append({ | |
| 285 | + "nested": { | |
| 286 | + "path": nested_path, | |
| 287 | + "_name": "image_knn_query", | |
| 288 | + "query": {"knn": image_knn_query}, | |
| 289 | + "score_mode": "max", | |
| 290 | + } | |
| 291 | + }) | |
| 292 | + else: | |
| 293 | + recall_clauses.append({ | |
| 294 | + "knn": { | |
| 295 | + **image_knn_query, | |
| 296 | + "_name": "image_knn_query", | |
| 297 | + } | |
| 298 | + }) | |
| 285 | 299 | |
| 286 | 300 | # 4. Build main query structure: filters and recall |
| 287 | 301 | if recall_clauses: | ... | ... |
search/rerank_client.py
| ... | ... | @@ -186,7 +186,7 @@ translation_score:所有åå—以 base_query_trans_ 开头的 named query çš„å |
| 186 | 186 | |
| 187 | 187 | ä¸é—´å˜é‡ï¼šè®¡ç®—原始query得分和翻译query得分 |
| 188 | 188 | weighted_source : |
| 189 | -weighted_translation : 0.8 * translation_score | |
| 189 | +weighted_translation : text_translation_weight * translation_score(由 fusion.text_translation_weight é…置) | |
| 190 | 190 | |
| 191 | 191 | 区分主信å·å’Œè¾…助信å·ï¼š |
| 192 | 192 | åˆæˆprimary_text_scoreå’Œsupport_text_scoreï¼Œå– æ›´å¼º 的那一路(原文检索 vs ç¿»è¯‘æ£€ç´¢ï¼‰ä½œä¸ºä¸»ä¿¡å· |
| ... | ... | @@ -197,7 +197,12 @@ support_text_score : weighted_source + weighted_translation - primary_text_score |
| 197 | 197 | 最终text_scoreï¼šä¸»ä¿¡å· + 0.25 * è¾…åŠ©ä¿¡å· |
| 198 | 198 | text_score : primary_text_score + 0.25 * support_text_score |
| 199 | 199 | """ |
| 200 | -def _collect_text_score_components(matched_queries: Any, fallback_es_score: float) -> Dict[str, float]: | |
| 200 | +def _collect_text_score_components( | |
| 201 | + matched_queries: Any, | |
| 202 | + fallback_es_score: float, | |
| 203 | + *, | |
| 204 | + translation_weight: float, | |
| 205 | +) -> Dict[str, float]: | |
| 201 | 206 | source_score = _extract_named_query_score(matched_queries, "base_query") |
| 202 | 207 | translation_score = 0.0 |
| 203 | 208 | |
| ... | ... | @@ -216,7 +221,7 @@ def _collect_text_score_components(matched_queries: Any, fallback_es_score: floa |
| 216 | 221 | translation_score = 1.0 |
| 217 | 222 | |
| 218 | 223 | weighted_source = source_score |
| 219 | - weighted_translation = 0.8 * translation_score | |
| 224 | + weighted_translation = float(translation_weight) * translation_score | |
| 220 | 225 | weighted_components = [weighted_source, weighted_translation] |
| 221 | 226 | primary_text_score = max(weighted_components) |
| 222 | 227 | support_text_score = sum(weighted_components) - primary_text_score |
| ... | ... | @@ -249,7 +254,11 @@ def _build_hit_signal_bundle( |
| 249 | 254 | ) -> Dict[str, Any]: |
| 250 | 255 | es_score = _to_score(hit.get("_score")) |
| 251 | 256 | matched_queries = hit.get("matched_queries") |
| 252 | - text_components = _collect_text_score_components(matched_queries, es_score) | |
| 257 | + text_components = _collect_text_score_components( | |
| 258 | + matched_queries, | |
| 259 | + es_score, | |
| 260 | + translation_weight=fusion.text_translation_weight, | |
| 261 | + ) | |
| 253 | 262 | knn_components = _collect_knn_score_components(matched_queries, fusion) |
| 254 | 263 | return { |
| 255 | 264 | "doc_id": hit.get("_id"), | ... | ... |
tests/test_es_query_builder.py
| ... | ... | @@ -35,6 +35,16 @@ def _recall_should_clauses(es_body: Dict[str, Any]) -> list[Dict[str, Any]]: |
| 35 | 35 | return [root] |
| 36 | 36 | |
| 37 | 37 | |
| 38 | +def _recall_clause_name(clause: Dict[str, Any]) -> str | None: | |
| 39 | + if "bool" in clause: | |
| 40 | + return clause["bool"].get("_name") | |
| 41 | + if "knn" in clause: | |
| 42 | + return clause["knn"].get("_name") | |
| 43 | + if "nested" in clause: | |
| 44 | + return clause["nested"].get("_name") | |
| 45 | + return None | |
| 46 | + | |
| 47 | + | |
| 38 | 48 | def test_knn_clause_moves_under_query_should_and_uses_outer_filters(): |
| 39 | 49 | qb = _builder() |
| 40 | 50 | q = qb.build_query( |
| ... | ... | @@ -188,9 +198,11 @@ def test_image_knn_clause_is_added_alongside_base_translation_and_text_knn(): |
| 188 | 198 | |
| 189 | 199 | should = _recall_should_clauses(q) |
| 190 | 200 | names = [ |
| 191 | - clause["bool"]["_name"] if "bool" in clause else clause["knn"]["_name"] | |
| 201 | + _recall_clause_name(clause) | |
| 192 | 202 | for clause in should |
| 193 | 203 | ] |
| 194 | 204 | assert names == ["base_query", "base_query_trans_zh", "knn_query", "image_knn_query"] |
| 195 | - image_knn = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "image_knn_query") | |
| 196 | - assert image_knn["field"] == "image_embedding.vector" | |
| 205 | + image_knn = next(clause["nested"] for clause in should if clause.get("nested", {}).get("_name") == "image_knn_query") | |
| 206 | + assert image_knn["path"] == "image_embedding" | |
| 207 | + assert image_knn["score_mode"] == "max" | |
| 208 | + assert image_knn["query"]["knn"]["field"] == "image_embedding.vector" | ... | ... |