Commit de98daa3248e2ec397d44865c6d4b866ea8a6adb

Authored by tangwang
1 parent 9b956985

多模态召回优化

config/config.yaml
... ... @@ -213,14 +213,15 @@ query_config:
213 213 knn_text_boost: 4
214 214 knn_image_boost: 4
215 215  
216   - knn_text_k: 150
217   - knn_text_num_candidates: 400
  216 + # knn_text_num_candidates = k * 3.4
  217 + knn_text_k: 160
  218 + knn_text_num_candidates: 560
218 219  
219   - knn_text_k_long: 300
220   - knn_text_num_candidates_long: 720
  220 + knn_text_k_long: 400
  221 + knn_text_num_candidates_long: 1200
221 222  
222   - knn_image_k: 3000
223   - knn_image_num_candidates: 7200
  223 + knn_image_k: 400
  224 + knn_image_num_candidates: 1200
224 225  
225 226 # Function Score配置(ES层打分规则)
226 227 function_score:
... ... @@ -236,6 +237,9 @@ coarse_rank:
236 237 fusion:
237 238 text_bias: 0.1
238 239 text_exponent: 0.35
  240 + # base_query_trans_* 相对 base_query 的权重(见 search/rerank_client 中文本 dismax 融合)
  241 + # 因为es的打分已经给了trans进行了折扣,所以这里不再继续折扣
  242 + text_translation_weight: 1.0
239 243 knn_text_weight: 1.0
240 244 knn_image_weight: 1.0
241 245 knn_tie_breaker: 0.1
... ... @@ -273,6 +277,8 @@ rerank:
273 277 fine_exponent: 1.0
274 278 text_bias: 0.1
275 279 text_exponent: 0.35
  280 + # base_query_trans_* 相对 base_query 的权重(见 search/rerank_client 中文本 dismax 融合)
  281 + text_translation_weight: 1.0
276 282 knn_text_weight: 1.0
277 283 knn_image_weight: 1.0
278 284 knn_tie_breaker: 0.1
... ...
config/loader.py
... ... @@ -498,6 +498,9 @@ class AppConfigLoader:
498 498 knn_tie_breaker=float(coarse_fusion_raw.get("knn_tie_breaker", 0.0)),
499 499 knn_bias=float(coarse_fusion_raw.get("knn_bias", 0.6)),
500 500 knn_exponent=float(coarse_fusion_raw.get("knn_exponent", 0.2)),
  501 + text_translation_weight=float(
  502 + coarse_fusion_raw.get("text_translation_weight", 0.8)
  503 + ),
501 504 ),
502 505 ),
503 506 fine_rank=FineRankConfig(
... ... @@ -538,6 +541,9 @@ class AppConfigLoader:
538 541 knn_exponent=float(fusion_raw.get("knn_exponent", 0.2)),
539 542 fine_bias=float(fusion_raw.get("fine_bias", 0.00001)),
540 543 fine_exponent=float(fusion_raw.get("fine_exponent", 1.0)),
  544 + text_translation_weight=float(
  545 + fusion_raw.get("text_translation_weight", 0.8)
  546 + ),
541 547 ),
542 548 ),
543 549 spu_config=SPUConfig(
... ...
config/schema.py
... ... @@ -119,6 +119,8 @@ class RerankFusionConfig:
119 119 knn_exponent: float = 0.2
120 120 fine_bias: float = 0.00001
121 121 fine_exponent: float = 1.0
  122 + #: 翻译子句 named query 分数相对原文 base_query 的权重(加权后再与原文做 dismax 融合)
  123 + text_translation_weight: float = 0.8
122 124  
123 125  
124 126 @dataclass(frozen=True)
... ... @@ -136,6 +138,8 @@ class CoarseRankFusionConfig:
136 138 knn_tie_breaker: float = 0.0
137 139 knn_bias: float = 0.6
138 140 knn_exponent: float = 0.2
  141 + #: 翻译子句 named query 分数相对原文 base_query 的权重(加权后再与原文做 dismax 融合)
  142 + text_translation_weight: float = 0.8
139 143  
140 144  
141 145 @dataclass(frozen=True)
... ...
mappings/generate_search_products_mapping.py
... ... @@ -80,7 +80,7 @@ ANALYZERS = {
80 80 }
81 81  
82 82 SETTINGS = {
83   - "number_of_shards": 1,
  83 + "number_of_shards": 4,
84 84 "number_of_replicas": 0,
85 85 "refresh_interval": "30s",
86 86 "analysis": {
... ...
mappings/search_products.json
1 1 {
2 2 "settings": {
3   - "number_of_shards": 1,
  3 + "number_of_shards": 4,
4 4 "number_of_replicas": 0,
5 5 "refresh_interval": "30s",
6 6 "analysis": {
... ...
search/es_query_builder.py
... ... @@ -272,16 +272,30 @@ class ESQueryBuilder:
272 272 })
273 273  
274 274 if has_image_embedding:
275   - recall_clauses.append({
276   - "knn": {
277   - "field": self.image_embedding_field,
278   - "query_vector": image_query_vector.tolist(),
279   - "k": self.knn_image_k,
280   - "num_candidates": self.knn_image_num_candidates,
281   - "boost": self.knn_image_boost,
282   - "_name": "image_knn_query",
283   - }
284   - })
  275 + nested_path, _, _ = str(self.image_embedding_field).rpartition(".")
  276 + image_knn_query = {
  277 + "field": self.image_embedding_field,
  278 + "query_vector": image_query_vector.tolist(),
  279 + "k": self.knn_image_k,
  280 + "num_candidates": self.knn_image_num_candidates,
  281 + "boost": self.knn_image_boost,
  282 + }
  283 + if nested_path:
  284 + recall_clauses.append({
  285 + "nested": {
  286 + "path": nested_path,
  287 + "_name": "image_knn_query",
  288 + "query": {"knn": image_knn_query},
  289 + "score_mode": "max",
  290 + }
  291 + })
  292 + else:
  293 + recall_clauses.append({
  294 + "knn": {
  295 + **image_knn_query,
  296 + "_name": "image_knn_query",
  297 + }
  298 + })
285 299  
286 300 # 4. Build main query structure: filters and recall
287 301 if recall_clauses:
... ...
search/rerank_client.py
... ... @@ -186,7 +186,7 @@ translation_score:所有å字以 base_query_trans_ 开头的 named query çš„å
186 186  
187 187 中间å˜é‡ï¼šè®¡ç®—原始query得分和翻译query得分
188 188 weighted_source :
189   -weighted_translation : 0.8 * translation_score
  189 +weighted_translation : text_translation_weight * translation_score(由 fusion.text_translation_weight é…置)
190 190  
191 191 区分主信å·å’Œè¾…助信å·ï¼š
192 192 åˆæˆprimary_text_scoreå’Œsupport_text_scoreï¼Œå– æ›´å¼º 的那一路(原文检索 vs 翻译检索)作为主信å·
... ... @@ -197,7 +197,12 @@ support_text_score : weighted_source + weighted_translation - primary_text_score
197 197 最终text_scoreï¼šä¸»ä¿¡å· + 0.25 * 辅助信å·
198 198 text_score : primary_text_score + 0.25 * support_text_score
199 199 """
200   -def _collect_text_score_components(matched_queries: Any, fallback_es_score: float) -> Dict[str, float]:
  200 +def _collect_text_score_components(
  201 + matched_queries: Any,
  202 + fallback_es_score: float,
  203 + *,
  204 + translation_weight: float,
  205 +) -> Dict[str, float]:
201 206 source_score = _extract_named_query_score(matched_queries, "base_query")
202 207 translation_score = 0.0
203 208  
... ... @@ -216,7 +221,7 @@ def _collect_text_score_components(matched_queries: Any, fallback_es_score: floa
216 221 translation_score = 1.0
217 222  
218 223 weighted_source = source_score
219   - weighted_translation = 0.8 * translation_score
  224 + weighted_translation = float(translation_weight) * translation_score
220 225 weighted_components = [weighted_source, weighted_translation]
221 226 primary_text_score = max(weighted_components)
222 227 support_text_score = sum(weighted_components) - primary_text_score
... ... @@ -249,7 +254,11 @@ def _build_hit_signal_bundle(
249 254 ) -> Dict[str, Any]:
250 255 es_score = _to_score(hit.get("_score"))
251 256 matched_queries = hit.get("matched_queries")
252   - text_components = _collect_text_score_components(matched_queries, es_score)
  257 + text_components = _collect_text_score_components(
  258 + matched_queries,
  259 + es_score,
  260 + translation_weight=fusion.text_translation_weight,
  261 + )
253 262 knn_components = _collect_knn_score_components(matched_queries, fusion)
254 263 return {
255 264 "doc_id": hit.get("_id"),
... ...
tests/test_es_query_builder.py
... ... @@ -35,6 +35,16 @@ def _recall_should_clauses(es_body: Dict[str, Any]) -> list[Dict[str, Any]]:
35 35 return [root]
36 36  
37 37  
  38 +def _recall_clause_name(clause: Dict[str, Any]) -> str | None:
  39 + if "bool" in clause:
  40 + return clause["bool"].get("_name")
  41 + if "knn" in clause:
  42 + return clause["knn"].get("_name")
  43 + if "nested" in clause:
  44 + return clause["nested"].get("_name")
  45 + return None
  46 +
  47 +
38 48 def test_knn_clause_moves_under_query_should_and_uses_outer_filters():
39 49 qb = _builder()
40 50 q = qb.build_query(
... ... @@ -188,9 +198,11 @@ def test_image_knn_clause_is_added_alongside_base_translation_and_text_knn():
188 198  
189 199 should = _recall_should_clauses(q)
190 200 names = [
191   - clause["bool"]["_name"] if "bool" in clause else clause["knn"]["_name"]
  201 + _recall_clause_name(clause)
192 202 for clause in should
193 203 ]
194 204 assert names == ["base_query", "base_query_trans_zh", "knn_query", "image_knn_query"]
195   - image_knn = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "image_knn_query")
196   - assert image_knn["field"] == "image_embedding.vector"
  205 + image_knn = next(clause["nested"] for clause in should if clause.get("nested", {}).get("_name") == "image_knn_query")
  206 + assert image_knn["path"] == "image_embedding"
  207 + assert image_knn["score_mode"] == "max"
  208 + assert image_knn["query"]["knn"]["field"] == "image_embedding.vector"
... ...