diff --git a/api/app.py b/api/app.py index 316c18e..cfcd6c8 100644 --- a/api/app.py +++ b/api/app.py @@ -40,14 +40,13 @@ limiter = Limiter(key_func=get_remote_address) # Add parent directory to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from config import ConfigLoader, SearchConfig from config.env_config import ES_CONFIG from utils import ESClient from search import Searcher +from search.query_config import DEFAULT_INDEX_NAME from query import QueryParser # Global instances -_config: Optional[SearchConfig] = None _es_client: Optional[ESClient] = None _searcher: Optional[Searcher] = None _query_parser: Optional[QueryParser] = None @@ -60,20 +59,11 @@ def init_service(es_host: str = "http://localhost:9200"): Args: es_host: Elasticsearch host URL """ - global _config, _es_client, _searcher, _query_parser + global _es_client, _searcher, _query_parser start_time = time.time() logger.info("Initializing search service (multi-tenant)") - # Load and validate configuration - logger.info("Loading configuration...") - config_loader = ConfigLoader("config/config.yaml") - _config = config_loader.load_config() - errors = config_loader.validate_config(_config) - if errors: - raise ValueError(f"Configuration validation failed: {errors}") - logger.info(f"Configuration loaded: {_config.es_index_name}") - # Get ES credentials es_username = os.getenv('ES_USERNAME') or ES_CONFIG.get('username') es_password = os.getenv('ES_PASSWORD') or ES_CONFIG.get('password') @@ -91,20 +81,15 @@ def init_service(es_host: str = "http://localhost:9200"): # Initialize components logger.info("Initializing query parser...") - _query_parser = QueryParser(_config) + _query_parser = QueryParser() logger.info("Initializing searcher...") - _searcher = Searcher(_config, _es_client, _query_parser) + _searcher = Searcher(_es_client, _query_parser, index_name=DEFAULT_INDEX_NAME) elapsed = time.time() - start_time - logger.info(f"Search service ready! (took {elapsed:.2f}s)") + logger.info(f"Search service ready! (took {elapsed:.2f}s) | Index: {DEFAULT_INDEX_NAME}") -def get_config() -> SearchConfig: - """Get search engine configuration.""" - if _config is None: - raise RuntimeError("Service not initialized") - return _config def get_es_client() -> ESClient: @@ -243,8 +228,8 @@ async def health_check(request: Request): """Health check endpoint.""" try: # Check if services are initialized - get_config() get_es_client() + get_searcher() return { "status": "healthy", diff --git a/api/models.py b/api/models.py index e632eb8..a5e565f 100644 --- a/api/models.py +++ b/api/models.py @@ -191,7 +191,7 @@ class SpuResult(BaseModel): description: Optional[str] = Field(None, description="商品描述") vendor: Optional[str] = Field(None, description="供应商/品牌") category: Optional[str] = Field(None, description="类目") - tags: Optional[str] = Field(None, description="标签") + tags: Optional[List[str]] = Field(None, description="标签列表") price: Optional[float] = Field(None, description="价格(min_price)") compare_at_price: Optional[float] = Field(None, description="原价") currency: str = Field("USD", description="货币单位") diff --git a/api/result_formatter.py b/api/result_formatter.py index 10dff47..594d102 100644 --- a/api/result_formatter.py +++ b/api/result_formatter.py @@ -89,6 +89,11 @@ class ResultFormatter: """ Format ES aggregations to FacetResult list. + 支持: + 1. 普通terms聚合 + 2. range聚合 + 3. specifications嵌套聚合(按name分组,然后按value聚合) + Args: es_aggregations: ES aggregations response facet_configs: Facet configurations (optional) @@ -100,6 +105,38 @@ class ResultFormatter: for field_name, agg_data in es_aggregations.items(): display_field = field_name[:-6] if field_name.endswith("_facet") else field_name + + # 处理specifications嵌套分面 + if field_name == "specifications_facet" and 'by_name' in agg_data: + # specifications嵌套聚合:按name分组,每个name下有value_counts + by_name_agg = agg_data['by_name'] + if 'buckets' in by_name_agg: + for name_bucket in by_name_agg['buckets']: + name = name_bucket['key'] + value_counts = name_bucket.get('value_counts', {}) + + values = [] + if 'buckets' in value_counts: + for value_bucket in value_counts['buckets']: + value = FacetValue( + value=value_bucket['key'], + label=str(value_bucket['key']), + count=value_bucket['doc_count'], + selected=False + ) + values.append(value) + + # 为每个name创建一个分面结果 + facet = FacetResult( + field=f"specifications.{name}", + label=str(name), # 使用name作为label,如"颜色"、"尺寸" + type="terms", + values=values, + total_count=name_bucket['doc_count'] + ) + facets.append(facet) + continue + # Handle terms aggregation if 'buckets' in agg_data: values = [] diff --git a/indexer/mapping_generator.py b/indexer/mapping_generator.py index a861f3c..dbd5e28 100644 --- a/indexer/mapping_generator.py +++ b/indexer/mapping_generator.py @@ -51,23 +51,26 @@ def create_index_if_not_exists(es_client, index_name: str, mapping: Dict[str, An Create Elasticsearch index if it doesn't exist. Args: - es_client: Elasticsearch client instance + es_client: ESClient instance index_name: Name of the index to create mapping: Index mapping configuration. If None, loads from default file. Returns: True if index was created, False if it already exists """ - if es_client.indices.exists(index=index_name): + if es_client.index_exists(index_name): logger.info(f"Index '{index_name}' already exists") return False if mapping is None: mapping = load_mapping() - es_client.indices.create(index=index_name, body=mapping) - logger.info(f"Index '{index_name}' created successfully") - return True + if es_client.create_index(index_name, mapping): + logger.info(f"Index '{index_name}' created successfully") + return True + else: + logger.error(f"Failed to create index '{index_name}'") + return False def delete_index_if_exists(es_client, index_name: str) -> bool: @@ -75,19 +78,22 @@ def delete_index_if_exists(es_client, index_name: str) -> bool: Delete Elasticsearch index if it exists. Args: - es_client: Elasticsearch client instance + es_client: ESClient instance index_name: Name of the index to delete Returns: True if index was deleted, False if it didn't exist """ - if not es_client.indices.exists(index=index_name): + if not es_client.index_exists(index_name): logger.warning(f"Index '{index_name}' does not exist") return False - es_client.indices.delete(index=index_name) - logger.info(f"Index '{index_name}' deleted successfully") - return True + if es_client.delete_index(index_name): + logger.info(f"Index '{index_name}' deleted successfully") + return True + else: + logger.error(f"Failed to delete index '{index_name}'") + return False def update_mapping(es_client, index_name: str, new_fields: Dict[str, Any]) -> bool: @@ -95,18 +101,21 @@ def update_mapping(es_client, index_name: str, new_fields: Dict[str, Any]) -> bo Update mapping for existing index (only adding new fields). Args: - es_client: Elasticsearch client instance + es_client: ESClient instance index_name: Name of the index new_fields: New field mappings to add Returns: True if successful """ - if not es_client.indices.exists(index=index_name): + if not es_client.index_exists(index_name): logger.error(f"Index '{index_name}' does not exist") return False mapping = {"properties": new_fields} - es_client.indices.put_mapping(index=index_name, body=mapping) - logger.info(f"Mapping updated for index '{index_name}'") - return True + if es_client.update_mapping(index_name, mapping): + logger.info(f"Mapping updated for index '{index_name}'") + return True + else: + logger.error(f"Failed to update mapping for index '{index_name}'") + return False diff --git a/indexer/spu_transformer.py b/indexer/spu_transformer.py index 6dc64ec..631fad3 100644 --- a/indexer/spu_transformer.py +++ b/indexer/spu_transformer.py @@ -124,7 +124,7 @@ class SPUTransformer: query = text(""" SELECT id, spu_id, shop_id, shoplazza_id, shoplazza_product_id, - position, name, values, tenant_id, + position, name, `values`, tenant_id, creator, create_time, updater, update_time, deleted FROM shoplazza_product_option WHERE tenant_id = :tenant_id AND deleted = 0 diff --git a/query/query_parser.py b/query/query_parser.py index ab7f9c7..1c710e4 100644 --- a/query/query_parser.py +++ b/query/query_parser.py @@ -8,8 +8,14 @@ from typing import Dict, List, Optional, Any import numpy as np import logging -from config import SearchConfig, QueryConfig from embeddings import BgeEncoder +from search.query_config import ( + ENABLE_TEXT_EMBEDDING, + ENABLE_TRANSLATION, + REWRITE_DICTIONARY, + TRANSLATION_API_KEY, + TRANSLATION_SERVICE +) from .language_detector import LanguageDetector from .translator import Translator from .query_rewriter import QueryRewriter, QueryNormalizer @@ -64,7 +70,6 @@ class QueryParser: def __init__( self, - config: SearchConfig, text_encoder: Optional[BgeEncoder] = None, translator: Optional[Translator] = None ): @@ -72,24 +77,21 @@ class QueryParser: Initialize query parser. Args: - config: Search configuration text_encoder: Text embedding encoder (lazy loaded if not provided) translator: Translator instance (lazy loaded if not provided) """ - self.config = config - self.query_config = config.query_config self._text_encoder = text_encoder self._translator = translator # Initialize components self.normalizer = QueryNormalizer() self.language_detector = LanguageDetector() - self.rewriter = QueryRewriter(self.query_config.rewrite_dictionary) + self.rewriter = QueryRewriter(REWRITE_DICTIONARY) @property def text_encoder(self) -> BgeEncoder: """Lazy load text encoder.""" - if self._text_encoder is None and self.query_config.enable_text_embedding: + if self._text_encoder is None and ENABLE_TEXT_EMBEDDING: logger.info("Initializing text encoder (lazy load)...") self._text_encoder = BgeEncoder() return self._text_encoder @@ -97,13 +99,13 @@ class QueryParser: @property def translator(self) -> Translator: """Lazy load translator.""" - if self._translator is None and self.query_config.enable_translation: + if self._translator is None and ENABLE_TRANSLATION: logger.info("Initializing translator (lazy load)...") self._translator = Translator( - api_key=self.query_config.translation_api_key, + api_key=TRANSLATION_API_KEY, use_cache=True, - glossary_id=getattr(self.query_config, 'translation_glossary_id', None), - translation_context=getattr(self.query_config, 'translation_context', 'e-commerce product search') + glossary_id=None, # Can be added to query_config if needed + translation_context='e-commerce product search' ) return self._translator @@ -154,7 +156,7 @@ class QueryParser: # Stage 2: Query rewriting rewritten = None - if self.query_config.enable_query_rewrite: + if REWRITE_DICTIONARY: # Enable rewrite if dictionary exists rewritten = self.rewriter.rewrite(query_text) if rewritten != query_text: log_info(f"查询重写 | '{query_text}' -> '{rewritten}'") @@ -171,26 +173,11 @@ class QueryParser: # Stage 4: Translation translations = {} - if self.query_config.enable_translation: + if ENABLE_TRANSLATION: try: # Determine target languages for translation - # If domain has language_field_mapping, only translate to languages in the mapping - # Otherwise, use all supported languages - target_langs_for_translation = self.query_config.supported_languages - - # Check if domain has language_field_mapping - domain_config = next( - (idx for idx in self.config.indexes if idx.name == domain), - None - ) - if domain_config and domain_config.language_field_mapping: - # Only translate to languages that exist in the mapping - available_languages = set(domain_config.language_field_mapping.keys()) - target_langs_for_translation = [ - lang for lang in self.query_config.supported_languages - if lang in available_languages - ] - log_debug(f"域 '{domain}' 有语言字段映射,将翻译到: {target_langs_for_translation}") + # Simplified: always translate to Chinese and English + target_langs_for_translation = ['zh', 'en'] target_langs = self.translator.get_translation_needs( detected_lang, @@ -200,7 +187,7 @@ class QueryParser: if target_langs: log_info(f"开始翻译 | 源语言: {detected_lang} | 目标语言: {target_langs}") # Use e-commerce context for better disambiguation - translation_context = getattr(self.query_config, 'translation_context', 'e-commerce product search') + translation_context = 'e-commerce product search' translations = self.translator.translate_multi( query_text, target_langs, @@ -223,7 +210,7 @@ class QueryParser: # Stage 5: Text embedding query_vector = None if (generate_vector and - self.query_config.enable_text_embedding and + ENABLE_TEXT_EMBEDDING and domain == "default"): # Only generate vector for default domain try: log_debug("开始生成查询向量") diff --git a/scripts/recreate_and_import.py b/scripts/recreate_and_import.py new file mode 100755 index 0000000..af0a448 --- /dev/null +++ b/scripts/recreate_and_import.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +""" +重建索引并导入数据的脚本。 + +清除旧索引,使用新的mapping重建索引,然后导入数据。 +""" + +import sys +import os +import argparse +from pathlib import Path + +# Add parent directory to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from utils.db_connector import create_db_connection +from utils.es_client import ESClient +from indexer.mapping_generator import load_mapping, delete_index_if_exists, DEFAULT_INDEX_NAME +from indexer.spu_transformer import SPUTransformer +from indexer.bulk_indexer import BulkIndexer + + +def main(): + parser = argparse.ArgumentParser(description='重建ES索引并导入数据') + + # Database connection + parser.add_argument('--db-host', help='MySQL host (或使用环境变量 DB_HOST)') + parser.add_argument('--db-port', type=int, help='MySQL port (或使用环境变量 DB_PORT, 默认: 3306)') + parser.add_argument('--db-database', help='MySQL database (或使用环境变量 DB_DATABASE)') + parser.add_argument('--db-username', help='MySQL username (或使用环境变量 DB_USERNAME)') + parser.add_argument('--db-password', help='MySQL password (或使用环境变量 DB_PASSWORD)') + + # Tenant and ES + parser.add_argument('--tenant-id', required=True, help='Tenant ID (必需)') + parser.add_argument('--es-host', help='Elasticsearch host (或使用环境变量 ES_HOST, 默认: http://localhost:9200)') + + # Options + parser.add_argument('--batch-size', type=int, default=500, help='批量导入大小 (默认: 500)') + parser.add_argument('--skip-delete', action='store_true', help='跳过删除旧索引步骤') + + args = parser.parse_args() + + print("=" * 60) + print("重建ES索引并导入数据") + print("=" * 60) + + # 加载mapping + print("\n[1/4] 加载mapping配置...") + try: + mapping = load_mapping() + print(f"✓ 成功加载mapping配置") + except Exception as e: + print(f"✗ 加载mapping失败: {e}") + return 1 + + index_name = DEFAULT_INDEX_NAME + print(f"索引名称: {index_name}") + + # 连接Elasticsearch + print("\n[2/4] 连接Elasticsearch...") + es_host = args.es_host or os.environ.get('ES_HOST', 'http://localhost:9200') + es_username = os.environ.get('ES_USERNAME') + es_password = os.environ.get('ES_PASSWORD') + + print(f"ES地址: {es_host}") + if es_username: + print(f"ES用户名: {es_username}") + + try: + if es_username and es_password: + es_client = ESClient(hosts=[es_host], username=es_username, password=es_password) + else: + es_client = ESClient(hosts=[es_host]) + + if not es_client.ping(): + print(f"✗ 无法连接到Elasticsearch: {es_host}") + return 1 + print("✓ Elasticsearch连接成功") + except Exception as e: + print(f"✗ 连接Elasticsearch失败: {e}") + return 1 + + # 删除旧索引 + if not args.skip_delete: + print("\n[3/4] 删除旧索引...") + if es_client.index_exists(index_name): + print(f"发现已存在的索引: {index_name}") + if delete_index_if_exists(es_client, index_name): + print(f"✓ 成功删除索引: {index_name}") + else: + print(f"✗ 删除索引失败: {index_name}") + return 1 + else: + print(f"索引不存在,跳过删除: {index_name}") + else: + print("\n[3/4] 跳过删除旧索引步骤") + + # 创建新索引 + print("\n[4/4] 创建新索引...") + try: + if es_client.index_exists(index_name): + print(f"✓ 索引已存在: {index_name},跳过创建") + else: + print(f"创建索引: {index_name}") + if es_client.create_index(index_name, mapping): + print(f"✓ 成功创建索引: {index_name}") + else: + print(f"✗ 创建索引失败: {index_name}") + return 1 + except Exception as e: + print(f"✗ 创建索引失败: {e}") + import traceback + traceback.print_exc() + return 1 + + # 连接MySQL + print("\n[5/5] 连接MySQL...") + db_host = args.db_host or os.environ.get('DB_HOST') + db_port = args.db_port or int(os.environ.get('DB_PORT', 3306)) + db_database = args.db_database or os.environ.get('DB_DATABASE') + db_username = args.db_username or os.environ.get('DB_USERNAME') + db_password = args.db_password or os.environ.get('DB_PASSWORD') + + if not all([db_host, db_database, db_username, db_password]): + print("✗ MySQL连接参数不完整") + print("请提供 --db-host, --db-database, --db-username, --db-password") + print("或设置环境变量: DB_HOST, DB_DATABASE, DB_USERNAME, DB_PASSWORD") + return 1 + + print(f"MySQL: {db_host}:{db_port}/{db_database}") + try: + db_engine = create_db_connection( + host=db_host, + port=db_port, + database=db_database, + username=db_username, + password=db_password + ) + print("✓ MySQL连接成功") + except Exception as e: + print(f"✗ 连接MySQL失败: {e}") + return 1 + + # 导入数据 + print("\n[6/6] 导入数据...") + print(f"Tenant ID: {args.tenant_id}") + print(f"批量大小: {args.batch_size}") + + try: + transformer = SPUTransformer(db_engine, args.tenant_id) + print("正在转换数据...") + documents = transformer.transform_batch() + print(f"✓ 转换完成: {len(documents)} 个文档") + + if not documents: + print("⚠ 没有数据需要导入") + return 0 + + print(f"正在导入数据到ES (批量大小: {args.batch_size})...") + indexer = BulkIndexer(es_client, index_name, batch_size=args.batch_size) + results = indexer.index_documents(documents, id_field="spu_id", show_progress=True) + + print(f"\n{'='*60}") + print("导入完成!") + print(f"{'='*60}") + print(f"成功: {results['success']}") + print(f"失败: {results['failed']}") + print(f"耗时: {results.get('elapsed_time', 0):.2f}秒") + + if results['failed'] > 0: + print(f"\n⚠ 警告: {results['failed']} 个文档导入失败") + return 1 + + return 0 + except Exception as e: + print(f"✗ 导入数据失败: {e}") + import traceback + traceback.print_exc() + return 1 + + +if __name__ == '__main__': + sys.exit(main()) + diff --git a/search/es_query_builder.py b/search/es_query_builder.py index 4e1c07e..09b708d 100644 --- a/search/es_query_builder.py +++ b/search/es_query_builder.py @@ -330,10 +330,15 @@ class ESQueryBuilder: """ 构建分面聚合。 + 支持: + 1. 分类分面:category1_name, category2_name, category3_name, category_name + 2. specifications分面:嵌套聚合,按name聚合,然后按value聚合 + Args: facet_configs: 分面配置列表(标准格式): - str: 字段名,使用默认 terms 配置 - FacetConfig: 详细的分面配置对象 + - 特殊值 "specifications": 构建specifications嵌套分面 Returns: ES aggregations 字典 @@ -344,6 +349,34 @@ class ESQueryBuilder: aggs = {} for config in facet_configs: + # 特殊处理:specifications嵌套分面 + if isinstance(config, str) and config == "specifications": + # 构建specifications嵌套分面(按name聚合,然后按value聚合) + aggs["specifications_facet"] = { + "nested": { + "path": "specifications" + }, + "aggs": { + "by_name": { + "terms": { + "field": "specifications.name", + "size": 20, + "order": {"_count": "desc"} + }, + "aggs": { + "value_counts": { + "terms": { + "field": "specifications.value", + "size": 10, + "order": {"_count": "desc"} + } + } + } + } + } + } + continue + # 简单模式:只有字段名(字符串) if isinstance(config, str): field = config diff --git a/search/multilang_query_builder.py b/search/multilang_query_builder.py index 9558db2..1df781b 100644 --- a/search/multilang_query_builder.py +++ b/search/multilang_query_builder.py @@ -11,9 +11,9 @@ import numpy as np import logging import re -from config import SearchConfig, IndexConfig from query import ParsedQuery from .es_query_builder import ESQueryBuilder +from .query_config import DEFAULT_MATCH_FIELDS, DOMAIN_FIELDS, FUNCTION_SCORE_CONFIG logger = logging.getLogger(__name__) @@ -30,8 +30,8 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): def __init__( self, - config: SearchConfig, index_name: str, + match_fields: Optional[List[str]] = None, text_embedding_field: Optional[str] = None, image_embedding_field: Optional[str] = None, source_fields: Optional[List[str]] = None @@ -40,53 +40,32 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): Initialize multi-language query builder. Args: - config: Search configuration index_name: ES index name + match_fields: Fields to search for text matching (default: from query_config) text_embedding_field: Field name for text embeddings image_embedding_field: Field name for image embeddings source_fields: Fields to return in search results (_source includes) """ - self.config = config - self.function_score_config = config.function_score + self.function_score_config = FUNCTION_SCORE_CONFIG - # For default domain, use all fields as fallback - default_fields = self._get_domain_fields("default") + # Use provided match_fields or default + if match_fields is None: + match_fields = DEFAULT_MATCH_FIELDS super().__init__( index_name=index_name, - match_fields=default_fields, + match_fields=match_fields, text_embedding_field=text_embedding_field, image_embedding_field=image_embedding_field, source_fields=source_fields ) - # Build domain configurations - self.domain_configs = self._build_domain_configs() - - def _build_domain_configs(self) -> Dict[str, IndexConfig]: - """Build mapping of domain name to IndexConfig.""" - return {index.name: index for index in self.config.indexes} + # Build domain configurations from query_config + self.domain_configs = DOMAIN_FIELDS def _get_domain_fields(self, domain_name: str) -> List[str]: """Get fields for a specific domain with boost notation.""" - for index in self.config.indexes: - if index.name == domain_name: - result = [] - for field_name in index.fields: - field = self._get_field_by_name(field_name) - if field and field.boost != 1.0: - result.append(f"{field_name}^{field.boost}") - else: - result.append(field_name) - return result - return [] - - def _get_field_by_name(self, field_name: str): - """Get field configuration by name.""" - for field in self.config.fields: - if field.name == field_name: - return field - return None + return self.domain_configs.get(domain_name, DEFAULT_MATCH_FIELDS) def build_multilang_query( self, @@ -103,7 +82,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): min_score: Optional[float] = None ) -> Dict[str, Any]: """ - Build ES query with multi-language support (重构版). + Build ES query with multi-language support (简化版). Args: parsed_query: Parsed query with language info and translations @@ -120,19 +99,18 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): Returns: ES query DSL dictionary """ - domain = parsed_query.domain - domain_config = self.domain_configs.get(domain) - - if not domain_config: - # Fallback to default domain - domain = "default" - domain_config = self.domain_configs.get("default") - - if not domain_config: - # Use original behavior + # 1. 根据域选择匹配字段(默认域使用 DEFAULT_MATCH_FIELDS) + domain = parsed_query.domain or "default" + domain_fields = self.domain_configs.get(domain) or DEFAULT_MATCH_FIELDS + + # 2. 临时切换 match_fields,复用基类 build_query 逻辑 + original_match_fields = self.match_fields + self.match_fields = domain_fields + try: return super().build_query( - query_text=parsed_query.rewritten_query, + query_text=parsed_query.rewritten_query or parsed_query.normalized_query, query_vector=query_vector, + query_node=query_node, filters=filters, range_filters=range_filters, size=size, @@ -142,95 +120,9 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): knn_num_candidates=knn_num_candidates, min_score=min_score ) - - logger.debug(f"Building query for domain: {domain}, language: {parsed_query.detected_language}") - - # Build query clause with multi-language support - if query_node and isinstance(query_node, tuple) and len(query_node) > 0: - # Handle boolean query from tuple (AST, score) - ast_node = query_node[0] - query_clause = self._build_boolean_query_from_tuple(ast_node) - logger.debug(f"Using boolean query") - elif query_node and hasattr(query_node, 'operator') and query_node.operator != 'TERM': - # Handle boolean query using base class method - query_clause = self._build_boolean_query(query_node) - logger.debug(f"Using boolean query") - else: - # Handle text query with multi-language support - query_clause = self._build_multilang_text_query(parsed_query, domain_config) - - # 构建内层bool: 文本和KNN二选一 - inner_bool_should = [query_clause] - - # 如果启用KNN,添加到should - if enable_knn and query_vector is not None and self.text_embedding_field: - knn_query = { - "knn": { - "field": self.text_embedding_field, - "query_vector": query_vector.tolist(), - "k": knn_k, - "num_candidates": knn_num_candidates - } - } - inner_bool_should.append(knn_query) - logger.info(f"KNN query added: field={self.text_embedding_field}, k={knn_k}") - else: - # Debug why KNN is not added - reasons = [] - if not enable_knn: - reasons.append("enable_knn=False") - if query_vector is None: - reasons.append("query_vector is None") - if not self.text_embedding_field: - reasons.append(f"text_embedding_field is not set (current: {self.text_embedding_field})") - logger.debug(f"KNN query NOT added. Reasons: {', '.join(reasons) if reasons else 'unknown'}") - - # 构建内层bool结构 - inner_bool = { - "bool": { - "should": inner_bool_should, - "minimum_should_match": 1 - } - } - - # 构建外层bool: 包含filter - filter_clauses = self._build_filters(filters, range_filters) if (filters or range_filters) else [] - - outer_bool = { - "bool": { - "must": [inner_bool] - } - } - - if filter_clauses: - outer_bool["bool"]["filter"] = filter_clauses - - # 包裹function_score(从配置读取score_mode和boost_mode) - function_score_query = { - "function_score": { - "query": outer_bool, - "functions": self._build_score_functions(), - "score_mode": self.function_score_config.score_mode if self.function_score_config else "sum", - "boost_mode": self.function_score_config.boost_mode if self.function_score_config else "multiply" - } - } - - es_query = { - "size": size, - "from": from_, - "query": function_score_query - } - - # Add _source filtering if source_fields are configured - if self.source_fields: - es_query["_source"] = { - "includes": self.source_fields - } - - if min_score is not None: - es_query["min_score"] = min_score - - return es_query + finally: + # 恢复原始配置,避免影响后续查询 + self.match_fields = original_match_fields def _build_score_functions(self) -> List[Dict[str, Any]]: """ @@ -291,7 +183,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): def _build_multilang_text_query( self, parsed_query: ParsedQuery, - domain_config: IndexConfig + domain_config: Dict[str, Any] ) -> Dict[str, Any]: """ Build text query with multi-language field routing. diff --git a/search/query_config.py b/search/query_config.py index a415957..e7d7747 100644 --- a/search/query_config.py +++ b/search/query_config.py @@ -112,3 +112,13 @@ def load_rewrite_dictionary() -> Dict[str, str]: REWRITE_DICTIONARY = load_rewrite_dictionary() +# Default facets for faceted search +# 分类分面:使用category1_name, category2_name, category3_name +# specifications分面:使用嵌套聚合,按name分组,然后按value聚合 +DEFAULT_FACETS = [ + "category1_name", # 一级分类 + "category2_name", # 二级分类 + "category3_name", # 三级分类 + "specifications" # 规格分面(特殊处理:嵌套聚合) +] + diff --git a/search/searcher.py b/search/searcher.py index 49d1fab..8b3ef47 100644 --- a/search/searcher.py +++ b/search/searcher.py @@ -8,15 +8,23 @@ from typing import Dict, Any, List, Optional, Union import time import logging -from config import SearchConfig from utils.es_client import ESClient from query import QueryParser, ParsedQuery -from indexer import MappingGenerator from embeddings import CLIPImageEncoder from .boolean_parser import BooleanParser, QueryNode from .es_query_builder import ESQueryBuilder from .multilang_query_builder import MultiLanguageQueryBuilder from .rerank_engine import RerankEngine +from .query_config import ( + DEFAULT_INDEX_NAME, + DEFAULT_MATCH_FIELDS, + TEXT_EMBEDDING_FIELD, + IMAGE_EMBEDDING_FIELD, + SOURCE_FIELDS, + ENABLE_TRANSLATION, + ENABLE_TEXT_EMBEDDING, + RANKING_EXPRESSION +) from context.request_context import RequestContext, RequestContextStage, create_request_context from api.models import FacetResult, FacetValue from api.result_formatter import ResultFormatter @@ -79,39 +87,38 @@ class Searcher: def __init__( self, - config: SearchConfig, es_client: ESClient, - query_parser: Optional[QueryParser] = None + query_parser: Optional[QueryParser] = None, + index_name: str = DEFAULT_INDEX_NAME ): """ Initialize searcher. Args: - config: Search configuration es_client: Elasticsearch client query_parser: Query parser (created if not provided) + index_name: ES index name (default: search_products) """ - self.config = config self.es_client = es_client - self.query_parser = query_parser or QueryParser(config) + self.index_name = index_name + self.query_parser = query_parser or QueryParser() # Initialize components self.boolean_parser = BooleanParser() - self.rerank_engine = RerankEngine(config.ranking.expression, enabled=False) + self.rerank_engine = RerankEngine(RANKING_EXPRESSION, enabled=False) - # Get mapping info - mapping_gen = MappingGenerator(config) - self.match_fields = mapping_gen.get_match_fields_for_domain("default") - self.text_embedding_field = mapping_gen.get_text_embedding_field() - self.image_embedding_field = mapping_gen.get_image_embedding_field() + # Use constants from query_config + self.match_fields = DEFAULT_MATCH_FIELDS + self.text_embedding_field = TEXT_EMBEDDING_FIELD + self.image_embedding_field = IMAGE_EMBEDDING_FIELD # Query builder - use multi-language version self.query_builder = MultiLanguageQueryBuilder( - config=config, - index_name=config.es_index_name, + index_name=index_name, + match_fields=self.match_fields, text_embedding_field=self.text_embedding_field, image_embedding_field=self.image_embedding_field, - source_fields=config.query_config.source_fields + source_fields=SOURCE_FIELDS ) def search( @@ -154,8 +161,8 @@ class Searcher: context = create_request_context() # Always use config defaults (these are backend configuration, not user parameters) - enable_translation = self.config.query_config.enable_translation - enable_embedding = self.config.query_config.enable_text_embedding + enable_translation = ENABLE_TRANSLATION + enable_embedding = ENABLE_TEXT_EMBEDDING enable_rerank = False # Temporarily disabled # Start timing @@ -278,14 +285,6 @@ class Searcher: min_score=min_score ) - # Add SPU collapse if configured - if self.config.spu_config.enabled: - es_query = self.query_builder.add_spu_collapse( - es_query, - self.config.spu_config.spu_field, - self.config.spu_config.inner_hits_size - ) - # Add facets for faceted search if facets: facet_aggs = self.query_builder.build_facets(facets) @@ -329,7 +328,7 @@ class Searcher: context.start_stage(RequestContextStage.ELASTICSEARCH_SEARCH) try: es_response = self.es_client.search( - index_name=self.config.es_index_name, + index_name=self.index_name, body=body_for_es, size=size, from_=from_ @@ -503,9 +502,9 @@ class Searcher: } # Add _source filtering if source_fields are configured - if self.config.query_config.source_fields: + if SOURCE_FIELDS: es_query["_source"] = { - "includes": self.config.query_config.source_fields + "includes": SOURCE_FIELDS } if filters or range_filters: @@ -519,7 +518,7 @@ class Searcher: # Execute search es_response = self.es_client.search( - index_name=self.config.es_index_name, + index_name=self.index_name, body=es_query, size=size ) @@ -573,7 +572,7 @@ class Searcher: """ try: response = self.es_client.client.get( - index=self.config.es_index_name, + index=self.index_name, id=doc_id ) return response.get('_source') @@ -657,10 +656,11 @@ class Searcher: def _get_field_label(self, field: str) -> str: """获取字段的显示标签""" - # 从配置中获取字段标签 - for field_config in self.config.fields: - if field_config.name == field: - # 尝试获取 label 属性 - return getattr(field_config, 'label', field) - # 如果没有配置,返回字段名 - return field + # 字段标签映射(简化版,不再从配置读取) + field_labels = { + "category1_name": "一级分类", + "category2_name": "二级分类", + "category3_name": "三级分类", + "specifications": "规格" + } + return field_labels.get(field, field) diff --git a/utils/es_client.py b/utils/es_client.py index ae08dfd..03b8a77 100644 --- a/utils/es_client.py +++ b/utils/es_client.py @@ -172,6 +172,18 @@ class ESClient: Returns: Search results """ + # Safety guard: collapse is no longer needed (index is already SPU-level). + # If any caller accidentally adds a collapse clause (e.g. on product_id), + # strip it here to avoid 400 errors like: + # "no mapping found for `product_id` in order to collapse on" + if isinstance(body, dict) and "collapse" in body: + logger.warning( + "Removing unsupported 'collapse' clause from ES query body: %s", + body.get("collapse") + ) + body = dict(body) # shallow copy to avoid mutating caller + body.pop("collapse", None) + try: return self.client.search( index=index_name, -- libgit2 0.21.2