diff --git a/search/es_query_builder.py b/search/es_query_builder.py index 09b708d..b661e40 100644 --- a/search/es_query_builder.py +++ b/search/es_query_builder.py @@ -2,11 +2,16 @@ Elasticsearch query builder. Converts parsed queries and search parameters into ES DSL queries. + +Simplified architecture: +- filters and (text_recall or embedding_recall) +- function_score wrapper for boosting fields """ from typing import Dict, Any, List, Optional, Union import numpy as np from .boolean_parser import QueryNode +from .query_config import FUNCTION_SCORE_CONFIG class ESQueryBuilder: @@ -51,14 +56,20 @@ class ESQueryBuilder: min_score: Optional[float] = None ) -> Dict[str, Any]: """ - Build complete ES query (重构版). + Build complete ES query (简化版). + + 结构:filters and (text_recall or embedding_recall) + - filters: 前端传递的过滤条件永远起作用 + - text_recall: 文本相关性召回(中英文字段都用) + - embedding_recall: 向量召回(KNN) + - function_score: 包装召回部分,支持提权字段 Args: query_text: Query text for BM25 matching query_vector: Query embedding for KNN search query_node: Parsed boolean expression tree - filters: Exact match filters - range_filters: Range filters for numeric fields + filters: Exact match filters (always applied) + range_filters: Range filters for numeric fields (always applied) size: Number of results from_: Offset for pagination enable_knn: Whether to use KNN search @@ -80,44 +91,161 @@ class ESQueryBuilder: "includes": self.source_fields } - # Build main query - if query_node and query_node.operator != 'TERM': - # Complex boolean query - query_clause = self._build_boolean_query(query_node) - else: - # Simple text query - query_clause = self._build_text_query(query_text) - - # Add filters if provided - if filters or range_filters: - filter_clauses = self._build_filters(filters, range_filters) + # 1. Build recall queries (text or embedding) + recall_clauses = [] + + # Text recall (always include if query_text exists) + if query_text: + if query_node and query_node.operator != 'TERM': + # Complex boolean query + text_query = self._build_boolean_query(query_node) + else: + # Simple text query + text_query = self._build_text_query(query_text) + recall_clauses.append(text_query) + + # Embedding recall (KNN - separate from query, handled below) + has_embedding = enable_knn and query_vector is not None and self.text_embedding_field + + # 2. Build filter clauses (always applied) + filter_clauses = self._build_filters(filters, range_filters) + + # 3. Build main query structure: filters and recall + if recall_clauses: + # Combine text recalls with OR logic (if multiple) + if len(recall_clauses) == 1: + recall_query = recall_clauses[0] + else: + recall_query = { + "bool": { + "should": recall_clauses, + "minimum_should_match": 1 + } + } + + # Wrap recall with function_score for boosting + recall_query = self._wrap_with_function_score(recall_query) + + # Combine filters and recall if filter_clauses: es_query["query"] = { "bool": { - "must": [query_clause], + "must": [recall_query], "filter": filter_clauses } } else: - es_query["query"] = query_clause + es_query["query"] = recall_query else: - es_query["query"] = query_clause + # No recall queries, only filters (match_all filtered) + if filter_clauses: + es_query["query"] = { + "bool": { + "must": [{"match_all": {}}], + "filter": filter_clauses + } + } + else: + es_query["query"] = {"match_all": {}} - # Add KNN search if enabled and vector provided - if enable_knn and query_vector is not None and self.text_embedding_field: + # 4. Add KNN search if enabled (separate from query, ES will combine) + if has_embedding: knn_clause = { "field": self.text_embedding_field, "query_vector": query_vector.tolist(), "k": knn_k, - "num_candidates": knn_num_candidates + "num_candidates": knn_num_candidates, + "boost": 0.2 # Lower boost for embedding recall } es_query["knn"] = knn_clause - # Add minimum score filter + # 5. Add minimum score filter if min_score is not None: es_query["min_score"] = min_score return es_query + + def _wrap_with_function_score(self, query: Dict[str, Any]) -> Dict[str, Any]: + """ + Wrap query with function_score for boosting fields. + + Args: + query: Base query to wrap + + Returns: + Function score query or original query if no functions configured + """ + functions = self._build_score_functions() + + # If no functions configured, return original query + if not functions: + return query + + # Build function_score query + function_score_query = { + "function_score": { + "query": query, + "functions": functions, + "score_mode": FUNCTION_SCORE_CONFIG.get("score_mode", "sum"), + "boost_mode": FUNCTION_SCORE_CONFIG.get("boost_mode", "multiply") + } + } + + return function_score_query + + def _build_score_functions(self) -> List[Dict[str, Any]]: + """ + Build function_score functions from config. + + Returns: + List of function score functions + """ + functions = [] + config_functions = FUNCTION_SCORE_CONFIG.get("functions", []) + + for func_config in config_functions: + func_type = func_config.get("type") + + if func_type == "filter_weight": + # Filter + Weight + functions.append({ + "filter": func_config["filter"], + "weight": func_config.get("weight", 1.0) + }) + + elif func_type == "field_value_factor": + # Field Value Factor + functions.append({ + "field_value_factor": { + "field": func_config["field"], + "factor": func_config.get("factor", 1.0), + "modifier": func_config.get("modifier", "none"), + "missing": func_config.get("missing", 1.0) + } + }) + + elif func_type == "decay": + # Decay Function (gauss/exp/linear) + decay_func = func_config.get("function", "gauss") + field = func_config["field"] + + decay_params = { + "origin": func_config.get("origin", "now"), + "scale": func_config["scale"] + } + + if "offset" in func_config: + decay_params["offset"] = func_config["offset"] + if "decay" in func_config: + decay_params["decay"] = func_config["decay"] + + functions.append({ + decay_func: { + field: decay_params + } + }) + + return functions def _build_text_query(self, query_text: str) -> Dict[str, Any]: """ @@ -235,11 +363,19 @@ class ESQueryBuilder: "term": {field: value} }) - # 2. 处理范围过滤(RangeFilter Pydantic 模型) + # 2. 处理范围过滤(支持 RangeFilter Pydantic 模型或字典) if range_filters: for field, range_filter in range_filters.items(): - # 将 RangeFilter 模型转换为字典 - range_dict = range_filter.model_dump(exclude_none=True) + # 支持 Pydantic 模型或字典格式 + if hasattr(range_filter, 'model_dump'): + # Pydantic 模型 + range_dict = range_filter.model_dump(exclude_none=True) + elif isinstance(range_filter, dict): + # 已经是字典格式 + range_dict = {k: v for k, v in range_filter.items() if v is not None} + else: + # 其他格式,跳过 + continue if range_dict: filter_clauses.append({ diff --git a/search/multilang_query_builder.py b/search/multilang_query_builder.py deleted file mode 100644 index 1df781b..0000000 --- a/search/multilang_query_builder.py +++ /dev/null @@ -1,459 +0,0 @@ -""" -Multi-language query builder for handling domain-specific searches. - -This module extends the ESQueryBuilder to support multi-language field mappings, -allowing queries to be routed to appropriate language-specific fields while -maintaining a unified external interface. -""" - -from typing import Dict, Any, List, Optional -import numpy as np -import logging -import re - -from query import ParsedQuery -from .es_query_builder import ESQueryBuilder -from .query_config import DEFAULT_MATCH_FIELDS, DOMAIN_FIELDS, FUNCTION_SCORE_CONFIG - -logger = logging.getLogger(__name__) - - -class MultiLanguageQueryBuilder(ESQueryBuilder): - """ - Enhanced query builder with multi-language support. - - Handles routing queries to appropriate language-specific fields based on: - 1. Detected query language - 2. Available translations - 3. Domain configuration (language_field_mapping) - """ - - def __init__( - self, - index_name: str, - match_fields: Optional[List[str]] = None, - text_embedding_field: Optional[str] = None, - image_embedding_field: Optional[str] = None, - source_fields: Optional[List[str]] = None - ): - """ - Initialize multi-language query builder. - - Args: - index_name: ES index name - match_fields: Fields to search for text matching (default: from query_config) - text_embedding_field: Field name for text embeddings - image_embedding_field: Field name for image embeddings - source_fields: Fields to return in search results (_source includes) - """ - self.function_score_config = FUNCTION_SCORE_CONFIG - - # Use provided match_fields or default - if match_fields is None: - match_fields = DEFAULT_MATCH_FIELDS - - super().__init__( - index_name=index_name, - match_fields=match_fields, - text_embedding_field=text_embedding_field, - image_embedding_field=image_embedding_field, - source_fields=source_fields - ) - - # Build domain configurations from query_config - self.domain_configs = DOMAIN_FIELDS - - def _get_domain_fields(self, domain_name: str) -> List[str]: - """Get fields for a specific domain with boost notation.""" - return self.domain_configs.get(domain_name, DEFAULT_MATCH_FIELDS) - - def build_multilang_query( - self, - parsed_query: ParsedQuery, - query_vector: Optional[np.ndarray] = None, - query_node: Optional[Any] = None, - filters: Optional[Dict[str, Any]] = None, - range_filters: Optional[Dict[str, Any]] = None, - size: int = 10, - from_: int = 0, - enable_knn: bool = True, - knn_k: int = 50, - knn_num_candidates: int = 200, - min_score: Optional[float] = None - ) -> Dict[str, Any]: - """ - Build ES query with multi-language support (简化版). - - Args: - parsed_query: Parsed query with language info and translations - query_vector: Query embedding for KNN search - filters: Exact match filters - range_filters: Range filters for numeric fields - size: Number of results - from_: Offset for pagination - enable_knn: Whether to use KNN search - knn_k: K value for KNN - knn_num_candidates: Number of candidates for KNN - min_score: Minimum score threshold - - Returns: - ES query DSL dictionary - """ - # 1. 根据域选择匹配字段(默认域使用 DEFAULT_MATCH_FIELDS) - domain = parsed_query.domain or "default" - domain_fields = self.domain_configs.get(domain) or DEFAULT_MATCH_FIELDS - - # 2. 临时切换 match_fields,复用基类 build_query 逻辑 - original_match_fields = self.match_fields - self.match_fields = domain_fields - try: - return super().build_query( - query_text=parsed_query.rewritten_query or parsed_query.normalized_query, - query_vector=query_vector, - query_node=query_node, - filters=filters, - range_filters=range_filters, - size=size, - from_=from_, - enable_knn=enable_knn, - knn_k=knn_k, - knn_num_candidates=knn_num_candidates, - min_score=min_score - ) - finally: - # 恢复原始配置,避免影响后续查询 - self.match_fields = original_match_fields - - def _build_score_functions(self) -> List[Dict[str, Any]]: - """ - 从配置构建 function_score 的打分函数列表 - - Returns: - 打分函数列表(ES原生格式) - """ - if not self.function_score_config or not self.function_score_config.functions: - return [] - - functions = [] - - for func_config in self.function_score_config.functions: - func_type = func_config.get('type') - - if func_type == 'filter_weight': - # Filter + Weight - functions.append({ - "filter": func_config['filter'], - "weight": func_config.get('weight', 1.0) - }) - - elif func_type == 'field_value_factor': - # Field Value Factor - functions.append({ - "field_value_factor": { - "field": func_config['field'], - "factor": func_config.get('factor', 1.0), - "modifier": func_config.get('modifier', 'none'), - "missing": func_config.get('missing', 1.0) - } - }) - - elif func_type == 'decay': - # Decay Function (gauss/exp/linear) - decay_func = func_config.get('function', 'gauss') - field = func_config['field'] - - decay_params = { - "origin": func_config.get('origin', 'now'), - "scale": func_config['scale'] - } - - if 'offset' in func_config: - decay_params['offset'] = func_config['offset'] - if 'decay' in func_config: - decay_params['decay'] = func_config['decay'] - - functions.append({ - decay_func: { - field: decay_params - } - }) - - return functions - - def _build_multilang_text_query( - self, - parsed_query: ParsedQuery, - domain_config: Dict[str, Any] - ) -> Dict[str, Any]: - """ - Build text query with multi-language field routing. - - Args: - parsed_query: Parsed query with language info - domain_config: Domain configuration - - Returns: - ES query clause - """ - if not domain_config.language_field_mapping: - # No multi-language mapping, use all fields with default analyzer - fields_with_boost = [] - for field_name in domain_config.fields: - field = self._get_field_by_name(field_name) - if field and field.boost != 1.0: - fields_with_boost.append(f"{field_name}^{field.boost}") - else: - fields_with_boost.append(field_name) - - return { - "multi_match": { - "query": parsed_query.rewritten_query, - "fields": fields_with_boost, - "minimum_should_match": "67%", - "tie_breaker": 0.9, - "boost": domain_config.boost, - "_name": f"{domain_config.name}_query" - } - } - - # Multi-language mapping exists - build targeted queries - should_clauses = [] - available_languages = set(domain_config.language_field_mapping.keys()) - - # 1. Query in detected language (if it exists in mapping) - detected_lang = parsed_query.detected_language - if detected_lang in available_languages: - target_fields = domain_config.language_field_mapping[detected_lang] - fields_with_boost = self._apply_field_boosts(target_fields) - - should_clauses.append({ - "multi_match": { - "query": parsed_query.rewritten_query, - "fields": fields_with_boost, - "minimum_should_match": "67%", - "tie_breaker": 0.9, - "boost": domain_config.boost * 1.5, # Higher boost for detected language - "_name": f"{domain_config.name}_{detected_lang}_query" - } - }) - logger.debug(f"Added query for detected language '{detected_lang}'") - - # 2. Query in translated languages (only for languages in mapping) - for lang, translation in parsed_query.translations.items(): - # Only use translations for languages that exist in the mapping - if lang in available_languages and translation and translation.strip(): - target_fields = domain_config.language_field_mapping[lang] - fields_with_boost = self._apply_field_boosts(target_fields) - - should_clauses.append({ - "multi_match": { - "query": translation, - "fields": fields_with_boost, - "minimum_should_match": "67%", - "tie_breaker": 0.9, - "boost": domain_config.boost, - "_name": f"{domain_config.name}_{lang}_translated_query" - } - }) - logger.debug(f"Added translated query for language '{lang}'") - - # 3. Fallback: query all fields in mapping if no language-specific query was built - if not should_clauses: - logger.debug("No language mapping matched, using all fields from mapping") - # Use all fields from all languages in the mapping - all_mapped_fields = [] - for lang_fields in domain_config.language_field_mapping.values(): - all_mapped_fields.extend(lang_fields) - # Remove duplicates while preserving order - unique_fields = list(dict.fromkeys(all_mapped_fields)) - fields_with_boost = self._apply_field_boosts(unique_fields) - - should_clauses.append({ - "multi_match": { - "query": parsed_query.rewritten_query, - "fields": fields_with_boost, - "minimum_should_match": "67%", - "tie_breaker": 0.9, - "boost": domain_config.boost * 0.8, # Lower boost for fallback - "_name": f"{domain_config.name}_fallback_query" - } - }) - - if len(should_clauses) == 1: - return should_clauses[0] - else: - return { - "bool": { - "should": should_clauses, - "minimum_should_match": 1 - } - } - - def _apply_field_boosts(self, field_names: List[str]) -> List[str]: - """Apply boost values to field names.""" - result = [] - for field_name in field_names: - field = self._get_field_by_name(field_name) - if field and field.boost != 1.0: - result.append(f"{field_name}^{field.boost}") - else: - result.append(field_name) - return result - - def _build_boolean_query_from_tuple(self, node) -> Dict[str, Any]: - """ - Build query from boolean expression tuple. - - Args: - node: Boolean expression tuple (operator, terms...) - - Returns: - ES query clause - """ - if not node: - return {"match_all": {}} - - # Handle different node types from boolean parser - if hasattr(node, 'operator'): - # QueryNode object - operator = node.operator - terms = node.terms if hasattr(node, 'terms') else None - - # For TERM nodes, check if there's a value - if operator == 'TERM' and hasattr(node, 'value') and node.value: - terms = node.value - elif isinstance(node, tuple) and len(node) > 0: - # Tuple format from boolean parser - if hasattr(node[0], 'operator'): - # Nested tuple with QueryNode - operator = node[0].operator - terms = node[0].terms - elif isinstance(node[0], str): - # Simple tuple like ('TERM', 'field:value') - operator = node[0] - terms = node[1] if len(node) > 1 else '' - else: - # Complex tuple like (OR( TERM(...), TERM(...) ), score) - if hasattr(node[0], '__class__') and hasattr(node[0], '__name__'): - # Constructor call like OR(...) - operator = node[0].__name__ - elif str(node[0]).startswith('('): - # String representation of constructor call - match = re.match(r'(\w+)\(', str(node[0])) - if match: - operator = match.group(1) - else: - return {"match_all": {}} - else: - operator = str(node[0]) - - # Extract terms from nested structure - terms = [] - if len(node) > 1 and isinstance(node[1], tuple): - terms = node[1] - else: - return {"match_all": {}} - - - if operator == 'TERM': - # Leaf node - handle field:query format - if isinstance(terms, str) and ':' in terms: - field, value = terms.split(':', 1) - return { - "term": { - field: value - } - } - elif isinstance(terms, str): - # Simple text term - create match query - return { - "multi_match": { - "query": terms, - "fields": self.match_fields, - "type": "best_fields", - "operator": "AND" - } - } - else: - # Invalid TERM node - return empty match - return { - "match_none": {} - } - - elif operator == 'OR': - # Any term must match - should_clauses = [] - if terms: - for term in terms: - clause = self._build_boolean_query_from_tuple(term) - if clause and clause.get("match_none") is None: - should_clauses.append(clause) - - if should_clauses: - return { - "bool": { - "should": should_clauses, - "minimum_should_match": 1 - } - } - else: - return {"match_none": {}} - - elif operator == 'AND': - # All terms must match - must_clauses = [] - if terms: - for term in terms: - clause = self._build_boolean_query_from_tuple(term) - if clause and clause.get("match_none") is None: - must_clauses.append(clause) - - if must_clauses: - return { - "bool": { - "must": must_clauses - } - } - else: - return {"match_none": {}} - - elif operator == 'ANDNOT': - # First term must match, second must not - if len(terms) >= 2: - return { - "bool": { - "must": [self._build_boolean_query_from_tuple(terms[0])], - "must_not": [self._build_boolean_query_from_tuple(terms[1])] - } - } - else: - return self._build_boolean_query_from_tuple(terms[0]) - - elif operator == 'RANK': - # Like OR but for ranking (all terms contribute to score) - should_clauses = [] - for term in terms: - should_clauses.append(self._build_boolean_query_from_tuple(term)) - return { - "bool": { - "should": should_clauses - } - } - - else: - # Unknown operator - return {"match_all": {}} - - def get_domain_summary(self) -> Dict[str, Any]: - """Get summary of all configured domains.""" - summary = {} - for domain_name, domain_config in self.domain_configs.items(): - summary[domain_name] = { - "label": domain_config.label, - "fields": domain_config.fields, - "analyzer": domain_config.analyzer.value, - "boost": domain_config.boost, - "has_multilang_mapping": domain_config.language_field_mapping is not None, - "supported_languages": list(domain_config.language_field_mapping.keys()) if domain_config.language_field_mapping else [] - } - return summary \ No newline at end of file diff --git a/search/query_config.py b/search/query_config.py index 1088e2d..ca77054 100644 --- a/search/query_config.py +++ b/search/query_config.py @@ -17,14 +17,24 @@ TEXT_EMBEDDING_FIELD = "title_embedding" IMAGE_EMBEDDING_FIELD = "image_embedding" # Default match fields for text search (with boost) +# 文本召回:同时搜索中英文字段,两者相互补充 DEFAULT_MATCH_FIELDS = [ + # 中文字段 "title_zh^3.0", "brief_zh^1.5", "description_zh^1.0", "vendor_zh^1.5", - "tags^1.0", "category_path_zh^1.5", - "category_name_zh^1.5" + "category_name_zh^1.5", + # 英文字段 + "title_en^3.0", + "brief_en^1.5", + "description_en^1.0", + "vendor_en^1.5", + "category_path_en^1.5", + "category_name_en^1.5", + # 语言无关字段 + "tags^1.0", ] # Domain-specific match fields diff --git a/search/searcher.py b/search/searcher.py index 238bcae..76375b9 100644 --- a/search/searcher.py +++ b/search/searcher.py @@ -13,7 +13,6 @@ from query import QueryParser, ParsedQuery from embeddings import CLIPImageEncoder from .boolean_parser import BooleanParser, QueryNode from .es_query_builder import ESQueryBuilder -from .multilang_query_builder import MultiLanguageQueryBuilder from .rerank_engine import RerankEngine from .query_config import ( DEFAULT_INDEX_NAME, @@ -112,8 +111,8 @@ class Searcher: self.text_embedding_field = TEXT_EMBEDDING_FIELD self.image_embedding_field = IMAGE_EMBEDDING_FIELD - # Query builder - use multi-language version - self.query_builder = MultiLanguageQueryBuilder( + # Query builder - simplified single-layer architecture + self.query_builder = ESQueryBuilder( index_name=index_name, match_fields=self.match_fields, text_embedding_field=self.text_embedding_field, @@ -274,8 +273,8 @@ class Searcher: filters = {} filters['tenant_id'] = tenant_id - es_query = self.query_builder.build_multilang_query( - parsed_query=parsed_query, + es_query = self.query_builder.build_query( + query_text=parsed_query.rewritten_query or parsed_query.normalized_query, query_vector=parsed_query.query_vector if enable_embedding else None, query_node=query_node, filters=filters, -- libgit2 0.21.2