diff --git a/api/app.py b/api/app.py index cfcd6c8..b425887 100644 --- a/api/app.py +++ b/api/app.py @@ -41,15 +41,16 @@ limiter = Limiter(key_func=get_remote_address) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from config.env_config import ES_CONFIG +from config import ConfigLoader from utils import ESClient from search import Searcher -from search.query_config import DEFAULT_INDEX_NAME from query import QueryParser # Global instances _es_client: Optional[ESClient] = None _searcher: Optional[Searcher] = None _query_parser: Optional[QueryParser] = None +_config = None def init_service(es_host: str = "http://localhost:9200"): @@ -59,11 +60,17 @@ def init_service(es_host: str = "http://localhost:9200"): Args: es_host: Elasticsearch host URL """ - global _es_client, _searcher, _query_parser + global _es_client, _searcher, _query_parser, _config start_time = time.time() logger.info("Initializing search service (multi-tenant)") + # Load configuration + logger.info("Loading configuration...") + config_loader = ConfigLoader("config/config.yaml") + _config = config_loader.load_config() + logger.info("Configuration loaded") + # Get ES credentials es_username = os.getenv('ES_USERNAME') or ES_CONFIG.get('username') es_password = os.getenv('ES_PASSWORD') or ES_CONFIG.get('password') @@ -81,13 +88,13 @@ def init_service(es_host: str = "http://localhost:9200"): # Initialize components logger.info("Initializing query parser...") - _query_parser = QueryParser() + _query_parser = QueryParser(_config) logger.info("Initializing searcher...") - _searcher = Searcher(_es_client, _query_parser, index_name=DEFAULT_INDEX_NAME) + _searcher = Searcher(_es_client, _config, _query_parser) elapsed = time.time() - start_time - logger.info(f"Search service ready! (took {elapsed:.2f}s) | Index: {DEFAULT_INDEX_NAME}") + logger.info(f"Search service ready! (took {elapsed:.2f}s) | Index: {_config.es_index_name}") @@ -113,6 +120,13 @@ def get_query_parser() -> QueryParser: return _query_parser +def get_config(): + """Get global config instance.""" + if _config is None: + raise RuntimeError("Service not initialized") + return _config + + # Create FastAPI app with enhanced configuration app = FastAPI( title="E-Commerce Search API", diff --git a/config/__init__.py b/config/__init__.py index d8c5d12..28bc0ac 100644 --- a/config/__init__.py +++ b/config/__init__.py @@ -23,6 +23,10 @@ from .config_loader import ( RerankConfig, ConfigurationError ) +from .utils import ( + get_match_fields_for_index, + get_domain_fields +) __all__ = [ # Field types @@ -46,4 +50,6 @@ __all__ = [ 'FunctionScoreConfig', 'RerankConfig', 'ConfigurationError', + 'get_match_fields_for_index', + 'get_domain_fields', ] diff --git a/config/config.yaml b/config/config.yaml index c5c95ab..33e8038 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -412,6 +412,11 @@ query_config: text_embedding_field: "title_embedding" # Field name for text embeddings image_embedding_field: null # Field name for image embeddings (if not set, will auto-detect) + # Embedding disable thresholds (disable vector search for short queries) + embedding_disable_thresholds: + chinese_char_limit: 4 # Disable embedding for Chinese queries with <= 4 characters + english_word_limit: 3 # Disable embedding for English queries with <= 3 words + # Translation API (DeepL) translation_service: "deepl" translation_api_key: null # Set via environment variable diff --git a/config/config_loader.py b/config/config_loader.py index e3f4c24..83edbe5 100644 --- a/config/config_loader.py +++ b/config/config_loader.py @@ -58,6 +58,10 @@ class QueryConfig: text_embedding_field: Optional[str] = None # Field name for text embeddings (e.g., "title_embedding") image_embedding_field: Optional[str] = None # Field name for image embeddings (e.g., "image_embedding") + # Embedding disable thresholds (disable vector search for short queries) + embedding_disable_chinese_char_limit: int = 4 # Disable embedding for Chinese queries with <= this many characters + embedding_disable_english_word_limit: int = 3 # Disable embedding for English queries with <= this many words + # ES source fields configuration - fields to return in search results # If None, auto-collect from field configs (fields with return_in_source=True) # If empty list, return all fields. Otherwise, only return specified fields. @@ -165,15 +169,18 @@ class ConfigLoader: return rewrite_dict - def load_config(self) -> SearchConfig: + def load_config(self, validate: bool = True) -> SearchConfig: """ Load unified configuration from YAML file. + Args: + validate: Whether to validate configuration after loading (default: True) + Returns: SearchConfig object Raises: - ConfigurationError: If config file not found or invalid + ConfigurationError: If config file not found, invalid, or validation fails """ if not self.config_file.exists(): raise ConfigurationError(f"Configuration file not found: {self.config_file}") @@ -184,7 +191,16 @@ class ConfigLoader: except yaml.YAMLError as e: raise ConfigurationError(f"Invalid YAML in {self.config_file}: {e}") - return self._parse_config(config_data) + config = self._parse_config(config_data) + + # Auto-validate configuration + if validate: + errors = self.validate_config(config) + if errors: + error_msg = "Configuration validation failed:\n" + "\n".join(f" - {err}" for err in errors) + raise ConfigurationError(error_msg) + + return config def _parse_config(self, config_data: Dict[str, Any]) -> SearchConfig: """Parse configuration dictionary into SearchConfig object.""" @@ -214,43 +230,48 @@ class ConfigLoader: if field.return_in_source ] + # Parse embedding disable thresholds + embedding_thresholds = query_config_data.get("embedding_disable_thresholds", {}) + query_config = QueryConfig( - supported_languages=query_config_data.get("supported_languages", ["zh", "en"]), - default_language=query_config_data.get("default_language", "zh"), + supported_languages=query_config_data.get("supported_languages") or ["zh", "en"], + default_language=query_config_data.get("default_language") or "zh", enable_translation=query_config_data.get("enable_translation", True), enable_text_embedding=query_config_data.get("enable_text_embedding", True), enable_query_rewrite=query_config_data.get("enable_query_rewrite", True), rewrite_dictionary=rewrite_dictionary, translation_api_key=query_config_data.get("translation_api_key"), - translation_service=query_config_data.get("translation_service", "deepl"), + translation_service=query_config_data.get("translation_service") or "deepl", translation_glossary_id=query_config_data.get("translation_glossary_id"), - translation_context=query_config_data.get("translation_context", "e-commerce product search"), + translation_context=query_config_data.get("translation_context") or "e-commerce product search", text_embedding_field=query_config_data.get("text_embedding_field"), image_embedding_field=query_config_data.get("image_embedding_field"), + embedding_disable_chinese_char_limit=embedding_thresholds.get("chinese_char_limit", 4), + embedding_disable_english_word_limit=embedding_thresholds.get("english_word_limit", 3), source_fields=source_fields ) # Parse ranking config ranking_data = config_data.get("ranking", {}) ranking = RankingConfig( - expression=ranking_data.get("expression", "bm25() + 0.2*text_embedding_relevance()"), - description=ranking_data.get("description", "Default BM25 + text embedding ranking") + expression=ranking_data.get("expression") or "bm25() + 0.2*text_embedding_relevance()", + description=ranking_data.get("description") or "Default BM25 + text embedding ranking" ) # Parse Function Score configuration fs_data = config_data.get("function_score", {}) function_score = FunctionScoreConfig( - score_mode=fs_data.get("score_mode", "sum"), - boost_mode=fs_data.get("boost_mode", "multiply"), - functions=fs_data.get("functions", []) + score_mode=fs_data.get("score_mode") or "sum", + boost_mode=fs_data.get("boost_mode") or "multiply", + functions=fs_data.get("functions") or [] ) # Parse Rerank configuration rerank_data = config_data.get("rerank", {}) rerank = RerankConfig( enabled=rerank_data.get("enabled", False), - expression=rerank_data.get("expression", ""), - description=rerank_data.get("description", "") + expression=rerank_data.get("expression") or "", + description=rerank_data.get("description") or "" ) # Parse SPU config @@ -447,21 +468,43 @@ class ConfigLoader: output_path = Path(output_path) # Convert config back to dictionary format + query_config_dict = { + "supported_languages": config.query_config.supported_languages, + "default_language": config.query_config.default_language, + "enable_translation": config.query_config.enable_translation, + "enable_text_embedding": config.query_config.enable_text_embedding, + "enable_query_rewrite": config.query_config.enable_query_rewrite, + "translation_service": config.query_config.translation_service, + } + + # Add optional fields only if they are set + if config.query_config.translation_api_key: + query_config_dict["translation_api_key"] = config.query_config.translation_api_key + if config.query_config.translation_glossary_id: + query_config_dict["translation_glossary_id"] = config.query_config.translation_glossary_id + if config.query_config.translation_context: + query_config_dict["translation_context"] = config.query_config.translation_context + if config.query_config.text_embedding_field: + query_config_dict["text_embedding_field"] = config.query_config.text_embedding_field + if config.query_config.image_embedding_field: + query_config_dict["image_embedding_field"] = config.query_config.image_embedding_field + if config.query_config.source_fields: + query_config_dict["source_fields"] = config.query_config.source_fields + + # Add embedding disable thresholds + if (config.query_config.embedding_disable_chinese_char_limit != 4 or + config.query_config.embedding_disable_english_word_limit != 3): + query_config_dict["embedding_disable_thresholds"] = { + "chinese_char_limit": config.query_config.embedding_disable_chinese_char_limit, + "english_word_limit": config.query_config.embedding_disable_english_word_limit + } + config_dict = { "es_index_name": config.es_index_name, "es_settings": config.es_settings, "fields": [self._field_to_dict(field) for field in config.fields], "indexes": [self._index_to_dict(index) for index in config.indexes], - "query_config": { - "supported_languages": config.query_config.supported_languages, - "default_language": config.query_config.default_language, - "enable_translation": config.query_config.enable_translation, - "enable_text_embedding": config.query_config.enable_text_embedding, - "enable_query_rewrite": config.query_config.enable_query_rewrite, - # rewrite_dictionary is stored in separate file, not in config - "translation_api_key": config.query_config.translation_api_key, - "translation_service": config.query_config.translation_service, - }, + "query_config": query_config_dict, "ranking": { "expression": config.ranking.expression, "description": config.ranking.description @@ -505,7 +548,7 @@ class ConfigLoader: f.write(f"{key}\t{value}\n") def _field_to_dict(self, field: FieldConfig) -> Dict[str, Any]: - """Convert FieldConfig to dictionary.""" + """Convert FieldConfig to dictionary, preserving all fields.""" result = { "name": field.name, "type": field.field_type.value, @@ -513,36 +556,49 @@ class ConfigLoader: "boost": field.boost, "store": field.store, "index": field.index, + "return_in_source": field.return_in_source, } + # Add optional fields only if they differ from defaults or are set if field.analyzer: result["analyzer"] = field.analyzer.value if field.search_analyzer: result["search_analyzer"] = field.search_analyzer.value if field.multi_language: result["multi_language"] = field.multi_language - result["languages"] = field.languages + if field.languages: + result["languages"] = field.languages if field.embedding_dims != 1024: result["embedding_dims"] = field.embedding_dims if field.embedding_similarity != "dot_product": result["embedding_similarity"] = field.embedding_similarity if field.nested: result["nested"] = field.nested - result["nested_properties"] = field.nested_properties + if field.nested_properties: + result["nested_properties"] = field.nested_properties + if field.keyword_subfield: + result["keyword_subfield"] = field.keyword_subfield + if field.keyword_ignore_above != 256: + result["keyword_ignore_above"] = field.keyword_ignore_above + if field.keyword_normalizer: + result["keyword_normalizer"] = field.keyword_normalizer return result def _index_to_dict(self, index: IndexConfig) -> Dict[str, Any]: - """Convert IndexConfig to dictionary.""" + """Convert IndexConfig to dictionary, preserving all fields.""" result = { "name": index.name, "label": index.label, "fields": index.fields, "analyzer": index.analyzer.value, - "boost": index.boost, - "example": index.example } - + + # Add optional fields only if they differ from defaults or are set + if index.boost != 1.0: + result["boost"] = index.boost + if index.example: + result["example"] = index.example if index.language_field_mapping: result["language_field_mapping"] = index.language_field_mapping diff --git a/config/utils.py b/config/utils.py new file mode 100644 index 0000000..96c0ef1 --- /dev/null +++ b/config/utils.py @@ -0,0 +1,70 @@ +""" +Configuration utility functions. + +Helper functions for working with SearchConfig objects. +""" + +from typing import Dict, List +from .config_loader import SearchConfig + + +def get_match_fields_for_index(config: SearchConfig, index_name: str = "default") -> List[str]: + """ + Generate match fields list with boost from IndexConfig and FieldConfig. + + Args: + config: SearchConfig instance + index_name: Name of the index domain (default: "default") + + Returns: + List of field names with boost, e.g., ["title_zh^3.0", "brief_zh^1.5"] + """ + # Find the index config + index_config = None + for idx in config.indexes: + if idx.name == index_name: + index_config = idx + break + + if not index_config: + return [] + + # Create a field name to FieldConfig mapping + field_map = {field.name: field for field in config.fields} + + # Generate match fields with boost + match_fields = [] + for field_name in index_config.fields: + field_config = field_map.get(field_name) + if field_config: + # Combine index boost and field boost + total_boost = index_config.boost * field_config.boost + if total_boost != 1.0: + match_fields.append(f"{field_name}^{total_boost}") + else: + match_fields.append(field_name) + else: + # Field not found in config, use index boost only + if index_config.boost != 1.0: + match_fields.append(f"{field_name}^{index_config.boost}") + else: + match_fields.append(field_name) + + return match_fields + + +def get_domain_fields(config: SearchConfig) -> Dict[str, List[str]]: + """ + Generate domain-specific match fields from all index configs. + + Args: + config: SearchConfig instance + + Returns: + Dictionary mapping domain name to list of match fields + """ + domain_fields = {} + for index_config in config.indexes: + domain_fields[index_config.name] = get_match_fields_for_index(config, index_config.name) + return domain_fields + diff --git a/frontend/index.html b/frontend/index.html index bca1767..bcb736c 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -100,9 +100,10 @@
diff --git a/main.py b/main.py index e6e4901..2c6e428 100755 --- a/main.py +++ b/main.py @@ -93,7 +93,7 @@ def cmd_search(args): from query import QueryParser query_parser = QueryParser(config) - searcher = Searcher(config, es_client, query_parser) + searcher = Searcher(es_client, config, query_parser) # Execute search print(f"Searching for: '{args.query}' (tenant: {args.tenant_id})") diff --git a/query/query_parser.py b/query/query_parser.py index 1c710e4..b37afdf 100644 --- a/query/query_parser.py +++ b/query/query_parser.py @@ -9,13 +9,7 @@ import numpy as np import logging from embeddings import BgeEncoder -from search.query_config import ( - ENABLE_TEXT_EMBEDDING, - ENABLE_TRANSLATION, - REWRITE_DICTIONARY, - TRANSLATION_API_KEY, - TRANSLATION_SERVICE -) +from config import SearchConfig from .language_detector import LanguageDetector from .translator import Translator from .query_rewriter import QueryRewriter, QueryNormalizer @@ -70,6 +64,7 @@ class QueryParser: def __init__( self, + config: SearchConfig, text_encoder: Optional[BgeEncoder] = None, translator: Optional[Translator] = None ): @@ -77,21 +72,23 @@ class QueryParser: Initialize query parser. Args: + config: SearchConfig instance text_encoder: Text embedding encoder (lazy loaded if not provided) translator: Translator instance (lazy loaded if not provided) """ + self.config = config self._text_encoder = text_encoder self._translator = translator # Initialize components self.normalizer = QueryNormalizer() self.language_detector = LanguageDetector() - self.rewriter = QueryRewriter(REWRITE_DICTIONARY) + self.rewriter = QueryRewriter(config.query_config.rewrite_dictionary) @property def text_encoder(self) -> BgeEncoder: """Lazy load text encoder.""" - if self._text_encoder is None and ENABLE_TEXT_EMBEDDING: + if self._text_encoder is None and self.config.query_config.enable_text_embedding: logger.info("Initializing text encoder (lazy load)...") self._text_encoder = BgeEncoder() return self._text_encoder @@ -99,13 +96,13 @@ class QueryParser: @property def translator(self) -> Translator: """Lazy load translator.""" - if self._translator is None and ENABLE_TRANSLATION: + if self._translator is None and self.config.query_config.enable_translation: logger.info("Initializing translator (lazy load)...") self._translator = Translator( - api_key=TRANSLATION_API_KEY, + api_key=self.config.query_config.translation_api_key, use_cache=True, - glossary_id=None, # Can be added to query_config if needed - translation_context='e-commerce product search' + glossary_id=self.config.query_config.translation_glossary_id, + translation_context=self.config.query_config.translation_context ) return self._translator @@ -156,7 +153,7 @@ class QueryParser: # Stage 2: Query rewriting rewritten = None - if REWRITE_DICTIONARY: # Enable rewrite if dictionary exists + if self.config.query_config.rewrite_dictionary: # Enable rewrite if dictionary exists rewritten = self.rewriter.rewrite(query_text) if rewritten != query_text: log_info(f"查询重写 | '{query_text}' -> '{rewritten}'") @@ -173,7 +170,7 @@ class QueryParser: # Stage 4: Translation translations = {} - if ENABLE_TRANSLATION: + if self.config.query_config.enable_translation: try: # Determine target languages for translation # Simplified: always translate to Chinese and English @@ -210,19 +207,47 @@ class QueryParser: # Stage 5: Text embedding query_vector = None if (generate_vector and - ENABLE_TEXT_EMBEDDING and + self.config.query_config.enable_text_embedding and domain == "default"): # Only generate vector for default domain - try: - log_debug("开始生成查询向量") - query_vector = self.text_encoder.encode([query_text])[0] - log_debug(f"查询向量生成完成 | 形状: {query_vector.shape}") - if context: - context.store_intermediate_result('query_vector_shape', query_vector.shape) - except Exception as e: - error_msg = f"查询向量生成失败 | 错误: {str(e)}" - log_info(error_msg) - if context: - context.add_warning(error_msg) + # Get thresholds from config + chinese_limit = self.config.query_config.embedding_disable_chinese_char_limit + english_limit = self.config.query_config.embedding_disable_english_word_limit + + # Check if embedding should be disabled for short queries + should_disable_embedding = False + disable_reason = None + + if detected_lang == 'zh': + # For Chinese: disable embedding if character count <= threshold + char_count = len(query_text.strip()) + if char_count <= chinese_limit: + should_disable_embedding = True + disable_reason = f"中文查询字数({char_count}) <= {chinese_limit},禁用向量搜索" + log_info(disable_reason) + if context: + context.store_intermediate_result('embedding_disabled_reason', disable_reason) + else: + # For English: disable embedding if word count <= threshold + word_count = len(query_text.strip().split()) + if word_count <= english_limit: + should_disable_embedding = True + disable_reason = f"英文查询单词数({word_count}) <= {english_limit},禁用向量搜索" + log_info(disable_reason) + if context: + context.store_intermediate_result('embedding_disabled_reason', disable_reason) + + if not should_disable_embedding: + try: + log_debug("开始生成查询向量") + query_vector = self.text_encoder.encode([query_text])[0] + log_debug(f"查询向量生成完成 | 形状: {query_vector.shape}") + if context: + context.store_intermediate_result('query_vector_shape', query_vector.shape) + except Exception as e: + error_msg = f"查询向量生成失败 | 错误: {str(e)}" + log_info(error_msg) + if context: + context.add_warning(error_msg) # Build result result = ParsedQuery( diff --git a/search/es_query_builder.py b/search/es_query_builder.py index a78218a..664843f 100644 --- a/search/es_query_builder.py +++ b/search/es_query_builder.py @@ -11,7 +11,7 @@ Simplified architecture: from typing import Dict, Any, List, Optional, Union import numpy as np from .boolean_parser import QueryNode -from .query_config import FUNCTION_SCORE_CONFIG +from config import FunctionScoreConfig class ESQueryBuilder: @@ -23,7 +23,8 @@ class ESQueryBuilder: match_fields: List[str], text_embedding_field: Optional[str] = None, image_embedding_field: Optional[str] = None, - source_fields: Optional[List[str]] = None + source_fields: Optional[List[str]] = None, + function_score_config: Optional[FunctionScoreConfig] = None ): """ Initialize query builder. @@ -34,12 +35,14 @@ class ESQueryBuilder: text_embedding_field: Field name for text embeddings image_embedding_field: Field name for image embeddings source_fields: Fields to return in search results (_source includes) + function_score_config: Function score configuration """ self.index_name = index_name self.match_fields = match_fields self.text_embedding_field = text_embedding_field self.image_embedding_field = image_embedding_field self.source_fields = source_fields + self.function_score_config = function_score_config def build_query( self, @@ -182,12 +185,15 @@ class ESQueryBuilder: return query # Build function_score query + score_mode = self.function_score_config.score_mode if self.function_score_config else "sum" + boost_mode = self.function_score_config.boost_mode if self.function_score_config else "multiply" + function_score_query = { "function_score": { "query": query, "functions": functions, - "score_mode": FUNCTION_SCORE_CONFIG.get("score_mode", "sum"), - "boost_mode": FUNCTION_SCORE_CONFIG.get("boost_mode", "multiply") + "score_mode": score_mode, + "boost_mode": boost_mode } } @@ -201,7 +207,10 @@ class ESQueryBuilder: List of function score functions """ functions = [] - config_functions = FUNCTION_SCORE_CONFIG.get("functions", []) + if not self.function_score_config: + return functions + + config_functions = self.function_score_config.functions or [] for func_config in config_functions: func_type = func_config.get("type") diff --git a/search/query_config.py b/search/query_config.py deleted file mode 100644 index ca77054..0000000 --- a/search/query_config.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -Query configuration constants. - -Since all tenants share the same ES mapping, we can hardcode field lists here. -""" - -import os -from typing import Dict, List - -# Default index name -DEFAULT_INDEX_NAME = "search_products" - -# Text embedding field -TEXT_EMBEDDING_FIELD = "title_embedding" - -# Image embedding field -IMAGE_EMBEDDING_FIELD = "image_embedding" - -# Default match fields for text search (with boost) -# 文本召回:同时搜索中英文字段,两者相互补充 -DEFAULT_MATCH_FIELDS = [ - # 中文字段 - "title_zh^3.0", - "brief_zh^1.5", - "description_zh^1.0", - "vendor_zh^1.5", - "category_path_zh^1.5", - "category_name_zh^1.5", - # 英文字段 - "title_en^3.0", - "brief_en^1.5", - "description_en^1.0", - "vendor_en^1.5", - "category_path_en^1.5", - "category_name_en^1.5", - # 语言无关字段 - "tags^1.0", -] - -# Domain-specific match fields -DOMAIN_FIELDS: Dict[str, List[str]] = { - "default": DEFAULT_MATCH_FIELDS, - "title": ["title_zh^2.0"], - "vendor": ["vendor_zh^1.5"], - "category": ["category_path_zh^1.5", "category_name_zh^1.5"], - "tags": ["tags^1.0"] -} - -# Source fields to return in search results -# 注意:为了在后端做多语言选择,_zh / _en 字段仍然需要从 ES 取出, -# 但不会原样透出给前端,而是统一映射到 title / description / vendor 等字段。 -SOURCE_FIELDS = [ - # 基本标识 - "tenant_id", - "spu_id", - "create_time", - "update_time", - - # 多语言文本字段(仅用于后端选择,不直接返回给前端) - "title_zh", - "title_en", - "brief_zh", - "brief_en", - "description_zh", - "description_en", - "vendor_zh", - "vendor_en", - "category_path_zh", - "category_path_en", - "category_name_zh", - "category_name_en", - - # 语言无关字段(直接返回给前端) - "tags", - "image_url", - "category_id", - "category_name", - "category_level", - "category1_name", - "category2_name", - "category3_name", - "option1_name", - "option2_name", - "option3_name", - "min_price", - "max_price", - "compare_at_price", - "sku_prices", - "sku_weights", - "sku_weight_units", - "total_inventory", - "skus", - "specifications", -] - -# Query processing settings -ENABLE_TRANSLATION = os.environ.get("ENABLE_TRANSLATION", "true").lower() == "true" -ENABLE_TEXT_EMBEDDING = os.environ.get("ENABLE_TEXT_EMBEDDING", "true").lower() == "true" -TRANSLATION_API_KEY = os.environ.get("DEEPL_API_KEY") -TRANSLATION_SERVICE = "deepl" - -# Ranking expression (currently disabled) -RANKING_EXPRESSION = "bm25() + 0.2*text_embedding_relevance()" - -# Function score config -FUNCTION_SCORE_CONFIG = { - "score_mode": "sum", - "boost_mode": "multiply", - "functions": [] -} - -# Load rewrite dictionary from file if exists -def load_rewrite_dictionary() -> Dict[str, str]: - """Load query rewrite dictionary from file.""" - rewrite_file = os.path.join( - os.path.dirname(os.path.dirname(__file__)), - "config", - "query_rewrite.dict" - ) - - if not os.path.exists(rewrite_file): - return {} - - rewrite_dict = {} - try: - with open(rewrite_file, 'r', encoding='utf-8') as f: - for line in f: - line = line.strip() - if not line or line.startswith('#'): - continue - parts = line.split('\t') - if len(parts) == 2: - rewrite_dict[parts[0].strip()] = parts[1].strip() - except Exception as e: - print(f"Warning: Failed to load rewrite dictionary: {e}") - - return rewrite_dict - -REWRITE_DICTIONARY = load_rewrite_dictionary() - -# Default facets for faceted search -# 分类分面:使用category1_name, category2_name, category3_name -# specifications分面:使用嵌套聚合,按name分组,然后按value聚合 -DEFAULT_FACETS = [ - "category1_name", # 一级分类 - "category2_name", # 二级分类 - "category3_name", # 三级分类 - "specifications" # 规格分面(特殊处理:嵌套聚合) -] - diff --git a/search/searcher.py b/search/searcher.py index c81afea..af8faba 100644 --- a/search/searcher.py +++ b/search/searcher.py @@ -14,16 +14,8 @@ from embeddings import CLIPImageEncoder from .boolean_parser import BooleanParser, QueryNode from .es_query_builder import ESQueryBuilder from .rerank_engine import RerankEngine -from .query_config import ( - DEFAULT_INDEX_NAME, - DEFAULT_MATCH_FIELDS, - TEXT_EMBEDDING_FIELD, - IMAGE_EMBEDDING_FIELD, - SOURCE_FIELDS, - ENABLE_TRANSLATION, - ENABLE_TEXT_EMBEDDING, - RANKING_EXPRESSION -) +from config import SearchConfig +from config.utils import get_match_fields_for_index from context.request_context import RequestContext, RequestContextStage, create_request_context from api.models import FacetResult, FacetValue from api.result_formatter import ResultFormatter @@ -87,37 +79,40 @@ class Searcher: def __init__( self, es_client: ESClient, - query_parser: Optional[QueryParser] = None, - index_name: str = DEFAULT_INDEX_NAME + config: SearchConfig, + query_parser: Optional[QueryParser] = None ): """ Initialize searcher. Args: es_client: Elasticsearch client + config: SearchConfig instance query_parser: Query parser (created if not provided) - index_name: ES index name (default: search_products) """ self.es_client = es_client - self.index_name = index_name - self.query_parser = query_parser or QueryParser() + self.config = config + self.index_name = config.es_index_name + self.query_parser = query_parser or QueryParser(config) # Initialize components self.boolean_parser = BooleanParser() - self.rerank_engine = RerankEngine(RANKING_EXPRESSION, enabled=False) + self.rerank_engine = RerankEngine(config.ranking.expression, enabled=False) - # Use constants from query_config - self.match_fields = DEFAULT_MATCH_FIELDS - self.text_embedding_field = TEXT_EMBEDDING_FIELD - self.image_embedding_field = IMAGE_EMBEDDING_FIELD + # Get match fields from config + self.match_fields = get_match_fields_for_index(config, "default") + self.text_embedding_field = config.query_config.text_embedding_field or "title_embedding" + self.image_embedding_field = config.query_config.image_embedding_field or "image_embedding" + self.source_fields = config.query_config.source_fields or [] # Query builder - simplified single-layer architecture self.query_builder = ESQueryBuilder( - index_name=index_name, + index_name=self.index_name, match_fields=self.match_fields, text_embedding_field=self.text_embedding_field, image_embedding_field=self.image_embedding_field, - source_fields=SOURCE_FIELDS + source_fields=self.source_fields, + function_score_config=self.config.function_score ) def search( @@ -162,8 +157,8 @@ class Searcher: context = create_request_context() # Always use config defaults (these are backend configuration, not user parameters) - enable_translation = ENABLE_TRANSLATION - enable_embedding = ENABLE_TEXT_EMBEDDING + enable_translation = self.config.query_config.enable_translation + enable_embedding = self.config.query_config.enable_text_embedding enable_rerank = False # Temporarily disabled # Start timing @@ -508,9 +503,9 @@ class Searcher: } # Add _source filtering if source_fields are configured - if SOURCE_FIELDS: + if self.source_fields: es_query["_source"] = { - "includes": SOURCE_FIELDS + "includes": self.source_fields } if filters or range_filters: diff --git a/tests/conftest.py b/tests/conftest.py index 7e1421d..f7dc9da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -137,8 +137,8 @@ def mock_es_client() -> Mock: def test_searcher(sample_search_config, mock_es_client) -> Searcher: """测试用Searcher实例""" return Searcher( - config=sample_search_config, - es_client=mock_es_client + es_client=mock_es_client, + config=sample_search_config ) -- libgit2 0.21.2