diff --git a/api/app.py b/api/app.py index 2f06c51..316c18e 100644 --- a/api/app.py +++ b/api/app.py @@ -9,6 +9,8 @@ import os import sys import logging import time +import argparse +import uvicorn from collections import defaultdict, deque from typing import Optional from fastapi import FastAPI, Request, HTTPException @@ -20,7 +22,6 @@ from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded -import argparse # Configure logging with better formatting logging.basicConfig( @@ -40,6 +41,7 @@ limiter = Limiter(key_func=get_remote_address) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from config import ConfigLoader, SearchConfig +from config.env_config import ES_CONFIG from utils import ESClient from search import Searcher from query import QueryParser @@ -60,55 +62,42 @@ def init_service(es_host: str = "http://localhost:9200"): """ global _config, _es_client, _searcher, _query_parser - print("Initializing search service (multi-tenant)") + start_time = time.time() + logger.info("Initializing search service (multi-tenant)") - # Load unified configuration + # Load and validate configuration + logger.info("Loading configuration...") config_loader = ConfigLoader("config/config.yaml") _config = config_loader.load_config() - - # Validate configuration errors = config_loader.validate_config(_config) if errors: raise ValueError(f"Configuration validation failed: {errors}") + logger.info(f"Configuration loaded: {_config.es_index_name}") - print(f"Configuration loaded: {_config.es_index_name}") + # Get ES credentials + es_username = os.getenv('ES_USERNAME') or ES_CONFIG.get('username') + es_password = os.getenv('ES_PASSWORD') or ES_CONFIG.get('password') - # Get ES credentials from environment variables or .env file - es_username = os.getenv('ES_USERNAME') - es_password = os.getenv('ES_PASSWORD') - - # Try to load from config if not in env - if not es_username or not es_password: - try: - from config.env_config import get_es_config - es_config = get_es_config() - es_username = es_username or es_config.get('username') - es_password = es_password or es_config.get('password') - except Exception: - pass - - # Initialize ES client with authentication if credentials are available + # Connect to Elasticsearch + logger.info(f"Connecting to Elasticsearch at {es_host}...") if es_username and es_password: - print(f"Connecting to Elasticsearch with authentication: {es_username}") _es_client = ESClient(hosts=[es_host], username=es_username, password=es_password) else: - print(f"Connecting to Elasticsearch without authentication") _es_client = ESClient(hosts=[es_host]) if not _es_client.ping(): raise ConnectionError(f"Failed to connect to Elasticsearch at {es_host}") + logger.info("Elasticsearch connected") - print(f"Connected to Elasticsearch: {es_host}") - - # Initialize query parser + # Initialize components + logger.info("Initializing query parser...") _query_parser = QueryParser(_config) - print("Query parser initialized") - - # Initialize searcher + + logger.info("Initializing searcher...") _searcher = Searcher(_config, _es_client, _query_parser) - print("Searcher initialized") - - print("Search service ready!") + + elapsed = time.time() - start_time + logger.info(f"Search service ready! (took {elapsed:.2f}s)") def get_config() -> SearchConfig: @@ -305,8 +294,6 @@ else: if __name__ == "__main__": - import uvicorn - parser = argparse.ArgumentParser(description='Start search API service (multi-tenant)') parser.add_argument('--host', default='0.0.0.0', help='Host to bind to') parser.add_argument('--port', type=int, default=6002, help='Port to bind to') diff --git a/config/config.yaml b/config/config.yaml index 5b6ef04..d59af1c 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -147,6 +147,14 @@ fields: index: false store: true + # 文本嵌入字段(用于语义搜索) + - name: "name_embedding" + type: "TEXT_EMBEDDING" + embedding_dims: 1024 + embedding_similarity: "dot_product" + index: true + store: false + # 嵌套variants字段 - name: "variants" type: "JSON" @@ -239,6 +247,10 @@ query_config: enable_text_embedding: true enable_query_rewrite: true + # Embedding field names (if not set, will auto-detect from fields) + text_embedding_field: "name_embedding" # Field name for text embeddings + image_embedding_field: null # Field name for image embeddings (if not set, will auto-detect) + # Translation API (DeepL) translation_service: "deepl" translation_api_key: null # Set via environment variable diff --git a/config/config_loader.py b/config/config_loader.py index aea676e..85e176d 100644 --- a/config/config_loader.py +++ b/config/config_loader.py @@ -54,6 +54,10 @@ class QueryConfig: translation_glossary_id: Optional[str] = None # DeepL glossary ID for custom terminology translation_context: str = "e-commerce product search" # Context hint for translation + # Embedding field names - if not set, will auto-detect from fields + text_embedding_field: Optional[str] = None # Field name for text embeddings (e.g., "name_embedding") + image_embedding_field: Optional[str] = None # Field name for image embeddings (e.g., "image_embedding") + # ES source fields configuration - fields to return in search results source_fields: List[str] = field(default_factory=lambda: [ "id", "spuId", "skuNo", "spuNo", "title", "enSpuName", "brandId", @@ -213,7 +217,9 @@ class ConfigLoader: translation_api_key=query_config_data.get("translation_api_key"), translation_service=query_config_data.get("translation_service", "deepl"), translation_glossary_id=query_config_data.get("translation_glossary_id"), - translation_context=query_config_data.get("translation_context", "e-commerce product search") + translation_context=query_config_data.get("translation_context", "e-commerce product search"), + text_embedding_field=query_config_data.get("text_embedding_field"), + image_embedding_field=query_config_data.get("image_embedding_field") ) # Parse ranking config diff --git a/config/env_config.py b/config/env_config.py index 6d6c82c..e77a4f2 100644 --- a/config/env_config.py +++ b/config/env_config.py @@ -2,11 +2,12 @@ Centralized configuration management for SearchEngine. Loads configuration from environment variables and .env file. +This module provides a single point for loading .env and setting defaults. +All configuration variables are exported directly - no need for getter functions. """ import os from pathlib import Path -from typing import Dict, Any from dotenv import load_dotenv # Load .env file from project root @@ -56,26 +57,6 @@ DB_CONFIG = { } -def get_es_config() -> Dict[str, Any]: - """Get Elasticsearch configuration.""" - return ES_CONFIG.copy() - - -def get_redis_config() -> Dict[str, Any]: - """Get Redis configuration.""" - return REDIS_CONFIG.copy() - - -def get_deepl_key() -> str: - """Get DeepL API key.""" - return DEEPL_AUTH_KEY - - -def get_db_config() -> Dict[str, Any]: - """Get MySQL database configuration.""" - return DB_CONFIG.copy() - - def print_config(): """Print current configuration (with sensitive data masked).""" print("=" * 60) diff --git a/embeddings/image_encoder.py b/embeddings/image_encoder.py index 0ecaf6d..f415b28 100644 --- a/embeddings/image_encoder.py +++ b/embeddings/image_encoder.py @@ -1,31 +1,24 @@ """ -Image embedding encoder using CN-CLIP model. +Image embedding encoder using network service. -Generates 1024-dimensional vectors for images using the CN-CLIP ViT-H-14 model. +Generates embeddings via HTTP API service running on localhost:5001. """ import sys import os -import io import requests -import torch import numpy as np from PIL import Image import logging import threading -from typing import List, Optional, Union -import cn_clip.clip as clip -from cn_clip.clip import load_from_name +from typing import List, Optional, Union, Dict, Any - -# DEFAULT_MODEL_NAME = "ViT-L-14-336" # ["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"] -DEFAULT_MODEL_NAME = "ViT-H-14" -MODEL_DOWNLOAD_DIR = "/data/tw/uat/EsSearcher" +logger = logging.getLogger(__name__) class CLIPImageEncoder: """ - CLIP Image Encoder for generating image embeddings using cn_clip. + Image Encoder for generating image embeddings using network service. Thread-safe singleton pattern. """ @@ -33,111 +26,80 @@ class CLIPImageEncoder: _instance = None _lock = threading.Lock() - def __new__(cls, model_name=DEFAULT_MODEL_NAME, device=None): + def __new__(cls, service_url='http://localhost:5001'): with cls._lock: if cls._instance is None: cls._instance = super(CLIPImageEncoder, cls).__new__(cls) - print(f"[CLIPImageEncoder] Creating new instance with model: {model_name}") - cls._instance._initialize_model(model_name, device) + logger.info(f"Creating CLIPImageEncoder instance with service URL: {service_url}") + cls._instance.service_url = service_url + cls._instance.endpoint = f"{service_url}/embedding/generate_image_embeddings" return cls._instance - def _initialize_model(self, model_name, device): - """Initialize the CLIP model using cn_clip""" - try: - self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") - self.model, self.preprocess = load_from_name( - model_name, - device=self.device, - download_root=MODEL_DOWNLOAD_DIR - ) - self.model.eval() - self.model_name = model_name - print(f"[CLIPImageEncoder] Model {model_name} initialized successfully on device {self.device}") - - except Exception as e: - print(f"[CLIPImageEncoder] Failed to initialize model: {str(e)}") - raise + def _call_service(self, request_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Call the embedding service API. - def validate_image(self, image_data: bytes) -> Image.Image: - """Validate image data and return PIL Image if valid""" - try: - image_stream = io.BytesIO(image_data) - image = Image.open(image_stream) - image.verify() - image_stream.seek(0) - image = Image.open(image_stream) - if image.mode != 'RGB': - image = image.convert('RGB') - return image - except Exception as e: - raise ValueError(f"Invalid image data: {str(e)}") + Args: + request_data: List of dictionaries with id and pic_url fields - def download_image(self, url: str, timeout: int = 10) -> bytes: - """Download image from URL""" + Returns: + List of dictionaries with id, pic_url, embedding and error fields + """ try: - if url.startswith(('http://', 'https://')): - response = requests.get(url, timeout=timeout) - if response.status_code != 200: - raise ValueError(f"HTTP {response.status_code}") - return response.content - else: - # Local file path - with open(url, 'rb') as f: - return f.read() - except Exception as e: - raise ValueError(f"Failed to download image from {url}: {str(e)}") - - def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image: - """Preprocess image for CLIP model""" - # Resize if too large - if max(image.size) > max_size: - ratio = max_size / max(image.size) - new_size = tuple(int(dim * ratio) for dim in image.size) - image = image.resize(new_size, Image.Resampling.LANCZOS) - return image - - def encode_text(self, text): - """Encode text to embedding vector using cn_clip""" - text_data = clip.tokenize([text] if type(text) == str else text).to(self.device) - with torch.no_grad(): - text_features = self.model.encode_text(text_data) - text_features /= text_features.norm(dim=-1, keepdim=True) - return text_features + response = requests.post( + self.endpoint, + json=request_data, + timeout=60 + ) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + logger.error(f"CLIPImageEncoder service request failed: {e}", exc_info=True) + raise def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: - """Encode image to embedding vector using cn_clip""" - if not isinstance(image, Image.Image): - raise ValueError("CLIPImageEncoder.encode_image Input must be a PIL.Image") + """ + Encode image to embedding vector using network service. - try: - infer_data = self.preprocess(image).unsqueeze(0).to(self.device) - with torch.no_grad(): - image_features = self.model.encode_image(infer_data) - image_features /= image_features.norm(dim=-1, keepdim=True) - return image_features.cpu().numpy().astype('float32')[0] - except Exception as e: - print(f"Failed to process image. Reason: {str(e)}") - return None + Note: This method is kept for compatibility but the service only works with URLs. + """ + logger.warning("encode_image with PIL Image not supported by service, returning None") + return None def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: - """Complete pipeline: download, validate, preprocess and encode image from URL""" - try: - # Download image - image_data = self.download_image(url) - - # Validate image - image = self.validate_image(image_data) - - # Preprocess image - image = self.preprocess_image(image) + """ + Generate image embedding via network service using URL. - # Encode image - embedding = self.encode_image(image) + Args: + url: Image URL to process - return embedding + Returns: + Embedding vector or None if failed + """ + try: + # Prepare request data + request_data = [{ + "id": "image_0", + "pic_url": url + }] + + # Call service + response_data = self._call_service(request_data) + + # Process response + if response_data and len(response_data) > 0: + response_item = response_data[0] + if response_item.get("embedding"): + return np.array(response_item["embedding"], dtype=np.float32) + else: + logger.warning(f"No embedding for URL {url}, error: {response_item.get('error', 'Unknown error')}") + return None + else: + logger.warning(f"No response for URL {url}") + return None except Exception as e: - print(f"Error processing image from URL {url}: {str(e)}") + logger.error(f"Failed to process image from URL {url}: {str(e)}", exc_info=True) return None def encode_batch( @@ -146,33 +108,71 @@ class CLIPImageEncoder: batch_size: int = 8 ) -> List[Optional[np.ndarray]]: """ - Encode a batch of images efficiently. + Encode a batch of images efficiently via network service. Args: images: List of image URLs or PIL Images - batch_size: Batch size for processing + batch_size: Batch size for processing (used for service requests) Returns: List of embeddings (or None for failed images) """ - results = [] - - for i in range(0, len(images), batch_size): - batch = images[i:i + batch_size] - batch_embeddings = [] - - for img in batch: - if isinstance(img, str): - # URL or file path - emb = self.encode_image_from_url(img) - elif isinstance(img, Image.Image): - # PIL Image - emb = self.encode_image(img) - else: - emb = None - - batch_embeddings.append(emb) - - results.extend(batch_embeddings) + # Initialize results with None for all images + results = [None] * len(images) + + # Filter out PIL Images since service only supports URLs + url_images = [] + url_indices = [] + + for i, img in enumerate(images): + if isinstance(img, str): + url_images.append(img) + url_indices.append(i) + elif isinstance(img, Image.Image): + logger.warning(f"PIL Image at index {i} not supported by service, returning None") + # results[i] is already None + + # Process URLs in batches + for i in range(0, len(url_images), batch_size): + batch_urls = url_images[i:i + batch_size] + batch_indices = url_indices[i:i + batch_size] + + # Prepare request data + request_data = [] + for j, url in enumerate(batch_urls): + request_data.append({ + "id": f"image_{j}", + "pic_url": url + }) + + try: + # Call service + response_data = self._call_service(request_data) + + # Process response + batch_results = [] + for j, url in enumerate(batch_urls): + response_item = None + for item in response_data: + if str(item.get("id")) == f"image_{j}": + response_item = item + break + + if response_item and response_item.get("embedding"): + batch_results.append(np.array(response_item["embedding"], dtype=np.float32)) + else: + error_msg = response_item.get("error", "Unknown error") if response_item else "No response" + logger.warning(f"Failed to encode URL {url}: {error_msg}") + batch_results.append(None) + + # Insert results at the correct positions + for j, result in enumerate(batch_results): + results[batch_indices[j]] = result + + except Exception as e: + logger.error(f"Batch processing failed: {e}", exc_info=True) + # Fill with None for this batch + for j in range(len(batch_urls)): + results[batch_indices[j]] = None return results diff --git a/embeddings/image_encoder__local.py b/embeddings/image_encoder__local.py new file mode 100644 index 0000000..0ecaf6d --- /dev/null +++ b/embeddings/image_encoder__local.py @@ -0,0 +1,178 @@ +""" +Image embedding encoder using CN-CLIP model. + +Generates 1024-dimensional vectors for images using the CN-CLIP ViT-H-14 model. +""" + +import sys +import os +import io +import requests +import torch +import numpy as np +from PIL import Image +import logging +import threading +from typing import List, Optional, Union +import cn_clip.clip as clip +from cn_clip.clip import load_from_name + + +# DEFAULT_MODEL_NAME = "ViT-L-14-336" # ["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"] +DEFAULT_MODEL_NAME = "ViT-H-14" +MODEL_DOWNLOAD_DIR = "/data/tw/uat/EsSearcher" + + +class CLIPImageEncoder: + """ + CLIP Image Encoder for generating image embeddings using cn_clip. + + Thread-safe singleton pattern. + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls, model_name=DEFAULT_MODEL_NAME, device=None): + with cls._lock: + if cls._instance is None: + cls._instance = super(CLIPImageEncoder, cls).__new__(cls) + print(f"[CLIPImageEncoder] Creating new instance with model: {model_name}") + cls._instance._initialize_model(model_name, device) + return cls._instance + + def _initialize_model(self, model_name, device): + """Initialize the CLIP model using cn_clip""" + try: + self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") + self.model, self.preprocess = load_from_name( + model_name, + device=self.device, + download_root=MODEL_DOWNLOAD_DIR + ) + self.model.eval() + self.model_name = model_name + print(f"[CLIPImageEncoder] Model {model_name} initialized successfully on device {self.device}") + + except Exception as e: + print(f"[CLIPImageEncoder] Failed to initialize model: {str(e)}") + raise + + def validate_image(self, image_data: bytes) -> Image.Image: + """Validate image data and return PIL Image if valid""" + try: + image_stream = io.BytesIO(image_data) + image = Image.open(image_stream) + image.verify() + image_stream.seek(0) + image = Image.open(image_stream) + if image.mode != 'RGB': + image = image.convert('RGB') + return image + except Exception as e: + raise ValueError(f"Invalid image data: {str(e)}") + + def download_image(self, url: str, timeout: int = 10) -> bytes: + """Download image from URL""" + try: + if url.startswith(('http://', 'https://')): + response = requests.get(url, timeout=timeout) + if response.status_code != 200: + raise ValueError(f"HTTP {response.status_code}") + return response.content + else: + # Local file path + with open(url, 'rb') as f: + return f.read() + except Exception as e: + raise ValueError(f"Failed to download image from {url}: {str(e)}") + + def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image: + """Preprocess image for CLIP model""" + # Resize if too large + if max(image.size) > max_size: + ratio = max_size / max(image.size) + new_size = tuple(int(dim * ratio) for dim in image.size) + image = image.resize(new_size, Image.Resampling.LANCZOS) + return image + + def encode_text(self, text): + """Encode text to embedding vector using cn_clip""" + text_data = clip.tokenize([text] if type(text) == str else text).to(self.device) + with torch.no_grad(): + text_features = self.model.encode_text(text_data) + text_features /= text_features.norm(dim=-1, keepdim=True) + return text_features + + def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: + """Encode image to embedding vector using cn_clip""" + if not isinstance(image, Image.Image): + raise ValueError("CLIPImageEncoder.encode_image Input must be a PIL.Image") + + try: + infer_data = self.preprocess(image).unsqueeze(0).to(self.device) + with torch.no_grad(): + image_features = self.model.encode_image(infer_data) + image_features /= image_features.norm(dim=-1, keepdim=True) + return image_features.cpu().numpy().astype('float32')[0] + except Exception as e: + print(f"Failed to process image. Reason: {str(e)}") + return None + + def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: + """Complete pipeline: download, validate, preprocess and encode image from URL""" + try: + # Download image + image_data = self.download_image(url) + + # Validate image + image = self.validate_image(image_data) + + # Preprocess image + image = self.preprocess_image(image) + + # Encode image + embedding = self.encode_image(image) + + return embedding + + except Exception as e: + print(f"Error processing image from URL {url}: {str(e)}") + return None + + def encode_batch( + self, + images: List[Union[str, Image.Image]], + batch_size: int = 8 + ) -> List[Optional[np.ndarray]]: + """ + Encode a batch of images efficiently. + + Args: + images: List of image URLs or PIL Images + batch_size: Batch size for processing + + Returns: + List of embeddings (or None for failed images) + """ + results = [] + + for i in range(0, len(images), batch_size): + batch = images[i:i + batch_size] + batch_embeddings = [] + + for img in batch: + if isinstance(img, str): + # URL or file path + emb = self.encode_image_from_url(img) + elif isinstance(img, Image.Image): + # PIL Image + emb = self.encode_image(img) + else: + emb = None + + batch_embeddings.append(emb) + + results.extend(batch_embeddings) + + return results diff --git a/embeddings/text_encoder.py b/embeddings/text_encoder.py index d2a893c..43369fb 100644 --- a/embeddings/text_encoder.py +++ b/embeddings/text_encoder.py @@ -1,122 +1,149 @@ """ -Text embedding encoder using BGE-M3 model. +Text embedding encoder using network service. -Generates 1024-dimensional vectors for text using the BGE-M3 multilingual model. +Generates embeddings via HTTP API service running on localhost:5001. """ import sys -import torch -from sentence_transformers import SentenceTransformer +import requests import time import threading -from modelscope import snapshot_download -from transformers import AutoModel -import os import numpy as np -from typing import List, Union +import logging +from typing import List, Union, Dict, Any + +logger = logging.getLogger(__name__) class BgeEncoder: """ - Singleton text encoder using BGE-M3 model. + Singleton text encoder using network service. - Thread-safe singleton pattern ensures only one model instance exists. + Thread-safe singleton pattern ensures only one instance exists. """ _instance = None _lock = threading.Lock() - def __new__(cls, model_dir='Xorbits/bge-m3'): + def __new__(cls, service_url='http://localhost:5001'): with cls._lock: if cls._instance is None: cls._instance = super(BgeEncoder, cls).__new__(cls) - print(f"[BgeEncoder] Creating a new instance with model directory: {model_dir}") - cls._instance.model = SentenceTransformer(snapshot_download(model_dir)) - print("[BgeEncoder] New instance has been created") + logger.info(f"Creating BgeEncoder instance with service URL: {service_url}") + cls._instance.service_url = service_url + cls._instance.endpoint = f"{service_url}/embedding/generate_embeddings" return cls._instance + def _call_service(self, request_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Call the embedding service API. + + Args: + request_data: List of dictionaries with id and text fields + + Returns: + List of dictionaries with id and embedding fields + """ + try: + response = requests.post( + self.endpoint, + json=request_data, + timeout=60 + ) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + logger.error(f"BgeEncoder service request failed: {e}", exc_info=True) + raise + def encode( self, sentences: Union[str, List[str]], normalize_embeddings: bool = True, - device: str = 'cuda', + device: str = 'cpu', batch_size: int = 32 ) -> np.ndarray: """ - Encode text into embeddings. + Encode text into embeddings via network service. Args: sentences: Single string or list of strings to encode - normalize_embeddings: Whether to normalize embeddings - device: Device to use ('cuda' or 'cpu') - batch_size: Batch size for encoding + normalize_embeddings: Whether to normalize embeddings (ignored for service) + device: Device parameter ignored for service compatibility + batch_size: Batch size for processing (used for service requests) Returns: numpy array of shape (n, 1024) containing embeddings """ - # Move model to specified device - if device == 'gpu': - device = 'cuda' + # Convert single string to list + if isinstance(sentences, str): + sentences = [sentences] - # Try requested device, fallback to CPU if CUDA fails - try: - if device == 'cuda': - # Check CUDA memory first - import torch - if torch.cuda.is_available(): - # Check if we have enough memory (at least 1GB free) - free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated() - if free_memory < 1024 * 1024 * 1024: # 1GB - print(f"[BgeEncoder] CUDA memory insufficient ({free_memory/1024/1024:.1f}MB free), falling back to CPU") - device = 'cpu' - else: - print(f"[BgeEncoder] CUDA not available, using CPU") - device = 'cpu' + # Prepare request data + request_data = [] + for i, text in enumerate(sentences): + request_item = { + "id": str(i), + "name_zh": text + } - self.model = self.model.to(device) + # Add English and Russian fields as empty for now + # Could be enhanced with language detection in the future + request_item["name_en"] = None + request_item["name_ru"] = None - embeddings = self.model.encode( - sentences, - normalize_embeddings=normalize_embeddings, - device=device, - show_progress_bar=False, - batch_size=batch_size - ) + request_data.append(request_item) + + try: + # Call service + response_data = self._call_service(request_data) + + # Process response + embeddings = [] + for i, text in enumerate(sentences): + # Find corresponding response by ID + response_item = None + for item in response_data: + if str(item.get("id")) == str(i): + response_item = item + break + + if response_item: + # Try Chinese embedding first, then English, then Russian + embedding = None + for lang in ["embedding_zh", "embedding_en", "embedding_ru"]: + if lang in response_item and response_item[lang] is not None: + embedding = response_item[lang] + break + + if embedding is not None: + embeddings.append(embedding) + else: + logger.warning(f"No embedding found for text {i}: {text[:50]}...") + embeddings.append([0.0] * 1024) + else: + logger.warning(f"No response found for text {i}") + embeddings.append([0.0] * 1024) - return embeddings + return np.array(embeddings, dtype=np.float32) except Exception as e: - print(f"[BgeEncoder] Device {device} failed: {e}") - if device != 'cpu': - print(f"[BgeEncoder] Falling back to CPU") - try: - self.model = self.model.to('cpu') - embeddings = self.model.encode( - sentences, - normalize_embeddings=normalize_embeddings, - device='cpu', - show_progress_bar=False, - batch_size=batch_size - ) - return embeddings - except Exception as e2: - print(f"[BgeEncoder] CPU also failed: {e2}") - raise - else: - raise + logger.error(f"Failed to encode texts: {e}", exc_info=True) + # Return zero embeddings as fallback + return np.zeros((len(sentences), 1024), dtype=np.float32) def encode_batch( self, texts: List[str], batch_size: int = 32, - device: str = 'cuda' + device: str = 'cpu' ) -> np.ndarray: """ - Encode a batch of texts efficiently. + Encode a batch of texts efficiently via network service. Args: texts: List of texts to encode batch_size: Batch size for processing - device: Device to use + device: Device parameter ignored for service compatibility Returns: numpy array of embeddings diff --git a/embeddings/text_encoder__local.py b/embeddings/text_encoder__local.py new file mode 100644 index 0000000..d2a893c --- /dev/null +++ b/embeddings/text_encoder__local.py @@ -0,0 +1,124 @@ +""" +Text embedding encoder using BGE-M3 model. + +Generates 1024-dimensional vectors for text using the BGE-M3 multilingual model. +""" + +import sys +import torch +from sentence_transformers import SentenceTransformer +import time +import threading +from modelscope import snapshot_download +from transformers import AutoModel +import os +import numpy as np +from typing import List, Union + + +class BgeEncoder: + """ + Singleton text encoder using BGE-M3 model. + + Thread-safe singleton pattern ensures only one model instance exists. + """ + _instance = None + _lock = threading.Lock() + + def __new__(cls, model_dir='Xorbits/bge-m3'): + with cls._lock: + if cls._instance is None: + cls._instance = super(BgeEncoder, cls).__new__(cls) + print(f"[BgeEncoder] Creating a new instance with model directory: {model_dir}") + cls._instance.model = SentenceTransformer(snapshot_download(model_dir)) + print("[BgeEncoder] New instance has been created") + return cls._instance + + def encode( + self, + sentences: Union[str, List[str]], + normalize_embeddings: bool = True, + device: str = 'cuda', + batch_size: int = 32 + ) -> np.ndarray: + """ + Encode text into embeddings. + + Args: + sentences: Single string or list of strings to encode + normalize_embeddings: Whether to normalize embeddings + device: Device to use ('cuda' or 'cpu') + batch_size: Batch size for encoding + + Returns: + numpy array of shape (n, 1024) containing embeddings + """ + # Move model to specified device + if device == 'gpu': + device = 'cuda' + + # Try requested device, fallback to CPU if CUDA fails + try: + if device == 'cuda': + # Check CUDA memory first + import torch + if torch.cuda.is_available(): + # Check if we have enough memory (at least 1GB free) + free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated() + if free_memory < 1024 * 1024 * 1024: # 1GB + print(f"[BgeEncoder] CUDA memory insufficient ({free_memory/1024/1024:.1f}MB free), falling back to CPU") + device = 'cpu' + else: + print(f"[BgeEncoder] CUDA not available, using CPU") + device = 'cpu' + + self.model = self.model.to(device) + + embeddings = self.model.encode( + sentences, + normalize_embeddings=normalize_embeddings, + device=device, + show_progress_bar=False, + batch_size=batch_size + ) + + return embeddings + + except Exception as e: + print(f"[BgeEncoder] Device {device} failed: {e}") + if device != 'cpu': + print(f"[BgeEncoder] Falling back to CPU") + try: + self.model = self.model.to('cpu') + embeddings = self.model.encode( + sentences, + normalize_embeddings=normalize_embeddings, + device='cpu', + show_progress_bar=False, + batch_size=batch_size + ) + return embeddings + except Exception as e2: + print(f"[BgeEncoder] CPU also failed: {e2}") + raise + else: + raise + + def encode_batch( + self, + texts: List[str], + batch_size: int = 32, + device: str = 'cuda' + ) -> np.ndarray: + """ + Encode a batch of texts efficiently. + + Args: + texts: List of texts to encode + batch_size: Batch size for processing + device: Device to use + + Returns: + numpy array of embeddings + """ + return self.encode(texts, batch_size=batch_size, device=device) diff --git a/indexer/bulk_indexer.py b/indexer/bulk_indexer.py index 291e996..9811a6e 100644 --- a/indexer/bulk_indexer.py +++ b/indexer/bulk_indexer.py @@ -7,6 +7,7 @@ Handles batch indexing of documents with progress tracking and error handling. from typing import List, Dict, Any, Optional from elasticsearch.helpers import bulk, BulkIndexError from utils.es_client import ESClient +from indexer import MappingGenerator import time @@ -232,8 +233,6 @@ class IndexingPipeline: Returns: Indexing statistics """ - from indexer.mapping_generator import MappingGenerator - # Generate and create index mapping_gen = MappingGenerator(self.config) mapping = mapping_gen.generate_mapping() diff --git a/indexer/mapping_generator.py b/indexer/mapping_generator.py index e5b387f..dfb7566 100644 --- a/indexer/mapping_generator.py +++ b/indexer/mapping_generator.py @@ -5,6 +5,8 @@ Generates Elasticsearch index mappings from search configuration. """ from typing import Dict, Any +import logging + from config import ( SearchConfig, FieldConfig, @@ -13,6 +15,8 @@ from config import ( get_default_similarity ) +logger = logging.getLogger(__name__) + class MappingGenerator: """Generates Elasticsearch mapping from search configuration.""" @@ -85,31 +89,18 @@ class MappingGenerator: Get the primary text embedding field name. Returns: - Field name or empty string if not found + Field name or empty string if not configured """ - # Look for name_embedding or first text_embedding field - for field in self.config.fields: - if field.name == "name_embedding": - return field.name - - # Otherwise return first text embedding field - for field in self.config.fields: - if "embedding" in field.name and "image" not in field.name: - return field.name - - return "" + return self.config.query_config.text_embedding_field or "" def get_image_embedding_field(self) -> str: """ Get the primary image embedding field name. Returns: - Field name or empty string if not found + Field name or empty string if not configured """ - for field in self.config.fields: - if "image" in field.name and "embedding" in field.name: - return field.name - return "" + return self.config.query_config.image_embedding_field or "" def get_field_by_name(self, field_name: str) -> FieldConfig: """ @@ -162,11 +153,11 @@ def create_index_if_not_exists(es_client, index_name: str, mapping: Dict[str, An True if index was created, False if it already exists """ if es_client.indices.exists(index=index_name): - print(f"Index '{index_name}' already exists") + logger.info(f"Index '{index_name}' already exists") return False es_client.indices.create(index=index_name, body=mapping) - print(f"Index '{index_name}' created successfully") + logger.info(f"Index '{index_name}' created successfully") return True @@ -182,11 +173,11 @@ def delete_index_if_exists(es_client, index_name: str) -> bool: True if index was deleted, False if it didn't exist """ if not es_client.indices.exists(index=index_name): - print(f"Index '{index_name}' does not exist") + logger.warning(f"Index '{index_name}' does not exist") return False es_client.indices.delete(index=index_name) - print(f"Index '{index_name}' deleted successfully") + logger.info(f"Index '{index_name}' deleted successfully") return True @@ -203,10 +194,10 @@ def update_mapping(es_client, index_name: str, new_fields: Dict[str, Any]) -> bo True if successful """ if not es_client.indices.exists(index=index_name): - print(f"Index '{index_name}' does not exist") + logger.error(f"Index '{index_name}' does not exist") return False mapping = {"properties": new_fields} es_client.indices.put_mapping(index=index_name, body=mapping) - print(f"Mapping updated for index '{index_name}'") + logger.info(f"Mapping updated for index '{index_name}'") return True diff --git a/query/query_parser.py b/query/query_parser.py index bf4de44..ab7f9c7 100644 --- a/query/query_parser.py +++ b/query/query_parser.py @@ -6,6 +6,7 @@ Handles query rewriting, translation, and embedding generation. from typing import Dict, List, Optional, Any import numpy as np +import logging from config import SearchConfig, QueryConfig from embeddings import BgeEncoder @@ -13,6 +14,8 @@ from .language_detector import LanguageDetector from .translator import Translator from .query_rewriter import QueryRewriter, QueryNormalizer +logger = logging.getLogger(__name__) + class ParsedQuery: """Container for parsed query results.""" @@ -87,7 +90,7 @@ class QueryParser: def text_encoder(self) -> BgeEncoder: """Lazy load text encoder.""" if self._text_encoder is None and self.query_config.enable_text_embedding: - print("[QueryParser] Initializing text encoder...") + logger.info("Initializing text encoder (lazy load)...") self._text_encoder = BgeEncoder() return self._text_encoder @@ -95,7 +98,7 @@ class QueryParser: def translator(self) -> Translator: """Lazy load translator.""" if self._translator is None and self.query_config.enable_translation: - print("[QueryParser] Initializing translator...") + logger.info("Initializing translator (lazy load)...") self._translator = Translator( api_key=self.query_config.translation_api_key, use_cache=True, @@ -124,18 +127,17 @@ class QueryParser: extra={'reqid': context.reqid, 'uid': context.uid} ) - # Use print statements for backward compatibility if no context def log_info(msg): - if logger: - logger.info(msg, extra={'reqid': context.reqid, 'uid': context.uid}) + if context and hasattr(context, 'logger'): + context.logger.info(msg, extra={'reqid': context.reqid, 'uid': context.uid}) else: - print(f"[QueryParser] {msg}") + logger.info(msg) def log_debug(msg): - if logger: - logger.debug(msg, extra={'reqid': context.reqid, 'uid': context.uid}) + if context and hasattr(context, 'logger'): + context.logger.debug(msg, extra={'reqid': context.reqid, 'uid': context.uid}) else: - print(f"[QueryParser] {msg}") + logger.debug(msg) # Stage 1: Normalize normalized = self.normalizer.normalize(query) @@ -246,15 +248,18 @@ class QueryParser: domain=domain ) - if logger: - logger.info( + if context and hasattr(context, 'logger'): + context.logger.info( f"查询解析完成 | 原查询: '{query}' | 最终查询: '{rewritten or query_text}' | " f"语言: {detected_lang} | 域: {domain} | " f"翻译数量: {len(translations)} | 向量: {'是' if query_vector is not None else '否'}", extra={'reqid': context.reqid, 'uid': context.uid} ) else: - print(f"[QueryParser] Parsing complete") + logger.info( + f"查询解析完成 | 原查询: '{query}' | 最终查询: '{rewritten or query_text}' | " + f"语言: {detected_lang} | 域: {domain}" + ) return result diff --git a/query/translator.py b/query/translator.py index dd9117b..2f99f75 100644 --- a/query/translator.py +++ b/query/translator.py @@ -8,6 +8,12 @@ import requests from typing import Dict, List, Optional from utils.cache import DictCache +# Try to import DEEPL_AUTH_KEY, but allow import to fail +try: + from config.env_config import DEEPL_AUTH_KEY +except ImportError: + DEEPL_AUTH_KEY = None + class Translator: """Multi-language translator using DeepL API.""" @@ -47,12 +53,8 @@ class Translator: translation_context: Context hint for translation (e.g., "e-commerce", "product search") """ # Get API key from config if not provided - if api_key is None: - try: - from config.env_config import get_deepl_key - api_key = get_deepl_key() - except ImportError: - pass + if api_key is None and DEEPL_AUTH_KEY: + api_key = DEEPL_AUTH_KEY self.api_key = api_key self.timeout = timeout diff --git a/scripts/ingest_shoplazza.py b/scripts/ingest_shoplazza.py index 697b7c1..2debe72 100644 --- a/scripts/ingest_shoplazza.py +++ b/scripts/ingest_shoplazza.py @@ -78,13 +78,12 @@ def main(): return 1 # Connect to Elasticsearch (use unified config loading) - from config.env_config import get_es_config - es_config = get_es_config() + from config.env_config import ES_CONFIG # Use provided es_host or fallback to config - es_host = args.es_host or es_config.get('host', 'http://localhost:9200') - es_username = es_config.get('username') - es_password = es_config.get('password') + es_host = args.es_host or ES_CONFIG.get('host', 'http://localhost:9200') + es_username = ES_CONFIG.get('username') + es_password = ES_CONFIG.get('password') print(f"Connecting to Elasticsearch: {es_host}") if es_username and es_password: diff --git a/search/multilang_query_builder.py b/search/multilang_query_builder.py index 156ef65..9558db2 100644 --- a/search/multilang_query_builder.py +++ b/search/multilang_query_builder.py @@ -8,11 +8,15 @@ maintaining a unified external interface. from typing import Dict, Any, List, Optional import numpy as np +import logging +import re from config import SearchConfig, IndexConfig from query import ParsedQuery from .es_query_builder import ESQueryBuilder +logger = logging.getLogger(__name__) + class MultiLanguageQueryBuilder(ESQueryBuilder): """ @@ -139,20 +143,18 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): min_score=min_score ) - print(f"[MultiLangQueryBuilder] Building query for domain: {domain}") - print(f"[MultiLangQueryBuilder] Detected language: {parsed_query.detected_language}") - print(f"[MultiLangQueryBuilder] Available translations: {list(parsed_query.translations.keys())}") + logger.debug(f"Building query for domain: {domain}, language: {parsed_query.detected_language}") # Build query clause with multi-language support if query_node and isinstance(query_node, tuple) and len(query_node) > 0: # Handle boolean query from tuple (AST, score) ast_node = query_node[0] query_clause = self._build_boolean_query_from_tuple(ast_node) - print(f"[MultiLangQueryBuilder] Using boolean query: {query_clause}") + logger.debug(f"Using boolean query") elif query_node and hasattr(query_node, 'operator') and query_node.operator != 'TERM': # Handle boolean query using base class method query_clause = self._build_boolean_query(query_node) - print(f"[MultiLangQueryBuilder] Using boolean query: {query_clause}") + logger.debug(f"Using boolean query") else: # Handle text query with multi-language support query_clause = self._build_multilang_text_query(parsed_query, domain_config) @@ -171,7 +173,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): } } inner_bool_should.append(knn_query) - print(f"[MultiLangQueryBuilder] KNN query added: field={self.text_embedding_field}, k={knn_k}, num_candidates={knn_num_candidates}") + logger.info(f"KNN query added: field={self.text_embedding_field}, k={knn_k}") else: # Debug why KNN is not added reasons = [] @@ -181,7 +183,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): reasons.append("query_vector is None") if not self.text_embedding_field: reasons.append(f"text_embedding_field is not set (current: {self.text_embedding_field})") - print(f"[MultiLangQueryBuilder] KNN query NOT added. Reasons: {', '.join(reasons) if reasons else 'unknown'}") + logger.debug(f"KNN query NOT added. Reasons: {', '.join(reasons) if reasons else 'unknown'}") # 构建内层bool结构 inner_bool = { @@ -342,7 +344,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): "_name": f"{domain_config.name}_{detected_lang}_query" } }) - print(f"[MultiLangQueryBuilder] Added query for detected language '{detected_lang}' on fields: {target_fields}") + logger.debug(f"Added query for detected language '{detected_lang}'") # 2. Query in translated languages (only for languages in mapping) for lang, translation in parsed_query.translations.items(): @@ -361,11 +363,11 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): "_name": f"{domain_config.name}_{lang}_translated_query" } }) - print(f"[MultiLangQueryBuilder] Added translated query for language '{lang}' on fields: {target_fields}") + logger.debug(f"Added translated query for language '{lang}'") # 3. Fallback: query all fields in mapping if no language-specific query was built if not should_clauses: - print(f"[MultiLangQueryBuilder] No language mapping matched, using all fields from mapping") + logger.debug("No language mapping matched, using all fields from mapping") # Use all fields from all languages in the mapping all_mapped_fields = [] for lang_fields in domain_config.language_field_mapping.values(): @@ -445,7 +447,6 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): operator = node[0].__name__ elif str(node[0]).startswith('('): # String representation of constructor call - import re match = re.match(r'(\w+)\(', str(node[0])) if match: operator = match.group(1) diff --git a/search/searcher.py b/search/searcher.py index 87f28da..ee0443f 100644 --- a/search/searcher.py +++ b/search/searcher.py @@ -6,11 +6,13 @@ Handles query parsing, boolean expressions, ranking, and result formatting. from typing import Dict, Any, List, Optional, Union import time +import logging from config import SearchConfig from utils.es_client import ESClient from query import QueryParser, ParsedQuery from indexer import MappingGenerator +from embeddings import CLIPImageEncoder from .boolean_parser import BooleanParser, QueryNode from .es_query_builder import ESQueryBuilder from .multilang_query_builder import MultiLanguageQueryBuilder @@ -19,6 +21,8 @@ from context.request_context import RequestContext, RequestContextStage, create_ from api.models import FacetResult, FacetValue from api.result_formatter import ResultFormatter +logger = logging.getLogger(__name__) + class SearchResult: """Container for search results (外部友好格式).""" @@ -476,7 +480,6 @@ class Searcher: raise ValueError("Image embedding field not configured") # Generate image embedding - from embeddings import CLIPImageEncoder image_encoder = CLIPImageEncoder() image_vector = image_encoder.encode_image_from_url(image_url) @@ -575,7 +578,7 @@ class Searcher: ) return response.get('_source') except Exception as e: - print(f"[Searcher] Failed to get document {doc_id}: {e}") + logger.error(f"Failed to get document {doc_id}: {e}", exc_info=True) return None def _standardize_facets( diff --git a/utils/es_client.py b/utils/es_client.py index 2bd5114..ae08dfd 100644 --- a/utils/es_client.py +++ b/utils/es_client.py @@ -3,8 +3,18 @@ Elasticsearch client wrapper. """ from elasticsearch import Elasticsearch +from elasticsearch.helpers import bulk from typing import Dict, Any, List, Optional import os +import logging + +# Try to import ES_CONFIG, but allow import to fail +try: + from config.env_config import ES_CONFIG +except ImportError: + ES_CONFIG = None + +logger = logging.getLogger(__name__) class ESClient: @@ -56,7 +66,7 @@ class ESClient: try: return self.client.ping() except Exception as e: - print(f"Failed to ping Elasticsearch: {e}") + logger.error(f"Failed to ping Elasticsearch: {e}", exc_info=True) return False def create_index(self, index_name: str, body: Dict[str, Any]) -> bool: @@ -72,12 +82,10 @@ class ESClient: """ try: self.client.indices.create(index=index_name, body=body) - print(f"Index '{index_name}' created successfully") + logger.info(f"Index '{index_name}' created successfully") return True except Exception as e: - print(f"ERROR: Failed to create index '{index_name}': {e}") - import traceback - traceback.print_exc() + logger.error(f"Failed to create index '{index_name}': {e}", exc_info=True) return False def delete_index(self, index_name: str) -> bool: @@ -93,13 +101,13 @@ class ESClient: try: if self.client.indices.exists(index=index_name): self.client.indices.delete(index=index_name) - print(f"Index '{index_name}' deleted successfully") + logger.info(f"Index '{index_name}' deleted successfully") return True else: - print(f"Index '{index_name}' does not exist") + logger.warning(f"Index '{index_name}' does not exist") return False except Exception as e: - print(f"Failed to delete index '{index_name}': {e}") + logger.error(f"Failed to delete index '{index_name}': {e}", exc_info=True) return False def index_exists(self, index_name: str) -> bool: @@ -117,8 +125,6 @@ class ESClient: Returns: Dictionary with results """ - from elasticsearch.helpers import bulk - actions = [] for doc in docs: action = { @@ -140,7 +146,7 @@ class ESClient: 'errors': failed } except Exception as e: - print(f"Bulk indexing failed: {e}") + logger.error(f"Bulk indexing failed: {e}", exc_info=True) return { 'success': 0, 'failed': len(docs), @@ -174,7 +180,7 @@ class ESClient: from_=from_ ) except Exception as e: - print(f"Search failed: {e}") + logger.error(f"Search failed: {e}", exc_info=True) return { 'hits': { 'total': {'value': 0}, @@ -188,7 +194,7 @@ class ESClient: try: return self.client.indices.get_mapping(index=index_name) except Exception as e: - print(f"Failed to get mapping for '{index_name}': {e}") + logger.error(f"Failed to get mapping for '{index_name}': {e}", exc_info=True) return {} def refresh(self, index_name: str) -> bool: @@ -197,7 +203,7 @@ class ESClient: self.client.indices.refresh(index=index_name) return True except Exception as e: - print(f"Failed to refresh index '{index_name}': {e}") + logger.error(f"Failed to refresh index '{index_name}': {e}", exc_info=True) return False def count(self, index_name: str, body: Optional[Dict[str, Any]] = None) -> int: @@ -215,7 +221,7 @@ class ESClient: result = self.client.count(index=index_name, body=body) return result['count'] except Exception as e: - print(f"Count failed: {e}") + logger.error(f"Count failed: {e}", exc_info=True) return 0 @@ -231,15 +237,13 @@ def get_es_client_from_env() -> ESClient: Returns: ESClient instance """ - try: - from config.env_config import get_es_config - es_config = get_es_config() + if ES_CONFIG: return ESClient( - hosts=[es_config['host']], - username=es_config.get('username'), - password=es_config.get('password') + hosts=[ES_CONFIG['host']], + username=ES_CONFIG.get('username'), + password=ES_CONFIG.get('password') ) - except ImportError: + else: # Fallback to env variables return ESClient( hosts=[os.getenv('ES_HOST', 'http://localhost:9200')], -- libgit2 0.21.2