Commit 325eec032f3afde0f658f0b687d7132340027b68
1 parent
3bb1af6b
1. 日志、配置基础设施,使用优化
2. 向量服务不用本地预估,改用网络服务
Showing
16 changed files
with
645 additions
and
326 deletions
Show diff stats
api/app.py
| ... | ... | @@ -9,6 +9,8 @@ import os |
| 9 | 9 | import sys |
| 10 | 10 | import logging |
| 11 | 11 | import time |
| 12 | +import argparse | |
| 13 | +import uvicorn | |
| 12 | 14 | from collections import defaultdict, deque |
| 13 | 15 | from typing import Optional |
| 14 | 16 | from fastapi import FastAPI, Request, HTTPException |
| ... | ... | @@ -20,7 +22,6 @@ from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware |
| 20 | 22 | from slowapi import Limiter, _rate_limit_exceeded_handler |
| 21 | 23 | from slowapi.util import get_remote_address |
| 22 | 24 | from slowapi.errors import RateLimitExceeded |
| 23 | -import argparse | |
| 24 | 25 | |
| 25 | 26 | # Configure logging with better formatting |
| 26 | 27 | logging.basicConfig( |
| ... | ... | @@ -40,6 +41,7 @@ limiter = Limiter(key_func=get_remote_address) |
| 40 | 41 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| 41 | 42 | |
| 42 | 43 | from config import ConfigLoader, SearchConfig |
| 44 | +from config.env_config import ES_CONFIG | |
| 43 | 45 | from utils import ESClient |
| 44 | 46 | from search import Searcher |
| 45 | 47 | from query import QueryParser |
| ... | ... | @@ -60,55 +62,42 @@ def init_service(es_host: str = "http://localhost:9200"): |
| 60 | 62 | """ |
| 61 | 63 | global _config, _es_client, _searcher, _query_parser |
| 62 | 64 | |
| 63 | - print("Initializing search service (multi-tenant)") | |
| 65 | + start_time = time.time() | |
| 66 | + logger.info("Initializing search service (multi-tenant)") | |
| 64 | 67 | |
| 65 | - # Load unified configuration | |
| 68 | + # Load and validate configuration | |
| 69 | + logger.info("Loading configuration...") | |
| 66 | 70 | config_loader = ConfigLoader("config/config.yaml") |
| 67 | 71 | _config = config_loader.load_config() |
| 68 | - | |
| 69 | - # Validate configuration | |
| 70 | 72 | errors = config_loader.validate_config(_config) |
| 71 | 73 | if errors: |
| 72 | 74 | raise ValueError(f"Configuration validation failed: {errors}") |
| 75 | + logger.info(f"Configuration loaded: {_config.es_index_name}") | |
| 73 | 76 | |
| 74 | - print(f"Configuration loaded: {_config.es_index_name}") | |
| 77 | + # Get ES credentials | |
| 78 | + es_username = os.getenv('ES_USERNAME') or ES_CONFIG.get('username') | |
| 79 | + es_password = os.getenv('ES_PASSWORD') or ES_CONFIG.get('password') | |
| 75 | 80 | |
| 76 | - # Get ES credentials from environment variables or .env file | |
| 77 | - es_username = os.getenv('ES_USERNAME') | |
| 78 | - es_password = os.getenv('ES_PASSWORD') | |
| 79 | - | |
| 80 | - # Try to load from config if not in env | |
| 81 | - if not es_username or not es_password: | |
| 82 | - try: | |
| 83 | - from config.env_config import get_es_config | |
| 84 | - es_config = get_es_config() | |
| 85 | - es_username = es_username or es_config.get('username') | |
| 86 | - es_password = es_password or es_config.get('password') | |
| 87 | - except Exception: | |
| 88 | - pass | |
| 89 | - | |
| 90 | - # Initialize ES client with authentication if credentials are available | |
| 81 | + # Connect to Elasticsearch | |
| 82 | + logger.info(f"Connecting to Elasticsearch at {es_host}...") | |
| 91 | 83 | if es_username and es_password: |
| 92 | - print(f"Connecting to Elasticsearch with authentication: {es_username}") | |
| 93 | 84 | _es_client = ESClient(hosts=[es_host], username=es_username, password=es_password) |
| 94 | 85 | else: |
| 95 | - print(f"Connecting to Elasticsearch without authentication") | |
| 96 | 86 | _es_client = ESClient(hosts=[es_host]) |
| 97 | 87 | |
| 98 | 88 | if not _es_client.ping(): |
| 99 | 89 | raise ConnectionError(f"Failed to connect to Elasticsearch at {es_host}") |
| 90 | + logger.info("Elasticsearch connected") | |
| 100 | 91 | |
| 101 | - print(f"Connected to Elasticsearch: {es_host}") | |
| 102 | - | |
| 103 | - # Initialize query parser | |
| 92 | + # Initialize components | |
| 93 | + logger.info("Initializing query parser...") | |
| 104 | 94 | _query_parser = QueryParser(_config) |
| 105 | - print("Query parser initialized") | |
| 106 | - | |
| 107 | - # Initialize searcher | |
| 95 | + | |
| 96 | + logger.info("Initializing searcher...") | |
| 108 | 97 | _searcher = Searcher(_config, _es_client, _query_parser) |
| 109 | - print("Searcher initialized") | |
| 110 | - | |
| 111 | - print("Search service ready!") | |
| 98 | + | |
| 99 | + elapsed = time.time() - start_time | |
| 100 | + logger.info(f"Search service ready! (took {elapsed:.2f}s)") | |
| 112 | 101 | |
| 113 | 102 | |
| 114 | 103 | def get_config() -> SearchConfig: |
| ... | ... | @@ -305,8 +294,6 @@ else: |
| 305 | 294 | |
| 306 | 295 | |
| 307 | 296 | if __name__ == "__main__": |
| 308 | - import uvicorn | |
| 309 | - | |
| 310 | 297 | parser = argparse.ArgumentParser(description='Start search API service (multi-tenant)') |
| 311 | 298 | parser.add_argument('--host', default='0.0.0.0', help='Host to bind to') |
| 312 | 299 | parser.add_argument('--port', type=int, default=6002, help='Port to bind to') | ... | ... |
config/config.yaml
| ... | ... | @@ -147,6 +147,14 @@ fields: |
| 147 | 147 | index: false |
| 148 | 148 | store: true |
| 149 | 149 | |
| 150 | + # 文本嵌入字段(用于语义搜索) | |
| 151 | + - name: "name_embedding" | |
| 152 | + type: "TEXT_EMBEDDING" | |
| 153 | + embedding_dims: 1024 | |
| 154 | + embedding_similarity: "dot_product" | |
| 155 | + index: true | |
| 156 | + store: false | |
| 157 | + | |
| 150 | 158 | # 嵌套variants字段 |
| 151 | 159 | - name: "variants" |
| 152 | 160 | type: "JSON" |
| ... | ... | @@ -239,6 +247,10 @@ query_config: |
| 239 | 247 | enable_text_embedding: true |
| 240 | 248 | enable_query_rewrite: true |
| 241 | 249 | |
| 250 | + # Embedding field names (if not set, will auto-detect from fields) | |
| 251 | + text_embedding_field: "name_embedding" # Field name for text embeddings | |
| 252 | + image_embedding_field: null # Field name for image embeddings (if not set, will auto-detect) | |
| 253 | + | |
| 242 | 254 | # Translation API (DeepL) |
| 243 | 255 | translation_service: "deepl" |
| 244 | 256 | translation_api_key: null # Set via environment variable | ... | ... |
config/config_loader.py
| ... | ... | @@ -54,6 +54,10 @@ class QueryConfig: |
| 54 | 54 | translation_glossary_id: Optional[str] = None # DeepL glossary ID for custom terminology |
| 55 | 55 | translation_context: str = "e-commerce product search" # Context hint for translation |
| 56 | 56 | |
| 57 | + # Embedding field names - if not set, will auto-detect from fields | |
| 58 | + text_embedding_field: Optional[str] = None # Field name for text embeddings (e.g., "name_embedding") | |
| 59 | + image_embedding_field: Optional[str] = None # Field name for image embeddings (e.g., "image_embedding") | |
| 60 | + | |
| 57 | 61 | # ES source fields configuration - fields to return in search results |
| 58 | 62 | source_fields: List[str] = field(default_factory=lambda: [ |
| 59 | 63 | "id", "spuId", "skuNo", "spuNo", "title", "enSpuName", "brandId", |
| ... | ... | @@ -213,7 +217,9 @@ class ConfigLoader: |
| 213 | 217 | translation_api_key=query_config_data.get("translation_api_key"), |
| 214 | 218 | translation_service=query_config_data.get("translation_service", "deepl"), |
| 215 | 219 | translation_glossary_id=query_config_data.get("translation_glossary_id"), |
| 216 | - translation_context=query_config_data.get("translation_context", "e-commerce product search") | |
| 220 | + translation_context=query_config_data.get("translation_context", "e-commerce product search"), | |
| 221 | + text_embedding_field=query_config_data.get("text_embedding_field"), | |
| 222 | + image_embedding_field=query_config_data.get("image_embedding_field") | |
| 217 | 223 | ) |
| 218 | 224 | |
| 219 | 225 | # Parse ranking config | ... | ... |
config/env_config.py
| ... | ... | @@ -2,11 +2,12 @@ |
| 2 | 2 | Centralized configuration management for SearchEngine. |
| 3 | 3 | |
| 4 | 4 | Loads configuration from environment variables and .env file. |
| 5 | +This module provides a single point for loading .env and setting defaults. | |
| 6 | +All configuration variables are exported directly - no need for getter functions. | |
| 5 | 7 | """ |
| 6 | 8 | |
| 7 | 9 | import os |
| 8 | 10 | from pathlib import Path |
| 9 | -from typing import Dict, Any | |
| 10 | 11 | from dotenv import load_dotenv |
| 11 | 12 | |
| 12 | 13 | # Load .env file from project root |
| ... | ... | @@ -56,26 +57,6 @@ DB_CONFIG = { |
| 56 | 57 | } |
| 57 | 58 | |
| 58 | 59 | |
| 59 | -def get_es_config() -> Dict[str, Any]: | |
| 60 | - """Get Elasticsearch configuration.""" | |
| 61 | - return ES_CONFIG.copy() | |
| 62 | - | |
| 63 | - | |
| 64 | -def get_redis_config() -> Dict[str, Any]: | |
| 65 | - """Get Redis configuration.""" | |
| 66 | - return REDIS_CONFIG.copy() | |
| 67 | - | |
| 68 | - | |
| 69 | -def get_deepl_key() -> str: | |
| 70 | - """Get DeepL API key.""" | |
| 71 | - return DEEPL_AUTH_KEY | |
| 72 | - | |
| 73 | - | |
| 74 | -def get_db_config() -> Dict[str, Any]: | |
| 75 | - """Get MySQL database configuration.""" | |
| 76 | - return DB_CONFIG.copy() | |
| 77 | - | |
| 78 | - | |
| 79 | 60 | def print_config(): |
| 80 | 61 | """Print current configuration (with sensitive data masked).""" |
| 81 | 62 | print("=" * 60) | ... | ... |
embeddings/image_encoder.py
| 1 | 1 | """ |
| 2 | -Image embedding encoder using CN-CLIP model. | |
| 2 | +Image embedding encoder using network service. | |
| 3 | 3 | |
| 4 | -Generates 1024-dimensional vectors for images using the CN-CLIP ViT-H-14 model. | |
| 4 | +Generates embeddings via HTTP API service running on localhost:5001. | |
| 5 | 5 | """ |
| 6 | 6 | |
| 7 | 7 | import sys |
| 8 | 8 | import os |
| 9 | -import io | |
| 10 | 9 | import requests |
| 11 | -import torch | |
| 12 | 10 | import numpy as np |
| 13 | 11 | from PIL import Image |
| 14 | 12 | import logging |
| 15 | 13 | import threading |
| 16 | -from typing import List, Optional, Union | |
| 17 | -import cn_clip.clip as clip | |
| 18 | -from cn_clip.clip import load_from_name | |
| 14 | +from typing import List, Optional, Union, Dict, Any | |
| 19 | 15 | |
| 20 | - | |
| 21 | -# DEFAULT_MODEL_NAME = "ViT-L-14-336" # ["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"] | |
| 22 | -DEFAULT_MODEL_NAME = "ViT-H-14" | |
| 23 | -MODEL_DOWNLOAD_DIR = "/data/tw/uat/EsSearcher" | |
| 16 | +logger = logging.getLogger(__name__) | |
| 24 | 17 | |
| 25 | 18 | |
| 26 | 19 | class CLIPImageEncoder: |
| 27 | 20 | """ |
| 28 | - CLIP Image Encoder for generating image embeddings using cn_clip. | |
| 21 | + Image Encoder for generating image embeddings using network service. | |
| 29 | 22 | |
| 30 | 23 | Thread-safe singleton pattern. |
| 31 | 24 | """ |
| ... | ... | @@ -33,111 +26,80 @@ class CLIPImageEncoder: |
| 33 | 26 | _instance = None |
| 34 | 27 | _lock = threading.Lock() |
| 35 | 28 | |
| 36 | - def __new__(cls, model_name=DEFAULT_MODEL_NAME, device=None): | |
| 29 | + def __new__(cls, service_url='http://localhost:5001'): | |
| 37 | 30 | with cls._lock: |
| 38 | 31 | if cls._instance is None: |
| 39 | 32 | cls._instance = super(CLIPImageEncoder, cls).__new__(cls) |
| 40 | - print(f"[CLIPImageEncoder] Creating new instance with model: {model_name}") | |
| 41 | - cls._instance._initialize_model(model_name, device) | |
| 33 | + logger.info(f"Creating CLIPImageEncoder instance with service URL: {service_url}") | |
| 34 | + cls._instance.service_url = service_url | |
| 35 | + cls._instance.endpoint = f"{service_url}/embedding/generate_image_embeddings" | |
| 42 | 36 | return cls._instance |
| 43 | 37 | |
| 44 | - def _initialize_model(self, model_name, device): | |
| 45 | - """Initialize the CLIP model using cn_clip""" | |
| 46 | - try: | |
| 47 | - self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
| 48 | - self.model, self.preprocess = load_from_name( | |
| 49 | - model_name, | |
| 50 | - device=self.device, | |
| 51 | - download_root=MODEL_DOWNLOAD_DIR | |
| 52 | - ) | |
| 53 | - self.model.eval() | |
| 54 | - self.model_name = model_name | |
| 55 | - print(f"[CLIPImageEncoder] Model {model_name} initialized successfully on device {self.device}") | |
| 56 | - | |
| 57 | - except Exception as e: | |
| 58 | - print(f"[CLIPImageEncoder] Failed to initialize model: {str(e)}") | |
| 59 | - raise | |
| 38 | + def _call_service(self, request_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| 39 | + """ | |
| 40 | + Call the embedding service API. | |
| 60 | 41 | |
| 61 | - def validate_image(self, image_data: bytes) -> Image.Image: | |
| 62 | - """Validate image data and return PIL Image if valid""" | |
| 63 | - try: | |
| 64 | - image_stream = io.BytesIO(image_data) | |
| 65 | - image = Image.open(image_stream) | |
| 66 | - image.verify() | |
| 67 | - image_stream.seek(0) | |
| 68 | - image = Image.open(image_stream) | |
| 69 | - if image.mode != 'RGB': | |
| 70 | - image = image.convert('RGB') | |
| 71 | - return image | |
| 72 | - except Exception as e: | |
| 73 | - raise ValueError(f"Invalid image data: {str(e)}") | |
| 42 | + Args: | |
| 43 | + request_data: List of dictionaries with id and pic_url fields | |
| 74 | 44 | |
| 75 | - def download_image(self, url: str, timeout: int = 10) -> bytes: | |
| 76 | - """Download image from URL""" | |
| 45 | + Returns: | |
| 46 | + List of dictionaries with id, pic_url, embedding and error fields | |
| 47 | + """ | |
| 77 | 48 | try: |
| 78 | - if url.startswith(('http://', 'https://')): | |
| 79 | - response = requests.get(url, timeout=timeout) | |
| 80 | - if response.status_code != 200: | |
| 81 | - raise ValueError(f"HTTP {response.status_code}") | |
| 82 | - return response.content | |
| 83 | - else: | |
| 84 | - # Local file path | |
| 85 | - with open(url, 'rb') as f: | |
| 86 | - return f.read() | |
| 87 | - except Exception as e: | |
| 88 | - raise ValueError(f"Failed to download image from {url}: {str(e)}") | |
| 89 | - | |
| 90 | - def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image: | |
| 91 | - """Preprocess image for CLIP model""" | |
| 92 | - # Resize if too large | |
| 93 | - if max(image.size) > max_size: | |
| 94 | - ratio = max_size / max(image.size) | |
| 95 | - new_size = tuple(int(dim * ratio) for dim in image.size) | |
| 96 | - image = image.resize(new_size, Image.Resampling.LANCZOS) | |
| 97 | - return image | |
| 98 | - | |
| 99 | - def encode_text(self, text): | |
| 100 | - """Encode text to embedding vector using cn_clip""" | |
| 101 | - text_data = clip.tokenize([text] if type(text) == str else text).to(self.device) | |
| 102 | - with torch.no_grad(): | |
| 103 | - text_features = self.model.encode_text(text_data) | |
| 104 | - text_features /= text_features.norm(dim=-1, keepdim=True) | |
| 105 | - return text_features | |
| 49 | + response = requests.post( | |
| 50 | + self.endpoint, | |
| 51 | + json=request_data, | |
| 52 | + timeout=60 | |
| 53 | + ) | |
| 54 | + response.raise_for_status() | |
| 55 | + return response.json() | |
| 56 | + except requests.exceptions.RequestException as e: | |
| 57 | + logger.error(f"CLIPImageEncoder service request failed: {e}", exc_info=True) | |
| 58 | + raise | |
| 106 | 59 | |
| 107 | 60 | def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: |
| 108 | - """Encode image to embedding vector using cn_clip""" | |
| 109 | - if not isinstance(image, Image.Image): | |
| 110 | - raise ValueError("CLIPImageEncoder.encode_image Input must be a PIL.Image") | |
| 61 | + """ | |
| 62 | + Encode image to embedding vector using network service. | |
| 111 | 63 | |
| 112 | - try: | |
| 113 | - infer_data = self.preprocess(image).unsqueeze(0).to(self.device) | |
| 114 | - with torch.no_grad(): | |
| 115 | - image_features = self.model.encode_image(infer_data) | |
| 116 | - image_features /= image_features.norm(dim=-1, keepdim=True) | |
| 117 | - return image_features.cpu().numpy().astype('float32')[0] | |
| 118 | - except Exception as e: | |
| 119 | - print(f"Failed to process image. Reason: {str(e)}") | |
| 120 | - return None | |
| 64 | + Note: This method is kept for compatibility but the service only works with URLs. | |
| 65 | + """ | |
| 66 | + logger.warning("encode_image with PIL Image not supported by service, returning None") | |
| 67 | + return None | |
| 121 | 68 | |
| 122 | 69 | def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: |
| 123 | - """Complete pipeline: download, validate, preprocess and encode image from URL""" | |
| 124 | - try: | |
| 125 | - # Download image | |
| 126 | - image_data = self.download_image(url) | |
| 127 | - | |
| 128 | - # Validate image | |
| 129 | - image = self.validate_image(image_data) | |
| 130 | - | |
| 131 | - # Preprocess image | |
| 132 | - image = self.preprocess_image(image) | |
| 70 | + """ | |
| 71 | + Generate image embedding via network service using URL. | |
| 133 | 72 | |
| 134 | - # Encode image | |
| 135 | - embedding = self.encode_image(image) | |
| 73 | + Args: | |
| 74 | + url: Image URL to process | |
| 136 | 75 | |
| 137 | - return embedding | |
| 76 | + Returns: | |
| 77 | + Embedding vector or None if failed | |
| 78 | + """ | |
| 79 | + try: | |
| 80 | + # Prepare request data | |
| 81 | + request_data = [{ | |
| 82 | + "id": "image_0", | |
| 83 | + "pic_url": url | |
| 84 | + }] | |
| 85 | + | |
| 86 | + # Call service | |
| 87 | + response_data = self._call_service(request_data) | |
| 88 | + | |
| 89 | + # Process response | |
| 90 | + if response_data and len(response_data) > 0: | |
| 91 | + response_item = response_data[0] | |
| 92 | + if response_item.get("embedding"): | |
| 93 | + return np.array(response_item["embedding"], dtype=np.float32) | |
| 94 | + else: | |
| 95 | + logger.warning(f"No embedding for URL {url}, error: {response_item.get('error', 'Unknown error')}") | |
| 96 | + return None | |
| 97 | + else: | |
| 98 | + logger.warning(f"No response for URL {url}") | |
| 99 | + return None | |
| 138 | 100 | |
| 139 | 101 | except Exception as e: |
| 140 | - print(f"Error processing image from URL {url}: {str(e)}") | |
| 102 | + logger.error(f"Failed to process image from URL {url}: {str(e)}", exc_info=True) | |
| 141 | 103 | return None |
| 142 | 104 | |
| 143 | 105 | def encode_batch( |
| ... | ... | @@ -146,33 +108,71 @@ class CLIPImageEncoder: |
| 146 | 108 | batch_size: int = 8 |
| 147 | 109 | ) -> List[Optional[np.ndarray]]: |
| 148 | 110 | """ |
| 149 | - Encode a batch of images efficiently. | |
| 111 | + Encode a batch of images efficiently via network service. | |
| 150 | 112 | |
| 151 | 113 | Args: |
| 152 | 114 | images: List of image URLs or PIL Images |
| 153 | - batch_size: Batch size for processing | |
| 115 | + batch_size: Batch size for processing (used for service requests) | |
| 154 | 116 | |
| 155 | 117 | Returns: |
| 156 | 118 | List of embeddings (or None for failed images) |
| 157 | 119 | """ |
| 158 | - results = [] | |
| 159 | - | |
| 160 | - for i in range(0, len(images), batch_size): | |
| 161 | - batch = images[i:i + batch_size] | |
| 162 | - batch_embeddings = [] | |
| 163 | - | |
| 164 | - for img in batch: | |
| 165 | - if isinstance(img, str): | |
| 166 | - # URL or file path | |
| 167 | - emb = self.encode_image_from_url(img) | |
| 168 | - elif isinstance(img, Image.Image): | |
| 169 | - # PIL Image | |
| 170 | - emb = self.encode_image(img) | |
| 171 | - else: | |
| 172 | - emb = None | |
| 173 | - | |
| 174 | - batch_embeddings.append(emb) | |
| 175 | - | |
| 176 | - results.extend(batch_embeddings) | |
| 120 | + # Initialize results with None for all images | |
| 121 | + results = [None] * len(images) | |
| 122 | + | |
| 123 | + # Filter out PIL Images since service only supports URLs | |
| 124 | + url_images = [] | |
| 125 | + url_indices = [] | |
| 126 | + | |
| 127 | + for i, img in enumerate(images): | |
| 128 | + if isinstance(img, str): | |
| 129 | + url_images.append(img) | |
| 130 | + url_indices.append(i) | |
| 131 | + elif isinstance(img, Image.Image): | |
| 132 | + logger.warning(f"PIL Image at index {i} not supported by service, returning None") | |
| 133 | + # results[i] is already None | |
| 134 | + | |
| 135 | + # Process URLs in batches | |
| 136 | + for i in range(0, len(url_images), batch_size): | |
| 137 | + batch_urls = url_images[i:i + batch_size] | |
| 138 | + batch_indices = url_indices[i:i + batch_size] | |
| 139 | + | |
| 140 | + # Prepare request data | |
| 141 | + request_data = [] | |
| 142 | + for j, url in enumerate(batch_urls): | |
| 143 | + request_data.append({ | |
| 144 | + "id": f"image_{j}", | |
| 145 | + "pic_url": url | |
| 146 | + }) | |
| 147 | + | |
| 148 | + try: | |
| 149 | + # Call service | |
| 150 | + response_data = self._call_service(request_data) | |
| 151 | + | |
| 152 | + # Process response | |
| 153 | + batch_results = [] | |
| 154 | + for j, url in enumerate(batch_urls): | |
| 155 | + response_item = None | |
| 156 | + for item in response_data: | |
| 157 | + if str(item.get("id")) == f"image_{j}": | |
| 158 | + response_item = item | |
| 159 | + break | |
| 160 | + | |
| 161 | + if response_item and response_item.get("embedding"): | |
| 162 | + batch_results.append(np.array(response_item["embedding"], dtype=np.float32)) | |
| 163 | + else: | |
| 164 | + error_msg = response_item.get("error", "Unknown error") if response_item else "No response" | |
| 165 | + logger.warning(f"Failed to encode URL {url}: {error_msg}") | |
| 166 | + batch_results.append(None) | |
| 167 | + | |
| 168 | + # Insert results at the correct positions | |
| 169 | + for j, result in enumerate(batch_results): | |
| 170 | + results[batch_indices[j]] = result | |
| 171 | + | |
| 172 | + except Exception as e: | |
| 173 | + logger.error(f"Batch processing failed: {e}", exc_info=True) | |
| 174 | + # Fill with None for this batch | |
| 175 | + for j in range(len(batch_urls)): | |
| 176 | + results[batch_indices[j]] = None | |
| 177 | 177 | |
| 178 | 178 | return results | ... | ... |
| ... | ... | @@ -0,0 +1,178 @@ |
| 1 | +""" | |
| 2 | +Image embedding encoder using CN-CLIP model. | |
| 3 | + | |
| 4 | +Generates 1024-dimensional vectors for images using the CN-CLIP ViT-H-14 model. | |
| 5 | +""" | |
| 6 | + | |
| 7 | +import sys | |
| 8 | +import os | |
| 9 | +import io | |
| 10 | +import requests | |
| 11 | +import torch | |
| 12 | +import numpy as np | |
| 13 | +from PIL import Image | |
| 14 | +import logging | |
| 15 | +import threading | |
| 16 | +from typing import List, Optional, Union | |
| 17 | +import cn_clip.clip as clip | |
| 18 | +from cn_clip.clip import load_from_name | |
| 19 | + | |
| 20 | + | |
| 21 | +# DEFAULT_MODEL_NAME = "ViT-L-14-336" # ["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"] | |
| 22 | +DEFAULT_MODEL_NAME = "ViT-H-14" | |
| 23 | +MODEL_DOWNLOAD_DIR = "/data/tw/uat/EsSearcher" | |
| 24 | + | |
| 25 | + | |
| 26 | +class CLIPImageEncoder: | |
| 27 | + """ | |
| 28 | + CLIP Image Encoder for generating image embeddings using cn_clip. | |
| 29 | + | |
| 30 | + Thread-safe singleton pattern. | |
| 31 | + """ | |
| 32 | + | |
| 33 | + _instance = None | |
| 34 | + _lock = threading.Lock() | |
| 35 | + | |
| 36 | + def __new__(cls, model_name=DEFAULT_MODEL_NAME, device=None): | |
| 37 | + with cls._lock: | |
| 38 | + if cls._instance is None: | |
| 39 | + cls._instance = super(CLIPImageEncoder, cls).__new__(cls) | |
| 40 | + print(f"[CLIPImageEncoder] Creating new instance with model: {model_name}") | |
| 41 | + cls._instance._initialize_model(model_name, device) | |
| 42 | + return cls._instance | |
| 43 | + | |
| 44 | + def _initialize_model(self, model_name, device): | |
| 45 | + """Initialize the CLIP model using cn_clip""" | |
| 46 | + try: | |
| 47 | + self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
| 48 | + self.model, self.preprocess = load_from_name( | |
| 49 | + model_name, | |
| 50 | + device=self.device, | |
| 51 | + download_root=MODEL_DOWNLOAD_DIR | |
| 52 | + ) | |
| 53 | + self.model.eval() | |
| 54 | + self.model_name = model_name | |
| 55 | + print(f"[CLIPImageEncoder] Model {model_name} initialized successfully on device {self.device}") | |
| 56 | + | |
| 57 | + except Exception as e: | |
| 58 | + print(f"[CLIPImageEncoder] Failed to initialize model: {str(e)}") | |
| 59 | + raise | |
| 60 | + | |
| 61 | + def validate_image(self, image_data: bytes) -> Image.Image: | |
| 62 | + """Validate image data and return PIL Image if valid""" | |
| 63 | + try: | |
| 64 | + image_stream = io.BytesIO(image_data) | |
| 65 | + image = Image.open(image_stream) | |
| 66 | + image.verify() | |
| 67 | + image_stream.seek(0) | |
| 68 | + image = Image.open(image_stream) | |
| 69 | + if image.mode != 'RGB': | |
| 70 | + image = image.convert('RGB') | |
| 71 | + return image | |
| 72 | + except Exception as e: | |
| 73 | + raise ValueError(f"Invalid image data: {str(e)}") | |
| 74 | + | |
| 75 | + def download_image(self, url: str, timeout: int = 10) -> bytes: | |
| 76 | + """Download image from URL""" | |
| 77 | + try: | |
| 78 | + if url.startswith(('http://', 'https://')): | |
| 79 | + response = requests.get(url, timeout=timeout) | |
| 80 | + if response.status_code != 200: | |
| 81 | + raise ValueError(f"HTTP {response.status_code}") | |
| 82 | + return response.content | |
| 83 | + else: | |
| 84 | + # Local file path | |
| 85 | + with open(url, 'rb') as f: | |
| 86 | + return f.read() | |
| 87 | + except Exception as e: | |
| 88 | + raise ValueError(f"Failed to download image from {url}: {str(e)}") | |
| 89 | + | |
| 90 | + def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image: | |
| 91 | + """Preprocess image for CLIP model""" | |
| 92 | + # Resize if too large | |
| 93 | + if max(image.size) > max_size: | |
| 94 | + ratio = max_size / max(image.size) | |
| 95 | + new_size = tuple(int(dim * ratio) for dim in image.size) | |
| 96 | + image = image.resize(new_size, Image.Resampling.LANCZOS) | |
| 97 | + return image | |
| 98 | + | |
| 99 | + def encode_text(self, text): | |
| 100 | + """Encode text to embedding vector using cn_clip""" | |
| 101 | + text_data = clip.tokenize([text] if type(text) == str else text).to(self.device) | |
| 102 | + with torch.no_grad(): | |
| 103 | + text_features = self.model.encode_text(text_data) | |
| 104 | + text_features /= text_features.norm(dim=-1, keepdim=True) | |
| 105 | + return text_features | |
| 106 | + | |
| 107 | + def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: | |
| 108 | + """Encode image to embedding vector using cn_clip""" | |
| 109 | + if not isinstance(image, Image.Image): | |
| 110 | + raise ValueError("CLIPImageEncoder.encode_image Input must be a PIL.Image") | |
| 111 | + | |
| 112 | + try: | |
| 113 | + infer_data = self.preprocess(image).unsqueeze(0).to(self.device) | |
| 114 | + with torch.no_grad(): | |
| 115 | + image_features = self.model.encode_image(infer_data) | |
| 116 | + image_features /= image_features.norm(dim=-1, keepdim=True) | |
| 117 | + return image_features.cpu().numpy().astype('float32')[0] | |
| 118 | + except Exception as e: | |
| 119 | + print(f"Failed to process image. Reason: {str(e)}") | |
| 120 | + return None | |
| 121 | + | |
| 122 | + def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: | |
| 123 | + """Complete pipeline: download, validate, preprocess and encode image from URL""" | |
| 124 | + try: | |
| 125 | + # Download image | |
| 126 | + image_data = self.download_image(url) | |
| 127 | + | |
| 128 | + # Validate image | |
| 129 | + image = self.validate_image(image_data) | |
| 130 | + | |
| 131 | + # Preprocess image | |
| 132 | + image = self.preprocess_image(image) | |
| 133 | + | |
| 134 | + # Encode image | |
| 135 | + embedding = self.encode_image(image) | |
| 136 | + | |
| 137 | + return embedding | |
| 138 | + | |
| 139 | + except Exception as e: | |
| 140 | + print(f"Error processing image from URL {url}: {str(e)}") | |
| 141 | + return None | |
| 142 | + | |
| 143 | + def encode_batch( | |
| 144 | + self, | |
| 145 | + images: List[Union[str, Image.Image]], | |
| 146 | + batch_size: int = 8 | |
| 147 | + ) -> List[Optional[np.ndarray]]: | |
| 148 | + """ | |
| 149 | + Encode a batch of images efficiently. | |
| 150 | + | |
| 151 | + Args: | |
| 152 | + images: List of image URLs or PIL Images | |
| 153 | + batch_size: Batch size for processing | |
| 154 | + | |
| 155 | + Returns: | |
| 156 | + List of embeddings (or None for failed images) | |
| 157 | + """ | |
| 158 | + results = [] | |
| 159 | + | |
| 160 | + for i in range(0, len(images), batch_size): | |
| 161 | + batch = images[i:i + batch_size] | |
| 162 | + batch_embeddings = [] | |
| 163 | + | |
| 164 | + for img in batch: | |
| 165 | + if isinstance(img, str): | |
| 166 | + # URL or file path | |
| 167 | + emb = self.encode_image_from_url(img) | |
| 168 | + elif isinstance(img, Image.Image): | |
| 169 | + # PIL Image | |
| 170 | + emb = self.encode_image(img) | |
| 171 | + else: | |
| 172 | + emb = None | |
| 173 | + | |
| 174 | + batch_embeddings.append(emb) | |
| 175 | + | |
| 176 | + results.extend(batch_embeddings) | |
| 177 | + | |
| 178 | + return results | ... | ... |
embeddings/text_encoder.py
| 1 | 1 | """ |
| 2 | -Text embedding encoder using BGE-M3 model. | |
| 2 | +Text embedding encoder using network service. | |
| 3 | 3 | |
| 4 | -Generates 1024-dimensional vectors for text using the BGE-M3 multilingual model. | |
| 4 | +Generates embeddings via HTTP API service running on localhost:5001. | |
| 5 | 5 | """ |
| 6 | 6 | |
| 7 | 7 | import sys |
| 8 | -import torch | |
| 9 | -from sentence_transformers import SentenceTransformer | |
| 8 | +import requests | |
| 10 | 9 | import time |
| 11 | 10 | import threading |
| 12 | -from modelscope import snapshot_download | |
| 13 | -from transformers import AutoModel | |
| 14 | -import os | |
| 15 | 11 | import numpy as np |
| 16 | -from typing import List, Union | |
| 12 | +import logging | |
| 13 | +from typing import List, Union, Dict, Any | |
| 14 | + | |
| 15 | +logger = logging.getLogger(__name__) | |
| 17 | 16 | |
| 18 | 17 | |
| 19 | 18 | class BgeEncoder: |
| 20 | 19 | """ |
| 21 | - Singleton text encoder using BGE-M3 model. | |
| 20 | + Singleton text encoder using network service. | |
| 22 | 21 | |
| 23 | - Thread-safe singleton pattern ensures only one model instance exists. | |
| 22 | + Thread-safe singleton pattern ensures only one instance exists. | |
| 24 | 23 | """ |
| 25 | 24 | _instance = None |
| 26 | 25 | _lock = threading.Lock() |
| 27 | 26 | |
| 28 | - def __new__(cls, model_dir='Xorbits/bge-m3'): | |
| 27 | + def __new__(cls, service_url='http://localhost:5001'): | |
| 29 | 28 | with cls._lock: |
| 30 | 29 | if cls._instance is None: |
| 31 | 30 | cls._instance = super(BgeEncoder, cls).__new__(cls) |
| 32 | - print(f"[BgeEncoder] Creating a new instance with model directory: {model_dir}") | |
| 33 | - cls._instance.model = SentenceTransformer(snapshot_download(model_dir)) | |
| 34 | - print("[BgeEncoder] New instance has been created") | |
| 31 | + logger.info(f"Creating BgeEncoder instance with service URL: {service_url}") | |
| 32 | + cls._instance.service_url = service_url | |
| 33 | + cls._instance.endpoint = f"{service_url}/embedding/generate_embeddings" | |
| 35 | 34 | return cls._instance |
| 36 | 35 | |
| 36 | + def _call_service(self, request_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| 37 | + """ | |
| 38 | + Call the embedding service API. | |
| 39 | + | |
| 40 | + Args: | |
| 41 | + request_data: List of dictionaries with id and text fields | |
| 42 | + | |
| 43 | + Returns: | |
| 44 | + List of dictionaries with id and embedding fields | |
| 45 | + """ | |
| 46 | + try: | |
| 47 | + response = requests.post( | |
| 48 | + self.endpoint, | |
| 49 | + json=request_data, | |
| 50 | + timeout=60 | |
| 51 | + ) | |
| 52 | + response.raise_for_status() | |
| 53 | + return response.json() | |
| 54 | + except requests.exceptions.RequestException as e: | |
| 55 | + logger.error(f"BgeEncoder service request failed: {e}", exc_info=True) | |
| 56 | + raise | |
| 57 | + | |
| 37 | 58 | def encode( |
| 38 | 59 | self, |
| 39 | 60 | sentences: Union[str, List[str]], |
| 40 | 61 | normalize_embeddings: bool = True, |
| 41 | - device: str = 'cuda', | |
| 62 | + device: str = 'cpu', | |
| 42 | 63 | batch_size: int = 32 |
| 43 | 64 | ) -> np.ndarray: |
| 44 | 65 | """ |
| 45 | - Encode text into embeddings. | |
| 66 | + Encode text into embeddings via network service. | |
| 46 | 67 | |
| 47 | 68 | Args: |
| 48 | 69 | sentences: Single string or list of strings to encode |
| 49 | - normalize_embeddings: Whether to normalize embeddings | |
| 50 | - device: Device to use ('cuda' or 'cpu') | |
| 51 | - batch_size: Batch size for encoding | |
| 70 | + normalize_embeddings: Whether to normalize embeddings (ignored for service) | |
| 71 | + device: Device parameter ignored for service compatibility | |
| 72 | + batch_size: Batch size for processing (used for service requests) | |
| 52 | 73 | |
| 53 | 74 | Returns: |
| 54 | 75 | numpy array of shape (n, 1024) containing embeddings |
| 55 | 76 | """ |
| 56 | - # Move model to specified device | |
| 57 | - if device == 'gpu': | |
| 58 | - device = 'cuda' | |
| 77 | + # Convert single string to list | |
| 78 | + if isinstance(sentences, str): | |
| 79 | + sentences = [sentences] | |
| 59 | 80 | |
| 60 | - # Try requested device, fallback to CPU if CUDA fails | |
| 61 | - try: | |
| 62 | - if device == 'cuda': | |
| 63 | - # Check CUDA memory first | |
| 64 | - import torch | |
| 65 | - if torch.cuda.is_available(): | |
| 66 | - # Check if we have enough memory (at least 1GB free) | |
| 67 | - free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated() | |
| 68 | - if free_memory < 1024 * 1024 * 1024: # 1GB | |
| 69 | - print(f"[BgeEncoder] CUDA memory insufficient ({free_memory/1024/1024:.1f}MB free), falling back to CPU") | |
| 70 | - device = 'cpu' | |
| 71 | - else: | |
| 72 | - print(f"[BgeEncoder] CUDA not available, using CPU") | |
| 73 | - device = 'cpu' | |
| 81 | + # Prepare request data | |
| 82 | + request_data = [] | |
| 83 | + for i, text in enumerate(sentences): | |
| 84 | + request_item = { | |
| 85 | + "id": str(i), | |
| 86 | + "name_zh": text | |
| 87 | + } | |
| 74 | 88 | |
| 75 | - self.model = self.model.to(device) | |
| 89 | + # Add English and Russian fields as empty for now | |
| 90 | + # Could be enhanced with language detection in the future | |
| 91 | + request_item["name_en"] = None | |
| 92 | + request_item["name_ru"] = None | |
| 76 | 93 | |
| 77 | - embeddings = self.model.encode( | |
| 78 | - sentences, | |
| 79 | - normalize_embeddings=normalize_embeddings, | |
| 80 | - device=device, | |
| 81 | - show_progress_bar=False, | |
| 82 | - batch_size=batch_size | |
| 83 | - ) | |
| 94 | + request_data.append(request_item) | |
| 95 | + | |
| 96 | + try: | |
| 97 | + # Call service | |
| 98 | + response_data = self._call_service(request_data) | |
| 99 | + | |
| 100 | + # Process response | |
| 101 | + embeddings = [] | |
| 102 | + for i, text in enumerate(sentences): | |
| 103 | + # Find corresponding response by ID | |
| 104 | + response_item = None | |
| 105 | + for item in response_data: | |
| 106 | + if str(item.get("id")) == str(i): | |
| 107 | + response_item = item | |
| 108 | + break | |
| 109 | + | |
| 110 | + if response_item: | |
| 111 | + # Try Chinese embedding first, then English, then Russian | |
| 112 | + embedding = None | |
| 113 | + for lang in ["embedding_zh", "embedding_en", "embedding_ru"]: | |
| 114 | + if lang in response_item and response_item[lang] is not None: | |
| 115 | + embedding = response_item[lang] | |
| 116 | + break | |
| 117 | + | |
| 118 | + if embedding is not None: | |
| 119 | + embeddings.append(embedding) | |
| 120 | + else: | |
| 121 | + logger.warning(f"No embedding found for text {i}: {text[:50]}...") | |
| 122 | + embeddings.append([0.0] * 1024) | |
| 123 | + else: | |
| 124 | + logger.warning(f"No response found for text {i}") | |
| 125 | + embeddings.append([0.0] * 1024) | |
| 84 | 126 | |
| 85 | - return embeddings | |
| 127 | + return np.array(embeddings, dtype=np.float32) | |
| 86 | 128 | |
| 87 | 129 | except Exception as e: |
| 88 | - print(f"[BgeEncoder] Device {device} failed: {e}") | |
| 89 | - if device != 'cpu': | |
| 90 | - print(f"[BgeEncoder] Falling back to CPU") | |
| 91 | - try: | |
| 92 | - self.model = self.model.to('cpu') | |
| 93 | - embeddings = self.model.encode( | |
| 94 | - sentences, | |
| 95 | - normalize_embeddings=normalize_embeddings, | |
| 96 | - device='cpu', | |
| 97 | - show_progress_bar=False, | |
| 98 | - batch_size=batch_size | |
| 99 | - ) | |
| 100 | - return embeddings | |
| 101 | - except Exception as e2: | |
| 102 | - print(f"[BgeEncoder] CPU also failed: {e2}") | |
| 103 | - raise | |
| 104 | - else: | |
| 105 | - raise | |
| 130 | + logger.error(f"Failed to encode texts: {e}", exc_info=True) | |
| 131 | + # Return zero embeddings as fallback | |
| 132 | + return np.zeros((len(sentences), 1024), dtype=np.float32) | |
| 106 | 133 | |
| 107 | 134 | def encode_batch( |
| 108 | 135 | self, |
| 109 | 136 | texts: List[str], |
| 110 | 137 | batch_size: int = 32, |
| 111 | - device: str = 'cuda' | |
| 138 | + device: str = 'cpu' | |
| 112 | 139 | ) -> np.ndarray: |
| 113 | 140 | """ |
| 114 | - Encode a batch of texts efficiently. | |
| 141 | + Encode a batch of texts efficiently via network service. | |
| 115 | 142 | |
| 116 | 143 | Args: |
| 117 | 144 | texts: List of texts to encode |
| 118 | 145 | batch_size: Batch size for processing |
| 119 | - device: Device to use | |
| 146 | + device: Device parameter ignored for service compatibility | |
| 120 | 147 | |
| 121 | 148 | Returns: |
| 122 | 149 | numpy array of embeddings | ... | ... |
| ... | ... | @@ -0,0 +1,124 @@ |
| 1 | +""" | |
| 2 | +Text embedding encoder using BGE-M3 model. | |
| 3 | + | |
| 4 | +Generates 1024-dimensional vectors for text using the BGE-M3 multilingual model. | |
| 5 | +""" | |
| 6 | + | |
| 7 | +import sys | |
| 8 | +import torch | |
| 9 | +from sentence_transformers import SentenceTransformer | |
| 10 | +import time | |
| 11 | +import threading | |
| 12 | +from modelscope import snapshot_download | |
| 13 | +from transformers import AutoModel | |
| 14 | +import os | |
| 15 | +import numpy as np | |
| 16 | +from typing import List, Union | |
| 17 | + | |
| 18 | + | |
| 19 | +class BgeEncoder: | |
| 20 | + """ | |
| 21 | + Singleton text encoder using BGE-M3 model. | |
| 22 | + | |
| 23 | + Thread-safe singleton pattern ensures only one model instance exists. | |
| 24 | + """ | |
| 25 | + _instance = None | |
| 26 | + _lock = threading.Lock() | |
| 27 | + | |
| 28 | + def __new__(cls, model_dir='Xorbits/bge-m3'): | |
| 29 | + with cls._lock: | |
| 30 | + if cls._instance is None: | |
| 31 | + cls._instance = super(BgeEncoder, cls).__new__(cls) | |
| 32 | + print(f"[BgeEncoder] Creating a new instance with model directory: {model_dir}") | |
| 33 | + cls._instance.model = SentenceTransformer(snapshot_download(model_dir)) | |
| 34 | + print("[BgeEncoder] New instance has been created") | |
| 35 | + return cls._instance | |
| 36 | + | |
| 37 | + def encode( | |
| 38 | + self, | |
| 39 | + sentences: Union[str, List[str]], | |
| 40 | + normalize_embeddings: bool = True, | |
| 41 | + device: str = 'cuda', | |
| 42 | + batch_size: int = 32 | |
| 43 | + ) -> np.ndarray: | |
| 44 | + """ | |
| 45 | + Encode text into embeddings. | |
| 46 | + | |
| 47 | + Args: | |
| 48 | + sentences: Single string or list of strings to encode | |
| 49 | + normalize_embeddings: Whether to normalize embeddings | |
| 50 | + device: Device to use ('cuda' or 'cpu') | |
| 51 | + batch_size: Batch size for encoding | |
| 52 | + | |
| 53 | + Returns: | |
| 54 | + numpy array of shape (n, 1024) containing embeddings | |
| 55 | + """ | |
| 56 | + # Move model to specified device | |
| 57 | + if device == 'gpu': | |
| 58 | + device = 'cuda' | |
| 59 | + | |
| 60 | + # Try requested device, fallback to CPU if CUDA fails | |
| 61 | + try: | |
| 62 | + if device == 'cuda': | |
| 63 | + # Check CUDA memory first | |
| 64 | + import torch | |
| 65 | + if torch.cuda.is_available(): | |
| 66 | + # Check if we have enough memory (at least 1GB free) | |
| 67 | + free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated() | |
| 68 | + if free_memory < 1024 * 1024 * 1024: # 1GB | |
| 69 | + print(f"[BgeEncoder] CUDA memory insufficient ({free_memory/1024/1024:.1f}MB free), falling back to CPU") | |
| 70 | + device = 'cpu' | |
| 71 | + else: | |
| 72 | + print(f"[BgeEncoder] CUDA not available, using CPU") | |
| 73 | + device = 'cpu' | |
| 74 | + | |
| 75 | + self.model = self.model.to(device) | |
| 76 | + | |
| 77 | + embeddings = self.model.encode( | |
| 78 | + sentences, | |
| 79 | + normalize_embeddings=normalize_embeddings, | |
| 80 | + device=device, | |
| 81 | + show_progress_bar=False, | |
| 82 | + batch_size=batch_size | |
| 83 | + ) | |
| 84 | + | |
| 85 | + return embeddings | |
| 86 | + | |
| 87 | + except Exception as e: | |
| 88 | + print(f"[BgeEncoder] Device {device} failed: {e}") | |
| 89 | + if device != 'cpu': | |
| 90 | + print(f"[BgeEncoder] Falling back to CPU") | |
| 91 | + try: | |
| 92 | + self.model = self.model.to('cpu') | |
| 93 | + embeddings = self.model.encode( | |
| 94 | + sentences, | |
| 95 | + normalize_embeddings=normalize_embeddings, | |
| 96 | + device='cpu', | |
| 97 | + show_progress_bar=False, | |
| 98 | + batch_size=batch_size | |
| 99 | + ) | |
| 100 | + return embeddings | |
| 101 | + except Exception as e2: | |
| 102 | + print(f"[BgeEncoder] CPU also failed: {e2}") | |
| 103 | + raise | |
| 104 | + else: | |
| 105 | + raise | |
| 106 | + | |
| 107 | + def encode_batch( | |
| 108 | + self, | |
| 109 | + texts: List[str], | |
| 110 | + batch_size: int = 32, | |
| 111 | + device: str = 'cuda' | |
| 112 | + ) -> np.ndarray: | |
| 113 | + """ | |
| 114 | + Encode a batch of texts efficiently. | |
| 115 | + | |
| 116 | + Args: | |
| 117 | + texts: List of texts to encode | |
| 118 | + batch_size: Batch size for processing | |
| 119 | + device: Device to use | |
| 120 | + | |
| 121 | + Returns: | |
| 122 | + numpy array of embeddings | |
| 123 | + """ | |
| 124 | + return self.encode(texts, batch_size=batch_size, device=device) | ... | ... |
indexer/bulk_indexer.py
| ... | ... | @@ -7,6 +7,7 @@ Handles batch indexing of documents with progress tracking and error handling. |
| 7 | 7 | from typing import List, Dict, Any, Optional |
| 8 | 8 | from elasticsearch.helpers import bulk, BulkIndexError |
| 9 | 9 | from utils.es_client import ESClient |
| 10 | +from indexer import MappingGenerator | |
| 10 | 11 | import time |
| 11 | 12 | |
| 12 | 13 | |
| ... | ... | @@ -232,8 +233,6 @@ class IndexingPipeline: |
| 232 | 233 | Returns: |
| 233 | 234 | Indexing statistics |
| 234 | 235 | """ |
| 235 | - from indexer.mapping_generator import MappingGenerator | |
| 236 | - | |
| 237 | 236 | # Generate and create index |
| 238 | 237 | mapping_gen = MappingGenerator(self.config) |
| 239 | 238 | mapping = mapping_gen.generate_mapping() | ... | ... |
indexer/mapping_generator.py
| ... | ... | @@ -5,6 +5,8 @@ Generates Elasticsearch index mappings from search configuration. |
| 5 | 5 | """ |
| 6 | 6 | |
| 7 | 7 | from typing import Dict, Any |
| 8 | +import logging | |
| 9 | + | |
| 8 | 10 | from config import ( |
| 9 | 11 | SearchConfig, |
| 10 | 12 | FieldConfig, |
| ... | ... | @@ -13,6 +15,8 @@ from config import ( |
| 13 | 15 | get_default_similarity |
| 14 | 16 | ) |
| 15 | 17 | |
| 18 | +logger = logging.getLogger(__name__) | |
| 19 | + | |
| 16 | 20 | |
| 17 | 21 | class MappingGenerator: |
| 18 | 22 | """Generates Elasticsearch mapping from search configuration.""" |
| ... | ... | @@ -85,31 +89,18 @@ class MappingGenerator: |
| 85 | 89 | Get the primary text embedding field name. |
| 86 | 90 | |
| 87 | 91 | Returns: |
| 88 | - Field name or empty string if not found | |
| 92 | + Field name or empty string if not configured | |
| 89 | 93 | """ |
| 90 | - # Look for name_embedding or first text_embedding field | |
| 91 | - for field in self.config.fields: | |
| 92 | - if field.name == "name_embedding": | |
| 93 | - return field.name | |
| 94 | - | |
| 95 | - # Otherwise return first text embedding field | |
| 96 | - for field in self.config.fields: | |
| 97 | - if "embedding" in field.name and "image" not in field.name: | |
| 98 | - return field.name | |
| 99 | - | |
| 100 | - return "" | |
| 94 | + return self.config.query_config.text_embedding_field or "" | |
| 101 | 95 | |
| 102 | 96 | def get_image_embedding_field(self) -> str: |
| 103 | 97 | """ |
| 104 | 98 | Get the primary image embedding field name. |
| 105 | 99 | |
| 106 | 100 | Returns: |
| 107 | - Field name or empty string if not found | |
| 101 | + Field name or empty string if not configured | |
| 108 | 102 | """ |
| 109 | - for field in self.config.fields: | |
| 110 | - if "image" in field.name and "embedding" in field.name: | |
| 111 | - return field.name | |
| 112 | - return "" | |
| 103 | + return self.config.query_config.image_embedding_field or "" | |
| 113 | 104 | |
| 114 | 105 | def get_field_by_name(self, field_name: str) -> FieldConfig: |
| 115 | 106 | """ |
| ... | ... | @@ -162,11 +153,11 @@ def create_index_if_not_exists(es_client, index_name: str, mapping: Dict[str, An |
| 162 | 153 | True if index was created, False if it already exists |
| 163 | 154 | """ |
| 164 | 155 | if es_client.indices.exists(index=index_name): |
| 165 | - print(f"Index '{index_name}' already exists") | |
| 156 | + logger.info(f"Index '{index_name}' already exists") | |
| 166 | 157 | return False |
| 167 | 158 | |
| 168 | 159 | es_client.indices.create(index=index_name, body=mapping) |
| 169 | - print(f"Index '{index_name}' created successfully") | |
| 160 | + logger.info(f"Index '{index_name}' created successfully") | |
| 170 | 161 | return True |
| 171 | 162 | |
| 172 | 163 | |
| ... | ... | @@ -182,11 +173,11 @@ def delete_index_if_exists(es_client, index_name: str) -> bool: |
| 182 | 173 | True if index was deleted, False if it didn't exist |
| 183 | 174 | """ |
| 184 | 175 | if not es_client.indices.exists(index=index_name): |
| 185 | - print(f"Index '{index_name}' does not exist") | |
| 176 | + logger.warning(f"Index '{index_name}' does not exist") | |
| 186 | 177 | return False |
| 187 | 178 | |
| 188 | 179 | es_client.indices.delete(index=index_name) |
| 189 | - print(f"Index '{index_name}' deleted successfully") | |
| 180 | + logger.info(f"Index '{index_name}' deleted successfully") | |
| 190 | 181 | return True |
| 191 | 182 | |
| 192 | 183 | |
| ... | ... | @@ -203,10 +194,10 @@ def update_mapping(es_client, index_name: str, new_fields: Dict[str, Any]) -> bo |
| 203 | 194 | True if successful |
| 204 | 195 | """ |
| 205 | 196 | if not es_client.indices.exists(index=index_name): |
| 206 | - print(f"Index '{index_name}' does not exist") | |
| 197 | + logger.error(f"Index '{index_name}' does not exist") | |
| 207 | 198 | return False |
| 208 | 199 | |
| 209 | 200 | mapping = {"properties": new_fields} |
| 210 | 201 | es_client.indices.put_mapping(index=index_name, body=mapping) |
| 211 | - print(f"Mapping updated for index '{index_name}'") | |
| 202 | + logger.info(f"Mapping updated for index '{index_name}'") | |
| 212 | 203 | return True | ... | ... |
query/query_parser.py
| ... | ... | @@ -6,6 +6,7 @@ Handles query rewriting, translation, and embedding generation. |
| 6 | 6 | |
| 7 | 7 | from typing import Dict, List, Optional, Any |
| 8 | 8 | import numpy as np |
| 9 | +import logging | |
| 9 | 10 | |
| 10 | 11 | from config import SearchConfig, QueryConfig |
| 11 | 12 | from embeddings import BgeEncoder |
| ... | ... | @@ -13,6 +14,8 @@ from .language_detector import LanguageDetector |
| 13 | 14 | from .translator import Translator |
| 14 | 15 | from .query_rewriter import QueryRewriter, QueryNormalizer |
| 15 | 16 | |
| 17 | +logger = logging.getLogger(__name__) | |
| 18 | + | |
| 16 | 19 | |
| 17 | 20 | class ParsedQuery: |
| 18 | 21 | """Container for parsed query results.""" |
| ... | ... | @@ -87,7 +90,7 @@ class QueryParser: |
| 87 | 90 | def text_encoder(self) -> BgeEncoder: |
| 88 | 91 | """Lazy load text encoder.""" |
| 89 | 92 | if self._text_encoder is None and self.query_config.enable_text_embedding: |
| 90 | - print("[QueryParser] Initializing text encoder...") | |
| 93 | + logger.info("Initializing text encoder (lazy load)...") | |
| 91 | 94 | self._text_encoder = BgeEncoder() |
| 92 | 95 | return self._text_encoder |
| 93 | 96 | |
| ... | ... | @@ -95,7 +98,7 @@ class QueryParser: |
| 95 | 98 | def translator(self) -> Translator: |
| 96 | 99 | """Lazy load translator.""" |
| 97 | 100 | if self._translator is None and self.query_config.enable_translation: |
| 98 | - print("[QueryParser] Initializing translator...") | |
| 101 | + logger.info("Initializing translator (lazy load)...") | |
| 99 | 102 | self._translator = Translator( |
| 100 | 103 | api_key=self.query_config.translation_api_key, |
| 101 | 104 | use_cache=True, |
| ... | ... | @@ -124,18 +127,17 @@ class QueryParser: |
| 124 | 127 | extra={'reqid': context.reqid, 'uid': context.uid} |
| 125 | 128 | ) |
| 126 | 129 | |
| 127 | - # Use print statements for backward compatibility if no context | |
| 128 | 130 | def log_info(msg): |
| 129 | - if logger: | |
| 130 | - logger.info(msg, extra={'reqid': context.reqid, 'uid': context.uid}) | |
| 131 | + if context and hasattr(context, 'logger'): | |
| 132 | + context.logger.info(msg, extra={'reqid': context.reqid, 'uid': context.uid}) | |
| 131 | 133 | else: |
| 132 | - print(f"[QueryParser] {msg}") | |
| 134 | + logger.info(msg) | |
| 133 | 135 | |
| 134 | 136 | def log_debug(msg): |
| 135 | - if logger: | |
| 136 | - logger.debug(msg, extra={'reqid': context.reqid, 'uid': context.uid}) | |
| 137 | + if context and hasattr(context, 'logger'): | |
| 138 | + context.logger.debug(msg, extra={'reqid': context.reqid, 'uid': context.uid}) | |
| 137 | 139 | else: |
| 138 | - print(f"[QueryParser] {msg}") | |
| 140 | + logger.debug(msg) | |
| 139 | 141 | |
| 140 | 142 | # Stage 1: Normalize |
| 141 | 143 | normalized = self.normalizer.normalize(query) |
| ... | ... | @@ -246,15 +248,18 @@ class QueryParser: |
| 246 | 248 | domain=domain |
| 247 | 249 | ) |
| 248 | 250 | |
| 249 | - if logger: | |
| 250 | - logger.info( | |
| 251 | + if context and hasattr(context, 'logger'): | |
| 252 | + context.logger.info( | |
| 251 | 253 | f"查询解析完成 | 原查询: '{query}' | 最终查询: '{rewritten or query_text}' | " |
| 252 | 254 | f"语言: {detected_lang} | 域: {domain} | " |
| 253 | 255 | f"翻译数量: {len(translations)} | 向量: {'是' if query_vector is not None else '否'}", |
| 254 | 256 | extra={'reqid': context.reqid, 'uid': context.uid} |
| 255 | 257 | ) |
| 256 | 258 | else: |
| 257 | - print(f"[QueryParser] Parsing complete") | |
| 259 | + logger.info( | |
| 260 | + f"查询解析完成 | 原查询: '{query}' | 最终查询: '{rewritten or query_text}' | " | |
| 261 | + f"语言: {detected_lang} | 域: {domain}" | |
| 262 | + ) | |
| 258 | 263 | |
| 259 | 264 | return result |
| 260 | 265 | ... | ... |
query/translator.py
| ... | ... | @@ -8,6 +8,12 @@ import requests |
| 8 | 8 | from typing import Dict, List, Optional |
| 9 | 9 | from utils.cache import DictCache |
| 10 | 10 | |
| 11 | +# Try to import DEEPL_AUTH_KEY, but allow import to fail | |
| 12 | +try: | |
| 13 | + from config.env_config import DEEPL_AUTH_KEY | |
| 14 | +except ImportError: | |
| 15 | + DEEPL_AUTH_KEY = None | |
| 16 | + | |
| 11 | 17 | |
| 12 | 18 | class Translator: |
| 13 | 19 | """Multi-language translator using DeepL API.""" |
| ... | ... | @@ -47,12 +53,8 @@ class Translator: |
| 47 | 53 | translation_context: Context hint for translation (e.g., "e-commerce", "product search") |
| 48 | 54 | """ |
| 49 | 55 | # Get API key from config if not provided |
| 50 | - if api_key is None: | |
| 51 | - try: | |
| 52 | - from config.env_config import get_deepl_key | |
| 53 | - api_key = get_deepl_key() | |
| 54 | - except ImportError: | |
| 55 | - pass | |
| 56 | + if api_key is None and DEEPL_AUTH_KEY: | |
| 57 | + api_key = DEEPL_AUTH_KEY | |
| 56 | 58 | |
| 57 | 59 | self.api_key = api_key |
| 58 | 60 | self.timeout = timeout | ... | ... |
scripts/ingest_shoplazza.py
| ... | ... | @@ -78,13 +78,12 @@ def main(): |
| 78 | 78 | return 1 |
| 79 | 79 | |
| 80 | 80 | # Connect to Elasticsearch (use unified config loading) |
| 81 | - from config.env_config import get_es_config | |
| 82 | - es_config = get_es_config() | |
| 81 | + from config.env_config import ES_CONFIG | |
| 83 | 82 | |
| 84 | 83 | # Use provided es_host or fallback to config |
| 85 | - es_host = args.es_host or es_config.get('host', 'http://localhost:9200') | |
| 86 | - es_username = es_config.get('username') | |
| 87 | - es_password = es_config.get('password') | |
| 84 | + es_host = args.es_host or ES_CONFIG.get('host', 'http://localhost:9200') | |
| 85 | + es_username = ES_CONFIG.get('username') | |
| 86 | + es_password = ES_CONFIG.get('password') | |
| 88 | 87 | |
| 89 | 88 | print(f"Connecting to Elasticsearch: {es_host}") |
| 90 | 89 | if es_username and es_password: | ... | ... |
search/multilang_query_builder.py
| ... | ... | @@ -8,11 +8,15 @@ maintaining a unified external interface. |
| 8 | 8 | |
| 9 | 9 | from typing import Dict, Any, List, Optional |
| 10 | 10 | import numpy as np |
| 11 | +import logging | |
| 12 | +import re | |
| 11 | 13 | |
| 12 | 14 | from config import SearchConfig, IndexConfig |
| 13 | 15 | from query import ParsedQuery |
| 14 | 16 | from .es_query_builder import ESQueryBuilder |
| 15 | 17 | |
| 18 | +logger = logging.getLogger(__name__) | |
| 19 | + | |
| 16 | 20 | |
| 17 | 21 | class MultiLanguageQueryBuilder(ESQueryBuilder): |
| 18 | 22 | """ |
| ... | ... | @@ -139,20 +143,18 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): |
| 139 | 143 | min_score=min_score |
| 140 | 144 | ) |
| 141 | 145 | |
| 142 | - print(f"[MultiLangQueryBuilder] Building query for domain: {domain}") | |
| 143 | - print(f"[MultiLangQueryBuilder] Detected language: {parsed_query.detected_language}") | |
| 144 | - print(f"[MultiLangQueryBuilder] Available translations: {list(parsed_query.translations.keys())}") | |
| 146 | + logger.debug(f"Building query for domain: {domain}, language: {parsed_query.detected_language}") | |
| 145 | 147 | |
| 146 | 148 | # Build query clause with multi-language support |
| 147 | 149 | if query_node and isinstance(query_node, tuple) and len(query_node) > 0: |
| 148 | 150 | # Handle boolean query from tuple (AST, score) |
| 149 | 151 | ast_node = query_node[0] |
| 150 | 152 | query_clause = self._build_boolean_query_from_tuple(ast_node) |
| 151 | - print(f"[MultiLangQueryBuilder] Using boolean query: {query_clause}") | |
| 153 | + logger.debug(f"Using boolean query") | |
| 152 | 154 | elif query_node and hasattr(query_node, 'operator') and query_node.operator != 'TERM': |
| 153 | 155 | # Handle boolean query using base class method |
| 154 | 156 | query_clause = self._build_boolean_query(query_node) |
| 155 | - print(f"[MultiLangQueryBuilder] Using boolean query: {query_clause}") | |
| 157 | + logger.debug(f"Using boolean query") | |
| 156 | 158 | else: |
| 157 | 159 | # Handle text query with multi-language support |
| 158 | 160 | query_clause = self._build_multilang_text_query(parsed_query, domain_config) |
| ... | ... | @@ -171,7 +173,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): |
| 171 | 173 | } |
| 172 | 174 | } |
| 173 | 175 | inner_bool_should.append(knn_query) |
| 174 | - print(f"[MultiLangQueryBuilder] KNN query added: field={self.text_embedding_field}, k={knn_k}, num_candidates={knn_num_candidates}") | |
| 176 | + logger.info(f"KNN query added: field={self.text_embedding_field}, k={knn_k}") | |
| 175 | 177 | else: |
| 176 | 178 | # Debug why KNN is not added |
| 177 | 179 | reasons = [] |
| ... | ... | @@ -181,7 +183,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): |
| 181 | 183 | reasons.append("query_vector is None") |
| 182 | 184 | if not self.text_embedding_field: |
| 183 | 185 | reasons.append(f"text_embedding_field is not set (current: {self.text_embedding_field})") |
| 184 | - print(f"[MultiLangQueryBuilder] KNN query NOT added. Reasons: {', '.join(reasons) if reasons else 'unknown'}") | |
| 186 | + logger.debug(f"KNN query NOT added. Reasons: {', '.join(reasons) if reasons else 'unknown'}") | |
| 185 | 187 | |
| 186 | 188 | # 构建内层bool结构 |
| 187 | 189 | inner_bool = { |
| ... | ... | @@ -342,7 +344,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): |
| 342 | 344 | "_name": f"{domain_config.name}_{detected_lang}_query" |
| 343 | 345 | } |
| 344 | 346 | }) |
| 345 | - print(f"[MultiLangQueryBuilder] Added query for detected language '{detected_lang}' on fields: {target_fields}") | |
| 347 | + logger.debug(f"Added query for detected language '{detected_lang}'") | |
| 346 | 348 | |
| 347 | 349 | # 2. Query in translated languages (only for languages in mapping) |
| 348 | 350 | for lang, translation in parsed_query.translations.items(): |
| ... | ... | @@ -361,11 +363,11 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): |
| 361 | 363 | "_name": f"{domain_config.name}_{lang}_translated_query" |
| 362 | 364 | } |
| 363 | 365 | }) |
| 364 | - print(f"[MultiLangQueryBuilder] Added translated query for language '{lang}' on fields: {target_fields}") | |
| 366 | + logger.debug(f"Added translated query for language '{lang}'") | |
| 365 | 367 | |
| 366 | 368 | # 3. Fallback: query all fields in mapping if no language-specific query was built |
| 367 | 369 | if not should_clauses: |
| 368 | - print(f"[MultiLangQueryBuilder] No language mapping matched, using all fields from mapping") | |
| 370 | + logger.debug("No language mapping matched, using all fields from mapping") | |
| 369 | 371 | # Use all fields from all languages in the mapping |
| 370 | 372 | all_mapped_fields = [] |
| 371 | 373 | for lang_fields in domain_config.language_field_mapping.values(): |
| ... | ... | @@ -445,7 +447,6 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): |
| 445 | 447 | operator = node[0].__name__ |
| 446 | 448 | elif str(node[0]).startswith('('): |
| 447 | 449 | # String representation of constructor call |
| 448 | - import re | |
| 449 | 450 | match = re.match(r'(\w+)\(', str(node[0])) |
| 450 | 451 | if match: |
| 451 | 452 | operator = match.group(1) | ... | ... |
search/searcher.py
| ... | ... | @@ -6,11 +6,13 @@ Handles query parsing, boolean expressions, ranking, and result formatting. |
| 6 | 6 | |
| 7 | 7 | from typing import Dict, Any, List, Optional, Union |
| 8 | 8 | import time |
| 9 | +import logging | |
| 9 | 10 | |
| 10 | 11 | from config import SearchConfig |
| 11 | 12 | from utils.es_client import ESClient |
| 12 | 13 | from query import QueryParser, ParsedQuery |
| 13 | 14 | from indexer import MappingGenerator |
| 15 | +from embeddings import CLIPImageEncoder | |
| 14 | 16 | from .boolean_parser import BooleanParser, QueryNode |
| 15 | 17 | from .es_query_builder import ESQueryBuilder |
| 16 | 18 | from .multilang_query_builder import MultiLanguageQueryBuilder |
| ... | ... | @@ -19,6 +21,8 @@ from context.request_context import RequestContext, RequestContextStage, create_ |
| 19 | 21 | from api.models import FacetResult, FacetValue |
| 20 | 22 | from api.result_formatter import ResultFormatter |
| 21 | 23 | |
| 24 | +logger = logging.getLogger(__name__) | |
| 25 | + | |
| 22 | 26 | |
| 23 | 27 | class SearchResult: |
| 24 | 28 | """Container for search results (外部友好格式).""" |
| ... | ... | @@ -476,7 +480,6 @@ class Searcher: |
| 476 | 480 | raise ValueError("Image embedding field not configured") |
| 477 | 481 | |
| 478 | 482 | # Generate image embedding |
| 479 | - from embeddings import CLIPImageEncoder | |
| 480 | 483 | image_encoder = CLIPImageEncoder() |
| 481 | 484 | image_vector = image_encoder.encode_image_from_url(image_url) |
| 482 | 485 | |
| ... | ... | @@ -575,7 +578,7 @@ class Searcher: |
| 575 | 578 | ) |
| 576 | 579 | return response.get('_source') |
| 577 | 580 | except Exception as e: |
| 578 | - print(f"[Searcher] Failed to get document {doc_id}: {e}") | |
| 581 | + logger.error(f"Failed to get document {doc_id}: {e}", exc_info=True) | |
| 579 | 582 | return None |
| 580 | 583 | |
| 581 | 584 | def _standardize_facets( | ... | ... |
utils/es_client.py
| ... | ... | @@ -3,8 +3,18 @@ Elasticsearch client wrapper. |
| 3 | 3 | """ |
| 4 | 4 | |
| 5 | 5 | from elasticsearch import Elasticsearch |
| 6 | +from elasticsearch.helpers import bulk | |
| 6 | 7 | from typing import Dict, Any, List, Optional |
| 7 | 8 | import os |
| 9 | +import logging | |
| 10 | + | |
| 11 | +# Try to import ES_CONFIG, but allow import to fail | |
| 12 | +try: | |
| 13 | + from config.env_config import ES_CONFIG | |
| 14 | +except ImportError: | |
| 15 | + ES_CONFIG = None | |
| 16 | + | |
| 17 | +logger = logging.getLogger(__name__) | |
| 8 | 18 | |
| 9 | 19 | |
| 10 | 20 | class ESClient: |
| ... | ... | @@ -56,7 +66,7 @@ class ESClient: |
| 56 | 66 | try: |
| 57 | 67 | return self.client.ping() |
| 58 | 68 | except Exception as e: |
| 59 | - print(f"Failed to ping Elasticsearch: {e}") | |
| 69 | + logger.error(f"Failed to ping Elasticsearch: {e}", exc_info=True) | |
| 60 | 70 | return False |
| 61 | 71 | |
| 62 | 72 | def create_index(self, index_name: str, body: Dict[str, Any]) -> bool: |
| ... | ... | @@ -72,12 +82,10 @@ class ESClient: |
| 72 | 82 | """ |
| 73 | 83 | try: |
| 74 | 84 | self.client.indices.create(index=index_name, body=body) |
| 75 | - print(f"Index '{index_name}' created successfully") | |
| 85 | + logger.info(f"Index '{index_name}' created successfully") | |
| 76 | 86 | return True |
| 77 | 87 | except Exception as e: |
| 78 | - print(f"ERROR: Failed to create index '{index_name}': {e}") | |
| 79 | - import traceback | |
| 80 | - traceback.print_exc() | |
| 88 | + logger.error(f"Failed to create index '{index_name}': {e}", exc_info=True) | |
| 81 | 89 | return False |
| 82 | 90 | |
| 83 | 91 | def delete_index(self, index_name: str) -> bool: |
| ... | ... | @@ -93,13 +101,13 @@ class ESClient: |
| 93 | 101 | try: |
| 94 | 102 | if self.client.indices.exists(index=index_name): |
| 95 | 103 | self.client.indices.delete(index=index_name) |
| 96 | - print(f"Index '{index_name}' deleted successfully") | |
| 104 | + logger.info(f"Index '{index_name}' deleted successfully") | |
| 97 | 105 | return True |
| 98 | 106 | else: |
| 99 | - print(f"Index '{index_name}' does not exist") | |
| 107 | + logger.warning(f"Index '{index_name}' does not exist") | |
| 100 | 108 | return False |
| 101 | 109 | except Exception as e: |
| 102 | - print(f"Failed to delete index '{index_name}': {e}") | |
| 110 | + logger.error(f"Failed to delete index '{index_name}': {e}", exc_info=True) | |
| 103 | 111 | return False |
| 104 | 112 | |
| 105 | 113 | def index_exists(self, index_name: str) -> bool: |
| ... | ... | @@ -117,8 +125,6 @@ class ESClient: |
| 117 | 125 | Returns: |
| 118 | 126 | Dictionary with results |
| 119 | 127 | """ |
| 120 | - from elasticsearch.helpers import bulk | |
| 121 | - | |
| 122 | 128 | actions = [] |
| 123 | 129 | for doc in docs: |
| 124 | 130 | action = { |
| ... | ... | @@ -140,7 +146,7 @@ class ESClient: |
| 140 | 146 | 'errors': failed |
| 141 | 147 | } |
| 142 | 148 | except Exception as e: |
| 143 | - print(f"Bulk indexing failed: {e}") | |
| 149 | + logger.error(f"Bulk indexing failed: {e}", exc_info=True) | |
| 144 | 150 | return { |
| 145 | 151 | 'success': 0, |
| 146 | 152 | 'failed': len(docs), |
| ... | ... | @@ -174,7 +180,7 @@ class ESClient: |
| 174 | 180 | from_=from_ |
| 175 | 181 | ) |
| 176 | 182 | except Exception as e: |
| 177 | - print(f"Search failed: {e}") | |
| 183 | + logger.error(f"Search failed: {e}", exc_info=True) | |
| 178 | 184 | return { |
| 179 | 185 | 'hits': { |
| 180 | 186 | 'total': {'value': 0}, |
| ... | ... | @@ -188,7 +194,7 @@ class ESClient: |
| 188 | 194 | try: |
| 189 | 195 | return self.client.indices.get_mapping(index=index_name) |
| 190 | 196 | except Exception as e: |
| 191 | - print(f"Failed to get mapping for '{index_name}': {e}") | |
| 197 | + logger.error(f"Failed to get mapping for '{index_name}': {e}", exc_info=True) | |
| 192 | 198 | return {} |
| 193 | 199 | |
| 194 | 200 | def refresh(self, index_name: str) -> bool: |
| ... | ... | @@ -197,7 +203,7 @@ class ESClient: |
| 197 | 203 | self.client.indices.refresh(index=index_name) |
| 198 | 204 | return True |
| 199 | 205 | except Exception as e: |
| 200 | - print(f"Failed to refresh index '{index_name}': {e}") | |
| 206 | + logger.error(f"Failed to refresh index '{index_name}': {e}", exc_info=True) | |
| 201 | 207 | return False |
| 202 | 208 | |
| 203 | 209 | def count(self, index_name: str, body: Optional[Dict[str, Any]] = None) -> int: |
| ... | ... | @@ -215,7 +221,7 @@ class ESClient: |
| 215 | 221 | result = self.client.count(index=index_name, body=body) |
| 216 | 222 | return result['count'] |
| 217 | 223 | except Exception as e: |
| 218 | - print(f"Count failed: {e}") | |
| 224 | + logger.error(f"Count failed: {e}", exc_info=True) | |
| 219 | 225 | return 0 |
| 220 | 226 | |
| 221 | 227 | |
| ... | ... | @@ -231,15 +237,13 @@ def get_es_client_from_env() -> ESClient: |
| 231 | 237 | Returns: |
| 232 | 238 | ESClient instance |
| 233 | 239 | """ |
| 234 | - try: | |
| 235 | - from config.env_config import get_es_config | |
| 236 | - es_config = get_es_config() | |
| 240 | + if ES_CONFIG: | |
| 237 | 241 | return ESClient( |
| 238 | - hosts=[es_config['host']], | |
| 239 | - username=es_config.get('username'), | |
| 240 | - password=es_config.get('password') | |
| 242 | + hosts=[ES_CONFIG['host']], | |
| 243 | + username=ES_CONFIG.get('username'), | |
| 244 | + password=ES_CONFIG.get('password') | |
| 241 | 245 | ) |
| 242 | - except ImportError: | |
| 246 | + else: | |
| 243 | 247 | # Fallback to env variables |
| 244 | 248 | return ESClient( |
| 245 | 249 | hosts=[os.getenv('ES_HOST', 'http://localhost:9200')], | ... | ... |