Commit de98daa3248e2ec397d44865c6d4b866ea8a6adb

Authored by tangwang
1 parent 9b956985

多模态召回优化

config/config.yaml
@@ -213,14 +213,15 @@ query_config: @@ -213,14 +213,15 @@ query_config:
213 knn_text_boost: 4 213 knn_text_boost: 4
214 knn_image_boost: 4 214 knn_image_boost: 4
215 215
216 - knn_text_k: 150  
217 - knn_text_num_candidates: 400 216 + # knn_text_num_candidates = k * 3.4
  217 + knn_text_k: 160
  218 + knn_text_num_candidates: 560
218 219
219 - knn_text_k_long: 300  
220 - knn_text_num_candidates_long: 720 220 + knn_text_k_long: 400
  221 + knn_text_num_candidates_long: 1200
221 222
222 - knn_image_k: 3000  
223 - knn_image_num_candidates: 7200 223 + knn_image_k: 400
  224 + knn_image_num_candidates: 1200
224 225
225 # Function Score配置(ES层打分规则) 226 # Function Score配置(ES层打分规则)
226 function_score: 227 function_score:
@@ -236,6 +237,9 @@ coarse_rank: @@ -236,6 +237,9 @@ coarse_rank:
236 fusion: 237 fusion:
237 text_bias: 0.1 238 text_bias: 0.1
238 text_exponent: 0.35 239 text_exponent: 0.35
  240 + # base_query_trans_* 相对 base_query 的权重(见 search/rerank_client 中文本 dismax 融合)
  241 + # 因为es的打分已经给了trans进行了折扣,所以这里不再继续折扣
  242 + text_translation_weight: 1.0
239 knn_text_weight: 1.0 243 knn_text_weight: 1.0
240 knn_image_weight: 1.0 244 knn_image_weight: 1.0
241 knn_tie_breaker: 0.1 245 knn_tie_breaker: 0.1
@@ -273,6 +277,8 @@ rerank: @@ -273,6 +277,8 @@ rerank:
273 fine_exponent: 1.0 277 fine_exponent: 1.0
274 text_bias: 0.1 278 text_bias: 0.1
275 text_exponent: 0.35 279 text_exponent: 0.35
  280 + # base_query_trans_* 相对 base_query 的权重(见 search/rerank_client 中文本 dismax 融合)
  281 + text_translation_weight: 1.0
276 knn_text_weight: 1.0 282 knn_text_weight: 1.0
277 knn_image_weight: 1.0 283 knn_image_weight: 1.0
278 knn_tie_breaker: 0.1 284 knn_tie_breaker: 0.1
@@ -498,6 +498,9 @@ class AppConfigLoader: @@ -498,6 +498,9 @@ class AppConfigLoader:
498 knn_tie_breaker=float(coarse_fusion_raw.get("knn_tie_breaker", 0.0)), 498 knn_tie_breaker=float(coarse_fusion_raw.get("knn_tie_breaker", 0.0)),
499 knn_bias=float(coarse_fusion_raw.get("knn_bias", 0.6)), 499 knn_bias=float(coarse_fusion_raw.get("knn_bias", 0.6)),
500 knn_exponent=float(coarse_fusion_raw.get("knn_exponent", 0.2)), 500 knn_exponent=float(coarse_fusion_raw.get("knn_exponent", 0.2)),
  501 + text_translation_weight=float(
  502 + coarse_fusion_raw.get("text_translation_weight", 0.8)
  503 + ),
501 ), 504 ),
502 ), 505 ),
503 fine_rank=FineRankConfig( 506 fine_rank=FineRankConfig(
@@ -538,6 +541,9 @@ class AppConfigLoader: @@ -538,6 +541,9 @@ class AppConfigLoader:
538 knn_exponent=float(fusion_raw.get("knn_exponent", 0.2)), 541 knn_exponent=float(fusion_raw.get("knn_exponent", 0.2)),
539 fine_bias=float(fusion_raw.get("fine_bias", 0.00001)), 542 fine_bias=float(fusion_raw.get("fine_bias", 0.00001)),
540 fine_exponent=float(fusion_raw.get("fine_exponent", 1.0)), 543 fine_exponent=float(fusion_raw.get("fine_exponent", 1.0)),
  544 + text_translation_weight=float(
  545 + fusion_raw.get("text_translation_weight", 0.8)
  546 + ),
541 ), 547 ),
542 ), 548 ),
543 spu_config=SPUConfig( 549 spu_config=SPUConfig(
@@ -119,6 +119,8 @@ class RerankFusionConfig: @@ -119,6 +119,8 @@ class RerankFusionConfig:
119 knn_exponent: float = 0.2 119 knn_exponent: float = 0.2
120 fine_bias: float = 0.00001 120 fine_bias: float = 0.00001
121 fine_exponent: float = 1.0 121 fine_exponent: float = 1.0
  122 + #: 翻译子句 named query 分数相对原文 base_query 的权重(加权后再与原文做 dismax 融合)
  123 + text_translation_weight: float = 0.8
122 124
123 125
124 @dataclass(frozen=True) 126 @dataclass(frozen=True)
@@ -136,6 +138,8 @@ class CoarseRankFusionConfig: @@ -136,6 +138,8 @@ class CoarseRankFusionConfig:
136 knn_tie_breaker: float = 0.0 138 knn_tie_breaker: float = 0.0
137 knn_bias: float = 0.6 139 knn_bias: float = 0.6
138 knn_exponent: float = 0.2 140 knn_exponent: float = 0.2
  141 + #: 翻译子句 named query 分数相对原文 base_query 的权重(加权后再与原文做 dismax 融合)
  142 + text_translation_weight: float = 0.8
139 143
140 144
141 @dataclass(frozen=True) 145 @dataclass(frozen=True)
mappings/generate_search_products_mapping.py
@@ -80,7 +80,7 @@ ANALYZERS = { @@ -80,7 +80,7 @@ ANALYZERS = {
80 } 80 }
81 81
82 SETTINGS = { 82 SETTINGS = {
83 - "number_of_shards": 1, 83 + "number_of_shards": 4,
84 "number_of_replicas": 0, 84 "number_of_replicas": 0,
85 "refresh_interval": "30s", 85 "refresh_interval": "30s",
86 "analysis": { 86 "analysis": {
mappings/search_products.json
1 { 1 {
2 "settings": { 2 "settings": {
3 - "number_of_shards": 1, 3 + "number_of_shards": 4,
4 "number_of_replicas": 0, 4 "number_of_replicas": 0,
5 "refresh_interval": "30s", 5 "refresh_interval": "30s",
6 "analysis": { 6 "analysis": {
search/es_query_builder.py
@@ -272,16 +272,30 @@ class ESQueryBuilder: @@ -272,16 +272,30 @@ class ESQueryBuilder:
272 }) 272 })
273 273
274 if has_image_embedding: 274 if has_image_embedding:
275 - recall_clauses.append({  
276 - "knn": {  
277 - "field": self.image_embedding_field,  
278 - "query_vector": image_query_vector.tolist(),  
279 - "k": self.knn_image_k,  
280 - "num_candidates": self.knn_image_num_candidates,  
281 - "boost": self.knn_image_boost,  
282 - "_name": "image_knn_query",  
283 - }  
284 - }) 275 + nested_path, _, _ = str(self.image_embedding_field).rpartition(".")
  276 + image_knn_query = {
  277 + "field": self.image_embedding_field,
  278 + "query_vector": image_query_vector.tolist(),
  279 + "k": self.knn_image_k,
  280 + "num_candidates": self.knn_image_num_candidates,
  281 + "boost": self.knn_image_boost,
  282 + }
  283 + if nested_path:
  284 + recall_clauses.append({
  285 + "nested": {
  286 + "path": nested_path,
  287 + "_name": "image_knn_query",
  288 + "query": {"knn": image_knn_query},
  289 + "score_mode": "max",
  290 + }
  291 + })
  292 + else:
  293 + recall_clauses.append({
  294 + "knn": {
  295 + **image_knn_query,
  296 + "_name": "image_knn_query",
  297 + }
  298 + })
285 299
286 # 4. Build main query structure: filters and recall 300 # 4. Build main query structure: filters and recall
287 if recall_clauses: 301 if recall_clauses:
search/rerank_client.py
@@ -186,7 +186,7 @@ translation_score:所有å字以 base_query_trans_ 开头的 named query çš„å @@ -186,7 +186,7 @@ translation_score:所有å字以 base_query_trans_ 开头的 named query çš„å
186 186
187 中间å˜é‡ï¼šè®¡ç®—原始query得分和翻译query得分 187 中间å˜é‡ï¼šè®¡ç®—原始query得分和翻译query得分
188 weighted_source : 188 weighted_source :
189 -weighted_translation : 0.8 * translation_score 189 +weighted_translation : text_translation_weight * translation_score(由 fusion.text_translation_weight é…置)
190 190
191 区分主信å·å’Œè¾…助信å·ï¼š 191 区分主信å·å’Œè¾…助信å·ï¼š
192 åˆæˆprimary_text_scoreå’Œsupport_text_scoreï¼Œå– æ›´å¼º 的那一路(原文检索 vs ç¿»è¯‘æ£€ç´¢ï¼‰ä½œä¸ºä¸»ä¿¡å· 192 åˆæˆprimary_text_scoreå’Œsupport_text_scoreï¼Œå– æ›´å¼º 的那一路(原文检索 vs 翻译检索)作为主信å·
@@ -197,7 +197,12 @@ support_text_score : weighted_source + weighted_translation - primary_text_score @@ -197,7 +197,12 @@ support_text_score : weighted_source + weighted_translation - primary_text_score
197 最终text_scoreï¼šä¸»ä¿¡å· + 0.25 * è¾…åŠ©ä¿¡å· 197 最终text_scoreï¼šä¸»ä¿¡å· + 0.25 * 辅助信å·
198 text_score : primary_text_score + 0.25 * support_text_score 198 text_score : primary_text_score + 0.25 * support_text_score
199 """ 199 """
200 -def _collect_text_score_components(matched_queries: Any, fallback_es_score: float) -> Dict[str, float]: 200 +def _collect_text_score_components(
  201 + matched_queries: Any,
  202 + fallback_es_score: float,
  203 + *,
  204 + translation_weight: float,
  205 +) -> Dict[str, float]:
201 source_score = _extract_named_query_score(matched_queries, "base_query") 206 source_score = _extract_named_query_score(matched_queries, "base_query")
202 translation_score = 0.0 207 translation_score = 0.0
203 208
@@ -216,7 +221,7 @@ def _collect_text_score_components(matched_queries: Any, fallback_es_score: floa @@ -216,7 +221,7 @@ def _collect_text_score_components(matched_queries: Any, fallback_es_score: floa
216 translation_score = 1.0 221 translation_score = 1.0
217 222
218 weighted_source = source_score 223 weighted_source = source_score
219 - weighted_translation = 0.8 * translation_score 224 + weighted_translation = float(translation_weight) * translation_score
220 weighted_components = [weighted_source, weighted_translation] 225 weighted_components = [weighted_source, weighted_translation]
221 primary_text_score = max(weighted_components) 226 primary_text_score = max(weighted_components)
222 support_text_score = sum(weighted_components) - primary_text_score 227 support_text_score = sum(weighted_components) - primary_text_score
@@ -249,7 +254,11 @@ def _build_hit_signal_bundle( @@ -249,7 +254,11 @@ def _build_hit_signal_bundle(
249 ) -> Dict[str, Any]: 254 ) -> Dict[str, Any]:
250 es_score = _to_score(hit.get("_score")) 255 es_score = _to_score(hit.get("_score"))
251 matched_queries = hit.get("matched_queries") 256 matched_queries = hit.get("matched_queries")
252 - text_components = _collect_text_score_components(matched_queries, es_score) 257 + text_components = _collect_text_score_components(
  258 + matched_queries,
  259 + es_score,
  260 + translation_weight=fusion.text_translation_weight,
  261 + )
253 knn_components = _collect_knn_score_components(matched_queries, fusion) 262 knn_components = _collect_knn_score_components(matched_queries, fusion)
254 return { 263 return {
255 "doc_id": hit.get("_id"), 264 "doc_id": hit.get("_id"),
tests/test_es_query_builder.py
@@ -35,6 +35,16 @@ def _recall_should_clauses(es_body: Dict[str, Any]) -> list[Dict[str, Any]]: @@ -35,6 +35,16 @@ def _recall_should_clauses(es_body: Dict[str, Any]) -> list[Dict[str, Any]]:
35 return [root] 35 return [root]
36 36
37 37
  38 +def _recall_clause_name(clause: Dict[str, Any]) -> str | None:
  39 + if "bool" in clause:
  40 + return clause["bool"].get("_name")
  41 + if "knn" in clause:
  42 + return clause["knn"].get("_name")
  43 + if "nested" in clause:
  44 + return clause["nested"].get("_name")
  45 + return None
  46 +
  47 +
38 def test_knn_clause_moves_under_query_should_and_uses_outer_filters(): 48 def test_knn_clause_moves_under_query_should_and_uses_outer_filters():
39 qb = _builder() 49 qb = _builder()
40 q = qb.build_query( 50 q = qb.build_query(
@@ -188,9 +198,11 @@ def test_image_knn_clause_is_added_alongside_base_translation_and_text_knn(): @@ -188,9 +198,11 @@ def test_image_knn_clause_is_added_alongside_base_translation_and_text_knn():
188 198
189 should = _recall_should_clauses(q) 199 should = _recall_should_clauses(q)
190 names = [ 200 names = [
191 - clause["bool"]["_name"] if "bool" in clause else clause["knn"]["_name"] 201 + _recall_clause_name(clause)
192 for clause in should 202 for clause in should
193 ] 203 ]
194 assert names == ["base_query", "base_query_trans_zh", "knn_query", "image_knn_query"] 204 assert names == ["base_query", "base_query_trans_zh", "knn_query", "image_knn_query"]
195 - image_knn = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "image_knn_query")  
196 - assert image_knn["field"] == "image_embedding.vector" 205 + image_knn = next(clause["nested"] for clause in should if clause.get("nested", {}).get("_name") == "image_knn_query")
  206 + assert image_knn["path"] == "image_embedding"
  207 + assert image_knn["score_mode"] == "max"
  208 + assert image_knn["query"]["knn"]["field"] == "image_embedding.vector"