""" Configuration loader and validator for customer-specific search configurations. This module handles loading, parsing, and validating YAML configuration files that define how each customer's data should be indexed and searched. """ import yaml import os from typing import Dict, Any, List, Optional from dataclasses import dataclass, field from pathlib import Path from .field_types import ( FieldConfig, FieldType, AnalyzerType, FIELD_TYPE_MAP, ANALYZER_MAP ) @dataclass class IndexConfig: """Configuration for an index domain (e.g., default, title, brand).""" name: str label: str fields: List[str] # List of field names to include analyzer: AnalyzerType boost: float = 1.0 example: Optional[str] = None # Multi-language field mapping: {"zh": ["name"], "en": ["enSpuName"], "ru": ["ruSkuName"]} language_field_mapping: Optional[Dict[str, List[str]]] = None @dataclass class RankingConfig: """Configuration for ranking expressions.""" expression: str # e.g., "bm25() + 0.2*text_embedding_relevance()" description: str @dataclass class QueryConfig: """Configuration for query processing.""" supported_languages: List[str] = field(default_factory=lambda: ["zh", "en"]) default_language: str = "zh" enable_translation: bool = True enable_text_embedding: bool = True enable_query_rewrite: bool = True rewrite_dictionary: Dict[str, str] = field(default_factory=dict) # Translation API settings translation_api_key: Optional[str] = None translation_service: str = "deepl" # deepl, google, etc. @dataclass class SPUConfig: """Configuration for SPU aggregation.""" enabled: bool = False spu_field: Optional[str] = None # Field containing SPU ID inner_hits_size: int = 3 @dataclass class FunctionScoreConfig: """Function Score配置(ES层打分规则)""" score_mode: str = "sum" # multiply, sum, avg, first, max, min boost_mode: str = "multiply" # multiply, replace, sum, avg, max, min functions: List[Dict[str, Any]] = field(default_factory=list) @dataclass class RerankConfig: """本地重排配置(当前禁用)""" enabled: bool = False expression: str = "" description: str = "" @dataclass class CustomerConfig: """Complete configuration for a customer.""" customer_id: str customer_name: str # Database settings mysql_config: Dict[str, Any] # Field definitions fields: List[FieldConfig] # Index structure (query domains) indexes: List[IndexConfig] # Query processing query_config: QueryConfig # Ranking configuration ranking: RankingConfig # Function Score configuration (ES层打分) function_score: FunctionScoreConfig # Rerank configuration (本地重排) rerank: RerankConfig # SPU configuration spu_config: SPUConfig # ES index settings es_index_name: str # Optional fields with defaults main_table: str = "shoplazza_product_sku" extension_table: Optional[str] = None es_settings: Dict[str, Any] = field(default_factory=dict) class ConfigurationError(Exception): """Raised when configuration validation fails.""" pass class ConfigLoader: """Loads and validates customer configurations from YAML files.""" def __init__(self, config_dir: str = "config/schema"): self.config_dir = Path(config_dir) def _load_rewrite_dictionary(self, customer_id: str) -> Dict[str, str]: """ Load query rewrite dictionary from external file. Args: customer_id: Customer identifier Returns: Dictionary mapping query terms to rewritten queries """ dict_file = self.config_dir / customer_id / "query_rewrite.dict" if not dict_file.exists(): # Dictionary file is optional, return empty dict if not found return {} rewrite_dict = {} try: with open(dict_file, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): line = line.strip() # Skip empty lines and comments if not line or line.startswith('#'): continue # Parse tab-separated format parts = line.split('\t') if len(parts) != 2: print(f"Warning: Invalid format in {dict_file} line {line_num}: {line}") continue key, value = parts rewrite_dict[key.strip()] = value.strip() except Exception as e: print(f"Error loading rewrite dictionary from {dict_file}: {e}") return {} return rewrite_dict def load_customer_config(self, customer_id: str) -> CustomerConfig: """ Load customer configuration from YAML file. Supports two directory structures: 1. New structure: config/schema/{customer_id}/config.yaml 2. Old structure: config/schema/{customer_id}_config.yaml (for backward compatibility) Args: customer_id: Customer identifier (used to find config file) Returns: CustomerConfig object Raises: ConfigurationError: If config file not found or invalid """ # Try new directory structure first config_file = self.config_dir / customer_id / "config.yaml" # Fall back to old structure if new one doesn't exist if not config_file.exists(): config_file = self.config_dir / f"{customer_id}_config.yaml" if not config_file.exists(): raise ConfigurationError(f"Configuration file not found: {config_file}") try: with open(config_file, 'r', encoding='utf-8') as f: config_data = yaml.safe_load(f) except yaml.YAMLError as e: raise ConfigurationError(f"Invalid YAML in {config_file}: {e}") return self._parse_config(config_data, customer_id) def _parse_config(self, config_data: Dict[str, Any], customer_id: str) -> CustomerConfig: """Parse configuration dictionary into CustomerConfig object.""" # Parse fields fields = [] for field_data in config_data.get("fields", []): fields.append(self._parse_field_config(field_data)) # Parse indexes indexes = [] for index_data in config_data.get("indexes", []): indexes.append(self._parse_index_config(index_data)) # Parse query config query_config_data = config_data.get("query_config", {}) # Load rewrite dictionary from external file instead of config rewrite_dictionary = self._load_rewrite_dictionary(customer_id) query_config = QueryConfig( supported_languages=query_config_data.get("supported_languages", ["zh", "en"]), default_language=query_config_data.get("default_language", "zh"), enable_translation=query_config_data.get("enable_translation", True), enable_text_embedding=query_config_data.get("enable_text_embedding", True), enable_query_rewrite=query_config_data.get("enable_query_rewrite", True), rewrite_dictionary=rewrite_dictionary, translation_api_key=query_config_data.get("translation_api_key"), translation_service=query_config_data.get("translation_service", "deepl") ) # Parse ranking config ranking_data = config_data.get("ranking", {}) ranking = RankingConfig( expression=ranking_data.get("expression", "bm25() + 0.2*text_embedding_relevance()"), description=ranking_data.get("description", "Default BM25 + text embedding ranking") ) # Parse Function Score configuration fs_data = config_data.get("function_score", {}) function_score = FunctionScoreConfig( score_mode=fs_data.get("score_mode", "sum"), boost_mode=fs_data.get("boost_mode", "multiply"), functions=fs_data.get("functions", []) ) # Parse Rerank configuration rerank_data = config_data.get("rerank", {}) rerank = RerankConfig( enabled=rerank_data.get("enabled", False), expression=rerank_data.get("expression", ""), description=rerank_data.get("description", "") ) # Parse SPU config spu_data = config_data.get("spu_config", {}) spu_config = SPUConfig( enabled=spu_data.get("enabled", False), spu_field=spu_data.get("spu_field"), inner_hits_size=spu_data.get("inner_hits_size", 3) ) return CustomerConfig( customer_id=customer_id, customer_name=config_data.get("customer_name", customer_id), mysql_config=config_data.get("mysql_config", {}), main_table=config_data.get("main_table", "shoplazza_product_sku"), extension_table=config_data.get("extension_table"), fields=fields, indexes=indexes, query_config=query_config, ranking=ranking, function_score=function_score, rerank=rerank, spu_config=spu_config, es_index_name=config_data.get("es_index_name", f"search_{customer_id}"), es_settings=config_data.get("es_settings", {}) ) def _parse_field_config(self, field_data: Dict[str, Any]) -> FieldConfig: """Parse field configuration from dictionary.""" name = field_data["name"] field_type_str = field_data["type"] # Map field type string to enum if field_type_str not in FIELD_TYPE_MAP: raise ConfigurationError(f"Unknown field type: {field_type_str}") field_type = FIELD_TYPE_MAP[field_type_str] # Map analyzer string to enum (if provided) analyzer = None analyzer_str = field_data.get("analyzer") if analyzer_str and analyzer_str in ANALYZER_MAP: analyzer = ANALYZER_MAP[analyzer_str] search_analyzer = None search_analyzer_str = field_data.get("search_analyzer") if search_analyzer_str and search_analyzer_str in ANALYZER_MAP: search_analyzer = ANALYZER_MAP[search_analyzer_str] return FieldConfig( name=name, field_type=field_type, source_table=field_data.get("source_table"), source_column=field_data.get("source_column", name), analyzer=analyzer, search_analyzer=search_analyzer, required=field_data.get("required", False), multi_language=field_data.get("multi_language", False), languages=field_data.get("languages"), boost=field_data.get("boost", 1.0), store=field_data.get("store", False), index=field_data.get("index", True), embedding_dims=field_data.get("embedding_dims", 1024), embedding_similarity=field_data.get("embedding_similarity", "dot_product"), nested=field_data.get("nested", False), nested_properties=field_data.get("nested_properties") ) def _parse_index_config(self, index_data: Dict[str, Any]) -> IndexConfig: """Parse index configuration from dictionary.""" analyzer_str = index_data.get("analyzer", "chinese_ecommerce") if analyzer_str not in ANALYZER_MAP: raise ConfigurationError(f"Unknown analyzer: {analyzer_str}") # Parse language field mapping if present language_field_mapping = index_data.get("language_field_mapping") return IndexConfig( name=index_data["name"], label=index_data.get("label", index_data["name"]), fields=index_data["fields"], analyzer=ANALYZER_MAP[analyzer_str], boost=index_data.get("boost", 1.0), example=index_data.get("example"), language_field_mapping=language_field_mapping ) def validate_config(self, config: CustomerConfig) -> List[str]: """ Validate customer configuration. Args: config: Customer configuration to validate Returns: List of validation error messages (empty if valid) """ errors = [] # Validate field references in indexes field_names = {field.name for field in config.fields} field_map = {field.name: field for field in config.fields} for index in config.indexes: # Validate fields in index.fields for field_name in index.fields: if field_name not in field_names: errors.append(f"Index '{index.name}' references unknown field '{field_name}'") # Validate language_field_mapping if present if index.language_field_mapping: for lang, field_list in index.language_field_mapping.items(): if not isinstance(field_list, list): errors.append(f"Index '{index.name}': language_field_mapping['{lang}'] must be a list") continue for field_name in field_list: # Check if field exists if field_name not in field_names: errors.append( f"Index '{index.name}': language_field_mapping['{lang}'] " f"references unknown field '{field_name}'" ) else: # Check if field is TEXT type (multi-language fields should be text fields) field = field_map[field_name] if field.field_type != FieldType.TEXT: errors.append( f"Index '{index.name}': language_field_mapping['{lang}'] " f"field '{field_name}' must be of type TEXT, got {field.field_type.value}" ) # Verify analyzer is appropriate for the language # This is a soft check - we just warn if analyzer doesn't match language if field.analyzer: analyzer_name = field.analyzer.value.lower() expected_analyzers = { 'zh': ['chinese', 'index_ansj', 'query_ansj'], 'en': ['english'], 'ru': ['russian'], 'ar': ['arabic'], 'es': ['spanish'], 'ja': ['japanese'] } if lang in expected_analyzers: expected = expected_analyzers[lang] if not any(exp in analyzer_name for exp in expected): # Warning only, not an error print( f"Warning: Index '{index.name}': field '{field_name}' for language '{lang}' " f"uses analyzer '{analyzer_name}', which may not be optimal for '{lang}'" ) # Validate SPU config if config.spu_config.enabled: if not config.spu_config.spu_field: errors.append("SPU aggregation enabled but no spu_field specified") elif config.spu_config.spu_field not in field_names: errors.append(f"SPU field '{config.spu_config.spu_field}' not found in fields") # Validate embedding fields have proper configuration for field in config.fields: if field.field_type in [FieldType.TEXT_EMBEDDING, FieldType.IMAGE_EMBEDDING]: if field.embedding_dims <= 0: errors.append(f"Field '{field.name}': embedding_dims must be positive") if field.embedding_similarity not in ["dot_product", "cosine", "l2_norm"]: errors.append(f"Field '{field.name}': invalid embedding_similarity") # Validate MySQL config if "host" not in config.mysql_config: errors.append("MySQL configuration missing 'host'") if "username" not in config.mysql_config: errors.append("MySQL configuration missing 'username'") if "password" not in config.mysql_config: errors.append("MySQL configuration missing 'password'") if "database" not in config.mysql_config: errors.append("MySQL configuration missing 'database'") return errors def save_config(self, config: CustomerConfig, output_path: Optional[str] = None) -> None: """ Save customer configuration to YAML file. Note: rewrite_dictionary is saved separately to query_rewrite.dict file Args: config: Configuration to save output_path: Optional output path (defaults to new directory structure) """ if output_path is None: # Use new directory structure by default customer_dir = self.config_dir / config.customer_id customer_dir.mkdir(parents=True, exist_ok=True) output_path = customer_dir / "config.yaml" # Convert config back to dictionary format config_dict = { "customer_name": config.customer_name, "mysql_config": config.mysql_config, "main_table": config.main_table, "extension_table": config.extension_table, "es_index_name": config.es_index_name, "es_settings": config.es_settings, "fields": [self._field_to_dict(field) for field in config.fields], "indexes": [self._index_to_dict(index) for index in config.indexes], "query_config": { "supported_languages": config.query_config.supported_languages, "default_language": config.query_config.default_language, "enable_translation": config.query_config.enable_translation, "enable_text_embedding": config.query_config.enable_text_embedding, "enable_query_rewrite": config.query_config.enable_query_rewrite, # rewrite_dictionary is stored in separate file, not in config "translation_api_key": config.query_config.translation_api_key, "translation_service": config.query_config.translation_service, }, "ranking": { "expression": config.ranking.expression, "description": config.ranking.description }, "spu_config": { "enabled": config.spu_config.enabled, "spu_field": config.spu_config.spu_field, "inner_hits_size": config.spu_config.inner_hits_size } } with open(output_path, 'w', encoding='utf-8') as f: yaml.dump(config_dict, f, default_flow_style=False, allow_unicode=True) # Save rewrite dictionary to separate file self._save_rewrite_dictionary(config.customer_id, config.query_config.rewrite_dictionary) def _save_rewrite_dictionary(self, customer_id: str, rewrite_dict: Dict[str, str]) -> None: """ Save rewrite dictionary to external file. Args: customer_id: Customer identifier rewrite_dict: Dictionary to save """ customer_dir = self.config_dir / customer_id customer_dir.mkdir(parents=True, exist_ok=True) dict_file = customer_dir / "query_rewrite.dict" with open(dict_file, 'w', encoding='utf-8') as f: for key, value in rewrite_dict.items(): f.write(f"{key}\t{value}\n") def _field_to_dict(self, field: FieldConfig) -> Dict[str, Any]: """Convert FieldConfig to dictionary.""" result = { "name": field.name, "type": field.field_type.value, "source_table": field.source_table, "source_column": field.source_column, "required": field.required, "boost": field.boost, "store": field.store, "index": field.index, } if field.analyzer: result["analyzer"] = field.analyzer.value if field.search_analyzer: result["search_analyzer"] = field.search_analyzer.value if field.multi_language: result["multi_language"] = field.multi_language result["languages"] = field.languages if field.embedding_dims != 1024: result["embedding_dims"] = field.embedding_dims if field.embedding_similarity != "dot_product": result["embedding_similarity"] = field.embedding_similarity if field.nested: result["nested"] = field.nested result["nested_properties"] = field.nested_properties return result def _index_to_dict(self, index: IndexConfig) -> Dict[str, Any]: """Convert IndexConfig to dictionary.""" result = { "name": index.name, "label": index.label, "fields": index.fields, "analyzer": index.analyzer.value, "boost": index.boost, "example": index.example } if index.language_field_mapping: result["language_field_mapping"] = index.language_field_mapping return result