Commit 7bfb99466aff30553a9efc1b275b46624471e659
1 parent
9c712e64
向量化模块
Showing
11 changed files
with
454 additions
and
356 deletions
Show diff stats
README.md
| ... | ... | @@ -4,7 +4,7 @@ |
| 4 | 4 | 语义: |
| 5 | 5 | |
| 6 | 6 | |
| 7 | -query anchor | |
| 7 | +query anchor | |
| 8 | 8 | 我想给elasticsearch 增加字段 query anchor ,即哪些query点击到了这个doc,一个doc下面有多个query anchor,每个query anchor又有这两个属性:weight、dweight,分别代表 query在doc下的点击分布权重、doc在query下的点击分布权重。请问该如何设计这两个ES字段。 |
| 9 | 9 | |
| 10 | 10 | 需要有zh en两套query anchor,因为他们的解析器不一样。 |
| ... | ... | @@ -89,6 +89,13 @@ docker run -d --name es -p 9200:9200 elasticsearch:8.11.0 |
| 89 | 89 | # 4. 启动服务 |
| 90 | 90 | ./run.sh |
| 91 | 91 | |
| 92 | +# (可选)启动本地向量服务(BGE-M3 / CN-CLIP,本地模型推理) | |
| 93 | +# 提供: POST http://localhost:6005/embed/text | |
| 94 | +# POST http://localhost:6005/embed/image | |
| 95 | +./scripts/start_embedding_service.sh | |
| 96 | +# | |
| 97 | +# 详细说明见:`embeddings/README.md` | |
| 98 | + | |
| 92 | 99 | # 5. 调用文本搜索 API |
| 93 | 100 | curl -X POST http://localhost:6002/search/ \ |
| 94 | 101 | -H "Content-Type: application/json" \ | ... | ... |
| ... | ... | @@ -0,0 +1,40 @@ |
| 1 | +## 向量化模块(embeddings) | |
| 2 | + | |
| 3 | +这个目录是一个完整的“向量化模块”,包含: | |
| 4 | + | |
| 5 | +- **HTTP 客户端**:`text_encoder.py` / `image_encoder.py`(供搜索/索引模块调用) | |
| 6 | +- **本地模型实现**:`bge_model.py` / `clip_model.py` | |
| 7 | +- **向量化服务(FastAPI)**:`server.py` | |
| 8 | +- **统一配置**:`config.py` | |
| 9 | + | |
| 10 | +### 服务接口 | |
| 11 | + | |
| 12 | +- `POST /embed/text` | |
| 13 | + - 入参:`["文本1", "文本2", ...]` | |
| 14 | + - 出参:`[[...], null, ...]`(与输入按 index 对齐,失败为 `null`) | |
| 15 | + | |
| 16 | +- `POST /embed/image` | |
| 17 | + - 入参:`["url或本地路径1", ...]` | |
| 18 | + - 出参:`[[...], null, ...]`(与输入按 index 对齐,失败为 `null`) | |
| 19 | + | |
| 20 | +### 启动服务 | |
| 21 | + | |
| 22 | +使用仓库脚本启动(默认端口 6005): | |
| 23 | + | |
| 24 | +```bash | |
| 25 | +./scripts/start_embedding_service.sh | |
| 26 | +``` | |
| 27 | + | |
| 28 | +### 修改配置 | |
| 29 | + | |
| 30 | +编辑 `embeddings/config.py`: | |
| 31 | + | |
| 32 | +- `PORT`: 服务端口(默认 6005) | |
| 33 | +- `TEXT_MODEL_DIR`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE` | |
| 34 | +- `IMAGE_MODEL_NAME`, `IMAGE_DEVICE` | |
| 35 | + | |
| 36 | +### 目录说明(旧文件) | |
| 37 | + | |
| 38 | +旧的 `vector_service/` 目录与 `*_encoder__local.py` 文件已经废弃,统一由本目录实现与维护。 | |
| 39 | + | |
| 40 | + | ... | ... |
embeddings/__init__.py
| 1 | -"""Embeddings package initialization.""" | |
| 1 | +""" | |
| 2 | +Embeddings module. | |
| 2 | 3 | |
| 3 | -from .text_encoder import BgeEncoder | |
| 4 | -from .image_encoder import CLIPImageEncoder | |
| 4 | +Important: keep package import lightweight. | |
| 5 | 5 | |
| 6 | -__all__ = [ | |
| 7 | - 'BgeEncoder', | |
| 8 | - 'CLIPImageEncoder', | |
| 9 | -] | |
| 6 | +Some callers do: | |
| 7 | + - `from embeddings import BgeEncoder` | |
| 8 | + - `from embeddings import CLIPImageEncoder` | |
| 9 | + | |
| 10 | +But the underlying implementations may import heavy optional deps (Pillow, torch, etc). | |
| 11 | +To avoid importing those at package import time (and to allow the embedding service to boot | |
| 12 | +without importing client code), we provide small lazy factories here. | |
| 13 | +""" | |
| 14 | + | |
| 15 | + | |
| 16 | +class BgeEncoder(object): | |
| 17 | + """Lazy factory for `embeddings.text_encoder.BgeEncoder`.""" | |
| 18 | + | |
| 19 | + def __new__(cls, *args, **kwargs): | |
| 20 | + from .text_encoder import BgeEncoder as _Real | |
| 21 | + | |
| 22 | + return _Real(*args, **kwargs) | |
| 23 | + | |
| 24 | + | |
| 25 | +class CLIPImageEncoder(object): | |
| 26 | + """Lazy factory for `embeddings.image_encoder.CLIPImageEncoder`.""" | |
| 27 | + | |
| 28 | + def __new__(cls, *args, **kwargs): | |
| 29 | + from .image_encoder import CLIPImageEncoder as _Real | |
| 30 | + | |
| 31 | + return _Real(*args, **kwargs) | |
| 32 | + | |
| 33 | + | |
| 34 | +__all__ = ["BgeEncoder", "CLIPImageEncoder"] | ... | ... |
| ... | ... | @@ -0,0 +1,81 @@ |
| 1 | +""" | |
| 2 | +BGE-M3 local text embedding implementation. | |
| 3 | + | |
| 4 | +Internal model implementation used by the embedding service. | |
| 5 | +""" | |
| 6 | + | |
| 7 | +import threading | |
| 8 | +from typing import List, Union | |
| 9 | + | |
| 10 | +import numpy as np | |
| 11 | +from sentence_transformers import SentenceTransformer | |
| 12 | +from modelscope import snapshot_download | |
| 13 | + | |
| 14 | + | |
| 15 | +class BgeTextModel(object): | |
| 16 | + """ | |
| 17 | + Thread-safe singleton text encoder using BGE-M3 model (local inference). | |
| 18 | + """ | |
| 19 | + | |
| 20 | + _instance = None | |
| 21 | + _lock = threading.Lock() | |
| 22 | + | |
| 23 | + def __new__(cls, model_dir: str = "Xorbits/bge-m3"): | |
| 24 | + with cls._lock: | |
| 25 | + if cls._instance is None: | |
| 26 | + cls._instance = super(BgeTextModel, cls).__new__(cls) | |
| 27 | + cls._instance.model = SentenceTransformer(snapshot_download(model_dir)) | |
| 28 | + return cls._instance | |
| 29 | + | |
| 30 | + def encode( | |
| 31 | + self, | |
| 32 | + sentences: Union[str, List[str]], | |
| 33 | + normalize_embeddings: bool = True, | |
| 34 | + device: str = "cuda", | |
| 35 | + batch_size: int = 32, | |
| 36 | + ) -> np.ndarray: | |
| 37 | + if device == "gpu": | |
| 38 | + device = "cuda" | |
| 39 | + | |
| 40 | + # Try requested device, fallback to CPU if CUDA fails | |
| 41 | + try: | |
| 42 | + if device == "cuda": | |
| 43 | + import torch | |
| 44 | + | |
| 45 | + if torch.cuda.is_available(): | |
| 46 | + free_memory = ( | |
| 47 | + torch.cuda.get_device_properties(0).total_memory | |
| 48 | + - torch.cuda.memory_allocated() | |
| 49 | + ) | |
| 50 | + if free_memory < 1024 * 1024 * 1024: # 1GB | |
| 51 | + device = "cpu" | |
| 52 | + else: | |
| 53 | + device = "cpu" | |
| 54 | + | |
| 55 | + self.model = self.model.to(device) | |
| 56 | + embeddings = self.model.encode( | |
| 57 | + sentences, | |
| 58 | + normalize_embeddings=normalize_embeddings, | |
| 59 | + device=device, | |
| 60 | + show_progress_bar=False, | |
| 61 | + batch_size=batch_size, | |
| 62 | + ) | |
| 63 | + return embeddings | |
| 64 | + | |
| 65 | + except Exception: | |
| 66 | + if device != "cpu": | |
| 67 | + self.model = self.model.to("cpu") | |
| 68 | + embeddings = self.model.encode( | |
| 69 | + sentences, | |
| 70 | + normalize_embeddings=normalize_embeddings, | |
| 71 | + device="cpu", | |
| 72 | + show_progress_bar=False, | |
| 73 | + batch_size=batch_size, | |
| 74 | + ) | |
| 75 | + return embeddings | |
| 76 | + raise | |
| 77 | + | |
| 78 | + def encode_batch(self, texts: List[str], batch_size: int = 32, device: str = "cuda") -> np.ndarray: | |
| 79 | + return self.encode(texts, batch_size=batch_size, device=device) | |
| 80 | + | |
| 81 | + | ... | ... |
embeddings/image_encoder__local.py renamed to embeddings/clip_model.py
| 1 | 1 | """ |
| 2 | -Image embedding encoder using CN-CLIP model. | |
| 2 | +CN-CLIP local image embedding implementation. | |
| 3 | 3 | |
| 4 | -Generates 1024-dimensional vectors for images using the CN-CLIP ViT-H-14 model. | |
| 4 | +Internal model implementation used by the embedding service. | |
| 5 | 5 | """ |
| 6 | 6 | |
| 7 | -import sys | |
| 8 | -import os | |
| 9 | 7 | import io |
| 8 | +import threading | |
| 9 | +from typing import List, Optional, Union | |
| 10 | + | |
| 11 | +import numpy as np | |
| 10 | 12 | import requests |
| 11 | 13 | import torch |
| 12 | -import numpy as np | |
| 13 | 14 | from PIL import Image |
| 14 | -import logging | |
| 15 | -import threading | |
| 16 | -from typing import List, Optional, Union | |
| 17 | -import cn_clip.clip as clip | |
| 18 | 15 | from cn_clip.clip import load_from_name |
| 16 | +import cn_clip.clip as clip | |
| 19 | 17 | |
| 20 | 18 | |
| 21 | -# DEFAULT_MODEL_NAME = "ViT-L-14-336" # ["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"] | |
| 22 | 19 | DEFAULT_MODEL_NAME = "ViT-H-14" |
| 23 | 20 | MODEL_DOWNLOAD_DIR = "/data/tw/uat/EsSearcher" |
| 24 | 21 | |
| 25 | 22 | |
| 26 | -class CLIPImageEncoder: | |
| 23 | +class ClipImageModel(object): | |
| 27 | 24 | """ |
| 28 | - CLIP Image Encoder for generating image embeddings using cn_clip. | |
| 29 | - | |
| 30 | - Thread-safe singleton pattern. | |
| 25 | + Thread-safe singleton image encoder using cn_clip (local inference). | |
| 31 | 26 | """ |
| 32 | 27 | |
| 33 | 28 | _instance = None |
| 34 | 29 | _lock = threading.Lock() |
| 35 | 30 | |
| 36 | - def __new__(cls, model_name=DEFAULT_MODEL_NAME, device=None): | |
| 31 | + def __new__(cls, model_name: str = DEFAULT_MODEL_NAME, device: Optional[str] = None): | |
| 37 | 32 | with cls._lock: |
| 38 | 33 | if cls._instance is None: |
| 39 | - cls._instance = super(CLIPImageEncoder, cls).__new__(cls) | |
| 40 | - print(f"[CLIPImageEncoder] Creating new instance with model: {model_name}") | |
| 34 | + cls._instance = super(ClipImageModel, cls).__new__(cls) | |
| 41 | 35 | cls._instance._initialize_model(model_name, device) |
| 42 | 36 | return cls._instance |
| 43 | 37 | |
| 44 | - def _initialize_model(self, model_name, device): | |
| 45 | - """Initialize the CLIP model using cn_clip""" | |
| 46 | - try: | |
| 47 | - self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
| 48 | - self.model, self.preprocess = load_from_name( | |
| 49 | - model_name, | |
| 50 | - device=self.device, | |
| 51 | - download_root=MODEL_DOWNLOAD_DIR | |
| 52 | - ) | |
| 53 | - self.model.eval() | |
| 54 | - self.model_name = model_name | |
| 55 | - print(f"[CLIPImageEncoder] Model {model_name} initialized successfully on device {self.device}") | |
| 56 | - | |
| 57 | - except Exception as e: | |
| 58 | - print(f"[CLIPImageEncoder] Failed to initialize model: {str(e)}") | |
| 59 | - raise | |
| 38 | + def _initialize_model(self, model_name: str, device: Optional[str]): | |
| 39 | + self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
| 40 | + self.model, self.preprocess = load_from_name( | |
| 41 | + model_name, device=self.device, download_root=MODEL_DOWNLOAD_DIR | |
| 42 | + ) | |
| 43 | + self.model.eval() | |
| 44 | + self.model_name = model_name | |
| 60 | 45 | |
| 61 | 46 | def validate_image(self, image_data: bytes) -> Image.Image: |
| 62 | - """Validate image data and return PIL Image if valid""" | |
| 63 | - try: | |
| 64 | - image_stream = io.BytesIO(image_data) | |
| 65 | - image = Image.open(image_stream) | |
| 66 | - image.verify() | |
| 67 | - image_stream.seek(0) | |
| 68 | - image = Image.open(image_stream) | |
| 69 | - if image.mode != 'RGB': | |
| 70 | - image = image.convert('RGB') | |
| 71 | - return image | |
| 72 | - except Exception as e: | |
| 73 | - raise ValueError(f"Invalid image data: {str(e)}") | |
| 47 | + image_stream = io.BytesIO(image_data) | |
| 48 | + image = Image.open(image_stream) | |
| 49 | + image.verify() | |
| 50 | + image_stream.seek(0) | |
| 51 | + image = Image.open(image_stream) | |
| 52 | + if image.mode != "RGB": | |
| 53 | + image = image.convert("RGB") | |
| 54 | + return image | |
| 74 | 55 | |
| 75 | 56 | def download_image(self, url: str, timeout: int = 10) -> bytes: |
| 76 | - """Download image from URL""" | |
| 77 | - try: | |
| 78 | - if url.startswith(('http://', 'https://')): | |
| 79 | - response = requests.get(url, timeout=timeout) | |
| 80 | - if response.status_code != 200: | |
| 81 | - raise ValueError(f"HTTP {response.status_code}") | |
| 82 | - return response.content | |
| 83 | - else: | |
| 84 | - # Local file path | |
| 85 | - with open(url, 'rb') as f: | |
| 86 | - return f.read() | |
| 87 | - except Exception as e: | |
| 88 | - raise ValueError(f"Failed to download image from {url}: {str(e)}") | |
| 57 | + if url.startswith(("http://", "https://")): | |
| 58 | + response = requests.get(url, timeout=timeout) | |
| 59 | + if response.status_code != 200: | |
| 60 | + raise ValueError("HTTP %s" % response.status_code) | |
| 61 | + return response.content | |
| 62 | + with open(url, "rb") as f: | |
| 63 | + return f.read() | |
| 89 | 64 | |
| 90 | 65 | def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image: |
| 91 | - """Preprocess image for CLIP model""" | |
| 92 | - # Resize if too large | |
| 93 | 66 | if max(image.size) > max_size: |
| 94 | - ratio = max_size / max(image.size) | |
| 67 | + ratio = float(max_size) / float(max(image.size)) | |
| 95 | 68 | new_size = tuple(int(dim * ratio) for dim in image.size) |
| 96 | 69 | image = image.resize(new_size, Image.Resampling.LANCZOS) |
| 97 | 70 | return image |
| 98 | 71 | |
| 99 | 72 | def encode_text(self, text): |
| 100 | - """Encode text to embedding vector using cn_clip""" | |
| 101 | - text_data = clip.tokenize([text] if type(text) == str else text).to(self.device) | |
| 73 | + text_data = clip.tokenize([text] if isinstance(text, str) else text).to(self.device) | |
| 102 | 74 | with torch.no_grad(): |
| 103 | 75 | text_features = self.model.encode_text(text_data) |
| 104 | 76 | text_features /= text_features.norm(dim=-1, keepdim=True) |
| 105 | 77 | return text_features |
| 106 | 78 | |
| 107 | 79 | def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: |
| 108 | - """Encode image to embedding vector using cn_clip""" | |
| 109 | 80 | if not isinstance(image, Image.Image): |
| 110 | - raise ValueError("CLIPImageEncoder.encode_image Input must be a PIL.Image") | |
| 111 | - | |
| 112 | - try: | |
| 113 | - infer_data = self.preprocess(image).unsqueeze(0).to(self.device) | |
| 114 | - with torch.no_grad(): | |
| 115 | - image_features = self.model.encode_image(infer_data) | |
| 116 | - image_features /= image_features.norm(dim=-1, keepdim=True) | |
| 117 | - return image_features.cpu().numpy().astype('float32')[0] | |
| 118 | - except Exception as e: | |
| 119 | - print(f"Failed to process image. Reason: {str(e)}") | |
| 120 | - return None | |
| 81 | + raise ValueError("ClipImageModel.encode_image input must be a PIL.Image") | |
| 82 | + infer_data = self.preprocess(image).unsqueeze(0).to(self.device) | |
| 83 | + with torch.no_grad(): | |
| 84 | + image_features = self.model.encode_image(infer_data) | |
| 85 | + image_features /= image_features.norm(dim=-1, keepdim=True) | |
| 86 | + return image_features.cpu().numpy().astype("float32")[0] | |
| 121 | 87 | |
| 122 | 88 | def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: |
| 123 | - """Complete pipeline: download, validate, preprocess and encode image from URL""" | |
| 124 | - try: | |
| 125 | - # Download image | |
| 126 | - image_data = self.download_image(url) | |
| 127 | - | |
| 128 | - # Validate image | |
| 129 | - image = self.validate_image(image_data) | |
| 130 | - | |
| 131 | - # Preprocess image | |
| 132 | - image = self.preprocess_image(image) | |
| 133 | - | |
| 134 | - # Encode image | |
| 135 | - embedding = self.encode_image(image) | |
| 136 | - | |
| 137 | - return embedding | |
| 138 | - | |
| 139 | - except Exception as e: | |
| 140 | - print(f"Error processing image from URL {url}: {str(e)}") | |
| 141 | - return None | |
| 142 | - | |
| 143 | - def encode_batch( | |
| 144 | - self, | |
| 145 | - images: List[Union[str, Image.Image]], | |
| 146 | - batch_size: int = 8 | |
| 147 | - ) -> List[Optional[np.ndarray]]: | |
| 148 | - """ | |
| 149 | - Encode a batch of images efficiently. | |
| 150 | - | |
| 151 | - Args: | |
| 152 | - images: List of image URLs or PIL Images | |
| 153 | - batch_size: Batch size for processing | |
| 154 | - | |
| 155 | - Returns: | |
| 156 | - List of embeddings (or None for failed images) | |
| 157 | - """ | |
| 158 | - results = [] | |
| 89 | + image_data = self.download_image(url) | |
| 90 | + image = self.validate_image(image_data) | |
| 91 | + image = self.preprocess_image(image) | |
| 92 | + return self.encode_image(image) | |
| 159 | 93 | |
| 94 | + def encode_batch(self, images: List[Union[str, Image.Image]], batch_size: int = 8) -> List[Optional[np.ndarray]]: | |
| 95 | + results: List[Optional[np.ndarray]] = [] | |
| 160 | 96 | for i in range(0, len(images), batch_size): |
| 161 | - batch = images[i:i + batch_size] | |
| 162 | - batch_embeddings = [] | |
| 163 | - | |
| 97 | + batch = images[i : i + batch_size] | |
| 164 | 98 | for img in batch: |
| 165 | 99 | if isinstance(img, str): |
| 166 | - # URL or file path | |
| 167 | - emb = self.encode_image_from_url(img) | |
| 100 | + results.append(self.encode_image_from_url(img)) | |
| 168 | 101 | elif isinstance(img, Image.Image): |
| 169 | - # PIL Image | |
| 170 | - emb = self.encode_image(img) | |
| 102 | + results.append(self.encode_image(img)) | |
| 171 | 103 | else: |
| 172 | - emb = None | |
| 173 | - | |
| 174 | - batch_embeddings.append(emb) | |
| 104 | + results.append(None) | |
| 105 | + return results | |
| 175 | 106 | |
| 176 | - results.extend(batch_embeddings) | |
| 177 | 107 | |
| 178 | - return results | ... | ... |
| ... | ... | @@ -0,0 +1,33 @@ |
| 1 | +""" | |
| 2 | +Embedding module configuration. | |
| 3 | + | |
| 4 | +This module is intentionally a plain Python file (no env var parsing, no extra deps). | |
| 5 | +Edit values here to configure: | |
| 6 | +- server host/port | |
| 7 | +- local model settings (paths/devices/batch sizes) | |
| 8 | +""" | |
| 9 | + | |
| 10 | +from typing import Optional | |
| 11 | + | |
| 12 | + | |
| 13 | +class EmbeddingConfig(object): | |
| 14 | + # Server | |
| 15 | + HOST = "0.0.0.0" | |
| 16 | + PORT = 6005 | |
| 17 | + | |
| 18 | + # Text embeddings (BGE-M3) | |
| 19 | + TEXT_MODEL_DIR = "Xorbits/bge-m3" | |
| 20 | + TEXT_DEVICE = "cuda" # "cuda" or "cpu" (model may fall back to CPU if needed) | |
| 21 | + TEXT_BATCH_SIZE = 32 | |
| 22 | + | |
| 23 | + # Image embeddings (CN-CLIP) | |
| 24 | + IMAGE_MODEL_NAME = "ViT-H-14" | |
| 25 | + IMAGE_DEVICE = None # type: Optional[str] # "cuda" / "cpu" / None(auto) | |
| 26 | + | |
| 27 | + # Service behavior | |
| 28 | + IMAGE_BATCH_SIZE = 8 | |
| 29 | + | |
| 30 | + | |
| 31 | +CONFIG = EmbeddingConfig() | |
| 32 | + | |
| 33 | + | ... | ... |
embeddings/image_encoder.py
| 1 | 1 | """ |
| 2 | 2 | Image embedding encoder using network service. |
| 3 | 3 | |
| 4 | -Generates embeddings via HTTP API service running on localhost:5001. | |
| 4 | +Generates embeddings via HTTP API service (default localhost:6005). | |
| 5 | 5 | """ |
| 6 | 6 | |
| 7 | 7 | import sys |
| ... | ... | @@ -26,24 +26,25 @@ class CLIPImageEncoder: |
| 26 | 26 | _instance = None |
| 27 | 27 | _lock = threading.Lock() |
| 28 | 28 | |
| 29 | - def __new__(cls, service_url='http://localhost:5001'): | |
| 29 | + def __new__(cls, service_url: Optional[str] = None): | |
| 30 | 30 | with cls._lock: |
| 31 | 31 | if cls._instance is None: |
| 32 | 32 | cls._instance = super(CLIPImageEncoder, cls).__new__(cls) |
| 33 | - logger.info(f"Creating CLIPImageEncoder instance with service URL: {service_url}") | |
| 34 | - cls._instance.service_url = service_url | |
| 35 | - cls._instance.endpoint = f"{service_url}/embedding/generate_image_embeddings" | |
| 33 | + resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL", "http://localhost:6005") | |
| 34 | + logger.info(f"Creating CLIPImageEncoder instance with service URL: {resolved_url}") | |
| 35 | + cls._instance.service_url = resolved_url | |
| 36 | + cls._instance.endpoint = f"{resolved_url}/embed/image" | |
| 36 | 37 | return cls._instance |
| 37 | 38 | |
| 38 | - def _call_service(self, request_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| 39 | + def _call_service(self, request_data: List[str]) -> List[Any]: | |
| 39 | 40 | """ |
| 40 | 41 | Call the embedding service API. |
| 41 | 42 | |
| 42 | 43 | Args: |
| 43 | - request_data: List of dictionaries with id and pic_url fields | |
| 44 | + request_data: List of image URLs / local file paths | |
| 44 | 45 | |
| 45 | 46 | Returns: |
| 46 | - List of dictionaries with id, pic_url, embedding and error fields | |
| 47 | + List of embeddings (list[float]) or nulls (None), aligned to input order | |
| 47 | 48 | """ |
| 48 | 49 | try: |
| 49 | 50 | response = requests.post( |
| ... | ... | @@ -77,26 +78,11 @@ class CLIPImageEncoder: |
| 77 | 78 | Embedding vector or None if failed |
| 78 | 79 | """ |
| 79 | 80 | try: |
| 80 | - # Prepare request data | |
| 81 | - request_data = [{ | |
| 82 | - "id": "image_0", | |
| 83 | - "pic_url": url | |
| 84 | - }] | |
| 85 | - | |
| 86 | - # Call service | |
| 87 | - response_data = self._call_service(request_data) | |
| 88 | - | |
| 89 | - # Process response | |
| 90 | - if response_data and len(response_data) > 0: | |
| 91 | - response_item = response_data[0] | |
| 92 | - if response_item.get("embedding"): | |
| 93 | - return np.array(response_item["embedding"], dtype=np.float32) | |
| 94 | - else: | |
| 95 | - logger.warning(f"No embedding for URL {url}, error: {response_item.get('error', 'Unknown error')}") | |
| 96 | - return None | |
| 97 | - else: | |
| 98 | - logger.warning(f"No response for URL {url}") | |
| 99 | - return None | |
| 81 | + response_data = self._call_service([url]) | |
| 82 | + if response_data and len(response_data) > 0 and response_data[0] is not None: | |
| 83 | + return np.array(response_data[0], dtype=np.float32) | |
| 84 | + logger.warning(f"No embedding for URL {url}") | |
| 85 | + return None | |
| 100 | 86 | |
| 101 | 87 | except Exception as e: |
| 102 | 88 | logger.error(f"Failed to process image from URL {url}: {str(e)}", exc_info=True) |
| ... | ... | @@ -137,32 +123,17 @@ class CLIPImageEncoder: |
| 137 | 123 | batch_urls = url_images[i:i + batch_size] |
| 138 | 124 | batch_indices = url_indices[i:i + batch_size] |
| 139 | 125 | |
| 140 | - # Prepare request data | |
| 141 | - request_data = [] | |
| 142 | - for j, url in enumerate(batch_urls): | |
| 143 | - request_data.append({ | |
| 144 | - "id": f"image_{j}", | |
| 145 | - "pic_url": url | |
| 146 | - }) | |
| 147 | - | |
| 148 | 126 | try: |
| 149 | 127 | # Call service |
| 150 | - response_data = self._call_service(request_data) | |
| 128 | + response_data = self._call_service(batch_urls) | |
| 151 | 129 | |
| 152 | - # Process response | |
| 130 | + # Process response (aligned list) | |
| 153 | 131 | batch_results = [] |
| 154 | 132 | for j, url in enumerate(batch_urls): |
| 155 | - response_item = None | |
| 156 | - for item in response_data: | |
| 157 | - if str(item.get("id")) == f"image_{j}": | |
| 158 | - response_item = item | |
| 159 | - break | |
| 160 | - | |
| 161 | - if response_item and response_item.get("embedding"): | |
| 162 | - batch_results.append(np.array(response_item["embedding"], dtype=np.float32)) | |
| 133 | + if response_data and j < len(response_data) and response_data[j] is not None: | |
| 134 | + batch_results.append(np.array(response_data[j], dtype=np.float32)) | |
| 163 | 135 | else: |
| 164 | - error_msg = response_item.get("error", "Unknown error") if response_item else "No response" | |
| 165 | - logger.warning(f"Failed to encode URL {url}: {error_msg}") | |
| 136 | + logger.warning(f"Failed to encode URL {url}: no embedding") | |
| 166 | 137 | batch_results.append(None) |
| 167 | 138 | |
| 168 | 139 | # Insert results at the correct positions | ... | ... |
| ... | ... | @@ -0,0 +1,122 @@ |
| 1 | +""" | |
| 2 | +Embedding service (FastAPI). | |
| 3 | + | |
| 4 | +API (simple list-in, list-out; aligned by index; failures -> null): | |
| 5 | +- POST /embed/text body: ["text1", "text2", ...] -> [[...], null, ...] | |
| 6 | +- POST /embed/image body: ["url_or_path1", ...] -> [[...], null, ...] | |
| 7 | +""" | |
| 8 | + | |
| 9 | +import threading | |
| 10 | +from typing import Any, Dict, List, Optional | |
| 11 | + | |
| 12 | +import numpy as np | |
| 13 | +from fastapi import FastAPI | |
| 14 | + | |
| 15 | +from embeddings.config import CONFIG | |
| 16 | +from embeddings.bge_model import BgeTextModel | |
| 17 | +from embeddings.clip_model import ClipImageModel | |
| 18 | + | |
| 19 | + | |
| 20 | +app = FastAPI(title="SearchEngine Embedding Service", version="1.0.0") | |
| 21 | + | |
| 22 | +_text_model = None | |
| 23 | +_image_model = None | |
| 24 | + | |
| 25 | +_text_init_lock = threading.Lock() | |
| 26 | +_image_init_lock = threading.Lock() | |
| 27 | + | |
| 28 | +_text_encode_lock = threading.Lock() | |
| 29 | +_image_encode_lock = threading.Lock() | |
| 30 | + | |
| 31 | + | |
| 32 | +def _get_text_model(): | |
| 33 | + global _text_model | |
| 34 | + if _text_model is None: | |
| 35 | + with _text_init_lock: | |
| 36 | + if _text_model is None: | |
| 37 | + _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR) | |
| 38 | + return _text_model | |
| 39 | + | |
| 40 | + | |
| 41 | +def _get_image_model(): | |
| 42 | + global _image_model | |
| 43 | + if _image_model is None: | |
| 44 | + with _image_init_lock: | |
| 45 | + if _image_model is None: | |
| 46 | + _image_model = ClipImageModel( | |
| 47 | + model_name=CONFIG.IMAGE_MODEL_NAME, | |
| 48 | + device=CONFIG.IMAGE_DEVICE, | |
| 49 | + ) | |
| 50 | + return _image_model | |
| 51 | + | |
| 52 | + | |
| 53 | +def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: | |
| 54 | + if embedding is None: | |
| 55 | + return None | |
| 56 | + if not isinstance(embedding, np.ndarray): | |
| 57 | + embedding = np.array(embedding, dtype=np.float32) | |
| 58 | + if embedding.ndim != 1: | |
| 59 | + embedding = embedding.reshape(-1) | |
| 60 | + return embedding.astype(np.float32).tolist() | |
| 61 | + | |
| 62 | + | |
| 63 | +@app.get("/health") | |
| 64 | +def health() -> Dict[str, Any]: | |
| 65 | + return {"status": "ok"} | |
| 66 | + | |
| 67 | + | |
| 68 | +@app.post("/embed/text") | |
| 69 | +def embed_text(texts: List[str]) -> List[Optional[List[float]]]: | |
| 70 | + model = _get_text_model() | |
| 71 | + out: List[Optional[List[float]]] = [None] * len(texts) | |
| 72 | + | |
| 73 | + indexed_texts: List[tuple] = [] | |
| 74 | + for i, t in enumerate(texts): | |
| 75 | + if t is None: | |
| 76 | + continue | |
| 77 | + if not isinstance(t, str): | |
| 78 | + t = str(t) | |
| 79 | + t = t.strip() | |
| 80 | + if not t: | |
| 81 | + continue | |
| 82 | + indexed_texts.append((i, t)) | |
| 83 | + | |
| 84 | + if not indexed_texts: | |
| 85 | + return out | |
| 86 | + | |
| 87 | + batch_texts = [t for _, t in indexed_texts] | |
| 88 | + try: | |
| 89 | + with _text_encode_lock: | |
| 90 | + embs = model.encode_batch( | |
| 91 | + batch_texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE | |
| 92 | + ) | |
| 93 | + for j, (idx, _t) in enumerate(indexed_texts): | |
| 94 | + out[idx] = _as_list(embs[j]) | |
| 95 | + except Exception: | |
| 96 | + # keep Nones | |
| 97 | + pass | |
| 98 | + return out | |
| 99 | + | |
| 100 | + | |
| 101 | +@app.post("/embed/image") | |
| 102 | +def embed_image(images: List[str]) -> List[Optional[List[float]]]: | |
| 103 | + model = _get_image_model() | |
| 104 | + out: List[Optional[List[float]]] = [None] * len(images) | |
| 105 | + | |
| 106 | + with _image_encode_lock: | |
| 107 | + for i, url_or_path in enumerate(images): | |
| 108 | + try: | |
| 109 | + if url_or_path is None: | |
| 110 | + continue | |
| 111 | + if not isinstance(url_or_path, str): | |
| 112 | + url_or_path = str(url_or_path) | |
| 113 | + url_or_path = url_or_path.strip() | |
| 114 | + if not url_or_path: | |
| 115 | + continue | |
| 116 | + emb = model.encode_image_from_url(url_or_path) | |
| 117 | + out[i] = _as_list(emb) | |
| 118 | + except Exception: | |
| 119 | + out[i] = None | |
| 120 | + return out | |
| 121 | + | |
| 122 | + | ... | ... |
embeddings/text_encoder.py
| 1 | 1 | """ |
| 2 | 2 | Text embedding encoder using network service. |
| 3 | 3 | |
| 4 | -Generates embeddings via HTTP API service running on localhost:5001. | |
| 4 | +Generates embeddings via HTTP API service (default localhost:6005). | |
| 5 | 5 | """ |
| 6 | 6 | |
| 7 | 7 | import sys |
| ... | ... | @@ -11,6 +11,7 @@ import threading |
| 11 | 11 | import numpy as np |
| 12 | 12 | import pickle |
| 13 | 13 | import redis |
| 14 | +import os | |
| 14 | 15 | from datetime import timedelta |
| 15 | 16 | from typing import List, Union, Dict, Any, Optional |
| 16 | 17 | import logging |
| ... | ... | @@ -33,13 +34,14 @@ class BgeEncoder: |
| 33 | 34 | _instance = None |
| 34 | 35 | _lock = threading.Lock() |
| 35 | 36 | |
| 36 | - def __new__(cls, service_url='http://localhost:5001'): | |
| 37 | + def __new__(cls, service_url: Optional[str] = None): | |
| 37 | 38 | with cls._lock: |
| 38 | 39 | if cls._instance is None: |
| 39 | 40 | cls._instance = super(BgeEncoder, cls).__new__(cls) |
| 40 | - logger.info(f"Creating BgeEncoder instance with service URL: {service_url}") | |
| 41 | - cls._instance.service_url = service_url | |
| 42 | - cls._instance.endpoint = f"{service_url}/embedding/generate_embeddings" | |
| 41 | + resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL", "http://localhost:6005") | |
| 42 | + logger.info(f"Creating BgeEncoder instance with service URL: {resolved_url}") | |
| 43 | + cls._instance.service_url = resolved_url | |
| 44 | + cls._instance.endpoint = f"{resolved_url}/embed/text" | |
| 43 | 45 | |
| 44 | 46 | # Initialize Redis cache |
| 45 | 47 | try: |
| ... | ... | @@ -62,15 +64,15 @@ class BgeEncoder: |
| 62 | 64 | cls._instance.redis_client = None |
| 63 | 65 | return cls._instance |
| 64 | 66 | |
| 65 | - def _call_service(self, request_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| 67 | + def _call_service(self, request_data: List[str]) -> List[Any]: | |
| 66 | 68 | """ |
| 67 | 69 | Call the embedding service API. |
| 68 | 70 | |
| 69 | 71 | Args: |
| 70 | - request_data: List of dictionaries with id and text fields | |
| 72 | + request_data: List of texts | |
| 71 | 73 | |
| 72 | 74 | Returns: |
| 73 | - List of dictionaries with id and embedding fields | |
| 75 | + List of embeddings (list[float]) or nulls (None), aligned to input order | |
| 74 | 76 | """ |
| 75 | 77 | try: |
| 76 | 78 | response = requests.post( |
| ... | ... | @@ -126,19 +128,7 @@ class BgeEncoder: |
| 126 | 128 | uncached_texts.append(text) |
| 127 | 129 | |
| 128 | 130 | # Prepare request data for uncached texts (after cache check) |
| 129 | - request_data = [] | |
| 130 | - for i, text in enumerate(uncached_texts): | |
| 131 | - request_item = { | |
| 132 | - "id": str(uncached_indices[i]), | |
| 133 | - "name_zh": text | |
| 134 | - } | |
| 135 | - | |
| 136 | - # Add English and Russian fields as empty for now | |
| 137 | - # Could be enhanced with language detection in the future | |
| 138 | - request_item["name_en"] = None | |
| 139 | - request_item["name_ru"] = None | |
| 140 | - | |
| 141 | - request_data.append(request_item) | |
| 131 | + request_data = list(uncached_texts) | |
| 142 | 132 | |
| 143 | 133 | # If there are uncached texts, call service |
| 144 | 134 | if uncached_texts: |
| ... | ... | @@ -149,43 +139,27 @@ class BgeEncoder: |
| 149 | 139 | # Process response |
| 150 | 140 | for i, text in enumerate(uncached_texts): |
| 151 | 141 | original_idx = uncached_indices[i] |
| 152 | - # Find corresponding response by ID | |
| 153 | - response_item = None | |
| 154 | - for item in response_data: | |
| 155 | - if str(item.get("id")) == str(original_idx): | |
| 156 | - response_item = item | |
| 157 | - break | |
| 158 | - | |
| 159 | - if response_item: | |
| 160 | - # Try Chinese embedding first, then English, then Russian | |
| 142 | + if response_data and i < len(response_data): | |
| 143 | + embedding = response_data[i] | |
| 144 | + else: | |
| 161 | 145 | embedding = None |
| 162 | - for lang in ["embedding_zh", "embedding_en", "embedding_ru"]: | |
| 163 | - if lang in response_item and response_item[lang] is not None: | |
| 164 | - embedding = response_item[lang] | |
| 165 | - break | |
| 166 | 146 | |
| 167 | - if embedding is not None: | |
| 168 | - embedding_array = np.array(embedding, dtype=np.float32) | |
| 169 | - # Validate embedding from service - if invalid, treat as no result | |
| 170 | - if self._is_valid_embedding(embedding_array): | |
| 171 | - embeddings[original_idx] = embedding_array | |
| 172 | - # Cache the embedding | |
| 173 | - self._set_cached_embedding(text, 'en', embedding_array) | |
| 174 | - else: | |
| 175 | - logger.warning( | |
| 176 | - f"Invalid embedding returned from service for text {original_idx} " | |
| 177 | - f"(contains NaN/Inf or invalid shape), treating as no result. " | |
| 178 | - f"Text preview: {text[:50]}..." | |
| 179 | - ) | |
| 180 | - # 不生成兜底向量,保持为 None | |
| 181 | - embeddings[original_idx] = None | |
| 147 | + if embedding is not None: | |
| 148 | + embedding_array = np.array(embedding, dtype=np.float32) | |
| 149 | + # Validate embedding from service - if invalid, treat as no result | |
| 150 | + if self._is_valid_embedding(embedding_array): | |
| 151 | + embeddings[original_idx] = embedding_array | |
| 152 | + # Cache the embedding | |
| 153 | + self._set_cached_embedding(text, 'en', embedding_array) | |
| 182 | 154 | else: |
| 183 | - logger.warning(f"No embedding found for text {original_idx}: {text[:50]}...") | |
| 184 | - # 不生成兜底向量,保持为 None | |
| 155 | + logger.warning( | |
| 156 | + f"Invalid embedding returned from service for text {original_idx} " | |
| 157 | + f"(contains NaN/Inf or invalid shape), treating as no result. " | |
| 158 | + f"Text preview: {text[:50]}..." | |
| 159 | + ) | |
| 185 | 160 | embeddings[original_idx] = None |
| 186 | 161 | else: |
| 187 | - logger.warning(f"No response found for text {original_idx}") | |
| 188 | - # 不生成兜底向量,保持为 None | |
| 162 | + logger.warning(f"No embedding found for text {original_idx}: {text[:50]}...") | |
| 189 | 163 | embeddings[original_idx] = None |
| 190 | 164 | |
| 191 | 165 | except Exception as e: | ... | ... |
embeddings/text_encoder__local.py deleted
| ... | ... | @@ -1,124 +0,0 @@ |
| 1 | -""" | |
| 2 | -Text embedding encoder using BGE-M3 model. | |
| 3 | - | |
| 4 | -Generates 1024-dimensional vectors for text using the BGE-M3 multilingual model. | |
| 5 | -""" | |
| 6 | - | |
| 7 | -import sys | |
| 8 | -import torch | |
| 9 | -from sentence_transformers import SentenceTransformer | |
| 10 | -import time | |
| 11 | -import threading | |
| 12 | -from modelscope import snapshot_download | |
| 13 | -from transformers import AutoModel | |
| 14 | -import os | |
| 15 | -import numpy as np | |
| 16 | -from typing import List, Union | |
| 17 | - | |
| 18 | - | |
| 19 | -class BgeEncoder: | |
| 20 | - """ | |
| 21 | - Singleton text encoder using BGE-M3 model. | |
| 22 | - | |
| 23 | - Thread-safe singleton pattern ensures only one model instance exists. | |
| 24 | - """ | |
| 25 | - _instance = None | |
| 26 | - _lock = threading.Lock() | |
| 27 | - | |
| 28 | - def __new__(cls, model_dir='Xorbits/bge-m3'): | |
| 29 | - with cls._lock: | |
| 30 | - if cls._instance is None: | |
| 31 | - cls._instance = super(BgeEncoder, cls).__new__(cls) | |
| 32 | - print(f"[BgeEncoder] Creating a new instance with model directory: {model_dir}") | |
| 33 | - cls._instance.model = SentenceTransformer(snapshot_download(model_dir)) | |
| 34 | - print("[BgeEncoder] New instance has been created") | |
| 35 | - return cls._instance | |
| 36 | - | |
| 37 | - def encode( | |
| 38 | - self, | |
| 39 | - sentences: Union[str, List[str]], | |
| 40 | - normalize_embeddings: bool = True, | |
| 41 | - device: str = 'cuda', | |
| 42 | - batch_size: int = 32 | |
| 43 | - ) -> np.ndarray: | |
| 44 | - """ | |
| 45 | - Encode text into embeddings. | |
| 46 | - | |
| 47 | - Args: | |
| 48 | - sentences: Single string or list of strings to encode | |
| 49 | - normalize_embeddings: Whether to normalize embeddings | |
| 50 | - device: Device to use ('cuda' or 'cpu') | |
| 51 | - batch_size: Batch size for encoding | |
| 52 | - | |
| 53 | - Returns: | |
| 54 | - numpy array of shape (n, 1024) containing embeddings | |
| 55 | - """ | |
| 56 | - # Move model to specified device | |
| 57 | - if device == 'gpu': | |
| 58 | - device = 'cuda' | |
| 59 | - | |
| 60 | - # Try requested device, fallback to CPU if CUDA fails | |
| 61 | - try: | |
| 62 | - if device == 'cuda': | |
| 63 | - # Check CUDA memory first | |
| 64 | - import torch | |
| 65 | - if torch.cuda.is_available(): | |
| 66 | - # Check if we have enough memory (at least 1GB free) | |
| 67 | - free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated() | |
| 68 | - if free_memory < 1024 * 1024 * 1024: # 1GB | |
| 69 | - print(f"[BgeEncoder] CUDA memory insufficient ({free_memory/1024/1024:.1f}MB free), falling back to CPU") | |
| 70 | - device = 'cpu' | |
| 71 | - else: | |
| 72 | - print(f"[BgeEncoder] CUDA not available, using CPU") | |
| 73 | - device = 'cpu' | |
| 74 | - | |
| 75 | - self.model = self.model.to(device) | |
| 76 | - | |
| 77 | - embeddings = self.model.encode( | |
| 78 | - sentences, | |
| 79 | - normalize_embeddings=normalize_embeddings, | |
| 80 | - device=device, | |
| 81 | - show_progress_bar=False, | |
| 82 | - batch_size=batch_size | |
| 83 | - ) | |
| 84 | - | |
| 85 | - return embeddings | |
| 86 | - | |
| 87 | - except Exception as e: | |
| 88 | - print(f"[BgeEncoder] Device {device} failed: {e}") | |
| 89 | - if device != 'cpu': | |
| 90 | - print(f"[BgeEncoder] Falling back to CPU") | |
| 91 | - try: | |
| 92 | - self.model = self.model.to('cpu') | |
| 93 | - embeddings = self.model.encode( | |
| 94 | - sentences, | |
| 95 | - normalize_embeddings=normalize_embeddings, | |
| 96 | - device='cpu', | |
| 97 | - show_progress_bar=False, | |
| 98 | - batch_size=batch_size | |
| 99 | - ) | |
| 100 | - return embeddings | |
| 101 | - except Exception as e2: | |
| 102 | - print(f"[BgeEncoder] CPU also failed: {e2}") | |
| 103 | - raise | |
| 104 | - else: | |
| 105 | - raise | |
| 106 | - | |
| 107 | - def encode_batch( | |
| 108 | - self, | |
| 109 | - texts: List[str], | |
| 110 | - batch_size: int = 32, | |
| 111 | - device: str = 'cuda' | |
| 112 | - ) -> np.ndarray: | |
| 113 | - """ | |
| 114 | - Encode a batch of texts efficiently. | |
| 115 | - | |
| 116 | - Args: | |
| 117 | - texts: List of texts to encode | |
| 118 | - batch_size: Batch size for processing | |
| 119 | - device: Device to use | |
| 120 | - | |
| 121 | - Returns: | |
| 122 | - numpy array of embeddings | |
| 123 | - """ | |
| 124 | - return self.encode(texts, batch_size=batch_size, device=device) |
| ... | ... | @@ -0,0 +1,40 @@ |
| 1 | +#!/bin/bash | |
| 2 | +# | |
| 3 | +# Start Local Embedding Service | |
| 4 | +# | |
| 5 | +# This service exposes: | |
| 6 | +# - POST /embed/text | |
| 7 | +# - POST /embed/image | |
| 8 | +# | |
| 9 | +# Defaults are defined in `embeddings/config.py` | |
| 10 | +# | |
| 11 | +set -e | |
| 12 | + | |
| 13 | +cd "$(dirname "$0")/.." | |
| 14 | + | |
| 15 | +# Load conda env if available (keep consistent with other scripts) | |
| 16 | +if [ -f "/home/tw/miniconda3/etc/profile.d/conda.sh" ]; then | |
| 17 | + source /home/tw/miniconda3/etc/profile.d/conda.sh | |
| 18 | + conda activate searchengine | |
| 19 | +fi | |
| 20 | + | |
| 21 | +EMBEDDING_SERVICE_HOST=$(python -c "from embeddings.config import CONFIG; print(CONFIG.HOST)") | |
| 22 | +EMBEDDING_SERVICE_PORT=$(python -c "from embeddings.config import CONFIG; print(CONFIG.PORT)") | |
| 23 | + | |
| 24 | +echo "========================================" | |
| 25 | +echo "Starting Local Embedding Service" | |
| 26 | +echo "========================================" | |
| 27 | +echo "Host: ${EMBEDDING_SERVICE_HOST}" | |
| 28 | +echo "Port: ${EMBEDDING_SERVICE_PORT}" | |
| 29 | +echo | |
| 30 | +echo "Tips:" | |
| 31 | +echo " - Use a single worker (GPU models cannot be safely duplicated across workers)." | |
| 32 | +echo " - Clients can set EMBEDDING_SERVICE_URL=http://localhost:${EMBEDDING_SERVICE_PORT}" | |
| 33 | +echo | |
| 34 | + | |
| 35 | +exec python -m uvicorn embeddings.server:app \ | |
| 36 | + --host "${EMBEDDING_SERVICE_HOST}" \ | |
| 37 | + --port "${EMBEDDING_SERVICE_PORT}" \ | |
| 38 | + --workers 1 | |
| 39 | + | |
| 40 | + | ... | ... |