Commit 325eec032f3afde0f658f0b687d7132340027b68
1 parent
3bb1af6b
1. 日志、配置基础设施,使用优化
2. 向量服务不用本地预估,改用网络服务
Showing
16 changed files
with
645 additions
and
326 deletions
Show diff stats
api/app.py
| @@ -9,6 +9,8 @@ import os | @@ -9,6 +9,8 @@ import os | ||
| 9 | import sys | 9 | import sys |
| 10 | import logging | 10 | import logging |
| 11 | import time | 11 | import time |
| 12 | +import argparse | ||
| 13 | +import uvicorn | ||
| 12 | from collections import defaultdict, deque | 14 | from collections import defaultdict, deque |
| 13 | from typing import Optional | 15 | from typing import Optional |
| 14 | from fastapi import FastAPI, Request, HTTPException | 16 | from fastapi import FastAPI, Request, HTTPException |
| @@ -20,7 +22,6 @@ from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware | @@ -20,7 +22,6 @@ from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware | ||
| 20 | from slowapi import Limiter, _rate_limit_exceeded_handler | 22 | from slowapi import Limiter, _rate_limit_exceeded_handler |
| 21 | from slowapi.util import get_remote_address | 23 | from slowapi.util import get_remote_address |
| 22 | from slowapi.errors import RateLimitExceeded | 24 | from slowapi.errors import RateLimitExceeded |
| 23 | -import argparse | ||
| 24 | 25 | ||
| 25 | # Configure logging with better formatting | 26 | # Configure logging with better formatting |
| 26 | logging.basicConfig( | 27 | logging.basicConfig( |
| @@ -40,6 +41,7 @@ limiter = Limiter(key_func=get_remote_address) | @@ -40,6 +41,7 @@ limiter = Limiter(key_func=get_remote_address) | ||
| 40 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | 41 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| 41 | 42 | ||
| 42 | from config import ConfigLoader, SearchConfig | 43 | from config import ConfigLoader, SearchConfig |
| 44 | +from config.env_config import ES_CONFIG | ||
| 43 | from utils import ESClient | 45 | from utils import ESClient |
| 44 | from search import Searcher | 46 | from search import Searcher |
| 45 | from query import QueryParser | 47 | from query import QueryParser |
| @@ -60,55 +62,42 @@ def init_service(es_host: str = "http://localhost:9200"): | @@ -60,55 +62,42 @@ def init_service(es_host: str = "http://localhost:9200"): | ||
| 60 | """ | 62 | """ |
| 61 | global _config, _es_client, _searcher, _query_parser | 63 | global _config, _es_client, _searcher, _query_parser |
| 62 | 64 | ||
| 63 | - print("Initializing search service (multi-tenant)") | 65 | + start_time = time.time() |
| 66 | + logger.info("Initializing search service (multi-tenant)") | ||
| 64 | 67 | ||
| 65 | - # Load unified configuration | 68 | + # Load and validate configuration |
| 69 | + logger.info("Loading configuration...") | ||
| 66 | config_loader = ConfigLoader("config/config.yaml") | 70 | config_loader = ConfigLoader("config/config.yaml") |
| 67 | _config = config_loader.load_config() | 71 | _config = config_loader.load_config() |
| 68 | - | ||
| 69 | - # Validate configuration | ||
| 70 | errors = config_loader.validate_config(_config) | 72 | errors = config_loader.validate_config(_config) |
| 71 | if errors: | 73 | if errors: |
| 72 | raise ValueError(f"Configuration validation failed: {errors}") | 74 | raise ValueError(f"Configuration validation failed: {errors}") |
| 75 | + logger.info(f"Configuration loaded: {_config.es_index_name}") | ||
| 73 | 76 | ||
| 74 | - print(f"Configuration loaded: {_config.es_index_name}") | 77 | + # Get ES credentials |
| 78 | + es_username = os.getenv('ES_USERNAME') or ES_CONFIG.get('username') | ||
| 79 | + es_password = os.getenv('ES_PASSWORD') or ES_CONFIG.get('password') | ||
| 75 | 80 | ||
| 76 | - # Get ES credentials from environment variables or .env file | ||
| 77 | - es_username = os.getenv('ES_USERNAME') | ||
| 78 | - es_password = os.getenv('ES_PASSWORD') | ||
| 79 | - | ||
| 80 | - # Try to load from config if not in env | ||
| 81 | - if not es_username or not es_password: | ||
| 82 | - try: | ||
| 83 | - from config.env_config import get_es_config | ||
| 84 | - es_config = get_es_config() | ||
| 85 | - es_username = es_username or es_config.get('username') | ||
| 86 | - es_password = es_password or es_config.get('password') | ||
| 87 | - except Exception: | ||
| 88 | - pass | ||
| 89 | - | ||
| 90 | - # Initialize ES client with authentication if credentials are available | 81 | + # Connect to Elasticsearch |
| 82 | + logger.info(f"Connecting to Elasticsearch at {es_host}...") | ||
| 91 | if es_username and es_password: | 83 | if es_username and es_password: |
| 92 | - print(f"Connecting to Elasticsearch with authentication: {es_username}") | ||
| 93 | _es_client = ESClient(hosts=[es_host], username=es_username, password=es_password) | 84 | _es_client = ESClient(hosts=[es_host], username=es_username, password=es_password) |
| 94 | else: | 85 | else: |
| 95 | - print(f"Connecting to Elasticsearch without authentication") | ||
| 96 | _es_client = ESClient(hosts=[es_host]) | 86 | _es_client = ESClient(hosts=[es_host]) |
| 97 | 87 | ||
| 98 | if not _es_client.ping(): | 88 | if not _es_client.ping(): |
| 99 | raise ConnectionError(f"Failed to connect to Elasticsearch at {es_host}") | 89 | raise ConnectionError(f"Failed to connect to Elasticsearch at {es_host}") |
| 90 | + logger.info("Elasticsearch connected") | ||
| 100 | 91 | ||
| 101 | - print(f"Connected to Elasticsearch: {es_host}") | ||
| 102 | - | ||
| 103 | - # Initialize query parser | 92 | + # Initialize components |
| 93 | + logger.info("Initializing query parser...") | ||
| 104 | _query_parser = QueryParser(_config) | 94 | _query_parser = QueryParser(_config) |
| 105 | - print("Query parser initialized") | ||
| 106 | - | ||
| 107 | - # Initialize searcher | 95 | + |
| 96 | + logger.info("Initializing searcher...") | ||
| 108 | _searcher = Searcher(_config, _es_client, _query_parser) | 97 | _searcher = Searcher(_config, _es_client, _query_parser) |
| 109 | - print("Searcher initialized") | ||
| 110 | - | ||
| 111 | - print("Search service ready!") | 98 | + |
| 99 | + elapsed = time.time() - start_time | ||
| 100 | + logger.info(f"Search service ready! (took {elapsed:.2f}s)") | ||
| 112 | 101 | ||
| 113 | 102 | ||
| 114 | def get_config() -> SearchConfig: | 103 | def get_config() -> SearchConfig: |
| @@ -305,8 +294,6 @@ else: | @@ -305,8 +294,6 @@ else: | ||
| 305 | 294 | ||
| 306 | 295 | ||
| 307 | if __name__ == "__main__": | 296 | if __name__ == "__main__": |
| 308 | - import uvicorn | ||
| 309 | - | ||
| 310 | parser = argparse.ArgumentParser(description='Start search API service (multi-tenant)') | 297 | parser = argparse.ArgumentParser(description='Start search API service (multi-tenant)') |
| 311 | parser.add_argument('--host', default='0.0.0.0', help='Host to bind to') | 298 | parser.add_argument('--host', default='0.0.0.0', help='Host to bind to') |
| 312 | parser.add_argument('--port', type=int, default=6002, help='Port to bind to') | 299 | parser.add_argument('--port', type=int, default=6002, help='Port to bind to') |
config/config.yaml
| @@ -147,6 +147,14 @@ fields: | @@ -147,6 +147,14 @@ fields: | ||
| 147 | index: false | 147 | index: false |
| 148 | store: true | 148 | store: true |
| 149 | 149 | ||
| 150 | + # 文本嵌入字段(用于语义搜索) | ||
| 151 | + - name: "name_embedding" | ||
| 152 | + type: "TEXT_EMBEDDING" | ||
| 153 | + embedding_dims: 1024 | ||
| 154 | + embedding_similarity: "dot_product" | ||
| 155 | + index: true | ||
| 156 | + store: false | ||
| 157 | + | ||
| 150 | # 嵌套variants字段 | 158 | # 嵌套variants字段 |
| 151 | - name: "variants" | 159 | - name: "variants" |
| 152 | type: "JSON" | 160 | type: "JSON" |
| @@ -239,6 +247,10 @@ query_config: | @@ -239,6 +247,10 @@ query_config: | ||
| 239 | enable_text_embedding: true | 247 | enable_text_embedding: true |
| 240 | enable_query_rewrite: true | 248 | enable_query_rewrite: true |
| 241 | 249 | ||
| 250 | + # Embedding field names (if not set, will auto-detect from fields) | ||
| 251 | + text_embedding_field: "name_embedding" # Field name for text embeddings | ||
| 252 | + image_embedding_field: null # Field name for image embeddings (if not set, will auto-detect) | ||
| 253 | + | ||
| 242 | # Translation API (DeepL) | 254 | # Translation API (DeepL) |
| 243 | translation_service: "deepl" | 255 | translation_service: "deepl" |
| 244 | translation_api_key: null # Set via environment variable | 256 | translation_api_key: null # Set via environment variable |
config/config_loader.py
| @@ -54,6 +54,10 @@ class QueryConfig: | @@ -54,6 +54,10 @@ class QueryConfig: | ||
| 54 | translation_glossary_id: Optional[str] = None # DeepL glossary ID for custom terminology | 54 | translation_glossary_id: Optional[str] = None # DeepL glossary ID for custom terminology |
| 55 | translation_context: str = "e-commerce product search" # Context hint for translation | 55 | translation_context: str = "e-commerce product search" # Context hint for translation |
| 56 | 56 | ||
| 57 | + # Embedding field names - if not set, will auto-detect from fields | ||
| 58 | + text_embedding_field: Optional[str] = None # Field name for text embeddings (e.g., "name_embedding") | ||
| 59 | + image_embedding_field: Optional[str] = None # Field name for image embeddings (e.g., "image_embedding") | ||
| 60 | + | ||
| 57 | # ES source fields configuration - fields to return in search results | 61 | # ES source fields configuration - fields to return in search results |
| 58 | source_fields: List[str] = field(default_factory=lambda: [ | 62 | source_fields: List[str] = field(default_factory=lambda: [ |
| 59 | "id", "spuId", "skuNo", "spuNo", "title", "enSpuName", "brandId", | 63 | "id", "spuId", "skuNo", "spuNo", "title", "enSpuName", "brandId", |
| @@ -213,7 +217,9 @@ class ConfigLoader: | @@ -213,7 +217,9 @@ class ConfigLoader: | ||
| 213 | translation_api_key=query_config_data.get("translation_api_key"), | 217 | translation_api_key=query_config_data.get("translation_api_key"), |
| 214 | translation_service=query_config_data.get("translation_service", "deepl"), | 218 | translation_service=query_config_data.get("translation_service", "deepl"), |
| 215 | translation_glossary_id=query_config_data.get("translation_glossary_id"), | 219 | translation_glossary_id=query_config_data.get("translation_glossary_id"), |
| 216 | - translation_context=query_config_data.get("translation_context", "e-commerce product search") | 220 | + translation_context=query_config_data.get("translation_context", "e-commerce product search"), |
| 221 | + text_embedding_field=query_config_data.get("text_embedding_field"), | ||
| 222 | + image_embedding_field=query_config_data.get("image_embedding_field") | ||
| 217 | ) | 223 | ) |
| 218 | 224 | ||
| 219 | # Parse ranking config | 225 | # Parse ranking config |
config/env_config.py
| @@ -2,11 +2,12 @@ | @@ -2,11 +2,12 @@ | ||
| 2 | Centralized configuration management for SearchEngine. | 2 | Centralized configuration management for SearchEngine. |
| 3 | 3 | ||
| 4 | Loads configuration from environment variables and .env file. | 4 | Loads configuration from environment variables and .env file. |
| 5 | +This module provides a single point for loading .env and setting defaults. | ||
| 6 | +All configuration variables are exported directly - no need for getter functions. | ||
| 5 | """ | 7 | """ |
| 6 | 8 | ||
| 7 | import os | 9 | import os |
| 8 | from pathlib import Path | 10 | from pathlib import Path |
| 9 | -from typing import Dict, Any | ||
| 10 | from dotenv import load_dotenv | 11 | from dotenv import load_dotenv |
| 11 | 12 | ||
| 12 | # Load .env file from project root | 13 | # Load .env file from project root |
| @@ -56,26 +57,6 @@ DB_CONFIG = { | @@ -56,26 +57,6 @@ DB_CONFIG = { | ||
| 56 | } | 57 | } |
| 57 | 58 | ||
| 58 | 59 | ||
| 59 | -def get_es_config() -> Dict[str, Any]: | ||
| 60 | - """Get Elasticsearch configuration.""" | ||
| 61 | - return ES_CONFIG.copy() | ||
| 62 | - | ||
| 63 | - | ||
| 64 | -def get_redis_config() -> Dict[str, Any]: | ||
| 65 | - """Get Redis configuration.""" | ||
| 66 | - return REDIS_CONFIG.copy() | ||
| 67 | - | ||
| 68 | - | ||
| 69 | -def get_deepl_key() -> str: | ||
| 70 | - """Get DeepL API key.""" | ||
| 71 | - return DEEPL_AUTH_KEY | ||
| 72 | - | ||
| 73 | - | ||
| 74 | -def get_db_config() -> Dict[str, Any]: | ||
| 75 | - """Get MySQL database configuration.""" | ||
| 76 | - return DB_CONFIG.copy() | ||
| 77 | - | ||
| 78 | - | ||
| 79 | def print_config(): | 60 | def print_config(): |
| 80 | """Print current configuration (with sensitive data masked).""" | 61 | """Print current configuration (with sensitive data masked).""" |
| 81 | print("=" * 60) | 62 | print("=" * 60) |
embeddings/image_encoder.py
| 1 | """ | 1 | """ |
| 2 | -Image embedding encoder using CN-CLIP model. | 2 | +Image embedding encoder using network service. |
| 3 | 3 | ||
| 4 | -Generates 1024-dimensional vectors for images using the CN-CLIP ViT-H-14 model. | 4 | +Generates embeddings via HTTP API service running on localhost:5001. |
| 5 | """ | 5 | """ |
| 6 | 6 | ||
| 7 | import sys | 7 | import sys |
| 8 | import os | 8 | import os |
| 9 | -import io | ||
| 10 | import requests | 9 | import requests |
| 11 | -import torch | ||
| 12 | import numpy as np | 10 | import numpy as np |
| 13 | from PIL import Image | 11 | from PIL import Image |
| 14 | import logging | 12 | import logging |
| 15 | import threading | 13 | import threading |
| 16 | -from typing import List, Optional, Union | ||
| 17 | -import cn_clip.clip as clip | ||
| 18 | -from cn_clip.clip import load_from_name | 14 | +from typing import List, Optional, Union, Dict, Any |
| 19 | 15 | ||
| 20 | - | ||
| 21 | -# DEFAULT_MODEL_NAME = "ViT-L-14-336" # ["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"] | ||
| 22 | -DEFAULT_MODEL_NAME = "ViT-H-14" | ||
| 23 | -MODEL_DOWNLOAD_DIR = "/data/tw/uat/EsSearcher" | 16 | +logger = logging.getLogger(__name__) |
| 24 | 17 | ||
| 25 | 18 | ||
| 26 | class CLIPImageEncoder: | 19 | class CLIPImageEncoder: |
| 27 | """ | 20 | """ |
| 28 | - CLIP Image Encoder for generating image embeddings using cn_clip. | 21 | + Image Encoder for generating image embeddings using network service. |
| 29 | 22 | ||
| 30 | Thread-safe singleton pattern. | 23 | Thread-safe singleton pattern. |
| 31 | """ | 24 | """ |
| @@ -33,111 +26,80 @@ class CLIPImageEncoder: | @@ -33,111 +26,80 @@ class CLIPImageEncoder: | ||
| 33 | _instance = None | 26 | _instance = None |
| 34 | _lock = threading.Lock() | 27 | _lock = threading.Lock() |
| 35 | 28 | ||
| 36 | - def __new__(cls, model_name=DEFAULT_MODEL_NAME, device=None): | 29 | + def __new__(cls, service_url='http://localhost:5001'): |
| 37 | with cls._lock: | 30 | with cls._lock: |
| 38 | if cls._instance is None: | 31 | if cls._instance is None: |
| 39 | cls._instance = super(CLIPImageEncoder, cls).__new__(cls) | 32 | cls._instance = super(CLIPImageEncoder, cls).__new__(cls) |
| 40 | - print(f"[CLIPImageEncoder] Creating new instance with model: {model_name}") | ||
| 41 | - cls._instance._initialize_model(model_name, device) | 33 | + logger.info(f"Creating CLIPImageEncoder instance with service URL: {service_url}") |
| 34 | + cls._instance.service_url = service_url | ||
| 35 | + cls._instance.endpoint = f"{service_url}/embedding/generate_image_embeddings" | ||
| 42 | return cls._instance | 36 | return cls._instance |
| 43 | 37 | ||
| 44 | - def _initialize_model(self, model_name, device): | ||
| 45 | - """Initialize the CLIP model using cn_clip""" | ||
| 46 | - try: | ||
| 47 | - self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | ||
| 48 | - self.model, self.preprocess = load_from_name( | ||
| 49 | - model_name, | ||
| 50 | - device=self.device, | ||
| 51 | - download_root=MODEL_DOWNLOAD_DIR | ||
| 52 | - ) | ||
| 53 | - self.model.eval() | ||
| 54 | - self.model_name = model_name | ||
| 55 | - print(f"[CLIPImageEncoder] Model {model_name} initialized successfully on device {self.device}") | ||
| 56 | - | ||
| 57 | - except Exception as e: | ||
| 58 | - print(f"[CLIPImageEncoder] Failed to initialize model: {str(e)}") | ||
| 59 | - raise | 38 | + def _call_service(self, request_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| 39 | + """ | ||
| 40 | + Call the embedding service API. | ||
| 60 | 41 | ||
| 61 | - def validate_image(self, image_data: bytes) -> Image.Image: | ||
| 62 | - """Validate image data and return PIL Image if valid""" | ||
| 63 | - try: | ||
| 64 | - image_stream = io.BytesIO(image_data) | ||
| 65 | - image = Image.open(image_stream) | ||
| 66 | - image.verify() | ||
| 67 | - image_stream.seek(0) | ||
| 68 | - image = Image.open(image_stream) | ||
| 69 | - if image.mode != 'RGB': | ||
| 70 | - image = image.convert('RGB') | ||
| 71 | - return image | ||
| 72 | - except Exception as e: | ||
| 73 | - raise ValueError(f"Invalid image data: {str(e)}") | 42 | + Args: |
| 43 | + request_data: List of dictionaries with id and pic_url fields | ||
| 74 | 44 | ||
| 75 | - def download_image(self, url: str, timeout: int = 10) -> bytes: | ||
| 76 | - """Download image from URL""" | 45 | + Returns: |
| 46 | + List of dictionaries with id, pic_url, embedding and error fields | ||
| 47 | + """ | ||
| 77 | try: | 48 | try: |
| 78 | - if url.startswith(('http://', 'https://')): | ||
| 79 | - response = requests.get(url, timeout=timeout) | ||
| 80 | - if response.status_code != 200: | ||
| 81 | - raise ValueError(f"HTTP {response.status_code}") | ||
| 82 | - return response.content | ||
| 83 | - else: | ||
| 84 | - # Local file path | ||
| 85 | - with open(url, 'rb') as f: | ||
| 86 | - return f.read() | ||
| 87 | - except Exception as e: | ||
| 88 | - raise ValueError(f"Failed to download image from {url}: {str(e)}") | ||
| 89 | - | ||
| 90 | - def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image: | ||
| 91 | - """Preprocess image for CLIP model""" | ||
| 92 | - # Resize if too large | ||
| 93 | - if max(image.size) > max_size: | ||
| 94 | - ratio = max_size / max(image.size) | ||
| 95 | - new_size = tuple(int(dim * ratio) for dim in image.size) | ||
| 96 | - image = image.resize(new_size, Image.Resampling.LANCZOS) | ||
| 97 | - return image | ||
| 98 | - | ||
| 99 | - def encode_text(self, text): | ||
| 100 | - """Encode text to embedding vector using cn_clip""" | ||
| 101 | - text_data = clip.tokenize([text] if type(text) == str else text).to(self.device) | ||
| 102 | - with torch.no_grad(): | ||
| 103 | - text_features = self.model.encode_text(text_data) | ||
| 104 | - text_features /= text_features.norm(dim=-1, keepdim=True) | ||
| 105 | - return text_features | 49 | + response = requests.post( |
| 50 | + self.endpoint, | ||
| 51 | + json=request_data, | ||
| 52 | + timeout=60 | ||
| 53 | + ) | ||
| 54 | + response.raise_for_status() | ||
| 55 | + return response.json() | ||
| 56 | + except requests.exceptions.RequestException as e: | ||
| 57 | + logger.error(f"CLIPImageEncoder service request failed: {e}", exc_info=True) | ||
| 58 | + raise | ||
| 106 | 59 | ||
| 107 | def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: | 60 | def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: |
| 108 | - """Encode image to embedding vector using cn_clip""" | ||
| 109 | - if not isinstance(image, Image.Image): | ||
| 110 | - raise ValueError("CLIPImageEncoder.encode_image Input must be a PIL.Image") | 61 | + """ |
| 62 | + Encode image to embedding vector using network service. | ||
| 111 | 63 | ||
| 112 | - try: | ||
| 113 | - infer_data = self.preprocess(image).unsqueeze(0).to(self.device) | ||
| 114 | - with torch.no_grad(): | ||
| 115 | - image_features = self.model.encode_image(infer_data) | ||
| 116 | - image_features /= image_features.norm(dim=-1, keepdim=True) | ||
| 117 | - return image_features.cpu().numpy().astype('float32')[0] | ||
| 118 | - except Exception as e: | ||
| 119 | - print(f"Failed to process image. Reason: {str(e)}") | ||
| 120 | - return None | 64 | + Note: This method is kept for compatibility but the service only works with URLs. |
| 65 | + """ | ||
| 66 | + logger.warning("encode_image with PIL Image not supported by service, returning None") | ||
| 67 | + return None | ||
| 121 | 68 | ||
| 122 | def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: | 69 | def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: |
| 123 | - """Complete pipeline: download, validate, preprocess and encode image from URL""" | ||
| 124 | - try: | ||
| 125 | - # Download image | ||
| 126 | - image_data = self.download_image(url) | ||
| 127 | - | ||
| 128 | - # Validate image | ||
| 129 | - image = self.validate_image(image_data) | ||
| 130 | - | ||
| 131 | - # Preprocess image | ||
| 132 | - image = self.preprocess_image(image) | 70 | + """ |
| 71 | + Generate image embedding via network service using URL. | ||
| 133 | 72 | ||
| 134 | - # Encode image | ||
| 135 | - embedding = self.encode_image(image) | 73 | + Args: |
| 74 | + url: Image URL to process | ||
| 136 | 75 | ||
| 137 | - return embedding | 76 | + Returns: |
| 77 | + Embedding vector or None if failed | ||
| 78 | + """ | ||
| 79 | + try: | ||
| 80 | + # Prepare request data | ||
| 81 | + request_data = [{ | ||
| 82 | + "id": "image_0", | ||
| 83 | + "pic_url": url | ||
| 84 | + }] | ||
| 85 | + | ||
| 86 | + # Call service | ||
| 87 | + response_data = self._call_service(request_data) | ||
| 88 | + | ||
| 89 | + # Process response | ||
| 90 | + if response_data and len(response_data) > 0: | ||
| 91 | + response_item = response_data[0] | ||
| 92 | + if response_item.get("embedding"): | ||
| 93 | + return np.array(response_item["embedding"], dtype=np.float32) | ||
| 94 | + else: | ||
| 95 | + logger.warning(f"No embedding for URL {url}, error: {response_item.get('error', 'Unknown error')}") | ||
| 96 | + return None | ||
| 97 | + else: | ||
| 98 | + logger.warning(f"No response for URL {url}") | ||
| 99 | + return None | ||
| 138 | 100 | ||
| 139 | except Exception as e: | 101 | except Exception as e: |
| 140 | - print(f"Error processing image from URL {url}: {str(e)}") | 102 | + logger.error(f"Failed to process image from URL {url}: {str(e)}", exc_info=True) |
| 141 | return None | 103 | return None |
| 142 | 104 | ||
| 143 | def encode_batch( | 105 | def encode_batch( |
| @@ -146,33 +108,71 @@ class CLIPImageEncoder: | @@ -146,33 +108,71 @@ class CLIPImageEncoder: | ||
| 146 | batch_size: int = 8 | 108 | batch_size: int = 8 |
| 147 | ) -> List[Optional[np.ndarray]]: | 109 | ) -> List[Optional[np.ndarray]]: |
| 148 | """ | 110 | """ |
| 149 | - Encode a batch of images efficiently. | 111 | + Encode a batch of images efficiently via network service. |
| 150 | 112 | ||
| 151 | Args: | 113 | Args: |
| 152 | images: List of image URLs or PIL Images | 114 | images: List of image URLs or PIL Images |
| 153 | - batch_size: Batch size for processing | 115 | + batch_size: Batch size for processing (used for service requests) |
| 154 | 116 | ||
| 155 | Returns: | 117 | Returns: |
| 156 | List of embeddings (or None for failed images) | 118 | List of embeddings (or None for failed images) |
| 157 | """ | 119 | """ |
| 158 | - results = [] | ||
| 159 | - | ||
| 160 | - for i in range(0, len(images), batch_size): | ||
| 161 | - batch = images[i:i + batch_size] | ||
| 162 | - batch_embeddings = [] | ||
| 163 | - | ||
| 164 | - for img in batch: | ||
| 165 | - if isinstance(img, str): | ||
| 166 | - # URL or file path | ||
| 167 | - emb = self.encode_image_from_url(img) | ||
| 168 | - elif isinstance(img, Image.Image): | ||
| 169 | - # PIL Image | ||
| 170 | - emb = self.encode_image(img) | ||
| 171 | - else: | ||
| 172 | - emb = None | ||
| 173 | - | ||
| 174 | - batch_embeddings.append(emb) | ||
| 175 | - | ||
| 176 | - results.extend(batch_embeddings) | 120 | + # Initialize results with None for all images |
| 121 | + results = [None] * len(images) | ||
| 122 | + | ||
| 123 | + # Filter out PIL Images since service only supports URLs | ||
| 124 | + url_images = [] | ||
| 125 | + url_indices = [] | ||
| 126 | + | ||
| 127 | + for i, img in enumerate(images): | ||
| 128 | + if isinstance(img, str): | ||
| 129 | + url_images.append(img) | ||
| 130 | + url_indices.append(i) | ||
| 131 | + elif isinstance(img, Image.Image): | ||
| 132 | + logger.warning(f"PIL Image at index {i} not supported by service, returning None") | ||
| 133 | + # results[i] is already None | ||
| 134 | + | ||
| 135 | + # Process URLs in batches | ||
| 136 | + for i in range(0, len(url_images), batch_size): | ||
| 137 | + batch_urls = url_images[i:i + batch_size] | ||
| 138 | + batch_indices = url_indices[i:i + batch_size] | ||
| 139 | + | ||
| 140 | + # Prepare request data | ||
| 141 | + request_data = [] | ||
| 142 | + for j, url in enumerate(batch_urls): | ||
| 143 | + request_data.append({ | ||
| 144 | + "id": f"image_{j}", | ||
| 145 | + "pic_url": url | ||
| 146 | + }) | ||
| 147 | + | ||
| 148 | + try: | ||
| 149 | + # Call service | ||
| 150 | + response_data = self._call_service(request_data) | ||
| 151 | + | ||
| 152 | + # Process response | ||
| 153 | + batch_results = [] | ||
| 154 | + for j, url in enumerate(batch_urls): | ||
| 155 | + response_item = None | ||
| 156 | + for item in response_data: | ||
| 157 | + if str(item.get("id")) == f"image_{j}": | ||
| 158 | + response_item = item | ||
| 159 | + break | ||
| 160 | + | ||
| 161 | + if response_item and response_item.get("embedding"): | ||
| 162 | + batch_results.append(np.array(response_item["embedding"], dtype=np.float32)) | ||
| 163 | + else: | ||
| 164 | + error_msg = response_item.get("error", "Unknown error") if response_item else "No response" | ||
| 165 | + logger.warning(f"Failed to encode URL {url}: {error_msg}") | ||
| 166 | + batch_results.append(None) | ||
| 167 | + | ||
| 168 | + # Insert results at the correct positions | ||
| 169 | + for j, result in enumerate(batch_results): | ||
| 170 | + results[batch_indices[j]] = result | ||
| 171 | + | ||
| 172 | + except Exception as e: | ||
| 173 | + logger.error(f"Batch processing failed: {e}", exc_info=True) | ||
| 174 | + # Fill with None for this batch | ||
| 175 | + for j in range(len(batch_urls)): | ||
| 176 | + results[batch_indices[j]] = None | ||
| 177 | 177 | ||
| 178 | return results | 178 | return results |
| @@ -0,0 +1,178 @@ | @@ -0,0 +1,178 @@ | ||
| 1 | +""" | ||
| 2 | +Image embedding encoder using CN-CLIP model. | ||
| 3 | + | ||
| 4 | +Generates 1024-dimensional vectors for images using the CN-CLIP ViT-H-14 model. | ||
| 5 | +""" | ||
| 6 | + | ||
| 7 | +import sys | ||
| 8 | +import os | ||
| 9 | +import io | ||
| 10 | +import requests | ||
| 11 | +import torch | ||
| 12 | +import numpy as np | ||
| 13 | +from PIL import Image | ||
| 14 | +import logging | ||
| 15 | +import threading | ||
| 16 | +from typing import List, Optional, Union | ||
| 17 | +import cn_clip.clip as clip | ||
| 18 | +from cn_clip.clip import load_from_name | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +# DEFAULT_MODEL_NAME = "ViT-L-14-336" # ["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"] | ||
| 22 | +DEFAULT_MODEL_NAME = "ViT-H-14" | ||
| 23 | +MODEL_DOWNLOAD_DIR = "/data/tw/uat/EsSearcher" | ||
| 24 | + | ||
| 25 | + | ||
| 26 | +class CLIPImageEncoder: | ||
| 27 | + """ | ||
| 28 | + CLIP Image Encoder for generating image embeddings using cn_clip. | ||
| 29 | + | ||
| 30 | + Thread-safe singleton pattern. | ||
| 31 | + """ | ||
| 32 | + | ||
| 33 | + _instance = None | ||
| 34 | + _lock = threading.Lock() | ||
| 35 | + | ||
| 36 | + def __new__(cls, model_name=DEFAULT_MODEL_NAME, device=None): | ||
| 37 | + with cls._lock: | ||
| 38 | + if cls._instance is None: | ||
| 39 | + cls._instance = super(CLIPImageEncoder, cls).__new__(cls) | ||
| 40 | + print(f"[CLIPImageEncoder] Creating new instance with model: {model_name}") | ||
| 41 | + cls._instance._initialize_model(model_name, device) | ||
| 42 | + return cls._instance | ||
| 43 | + | ||
| 44 | + def _initialize_model(self, model_name, device): | ||
| 45 | + """Initialize the CLIP model using cn_clip""" | ||
| 46 | + try: | ||
| 47 | + self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | ||
| 48 | + self.model, self.preprocess = load_from_name( | ||
| 49 | + model_name, | ||
| 50 | + device=self.device, | ||
| 51 | + download_root=MODEL_DOWNLOAD_DIR | ||
| 52 | + ) | ||
| 53 | + self.model.eval() | ||
| 54 | + self.model_name = model_name | ||
| 55 | + print(f"[CLIPImageEncoder] Model {model_name} initialized successfully on device {self.device}") | ||
| 56 | + | ||
| 57 | + except Exception as e: | ||
| 58 | + print(f"[CLIPImageEncoder] Failed to initialize model: {str(e)}") | ||
| 59 | + raise | ||
| 60 | + | ||
| 61 | + def validate_image(self, image_data: bytes) -> Image.Image: | ||
| 62 | + """Validate image data and return PIL Image if valid""" | ||
| 63 | + try: | ||
| 64 | + image_stream = io.BytesIO(image_data) | ||
| 65 | + image = Image.open(image_stream) | ||
| 66 | + image.verify() | ||
| 67 | + image_stream.seek(0) | ||
| 68 | + image = Image.open(image_stream) | ||
| 69 | + if image.mode != 'RGB': | ||
| 70 | + image = image.convert('RGB') | ||
| 71 | + return image | ||
| 72 | + except Exception as e: | ||
| 73 | + raise ValueError(f"Invalid image data: {str(e)}") | ||
| 74 | + | ||
| 75 | + def download_image(self, url: str, timeout: int = 10) -> bytes: | ||
| 76 | + """Download image from URL""" | ||
| 77 | + try: | ||
| 78 | + if url.startswith(('http://', 'https://')): | ||
| 79 | + response = requests.get(url, timeout=timeout) | ||
| 80 | + if response.status_code != 200: | ||
| 81 | + raise ValueError(f"HTTP {response.status_code}") | ||
| 82 | + return response.content | ||
| 83 | + else: | ||
| 84 | + # Local file path | ||
| 85 | + with open(url, 'rb') as f: | ||
| 86 | + return f.read() | ||
| 87 | + except Exception as e: | ||
| 88 | + raise ValueError(f"Failed to download image from {url}: {str(e)}") | ||
| 89 | + | ||
| 90 | + def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image: | ||
| 91 | + """Preprocess image for CLIP model""" | ||
| 92 | + # Resize if too large | ||
| 93 | + if max(image.size) > max_size: | ||
| 94 | + ratio = max_size / max(image.size) | ||
| 95 | + new_size = tuple(int(dim * ratio) for dim in image.size) | ||
| 96 | + image = image.resize(new_size, Image.Resampling.LANCZOS) | ||
| 97 | + return image | ||
| 98 | + | ||
| 99 | + def encode_text(self, text): | ||
| 100 | + """Encode text to embedding vector using cn_clip""" | ||
| 101 | + text_data = clip.tokenize([text] if type(text) == str else text).to(self.device) | ||
| 102 | + with torch.no_grad(): | ||
| 103 | + text_features = self.model.encode_text(text_data) | ||
| 104 | + text_features /= text_features.norm(dim=-1, keepdim=True) | ||
| 105 | + return text_features | ||
| 106 | + | ||
| 107 | + def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: | ||
| 108 | + """Encode image to embedding vector using cn_clip""" | ||
| 109 | + if not isinstance(image, Image.Image): | ||
| 110 | + raise ValueError("CLIPImageEncoder.encode_image Input must be a PIL.Image") | ||
| 111 | + | ||
| 112 | + try: | ||
| 113 | + infer_data = self.preprocess(image).unsqueeze(0).to(self.device) | ||
| 114 | + with torch.no_grad(): | ||
| 115 | + image_features = self.model.encode_image(infer_data) | ||
| 116 | + image_features /= image_features.norm(dim=-1, keepdim=True) | ||
| 117 | + return image_features.cpu().numpy().astype('float32')[0] | ||
| 118 | + except Exception as e: | ||
| 119 | + print(f"Failed to process image. Reason: {str(e)}") | ||
| 120 | + return None | ||
| 121 | + | ||
| 122 | + def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: | ||
| 123 | + """Complete pipeline: download, validate, preprocess and encode image from URL""" | ||
| 124 | + try: | ||
| 125 | + # Download image | ||
| 126 | + image_data = self.download_image(url) | ||
| 127 | + | ||
| 128 | + # Validate image | ||
| 129 | + image = self.validate_image(image_data) | ||
| 130 | + | ||
| 131 | + # Preprocess image | ||
| 132 | + image = self.preprocess_image(image) | ||
| 133 | + | ||
| 134 | + # Encode image | ||
| 135 | + embedding = self.encode_image(image) | ||
| 136 | + | ||
| 137 | + return embedding | ||
| 138 | + | ||
| 139 | + except Exception as e: | ||
| 140 | + print(f"Error processing image from URL {url}: {str(e)}") | ||
| 141 | + return None | ||
| 142 | + | ||
| 143 | + def encode_batch( | ||
| 144 | + self, | ||
| 145 | + images: List[Union[str, Image.Image]], | ||
| 146 | + batch_size: int = 8 | ||
| 147 | + ) -> List[Optional[np.ndarray]]: | ||
| 148 | + """ | ||
| 149 | + Encode a batch of images efficiently. | ||
| 150 | + | ||
| 151 | + Args: | ||
| 152 | + images: List of image URLs or PIL Images | ||
| 153 | + batch_size: Batch size for processing | ||
| 154 | + | ||
| 155 | + Returns: | ||
| 156 | + List of embeddings (or None for failed images) | ||
| 157 | + """ | ||
| 158 | + results = [] | ||
| 159 | + | ||
| 160 | + for i in range(0, len(images), batch_size): | ||
| 161 | + batch = images[i:i + batch_size] | ||
| 162 | + batch_embeddings = [] | ||
| 163 | + | ||
| 164 | + for img in batch: | ||
| 165 | + if isinstance(img, str): | ||
| 166 | + # URL or file path | ||
| 167 | + emb = self.encode_image_from_url(img) | ||
| 168 | + elif isinstance(img, Image.Image): | ||
| 169 | + # PIL Image | ||
| 170 | + emb = self.encode_image(img) | ||
| 171 | + else: | ||
| 172 | + emb = None | ||
| 173 | + | ||
| 174 | + batch_embeddings.append(emb) | ||
| 175 | + | ||
| 176 | + results.extend(batch_embeddings) | ||
| 177 | + | ||
| 178 | + return results |
embeddings/text_encoder.py
| 1 | """ | 1 | """ |
| 2 | -Text embedding encoder using BGE-M3 model. | 2 | +Text embedding encoder using network service. |
| 3 | 3 | ||
| 4 | -Generates 1024-dimensional vectors for text using the BGE-M3 multilingual model. | 4 | +Generates embeddings via HTTP API service running on localhost:5001. |
| 5 | """ | 5 | """ |
| 6 | 6 | ||
| 7 | import sys | 7 | import sys |
| 8 | -import torch | ||
| 9 | -from sentence_transformers import SentenceTransformer | 8 | +import requests |
| 10 | import time | 9 | import time |
| 11 | import threading | 10 | import threading |
| 12 | -from modelscope import snapshot_download | ||
| 13 | -from transformers import AutoModel | ||
| 14 | -import os | ||
| 15 | import numpy as np | 11 | import numpy as np |
| 16 | -from typing import List, Union | 12 | +import logging |
| 13 | +from typing import List, Union, Dict, Any | ||
| 14 | + | ||
| 15 | +logger = logging.getLogger(__name__) | ||
| 17 | 16 | ||
| 18 | 17 | ||
| 19 | class BgeEncoder: | 18 | class BgeEncoder: |
| 20 | """ | 19 | """ |
| 21 | - Singleton text encoder using BGE-M3 model. | 20 | + Singleton text encoder using network service. |
| 22 | 21 | ||
| 23 | - Thread-safe singleton pattern ensures only one model instance exists. | 22 | + Thread-safe singleton pattern ensures only one instance exists. |
| 24 | """ | 23 | """ |
| 25 | _instance = None | 24 | _instance = None |
| 26 | _lock = threading.Lock() | 25 | _lock = threading.Lock() |
| 27 | 26 | ||
| 28 | - def __new__(cls, model_dir='Xorbits/bge-m3'): | 27 | + def __new__(cls, service_url='http://localhost:5001'): |
| 29 | with cls._lock: | 28 | with cls._lock: |
| 30 | if cls._instance is None: | 29 | if cls._instance is None: |
| 31 | cls._instance = super(BgeEncoder, cls).__new__(cls) | 30 | cls._instance = super(BgeEncoder, cls).__new__(cls) |
| 32 | - print(f"[BgeEncoder] Creating a new instance with model directory: {model_dir}") | ||
| 33 | - cls._instance.model = SentenceTransformer(snapshot_download(model_dir)) | ||
| 34 | - print("[BgeEncoder] New instance has been created") | 31 | + logger.info(f"Creating BgeEncoder instance with service URL: {service_url}") |
| 32 | + cls._instance.service_url = service_url | ||
| 33 | + cls._instance.endpoint = f"{service_url}/embedding/generate_embeddings" | ||
| 35 | return cls._instance | 34 | return cls._instance |
| 36 | 35 | ||
| 36 | + def _call_service(self, request_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | ||
| 37 | + """ | ||
| 38 | + Call the embedding service API. | ||
| 39 | + | ||
| 40 | + Args: | ||
| 41 | + request_data: List of dictionaries with id and text fields | ||
| 42 | + | ||
| 43 | + Returns: | ||
| 44 | + List of dictionaries with id and embedding fields | ||
| 45 | + """ | ||
| 46 | + try: | ||
| 47 | + response = requests.post( | ||
| 48 | + self.endpoint, | ||
| 49 | + json=request_data, | ||
| 50 | + timeout=60 | ||
| 51 | + ) | ||
| 52 | + response.raise_for_status() | ||
| 53 | + return response.json() | ||
| 54 | + except requests.exceptions.RequestException as e: | ||
| 55 | + logger.error(f"BgeEncoder service request failed: {e}", exc_info=True) | ||
| 56 | + raise | ||
| 57 | + | ||
| 37 | def encode( | 58 | def encode( |
| 38 | self, | 59 | self, |
| 39 | sentences: Union[str, List[str]], | 60 | sentences: Union[str, List[str]], |
| 40 | normalize_embeddings: bool = True, | 61 | normalize_embeddings: bool = True, |
| 41 | - device: str = 'cuda', | 62 | + device: str = 'cpu', |
| 42 | batch_size: int = 32 | 63 | batch_size: int = 32 |
| 43 | ) -> np.ndarray: | 64 | ) -> np.ndarray: |
| 44 | """ | 65 | """ |
| 45 | - Encode text into embeddings. | 66 | + Encode text into embeddings via network service. |
| 46 | 67 | ||
| 47 | Args: | 68 | Args: |
| 48 | sentences: Single string or list of strings to encode | 69 | sentences: Single string or list of strings to encode |
| 49 | - normalize_embeddings: Whether to normalize embeddings | ||
| 50 | - device: Device to use ('cuda' or 'cpu') | ||
| 51 | - batch_size: Batch size for encoding | 70 | + normalize_embeddings: Whether to normalize embeddings (ignored for service) |
| 71 | + device: Device parameter ignored for service compatibility | ||
| 72 | + batch_size: Batch size for processing (used for service requests) | ||
| 52 | 73 | ||
| 53 | Returns: | 74 | Returns: |
| 54 | numpy array of shape (n, 1024) containing embeddings | 75 | numpy array of shape (n, 1024) containing embeddings |
| 55 | """ | 76 | """ |
| 56 | - # Move model to specified device | ||
| 57 | - if device == 'gpu': | ||
| 58 | - device = 'cuda' | 77 | + # Convert single string to list |
| 78 | + if isinstance(sentences, str): | ||
| 79 | + sentences = [sentences] | ||
| 59 | 80 | ||
| 60 | - # Try requested device, fallback to CPU if CUDA fails | ||
| 61 | - try: | ||
| 62 | - if device == 'cuda': | ||
| 63 | - # Check CUDA memory first | ||
| 64 | - import torch | ||
| 65 | - if torch.cuda.is_available(): | ||
| 66 | - # Check if we have enough memory (at least 1GB free) | ||
| 67 | - free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated() | ||
| 68 | - if free_memory < 1024 * 1024 * 1024: # 1GB | ||
| 69 | - print(f"[BgeEncoder] CUDA memory insufficient ({free_memory/1024/1024:.1f}MB free), falling back to CPU") | ||
| 70 | - device = 'cpu' | ||
| 71 | - else: | ||
| 72 | - print(f"[BgeEncoder] CUDA not available, using CPU") | ||
| 73 | - device = 'cpu' | 81 | + # Prepare request data |
| 82 | + request_data = [] | ||
| 83 | + for i, text in enumerate(sentences): | ||
| 84 | + request_item = { | ||
| 85 | + "id": str(i), | ||
| 86 | + "name_zh": text | ||
| 87 | + } | ||
| 74 | 88 | ||
| 75 | - self.model = self.model.to(device) | 89 | + # Add English and Russian fields as empty for now |
| 90 | + # Could be enhanced with language detection in the future | ||
| 91 | + request_item["name_en"] = None | ||
| 92 | + request_item["name_ru"] = None | ||
| 76 | 93 | ||
| 77 | - embeddings = self.model.encode( | ||
| 78 | - sentences, | ||
| 79 | - normalize_embeddings=normalize_embeddings, | ||
| 80 | - device=device, | ||
| 81 | - show_progress_bar=False, | ||
| 82 | - batch_size=batch_size | ||
| 83 | - ) | 94 | + request_data.append(request_item) |
| 95 | + | ||
| 96 | + try: | ||
| 97 | + # Call service | ||
| 98 | + response_data = self._call_service(request_data) | ||
| 99 | + | ||
| 100 | + # Process response | ||
| 101 | + embeddings = [] | ||
| 102 | + for i, text in enumerate(sentences): | ||
| 103 | + # Find corresponding response by ID | ||
| 104 | + response_item = None | ||
| 105 | + for item in response_data: | ||
| 106 | + if str(item.get("id")) == str(i): | ||
| 107 | + response_item = item | ||
| 108 | + break | ||
| 109 | + | ||
| 110 | + if response_item: | ||
| 111 | + # Try Chinese embedding first, then English, then Russian | ||
| 112 | + embedding = None | ||
| 113 | + for lang in ["embedding_zh", "embedding_en", "embedding_ru"]: | ||
| 114 | + if lang in response_item and response_item[lang] is not None: | ||
| 115 | + embedding = response_item[lang] | ||
| 116 | + break | ||
| 117 | + | ||
| 118 | + if embedding is not None: | ||
| 119 | + embeddings.append(embedding) | ||
| 120 | + else: | ||
| 121 | + logger.warning(f"No embedding found for text {i}: {text[:50]}...") | ||
| 122 | + embeddings.append([0.0] * 1024) | ||
| 123 | + else: | ||
| 124 | + logger.warning(f"No response found for text {i}") | ||
| 125 | + embeddings.append([0.0] * 1024) | ||
| 84 | 126 | ||
| 85 | - return embeddings | 127 | + return np.array(embeddings, dtype=np.float32) |
| 86 | 128 | ||
| 87 | except Exception as e: | 129 | except Exception as e: |
| 88 | - print(f"[BgeEncoder] Device {device} failed: {e}") | ||
| 89 | - if device != 'cpu': | ||
| 90 | - print(f"[BgeEncoder] Falling back to CPU") | ||
| 91 | - try: | ||
| 92 | - self.model = self.model.to('cpu') | ||
| 93 | - embeddings = self.model.encode( | ||
| 94 | - sentences, | ||
| 95 | - normalize_embeddings=normalize_embeddings, | ||
| 96 | - device='cpu', | ||
| 97 | - show_progress_bar=False, | ||
| 98 | - batch_size=batch_size | ||
| 99 | - ) | ||
| 100 | - return embeddings | ||
| 101 | - except Exception as e2: | ||
| 102 | - print(f"[BgeEncoder] CPU also failed: {e2}") | ||
| 103 | - raise | ||
| 104 | - else: | ||
| 105 | - raise | 130 | + logger.error(f"Failed to encode texts: {e}", exc_info=True) |
| 131 | + # Return zero embeddings as fallback | ||
| 132 | + return np.zeros((len(sentences), 1024), dtype=np.float32) | ||
| 106 | 133 | ||
| 107 | def encode_batch( | 134 | def encode_batch( |
| 108 | self, | 135 | self, |
| 109 | texts: List[str], | 136 | texts: List[str], |
| 110 | batch_size: int = 32, | 137 | batch_size: int = 32, |
| 111 | - device: str = 'cuda' | 138 | + device: str = 'cpu' |
| 112 | ) -> np.ndarray: | 139 | ) -> np.ndarray: |
| 113 | """ | 140 | """ |
| 114 | - Encode a batch of texts efficiently. | 141 | + Encode a batch of texts efficiently via network service. |
| 115 | 142 | ||
| 116 | Args: | 143 | Args: |
| 117 | texts: List of texts to encode | 144 | texts: List of texts to encode |
| 118 | batch_size: Batch size for processing | 145 | batch_size: Batch size for processing |
| 119 | - device: Device to use | 146 | + device: Device parameter ignored for service compatibility |
| 120 | 147 | ||
| 121 | Returns: | 148 | Returns: |
| 122 | numpy array of embeddings | 149 | numpy array of embeddings |
| @@ -0,0 +1,124 @@ | @@ -0,0 +1,124 @@ | ||
| 1 | +""" | ||
| 2 | +Text embedding encoder using BGE-M3 model. | ||
| 3 | + | ||
| 4 | +Generates 1024-dimensional vectors for text using the BGE-M3 multilingual model. | ||
| 5 | +""" | ||
| 6 | + | ||
| 7 | +import sys | ||
| 8 | +import torch | ||
| 9 | +from sentence_transformers import SentenceTransformer | ||
| 10 | +import time | ||
| 11 | +import threading | ||
| 12 | +from modelscope import snapshot_download | ||
| 13 | +from transformers import AutoModel | ||
| 14 | +import os | ||
| 15 | +import numpy as np | ||
| 16 | +from typing import List, Union | ||
| 17 | + | ||
| 18 | + | ||
| 19 | +class BgeEncoder: | ||
| 20 | + """ | ||
| 21 | + Singleton text encoder using BGE-M3 model. | ||
| 22 | + | ||
| 23 | + Thread-safe singleton pattern ensures only one model instance exists. | ||
| 24 | + """ | ||
| 25 | + _instance = None | ||
| 26 | + _lock = threading.Lock() | ||
| 27 | + | ||
| 28 | + def __new__(cls, model_dir='Xorbits/bge-m3'): | ||
| 29 | + with cls._lock: | ||
| 30 | + if cls._instance is None: | ||
| 31 | + cls._instance = super(BgeEncoder, cls).__new__(cls) | ||
| 32 | + print(f"[BgeEncoder] Creating a new instance with model directory: {model_dir}") | ||
| 33 | + cls._instance.model = SentenceTransformer(snapshot_download(model_dir)) | ||
| 34 | + print("[BgeEncoder] New instance has been created") | ||
| 35 | + return cls._instance | ||
| 36 | + | ||
| 37 | + def encode( | ||
| 38 | + self, | ||
| 39 | + sentences: Union[str, List[str]], | ||
| 40 | + normalize_embeddings: bool = True, | ||
| 41 | + device: str = 'cuda', | ||
| 42 | + batch_size: int = 32 | ||
| 43 | + ) -> np.ndarray: | ||
| 44 | + """ | ||
| 45 | + Encode text into embeddings. | ||
| 46 | + | ||
| 47 | + Args: | ||
| 48 | + sentences: Single string or list of strings to encode | ||
| 49 | + normalize_embeddings: Whether to normalize embeddings | ||
| 50 | + device: Device to use ('cuda' or 'cpu') | ||
| 51 | + batch_size: Batch size for encoding | ||
| 52 | + | ||
| 53 | + Returns: | ||
| 54 | + numpy array of shape (n, 1024) containing embeddings | ||
| 55 | + """ | ||
| 56 | + # Move model to specified device | ||
| 57 | + if device == 'gpu': | ||
| 58 | + device = 'cuda' | ||
| 59 | + | ||
| 60 | + # Try requested device, fallback to CPU if CUDA fails | ||
| 61 | + try: | ||
| 62 | + if device == 'cuda': | ||
| 63 | + # Check CUDA memory first | ||
| 64 | + import torch | ||
| 65 | + if torch.cuda.is_available(): | ||
| 66 | + # Check if we have enough memory (at least 1GB free) | ||
| 67 | + free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated() | ||
| 68 | + if free_memory < 1024 * 1024 * 1024: # 1GB | ||
| 69 | + print(f"[BgeEncoder] CUDA memory insufficient ({free_memory/1024/1024:.1f}MB free), falling back to CPU") | ||
| 70 | + device = 'cpu' | ||
| 71 | + else: | ||
| 72 | + print(f"[BgeEncoder] CUDA not available, using CPU") | ||
| 73 | + device = 'cpu' | ||
| 74 | + | ||
| 75 | + self.model = self.model.to(device) | ||
| 76 | + | ||
| 77 | + embeddings = self.model.encode( | ||
| 78 | + sentences, | ||
| 79 | + normalize_embeddings=normalize_embeddings, | ||
| 80 | + device=device, | ||
| 81 | + show_progress_bar=False, | ||
| 82 | + batch_size=batch_size | ||
| 83 | + ) | ||
| 84 | + | ||
| 85 | + return embeddings | ||
| 86 | + | ||
| 87 | + except Exception as e: | ||
| 88 | + print(f"[BgeEncoder] Device {device} failed: {e}") | ||
| 89 | + if device != 'cpu': | ||
| 90 | + print(f"[BgeEncoder] Falling back to CPU") | ||
| 91 | + try: | ||
| 92 | + self.model = self.model.to('cpu') | ||
| 93 | + embeddings = self.model.encode( | ||
| 94 | + sentences, | ||
| 95 | + normalize_embeddings=normalize_embeddings, | ||
| 96 | + device='cpu', | ||
| 97 | + show_progress_bar=False, | ||
| 98 | + batch_size=batch_size | ||
| 99 | + ) | ||
| 100 | + return embeddings | ||
| 101 | + except Exception as e2: | ||
| 102 | + print(f"[BgeEncoder] CPU also failed: {e2}") | ||
| 103 | + raise | ||
| 104 | + else: | ||
| 105 | + raise | ||
| 106 | + | ||
| 107 | + def encode_batch( | ||
| 108 | + self, | ||
| 109 | + texts: List[str], | ||
| 110 | + batch_size: int = 32, | ||
| 111 | + device: str = 'cuda' | ||
| 112 | + ) -> np.ndarray: | ||
| 113 | + """ | ||
| 114 | + Encode a batch of texts efficiently. | ||
| 115 | + | ||
| 116 | + Args: | ||
| 117 | + texts: List of texts to encode | ||
| 118 | + batch_size: Batch size for processing | ||
| 119 | + device: Device to use | ||
| 120 | + | ||
| 121 | + Returns: | ||
| 122 | + numpy array of embeddings | ||
| 123 | + """ | ||
| 124 | + return self.encode(texts, batch_size=batch_size, device=device) |
indexer/bulk_indexer.py
| @@ -7,6 +7,7 @@ Handles batch indexing of documents with progress tracking and error handling. | @@ -7,6 +7,7 @@ Handles batch indexing of documents with progress tracking and error handling. | ||
| 7 | from typing import List, Dict, Any, Optional | 7 | from typing import List, Dict, Any, Optional |
| 8 | from elasticsearch.helpers import bulk, BulkIndexError | 8 | from elasticsearch.helpers import bulk, BulkIndexError |
| 9 | from utils.es_client import ESClient | 9 | from utils.es_client import ESClient |
| 10 | +from indexer import MappingGenerator | ||
| 10 | import time | 11 | import time |
| 11 | 12 | ||
| 12 | 13 | ||
| @@ -232,8 +233,6 @@ class IndexingPipeline: | @@ -232,8 +233,6 @@ class IndexingPipeline: | ||
| 232 | Returns: | 233 | Returns: |
| 233 | Indexing statistics | 234 | Indexing statistics |
| 234 | """ | 235 | """ |
| 235 | - from indexer.mapping_generator import MappingGenerator | ||
| 236 | - | ||
| 237 | # Generate and create index | 236 | # Generate and create index |
| 238 | mapping_gen = MappingGenerator(self.config) | 237 | mapping_gen = MappingGenerator(self.config) |
| 239 | mapping = mapping_gen.generate_mapping() | 238 | mapping = mapping_gen.generate_mapping() |
indexer/mapping_generator.py
| @@ -5,6 +5,8 @@ Generates Elasticsearch index mappings from search configuration. | @@ -5,6 +5,8 @@ Generates Elasticsearch index mappings from search configuration. | ||
| 5 | """ | 5 | """ |
| 6 | 6 | ||
| 7 | from typing import Dict, Any | 7 | from typing import Dict, Any |
| 8 | +import logging | ||
| 9 | + | ||
| 8 | from config import ( | 10 | from config import ( |
| 9 | SearchConfig, | 11 | SearchConfig, |
| 10 | FieldConfig, | 12 | FieldConfig, |
| @@ -13,6 +15,8 @@ from config import ( | @@ -13,6 +15,8 @@ from config import ( | ||
| 13 | get_default_similarity | 15 | get_default_similarity |
| 14 | ) | 16 | ) |
| 15 | 17 | ||
| 18 | +logger = logging.getLogger(__name__) | ||
| 19 | + | ||
| 16 | 20 | ||
| 17 | class MappingGenerator: | 21 | class MappingGenerator: |
| 18 | """Generates Elasticsearch mapping from search configuration.""" | 22 | """Generates Elasticsearch mapping from search configuration.""" |
| @@ -85,31 +89,18 @@ class MappingGenerator: | @@ -85,31 +89,18 @@ class MappingGenerator: | ||
| 85 | Get the primary text embedding field name. | 89 | Get the primary text embedding field name. |
| 86 | 90 | ||
| 87 | Returns: | 91 | Returns: |
| 88 | - Field name or empty string if not found | 92 | + Field name or empty string if not configured |
| 89 | """ | 93 | """ |
| 90 | - # Look for name_embedding or first text_embedding field | ||
| 91 | - for field in self.config.fields: | ||
| 92 | - if field.name == "name_embedding": | ||
| 93 | - return field.name | ||
| 94 | - | ||
| 95 | - # Otherwise return first text embedding field | ||
| 96 | - for field in self.config.fields: | ||
| 97 | - if "embedding" in field.name and "image" not in field.name: | ||
| 98 | - return field.name | ||
| 99 | - | ||
| 100 | - return "" | 94 | + return self.config.query_config.text_embedding_field or "" |
| 101 | 95 | ||
| 102 | def get_image_embedding_field(self) -> str: | 96 | def get_image_embedding_field(self) -> str: |
| 103 | """ | 97 | """ |
| 104 | Get the primary image embedding field name. | 98 | Get the primary image embedding field name. |
| 105 | 99 | ||
| 106 | Returns: | 100 | Returns: |
| 107 | - Field name or empty string if not found | 101 | + Field name or empty string if not configured |
| 108 | """ | 102 | """ |
| 109 | - for field in self.config.fields: | ||
| 110 | - if "image" in field.name and "embedding" in field.name: | ||
| 111 | - return field.name | ||
| 112 | - return "" | 103 | + return self.config.query_config.image_embedding_field or "" |
| 113 | 104 | ||
| 114 | def get_field_by_name(self, field_name: str) -> FieldConfig: | 105 | def get_field_by_name(self, field_name: str) -> FieldConfig: |
| 115 | """ | 106 | """ |
| @@ -162,11 +153,11 @@ def create_index_if_not_exists(es_client, index_name: str, mapping: Dict[str, An | @@ -162,11 +153,11 @@ def create_index_if_not_exists(es_client, index_name: str, mapping: Dict[str, An | ||
| 162 | True if index was created, False if it already exists | 153 | True if index was created, False if it already exists |
| 163 | """ | 154 | """ |
| 164 | if es_client.indices.exists(index=index_name): | 155 | if es_client.indices.exists(index=index_name): |
| 165 | - print(f"Index '{index_name}' already exists") | 156 | + logger.info(f"Index '{index_name}' already exists") |
| 166 | return False | 157 | return False |
| 167 | 158 | ||
| 168 | es_client.indices.create(index=index_name, body=mapping) | 159 | es_client.indices.create(index=index_name, body=mapping) |
| 169 | - print(f"Index '{index_name}' created successfully") | 160 | + logger.info(f"Index '{index_name}' created successfully") |
| 170 | return True | 161 | return True |
| 171 | 162 | ||
| 172 | 163 | ||
| @@ -182,11 +173,11 @@ def delete_index_if_exists(es_client, index_name: str) -> bool: | @@ -182,11 +173,11 @@ def delete_index_if_exists(es_client, index_name: str) -> bool: | ||
| 182 | True if index was deleted, False if it didn't exist | 173 | True if index was deleted, False if it didn't exist |
| 183 | """ | 174 | """ |
| 184 | if not es_client.indices.exists(index=index_name): | 175 | if not es_client.indices.exists(index=index_name): |
| 185 | - print(f"Index '{index_name}' does not exist") | 176 | + logger.warning(f"Index '{index_name}' does not exist") |
| 186 | return False | 177 | return False |
| 187 | 178 | ||
| 188 | es_client.indices.delete(index=index_name) | 179 | es_client.indices.delete(index=index_name) |
| 189 | - print(f"Index '{index_name}' deleted successfully") | 180 | + logger.info(f"Index '{index_name}' deleted successfully") |
| 190 | return True | 181 | return True |
| 191 | 182 | ||
| 192 | 183 | ||
| @@ -203,10 +194,10 @@ def update_mapping(es_client, index_name: str, new_fields: Dict[str, Any]) -> bo | @@ -203,10 +194,10 @@ def update_mapping(es_client, index_name: str, new_fields: Dict[str, Any]) -> bo | ||
| 203 | True if successful | 194 | True if successful |
| 204 | """ | 195 | """ |
| 205 | if not es_client.indices.exists(index=index_name): | 196 | if not es_client.indices.exists(index=index_name): |
| 206 | - print(f"Index '{index_name}' does not exist") | 197 | + logger.error(f"Index '{index_name}' does not exist") |
| 207 | return False | 198 | return False |
| 208 | 199 | ||
| 209 | mapping = {"properties": new_fields} | 200 | mapping = {"properties": new_fields} |
| 210 | es_client.indices.put_mapping(index=index_name, body=mapping) | 201 | es_client.indices.put_mapping(index=index_name, body=mapping) |
| 211 | - print(f"Mapping updated for index '{index_name}'") | 202 | + logger.info(f"Mapping updated for index '{index_name}'") |
| 212 | return True | 203 | return True |
query/query_parser.py
| @@ -6,6 +6,7 @@ Handles query rewriting, translation, and embedding generation. | @@ -6,6 +6,7 @@ Handles query rewriting, translation, and embedding generation. | ||
| 6 | 6 | ||
| 7 | from typing import Dict, List, Optional, Any | 7 | from typing import Dict, List, Optional, Any |
| 8 | import numpy as np | 8 | import numpy as np |
| 9 | +import logging | ||
| 9 | 10 | ||
| 10 | from config import SearchConfig, QueryConfig | 11 | from config import SearchConfig, QueryConfig |
| 11 | from embeddings import BgeEncoder | 12 | from embeddings import BgeEncoder |
| @@ -13,6 +14,8 @@ from .language_detector import LanguageDetector | @@ -13,6 +14,8 @@ from .language_detector import LanguageDetector | ||
| 13 | from .translator import Translator | 14 | from .translator import Translator |
| 14 | from .query_rewriter import QueryRewriter, QueryNormalizer | 15 | from .query_rewriter import QueryRewriter, QueryNormalizer |
| 15 | 16 | ||
| 17 | +logger = logging.getLogger(__name__) | ||
| 18 | + | ||
| 16 | 19 | ||
| 17 | class ParsedQuery: | 20 | class ParsedQuery: |
| 18 | """Container for parsed query results.""" | 21 | """Container for parsed query results.""" |
| @@ -87,7 +90,7 @@ class QueryParser: | @@ -87,7 +90,7 @@ class QueryParser: | ||
| 87 | def text_encoder(self) -> BgeEncoder: | 90 | def text_encoder(self) -> BgeEncoder: |
| 88 | """Lazy load text encoder.""" | 91 | """Lazy load text encoder.""" |
| 89 | if self._text_encoder is None and self.query_config.enable_text_embedding: | 92 | if self._text_encoder is None and self.query_config.enable_text_embedding: |
| 90 | - print("[QueryParser] Initializing text encoder...") | 93 | + logger.info("Initializing text encoder (lazy load)...") |
| 91 | self._text_encoder = BgeEncoder() | 94 | self._text_encoder = BgeEncoder() |
| 92 | return self._text_encoder | 95 | return self._text_encoder |
| 93 | 96 | ||
| @@ -95,7 +98,7 @@ class QueryParser: | @@ -95,7 +98,7 @@ class QueryParser: | ||
| 95 | def translator(self) -> Translator: | 98 | def translator(self) -> Translator: |
| 96 | """Lazy load translator.""" | 99 | """Lazy load translator.""" |
| 97 | if self._translator is None and self.query_config.enable_translation: | 100 | if self._translator is None and self.query_config.enable_translation: |
| 98 | - print("[QueryParser] Initializing translator...") | 101 | + logger.info("Initializing translator (lazy load)...") |
| 99 | self._translator = Translator( | 102 | self._translator = Translator( |
| 100 | api_key=self.query_config.translation_api_key, | 103 | api_key=self.query_config.translation_api_key, |
| 101 | use_cache=True, | 104 | use_cache=True, |
| @@ -124,18 +127,17 @@ class QueryParser: | @@ -124,18 +127,17 @@ class QueryParser: | ||
| 124 | extra={'reqid': context.reqid, 'uid': context.uid} | 127 | extra={'reqid': context.reqid, 'uid': context.uid} |
| 125 | ) | 128 | ) |
| 126 | 129 | ||
| 127 | - # Use print statements for backward compatibility if no context | ||
| 128 | def log_info(msg): | 130 | def log_info(msg): |
| 129 | - if logger: | ||
| 130 | - logger.info(msg, extra={'reqid': context.reqid, 'uid': context.uid}) | 131 | + if context and hasattr(context, 'logger'): |
| 132 | + context.logger.info(msg, extra={'reqid': context.reqid, 'uid': context.uid}) | ||
| 131 | else: | 133 | else: |
| 132 | - print(f"[QueryParser] {msg}") | 134 | + logger.info(msg) |
| 133 | 135 | ||
| 134 | def log_debug(msg): | 136 | def log_debug(msg): |
| 135 | - if logger: | ||
| 136 | - logger.debug(msg, extra={'reqid': context.reqid, 'uid': context.uid}) | 137 | + if context and hasattr(context, 'logger'): |
| 138 | + context.logger.debug(msg, extra={'reqid': context.reqid, 'uid': context.uid}) | ||
| 137 | else: | 139 | else: |
| 138 | - print(f"[QueryParser] {msg}") | 140 | + logger.debug(msg) |
| 139 | 141 | ||
| 140 | # Stage 1: Normalize | 142 | # Stage 1: Normalize |
| 141 | normalized = self.normalizer.normalize(query) | 143 | normalized = self.normalizer.normalize(query) |
| @@ -246,15 +248,18 @@ class QueryParser: | @@ -246,15 +248,18 @@ class QueryParser: | ||
| 246 | domain=domain | 248 | domain=domain |
| 247 | ) | 249 | ) |
| 248 | 250 | ||
| 249 | - if logger: | ||
| 250 | - logger.info( | 251 | + if context and hasattr(context, 'logger'): |
| 252 | + context.logger.info( | ||
| 251 | f"查询解析完成 | 原查询: '{query}' | 最终查询: '{rewritten or query_text}' | " | 253 | f"查询解析完成 | 原查询: '{query}' | 最终查询: '{rewritten or query_text}' | " |
| 252 | f"语言: {detected_lang} | 域: {domain} | " | 254 | f"语言: {detected_lang} | 域: {domain} | " |
| 253 | f"翻译数量: {len(translations)} | 向量: {'是' if query_vector is not None else '否'}", | 255 | f"翻译数量: {len(translations)} | 向量: {'是' if query_vector is not None else '否'}", |
| 254 | extra={'reqid': context.reqid, 'uid': context.uid} | 256 | extra={'reqid': context.reqid, 'uid': context.uid} |
| 255 | ) | 257 | ) |
| 256 | else: | 258 | else: |
| 257 | - print(f"[QueryParser] Parsing complete") | 259 | + logger.info( |
| 260 | + f"查询解析完成 | 原查询: '{query}' | 最终查询: '{rewritten or query_text}' | " | ||
| 261 | + f"语言: {detected_lang} | 域: {domain}" | ||
| 262 | + ) | ||
| 258 | 263 | ||
| 259 | return result | 264 | return result |
| 260 | 265 |
query/translator.py
| @@ -8,6 +8,12 @@ import requests | @@ -8,6 +8,12 @@ import requests | ||
| 8 | from typing import Dict, List, Optional | 8 | from typing import Dict, List, Optional |
| 9 | from utils.cache import DictCache | 9 | from utils.cache import DictCache |
| 10 | 10 | ||
| 11 | +# Try to import DEEPL_AUTH_KEY, but allow import to fail | ||
| 12 | +try: | ||
| 13 | + from config.env_config import DEEPL_AUTH_KEY | ||
| 14 | +except ImportError: | ||
| 15 | + DEEPL_AUTH_KEY = None | ||
| 16 | + | ||
| 11 | 17 | ||
| 12 | class Translator: | 18 | class Translator: |
| 13 | """Multi-language translator using DeepL API.""" | 19 | """Multi-language translator using DeepL API.""" |
| @@ -47,12 +53,8 @@ class Translator: | @@ -47,12 +53,8 @@ class Translator: | ||
| 47 | translation_context: Context hint for translation (e.g., "e-commerce", "product search") | 53 | translation_context: Context hint for translation (e.g., "e-commerce", "product search") |
| 48 | """ | 54 | """ |
| 49 | # Get API key from config if not provided | 55 | # Get API key from config if not provided |
| 50 | - if api_key is None: | ||
| 51 | - try: | ||
| 52 | - from config.env_config import get_deepl_key | ||
| 53 | - api_key = get_deepl_key() | ||
| 54 | - except ImportError: | ||
| 55 | - pass | 56 | + if api_key is None and DEEPL_AUTH_KEY: |
| 57 | + api_key = DEEPL_AUTH_KEY | ||
| 56 | 58 | ||
| 57 | self.api_key = api_key | 59 | self.api_key = api_key |
| 58 | self.timeout = timeout | 60 | self.timeout = timeout |
scripts/ingest_shoplazza.py
| @@ -78,13 +78,12 @@ def main(): | @@ -78,13 +78,12 @@ def main(): | ||
| 78 | return 1 | 78 | return 1 |
| 79 | 79 | ||
| 80 | # Connect to Elasticsearch (use unified config loading) | 80 | # Connect to Elasticsearch (use unified config loading) |
| 81 | - from config.env_config import get_es_config | ||
| 82 | - es_config = get_es_config() | 81 | + from config.env_config import ES_CONFIG |
| 83 | 82 | ||
| 84 | # Use provided es_host or fallback to config | 83 | # Use provided es_host or fallback to config |
| 85 | - es_host = args.es_host or es_config.get('host', 'http://localhost:9200') | ||
| 86 | - es_username = es_config.get('username') | ||
| 87 | - es_password = es_config.get('password') | 84 | + es_host = args.es_host or ES_CONFIG.get('host', 'http://localhost:9200') |
| 85 | + es_username = ES_CONFIG.get('username') | ||
| 86 | + es_password = ES_CONFIG.get('password') | ||
| 88 | 87 | ||
| 89 | print(f"Connecting to Elasticsearch: {es_host}") | 88 | print(f"Connecting to Elasticsearch: {es_host}") |
| 90 | if es_username and es_password: | 89 | if es_username and es_password: |
search/multilang_query_builder.py
| @@ -8,11 +8,15 @@ maintaining a unified external interface. | @@ -8,11 +8,15 @@ maintaining a unified external interface. | ||
| 8 | 8 | ||
| 9 | from typing import Dict, Any, List, Optional | 9 | from typing import Dict, Any, List, Optional |
| 10 | import numpy as np | 10 | import numpy as np |
| 11 | +import logging | ||
| 12 | +import re | ||
| 11 | 13 | ||
| 12 | from config import SearchConfig, IndexConfig | 14 | from config import SearchConfig, IndexConfig |
| 13 | from query import ParsedQuery | 15 | from query import ParsedQuery |
| 14 | from .es_query_builder import ESQueryBuilder | 16 | from .es_query_builder import ESQueryBuilder |
| 15 | 17 | ||
| 18 | +logger = logging.getLogger(__name__) | ||
| 19 | + | ||
| 16 | 20 | ||
| 17 | class MultiLanguageQueryBuilder(ESQueryBuilder): | 21 | class MultiLanguageQueryBuilder(ESQueryBuilder): |
| 18 | """ | 22 | """ |
| @@ -139,20 +143,18 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): | @@ -139,20 +143,18 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): | ||
| 139 | min_score=min_score | 143 | min_score=min_score |
| 140 | ) | 144 | ) |
| 141 | 145 | ||
| 142 | - print(f"[MultiLangQueryBuilder] Building query for domain: {domain}") | ||
| 143 | - print(f"[MultiLangQueryBuilder] Detected language: {parsed_query.detected_language}") | ||
| 144 | - print(f"[MultiLangQueryBuilder] Available translations: {list(parsed_query.translations.keys())}") | 146 | + logger.debug(f"Building query for domain: {domain}, language: {parsed_query.detected_language}") |
| 145 | 147 | ||
| 146 | # Build query clause with multi-language support | 148 | # Build query clause with multi-language support |
| 147 | if query_node and isinstance(query_node, tuple) and len(query_node) > 0: | 149 | if query_node and isinstance(query_node, tuple) and len(query_node) > 0: |
| 148 | # Handle boolean query from tuple (AST, score) | 150 | # Handle boolean query from tuple (AST, score) |
| 149 | ast_node = query_node[0] | 151 | ast_node = query_node[0] |
| 150 | query_clause = self._build_boolean_query_from_tuple(ast_node) | 152 | query_clause = self._build_boolean_query_from_tuple(ast_node) |
| 151 | - print(f"[MultiLangQueryBuilder] Using boolean query: {query_clause}") | 153 | + logger.debug(f"Using boolean query") |
| 152 | elif query_node and hasattr(query_node, 'operator') and query_node.operator != 'TERM': | 154 | elif query_node and hasattr(query_node, 'operator') and query_node.operator != 'TERM': |
| 153 | # Handle boolean query using base class method | 155 | # Handle boolean query using base class method |
| 154 | query_clause = self._build_boolean_query(query_node) | 156 | query_clause = self._build_boolean_query(query_node) |
| 155 | - print(f"[MultiLangQueryBuilder] Using boolean query: {query_clause}") | 157 | + logger.debug(f"Using boolean query") |
| 156 | else: | 158 | else: |
| 157 | # Handle text query with multi-language support | 159 | # Handle text query with multi-language support |
| 158 | query_clause = self._build_multilang_text_query(parsed_query, domain_config) | 160 | query_clause = self._build_multilang_text_query(parsed_query, domain_config) |
| @@ -171,7 +173,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): | @@ -171,7 +173,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): | ||
| 171 | } | 173 | } |
| 172 | } | 174 | } |
| 173 | inner_bool_should.append(knn_query) | 175 | inner_bool_should.append(knn_query) |
| 174 | - print(f"[MultiLangQueryBuilder] KNN query added: field={self.text_embedding_field}, k={knn_k}, num_candidates={knn_num_candidates}") | 176 | + logger.info(f"KNN query added: field={self.text_embedding_field}, k={knn_k}") |
| 175 | else: | 177 | else: |
| 176 | # Debug why KNN is not added | 178 | # Debug why KNN is not added |
| 177 | reasons = [] | 179 | reasons = [] |
| @@ -181,7 +183,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): | @@ -181,7 +183,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): | ||
| 181 | reasons.append("query_vector is None") | 183 | reasons.append("query_vector is None") |
| 182 | if not self.text_embedding_field: | 184 | if not self.text_embedding_field: |
| 183 | reasons.append(f"text_embedding_field is not set (current: {self.text_embedding_field})") | 185 | reasons.append(f"text_embedding_field is not set (current: {self.text_embedding_field})") |
| 184 | - print(f"[MultiLangQueryBuilder] KNN query NOT added. Reasons: {', '.join(reasons) if reasons else 'unknown'}") | 186 | + logger.debug(f"KNN query NOT added. Reasons: {', '.join(reasons) if reasons else 'unknown'}") |
| 185 | 187 | ||
| 186 | # 构建内层bool结构 | 188 | # 构建内层bool结构 |
| 187 | inner_bool = { | 189 | inner_bool = { |
| @@ -342,7 +344,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): | @@ -342,7 +344,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): | ||
| 342 | "_name": f"{domain_config.name}_{detected_lang}_query" | 344 | "_name": f"{domain_config.name}_{detected_lang}_query" |
| 343 | } | 345 | } |
| 344 | }) | 346 | }) |
| 345 | - print(f"[MultiLangQueryBuilder] Added query for detected language '{detected_lang}' on fields: {target_fields}") | 347 | + logger.debug(f"Added query for detected language '{detected_lang}'") |
| 346 | 348 | ||
| 347 | # 2. Query in translated languages (only for languages in mapping) | 349 | # 2. Query in translated languages (only for languages in mapping) |
| 348 | for lang, translation in parsed_query.translations.items(): | 350 | for lang, translation in parsed_query.translations.items(): |
| @@ -361,11 +363,11 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): | @@ -361,11 +363,11 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): | ||
| 361 | "_name": f"{domain_config.name}_{lang}_translated_query" | 363 | "_name": f"{domain_config.name}_{lang}_translated_query" |
| 362 | } | 364 | } |
| 363 | }) | 365 | }) |
| 364 | - print(f"[MultiLangQueryBuilder] Added translated query for language '{lang}' on fields: {target_fields}") | 366 | + logger.debug(f"Added translated query for language '{lang}'") |
| 365 | 367 | ||
| 366 | # 3. Fallback: query all fields in mapping if no language-specific query was built | 368 | # 3. Fallback: query all fields in mapping if no language-specific query was built |
| 367 | if not should_clauses: | 369 | if not should_clauses: |
| 368 | - print(f"[MultiLangQueryBuilder] No language mapping matched, using all fields from mapping") | 370 | + logger.debug("No language mapping matched, using all fields from mapping") |
| 369 | # Use all fields from all languages in the mapping | 371 | # Use all fields from all languages in the mapping |
| 370 | all_mapped_fields = [] | 372 | all_mapped_fields = [] |
| 371 | for lang_fields in domain_config.language_field_mapping.values(): | 373 | for lang_fields in domain_config.language_field_mapping.values(): |
| @@ -445,7 +447,6 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): | @@ -445,7 +447,6 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): | ||
| 445 | operator = node[0].__name__ | 447 | operator = node[0].__name__ |
| 446 | elif str(node[0]).startswith('('): | 448 | elif str(node[0]).startswith('('): |
| 447 | # String representation of constructor call | 449 | # String representation of constructor call |
| 448 | - import re | ||
| 449 | match = re.match(r'(\w+)\(', str(node[0])) | 450 | match = re.match(r'(\w+)\(', str(node[0])) |
| 450 | if match: | 451 | if match: |
| 451 | operator = match.group(1) | 452 | operator = match.group(1) |
search/searcher.py
| @@ -6,11 +6,13 @@ Handles query parsing, boolean expressions, ranking, and result formatting. | @@ -6,11 +6,13 @@ Handles query parsing, boolean expressions, ranking, and result formatting. | ||
| 6 | 6 | ||
| 7 | from typing import Dict, Any, List, Optional, Union | 7 | from typing import Dict, Any, List, Optional, Union |
| 8 | import time | 8 | import time |
| 9 | +import logging | ||
| 9 | 10 | ||
| 10 | from config import SearchConfig | 11 | from config import SearchConfig |
| 11 | from utils.es_client import ESClient | 12 | from utils.es_client import ESClient |
| 12 | from query import QueryParser, ParsedQuery | 13 | from query import QueryParser, ParsedQuery |
| 13 | from indexer import MappingGenerator | 14 | from indexer import MappingGenerator |
| 15 | +from embeddings import CLIPImageEncoder | ||
| 14 | from .boolean_parser import BooleanParser, QueryNode | 16 | from .boolean_parser import BooleanParser, QueryNode |
| 15 | from .es_query_builder import ESQueryBuilder | 17 | from .es_query_builder import ESQueryBuilder |
| 16 | from .multilang_query_builder import MultiLanguageQueryBuilder | 18 | from .multilang_query_builder import MultiLanguageQueryBuilder |
| @@ -19,6 +21,8 @@ from context.request_context import RequestContext, RequestContextStage, create_ | @@ -19,6 +21,8 @@ from context.request_context import RequestContext, RequestContextStage, create_ | ||
| 19 | from api.models import FacetResult, FacetValue | 21 | from api.models import FacetResult, FacetValue |
| 20 | from api.result_formatter import ResultFormatter | 22 | from api.result_formatter import ResultFormatter |
| 21 | 23 | ||
| 24 | +logger = logging.getLogger(__name__) | ||
| 25 | + | ||
| 22 | 26 | ||
| 23 | class SearchResult: | 27 | class SearchResult: |
| 24 | """Container for search results (外部友好格式).""" | 28 | """Container for search results (外部友好格式).""" |
| @@ -476,7 +480,6 @@ class Searcher: | @@ -476,7 +480,6 @@ class Searcher: | ||
| 476 | raise ValueError("Image embedding field not configured") | 480 | raise ValueError("Image embedding field not configured") |
| 477 | 481 | ||
| 478 | # Generate image embedding | 482 | # Generate image embedding |
| 479 | - from embeddings import CLIPImageEncoder | ||
| 480 | image_encoder = CLIPImageEncoder() | 483 | image_encoder = CLIPImageEncoder() |
| 481 | image_vector = image_encoder.encode_image_from_url(image_url) | 484 | image_vector = image_encoder.encode_image_from_url(image_url) |
| 482 | 485 | ||
| @@ -575,7 +578,7 @@ class Searcher: | @@ -575,7 +578,7 @@ class Searcher: | ||
| 575 | ) | 578 | ) |
| 576 | return response.get('_source') | 579 | return response.get('_source') |
| 577 | except Exception as e: | 580 | except Exception as e: |
| 578 | - print(f"[Searcher] Failed to get document {doc_id}: {e}") | 581 | + logger.error(f"Failed to get document {doc_id}: {e}", exc_info=True) |
| 579 | return None | 582 | return None |
| 580 | 583 | ||
| 581 | def _standardize_facets( | 584 | def _standardize_facets( |
utils/es_client.py
| @@ -3,8 +3,18 @@ Elasticsearch client wrapper. | @@ -3,8 +3,18 @@ Elasticsearch client wrapper. | ||
| 3 | """ | 3 | """ |
| 4 | 4 | ||
| 5 | from elasticsearch import Elasticsearch | 5 | from elasticsearch import Elasticsearch |
| 6 | +from elasticsearch.helpers import bulk | ||
| 6 | from typing import Dict, Any, List, Optional | 7 | from typing import Dict, Any, List, Optional |
| 7 | import os | 8 | import os |
| 9 | +import logging | ||
| 10 | + | ||
| 11 | +# Try to import ES_CONFIG, but allow import to fail | ||
| 12 | +try: | ||
| 13 | + from config.env_config import ES_CONFIG | ||
| 14 | +except ImportError: | ||
| 15 | + ES_CONFIG = None | ||
| 16 | + | ||
| 17 | +logger = logging.getLogger(__name__) | ||
| 8 | 18 | ||
| 9 | 19 | ||
| 10 | class ESClient: | 20 | class ESClient: |
| @@ -56,7 +66,7 @@ class ESClient: | @@ -56,7 +66,7 @@ class ESClient: | ||
| 56 | try: | 66 | try: |
| 57 | return self.client.ping() | 67 | return self.client.ping() |
| 58 | except Exception as e: | 68 | except Exception as e: |
| 59 | - print(f"Failed to ping Elasticsearch: {e}") | 69 | + logger.error(f"Failed to ping Elasticsearch: {e}", exc_info=True) |
| 60 | return False | 70 | return False |
| 61 | 71 | ||
| 62 | def create_index(self, index_name: str, body: Dict[str, Any]) -> bool: | 72 | def create_index(self, index_name: str, body: Dict[str, Any]) -> bool: |
| @@ -72,12 +82,10 @@ class ESClient: | @@ -72,12 +82,10 @@ class ESClient: | ||
| 72 | """ | 82 | """ |
| 73 | try: | 83 | try: |
| 74 | self.client.indices.create(index=index_name, body=body) | 84 | self.client.indices.create(index=index_name, body=body) |
| 75 | - print(f"Index '{index_name}' created successfully") | 85 | + logger.info(f"Index '{index_name}' created successfully") |
| 76 | return True | 86 | return True |
| 77 | except Exception as e: | 87 | except Exception as e: |
| 78 | - print(f"ERROR: Failed to create index '{index_name}': {e}") | ||
| 79 | - import traceback | ||
| 80 | - traceback.print_exc() | 88 | + logger.error(f"Failed to create index '{index_name}': {e}", exc_info=True) |
| 81 | return False | 89 | return False |
| 82 | 90 | ||
| 83 | def delete_index(self, index_name: str) -> bool: | 91 | def delete_index(self, index_name: str) -> bool: |
| @@ -93,13 +101,13 @@ class ESClient: | @@ -93,13 +101,13 @@ class ESClient: | ||
| 93 | try: | 101 | try: |
| 94 | if self.client.indices.exists(index=index_name): | 102 | if self.client.indices.exists(index=index_name): |
| 95 | self.client.indices.delete(index=index_name) | 103 | self.client.indices.delete(index=index_name) |
| 96 | - print(f"Index '{index_name}' deleted successfully") | 104 | + logger.info(f"Index '{index_name}' deleted successfully") |
| 97 | return True | 105 | return True |
| 98 | else: | 106 | else: |
| 99 | - print(f"Index '{index_name}' does not exist") | 107 | + logger.warning(f"Index '{index_name}' does not exist") |
| 100 | return False | 108 | return False |
| 101 | except Exception as e: | 109 | except Exception as e: |
| 102 | - print(f"Failed to delete index '{index_name}': {e}") | 110 | + logger.error(f"Failed to delete index '{index_name}': {e}", exc_info=True) |
| 103 | return False | 111 | return False |
| 104 | 112 | ||
| 105 | def index_exists(self, index_name: str) -> bool: | 113 | def index_exists(self, index_name: str) -> bool: |
| @@ -117,8 +125,6 @@ class ESClient: | @@ -117,8 +125,6 @@ class ESClient: | ||
| 117 | Returns: | 125 | Returns: |
| 118 | Dictionary with results | 126 | Dictionary with results |
| 119 | """ | 127 | """ |
| 120 | - from elasticsearch.helpers import bulk | ||
| 121 | - | ||
| 122 | actions = [] | 128 | actions = [] |
| 123 | for doc in docs: | 129 | for doc in docs: |
| 124 | action = { | 130 | action = { |
| @@ -140,7 +146,7 @@ class ESClient: | @@ -140,7 +146,7 @@ class ESClient: | ||
| 140 | 'errors': failed | 146 | 'errors': failed |
| 141 | } | 147 | } |
| 142 | except Exception as e: | 148 | except Exception as e: |
| 143 | - print(f"Bulk indexing failed: {e}") | 149 | + logger.error(f"Bulk indexing failed: {e}", exc_info=True) |
| 144 | return { | 150 | return { |
| 145 | 'success': 0, | 151 | 'success': 0, |
| 146 | 'failed': len(docs), | 152 | 'failed': len(docs), |
| @@ -174,7 +180,7 @@ class ESClient: | @@ -174,7 +180,7 @@ class ESClient: | ||
| 174 | from_=from_ | 180 | from_=from_ |
| 175 | ) | 181 | ) |
| 176 | except Exception as e: | 182 | except Exception as e: |
| 177 | - print(f"Search failed: {e}") | 183 | + logger.error(f"Search failed: {e}", exc_info=True) |
| 178 | return { | 184 | return { |
| 179 | 'hits': { | 185 | 'hits': { |
| 180 | 'total': {'value': 0}, | 186 | 'total': {'value': 0}, |
| @@ -188,7 +194,7 @@ class ESClient: | @@ -188,7 +194,7 @@ class ESClient: | ||
| 188 | try: | 194 | try: |
| 189 | return self.client.indices.get_mapping(index=index_name) | 195 | return self.client.indices.get_mapping(index=index_name) |
| 190 | except Exception as e: | 196 | except Exception as e: |
| 191 | - print(f"Failed to get mapping for '{index_name}': {e}") | 197 | + logger.error(f"Failed to get mapping for '{index_name}': {e}", exc_info=True) |
| 192 | return {} | 198 | return {} |
| 193 | 199 | ||
| 194 | def refresh(self, index_name: str) -> bool: | 200 | def refresh(self, index_name: str) -> bool: |
| @@ -197,7 +203,7 @@ class ESClient: | @@ -197,7 +203,7 @@ class ESClient: | ||
| 197 | self.client.indices.refresh(index=index_name) | 203 | self.client.indices.refresh(index=index_name) |
| 198 | return True | 204 | return True |
| 199 | except Exception as e: | 205 | except Exception as e: |
| 200 | - print(f"Failed to refresh index '{index_name}': {e}") | 206 | + logger.error(f"Failed to refresh index '{index_name}': {e}", exc_info=True) |
| 201 | return False | 207 | return False |
| 202 | 208 | ||
| 203 | def count(self, index_name: str, body: Optional[Dict[str, Any]] = None) -> int: | 209 | def count(self, index_name: str, body: Optional[Dict[str, Any]] = None) -> int: |
| @@ -215,7 +221,7 @@ class ESClient: | @@ -215,7 +221,7 @@ class ESClient: | ||
| 215 | result = self.client.count(index=index_name, body=body) | 221 | result = self.client.count(index=index_name, body=body) |
| 216 | return result['count'] | 222 | return result['count'] |
| 217 | except Exception as e: | 223 | except Exception as e: |
| 218 | - print(f"Count failed: {e}") | 224 | + logger.error(f"Count failed: {e}", exc_info=True) |
| 219 | return 0 | 225 | return 0 |
| 220 | 226 | ||
| 221 | 227 | ||
| @@ -231,15 +237,13 @@ def get_es_client_from_env() -> ESClient: | @@ -231,15 +237,13 @@ def get_es_client_from_env() -> ESClient: | ||
| 231 | Returns: | 237 | Returns: |
| 232 | ESClient instance | 238 | ESClient instance |
| 233 | """ | 239 | """ |
| 234 | - try: | ||
| 235 | - from config.env_config import get_es_config | ||
| 236 | - es_config = get_es_config() | 240 | + if ES_CONFIG: |
| 237 | return ESClient( | 241 | return ESClient( |
| 238 | - hosts=[es_config['host']], | ||
| 239 | - username=es_config.get('username'), | ||
| 240 | - password=es_config.get('password') | 242 | + hosts=[ES_CONFIG['host']], |
| 243 | + username=ES_CONFIG.get('username'), | ||
| 244 | + password=ES_CONFIG.get('password') | ||
| 241 | ) | 245 | ) |
| 242 | - except ImportError: | 246 | + else: |
| 243 | # Fallback to env variables | 247 | # Fallback to env variables |
| 244 | return ESClient( | 248 | return ESClient( |
| 245 | hosts=[os.getenv('ES_HOST', 'http://localhost:9200')], | 249 | hosts=[os.getenv('ES_HOST', 'http://localhost:9200')], |