Commit 7bfb99466aff30553a9efc1b275b46624471e659
1 parent
9c712e64
向量化模块
Showing
11 changed files
with
454 additions
and
356 deletions
Show diff stats
README.md
| @@ -4,7 +4,7 @@ | @@ -4,7 +4,7 @@ | ||
| 4 | 语义: | 4 | 语义: |
| 5 | 5 | ||
| 6 | 6 | ||
| 7 | -query anchor | 7 | +query anchor |
| 8 | 我想给elasticsearch 增加字段 query anchor ,即哪些query点击到了这个doc,一个doc下面有多个query anchor,每个query anchor又有这两个属性:weight、dweight,分别代表 query在doc下的点击分布权重、doc在query下的点击分布权重。请问该如何设计这两个ES字段。 | 8 | 我想给elasticsearch 增加字段 query anchor ,即哪些query点击到了这个doc,一个doc下面有多个query anchor,每个query anchor又有这两个属性:weight、dweight,分别代表 query在doc下的点击分布权重、doc在query下的点击分布权重。请问该如何设计这两个ES字段。 |
| 9 | 9 | ||
| 10 | 需要有zh en两套query anchor,因为他们的解析器不一样。 | 10 | 需要有zh en两套query anchor,因为他们的解析器不一样。 |
| @@ -89,6 +89,13 @@ docker run -d --name es -p 9200:9200 elasticsearch:8.11.0 | @@ -89,6 +89,13 @@ docker run -d --name es -p 9200:9200 elasticsearch:8.11.0 | ||
| 89 | # 4. 启动服务 | 89 | # 4. 启动服务 |
| 90 | ./run.sh | 90 | ./run.sh |
| 91 | 91 | ||
| 92 | +# (可选)启动本地向量服务(BGE-M3 / CN-CLIP,本地模型推理) | ||
| 93 | +# 提供: POST http://localhost:6005/embed/text | ||
| 94 | +# POST http://localhost:6005/embed/image | ||
| 95 | +./scripts/start_embedding_service.sh | ||
| 96 | +# | ||
| 97 | +# 详细说明见:`embeddings/README.md` | ||
| 98 | + | ||
| 92 | # 5. 调用文本搜索 API | 99 | # 5. 调用文本搜索 API |
| 93 | curl -X POST http://localhost:6002/search/ \ | 100 | curl -X POST http://localhost:6002/search/ \ |
| 94 | -H "Content-Type: application/json" \ | 101 | -H "Content-Type: application/json" \ |
| @@ -0,0 +1,40 @@ | @@ -0,0 +1,40 @@ | ||
| 1 | +## 向量化模块(embeddings) | ||
| 2 | + | ||
| 3 | +这个目录是一个完整的“向量化模块”,包含: | ||
| 4 | + | ||
| 5 | +- **HTTP 客户端**:`text_encoder.py` / `image_encoder.py`(供搜索/索引模块调用) | ||
| 6 | +- **本地模型实现**:`bge_model.py` / `clip_model.py` | ||
| 7 | +- **向量化服务(FastAPI)**:`server.py` | ||
| 8 | +- **统一配置**:`config.py` | ||
| 9 | + | ||
| 10 | +### 服务接口 | ||
| 11 | + | ||
| 12 | +- `POST /embed/text` | ||
| 13 | + - 入参:`["文本1", "文本2", ...]` | ||
| 14 | + - 出参:`[[...], null, ...]`(与输入按 index 对齐,失败为 `null`) | ||
| 15 | + | ||
| 16 | +- `POST /embed/image` | ||
| 17 | + - 入参:`["url或本地路径1", ...]` | ||
| 18 | + - 出参:`[[...], null, ...]`(与输入按 index 对齐,失败为 `null`) | ||
| 19 | + | ||
| 20 | +### 启动服务 | ||
| 21 | + | ||
| 22 | +使用仓库脚本启动(默认端口 6005): | ||
| 23 | + | ||
| 24 | +```bash | ||
| 25 | +./scripts/start_embedding_service.sh | ||
| 26 | +``` | ||
| 27 | + | ||
| 28 | +### 修改配置 | ||
| 29 | + | ||
| 30 | +编辑 `embeddings/config.py`: | ||
| 31 | + | ||
| 32 | +- `PORT`: 服务端口(默认 6005) | ||
| 33 | +- `TEXT_MODEL_DIR`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE` | ||
| 34 | +- `IMAGE_MODEL_NAME`, `IMAGE_DEVICE` | ||
| 35 | + | ||
| 36 | +### 目录说明(旧文件) | ||
| 37 | + | ||
| 38 | +旧的 `vector_service/` 目录与 `*_encoder__local.py` 文件已经废弃,统一由本目录实现与维护。 | ||
| 39 | + | ||
| 40 | + |
embeddings/__init__.py
| 1 | -"""Embeddings package initialization.""" | 1 | +""" |
| 2 | +Embeddings module. | ||
| 2 | 3 | ||
| 3 | -from .text_encoder import BgeEncoder | ||
| 4 | -from .image_encoder import CLIPImageEncoder | 4 | +Important: keep package import lightweight. |
| 5 | 5 | ||
| 6 | -__all__ = [ | ||
| 7 | - 'BgeEncoder', | ||
| 8 | - 'CLIPImageEncoder', | ||
| 9 | -] | 6 | +Some callers do: |
| 7 | + - `from embeddings import BgeEncoder` | ||
| 8 | + - `from embeddings import CLIPImageEncoder` | ||
| 9 | + | ||
| 10 | +But the underlying implementations may import heavy optional deps (Pillow, torch, etc). | ||
| 11 | +To avoid importing those at package import time (and to allow the embedding service to boot | ||
| 12 | +without importing client code), we provide small lazy factories here. | ||
| 13 | +""" | ||
| 14 | + | ||
| 15 | + | ||
| 16 | +class BgeEncoder(object): | ||
| 17 | + """Lazy factory for `embeddings.text_encoder.BgeEncoder`.""" | ||
| 18 | + | ||
| 19 | + def __new__(cls, *args, **kwargs): | ||
| 20 | + from .text_encoder import BgeEncoder as _Real | ||
| 21 | + | ||
| 22 | + return _Real(*args, **kwargs) | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +class CLIPImageEncoder(object): | ||
| 26 | + """Lazy factory for `embeddings.image_encoder.CLIPImageEncoder`.""" | ||
| 27 | + | ||
| 28 | + def __new__(cls, *args, **kwargs): | ||
| 29 | + from .image_encoder import CLIPImageEncoder as _Real | ||
| 30 | + | ||
| 31 | + return _Real(*args, **kwargs) | ||
| 32 | + | ||
| 33 | + | ||
| 34 | +__all__ = ["BgeEncoder", "CLIPImageEncoder"] |
| @@ -0,0 +1,81 @@ | @@ -0,0 +1,81 @@ | ||
| 1 | +""" | ||
| 2 | +BGE-M3 local text embedding implementation. | ||
| 3 | + | ||
| 4 | +Internal model implementation used by the embedding service. | ||
| 5 | +""" | ||
| 6 | + | ||
| 7 | +import threading | ||
| 8 | +from typing import List, Union | ||
| 9 | + | ||
| 10 | +import numpy as np | ||
| 11 | +from sentence_transformers import SentenceTransformer | ||
| 12 | +from modelscope import snapshot_download | ||
| 13 | + | ||
| 14 | + | ||
| 15 | +class BgeTextModel(object): | ||
| 16 | + """ | ||
| 17 | + Thread-safe singleton text encoder using BGE-M3 model (local inference). | ||
| 18 | + """ | ||
| 19 | + | ||
| 20 | + _instance = None | ||
| 21 | + _lock = threading.Lock() | ||
| 22 | + | ||
| 23 | + def __new__(cls, model_dir: str = "Xorbits/bge-m3"): | ||
| 24 | + with cls._lock: | ||
| 25 | + if cls._instance is None: | ||
| 26 | + cls._instance = super(BgeTextModel, cls).__new__(cls) | ||
| 27 | + cls._instance.model = SentenceTransformer(snapshot_download(model_dir)) | ||
| 28 | + return cls._instance | ||
| 29 | + | ||
| 30 | + def encode( | ||
| 31 | + self, | ||
| 32 | + sentences: Union[str, List[str]], | ||
| 33 | + normalize_embeddings: bool = True, | ||
| 34 | + device: str = "cuda", | ||
| 35 | + batch_size: int = 32, | ||
| 36 | + ) -> np.ndarray: | ||
| 37 | + if device == "gpu": | ||
| 38 | + device = "cuda" | ||
| 39 | + | ||
| 40 | + # Try requested device, fallback to CPU if CUDA fails | ||
| 41 | + try: | ||
| 42 | + if device == "cuda": | ||
| 43 | + import torch | ||
| 44 | + | ||
| 45 | + if torch.cuda.is_available(): | ||
| 46 | + free_memory = ( | ||
| 47 | + torch.cuda.get_device_properties(0).total_memory | ||
| 48 | + - torch.cuda.memory_allocated() | ||
| 49 | + ) | ||
| 50 | + if free_memory < 1024 * 1024 * 1024: # 1GB | ||
| 51 | + device = "cpu" | ||
| 52 | + else: | ||
| 53 | + device = "cpu" | ||
| 54 | + | ||
| 55 | + self.model = self.model.to(device) | ||
| 56 | + embeddings = self.model.encode( | ||
| 57 | + sentences, | ||
| 58 | + normalize_embeddings=normalize_embeddings, | ||
| 59 | + device=device, | ||
| 60 | + show_progress_bar=False, | ||
| 61 | + batch_size=batch_size, | ||
| 62 | + ) | ||
| 63 | + return embeddings | ||
| 64 | + | ||
| 65 | + except Exception: | ||
| 66 | + if device != "cpu": | ||
| 67 | + self.model = self.model.to("cpu") | ||
| 68 | + embeddings = self.model.encode( | ||
| 69 | + sentences, | ||
| 70 | + normalize_embeddings=normalize_embeddings, | ||
| 71 | + device="cpu", | ||
| 72 | + show_progress_bar=False, | ||
| 73 | + batch_size=batch_size, | ||
| 74 | + ) | ||
| 75 | + return embeddings | ||
| 76 | + raise | ||
| 77 | + | ||
| 78 | + def encode_batch(self, texts: List[str], batch_size: int = 32, device: str = "cuda") -> np.ndarray: | ||
| 79 | + return self.encode(texts, batch_size=batch_size, device=device) | ||
| 80 | + | ||
| 81 | + |
embeddings/image_encoder__local.py renamed to embeddings/clip_model.py
| 1 | """ | 1 | """ |
| 2 | -Image embedding encoder using CN-CLIP model. | 2 | +CN-CLIP local image embedding implementation. |
| 3 | 3 | ||
| 4 | -Generates 1024-dimensional vectors for images using the CN-CLIP ViT-H-14 model. | 4 | +Internal model implementation used by the embedding service. |
| 5 | """ | 5 | """ |
| 6 | 6 | ||
| 7 | -import sys | ||
| 8 | -import os | ||
| 9 | import io | 7 | import io |
| 8 | +import threading | ||
| 9 | +from typing import List, Optional, Union | ||
| 10 | + | ||
| 11 | +import numpy as np | ||
| 10 | import requests | 12 | import requests |
| 11 | import torch | 13 | import torch |
| 12 | -import numpy as np | ||
| 13 | from PIL import Image | 14 | from PIL import Image |
| 14 | -import logging | ||
| 15 | -import threading | ||
| 16 | -from typing import List, Optional, Union | ||
| 17 | -import cn_clip.clip as clip | ||
| 18 | from cn_clip.clip import load_from_name | 15 | from cn_clip.clip import load_from_name |
| 16 | +import cn_clip.clip as clip | ||
| 19 | 17 | ||
| 20 | 18 | ||
| 21 | -# DEFAULT_MODEL_NAME = "ViT-L-14-336" # ["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"] | ||
| 22 | DEFAULT_MODEL_NAME = "ViT-H-14" | 19 | DEFAULT_MODEL_NAME = "ViT-H-14" |
| 23 | MODEL_DOWNLOAD_DIR = "/data/tw/uat/EsSearcher" | 20 | MODEL_DOWNLOAD_DIR = "/data/tw/uat/EsSearcher" |
| 24 | 21 | ||
| 25 | 22 | ||
| 26 | -class CLIPImageEncoder: | 23 | +class ClipImageModel(object): |
| 27 | """ | 24 | """ |
| 28 | - CLIP Image Encoder for generating image embeddings using cn_clip. | ||
| 29 | - | ||
| 30 | - Thread-safe singleton pattern. | 25 | + Thread-safe singleton image encoder using cn_clip (local inference). |
| 31 | """ | 26 | """ |
| 32 | 27 | ||
| 33 | _instance = None | 28 | _instance = None |
| 34 | _lock = threading.Lock() | 29 | _lock = threading.Lock() |
| 35 | 30 | ||
| 36 | - def __new__(cls, model_name=DEFAULT_MODEL_NAME, device=None): | 31 | + def __new__(cls, model_name: str = DEFAULT_MODEL_NAME, device: Optional[str] = None): |
| 37 | with cls._lock: | 32 | with cls._lock: |
| 38 | if cls._instance is None: | 33 | if cls._instance is None: |
| 39 | - cls._instance = super(CLIPImageEncoder, cls).__new__(cls) | ||
| 40 | - print(f"[CLIPImageEncoder] Creating new instance with model: {model_name}") | 34 | + cls._instance = super(ClipImageModel, cls).__new__(cls) |
| 41 | cls._instance._initialize_model(model_name, device) | 35 | cls._instance._initialize_model(model_name, device) |
| 42 | return cls._instance | 36 | return cls._instance |
| 43 | 37 | ||
| 44 | - def _initialize_model(self, model_name, device): | ||
| 45 | - """Initialize the CLIP model using cn_clip""" | ||
| 46 | - try: | ||
| 47 | - self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | ||
| 48 | - self.model, self.preprocess = load_from_name( | ||
| 49 | - model_name, | ||
| 50 | - device=self.device, | ||
| 51 | - download_root=MODEL_DOWNLOAD_DIR | ||
| 52 | - ) | ||
| 53 | - self.model.eval() | ||
| 54 | - self.model_name = model_name | ||
| 55 | - print(f"[CLIPImageEncoder] Model {model_name} initialized successfully on device {self.device}") | ||
| 56 | - | ||
| 57 | - except Exception as e: | ||
| 58 | - print(f"[CLIPImageEncoder] Failed to initialize model: {str(e)}") | ||
| 59 | - raise | 38 | + def _initialize_model(self, model_name: str, device: Optional[str]): |
| 39 | + self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | ||
| 40 | + self.model, self.preprocess = load_from_name( | ||
| 41 | + model_name, device=self.device, download_root=MODEL_DOWNLOAD_DIR | ||
| 42 | + ) | ||
| 43 | + self.model.eval() | ||
| 44 | + self.model_name = model_name | ||
| 60 | 45 | ||
| 61 | def validate_image(self, image_data: bytes) -> Image.Image: | 46 | def validate_image(self, image_data: bytes) -> Image.Image: |
| 62 | - """Validate image data and return PIL Image if valid""" | ||
| 63 | - try: | ||
| 64 | - image_stream = io.BytesIO(image_data) | ||
| 65 | - image = Image.open(image_stream) | ||
| 66 | - image.verify() | ||
| 67 | - image_stream.seek(0) | ||
| 68 | - image = Image.open(image_stream) | ||
| 69 | - if image.mode != 'RGB': | ||
| 70 | - image = image.convert('RGB') | ||
| 71 | - return image | ||
| 72 | - except Exception as e: | ||
| 73 | - raise ValueError(f"Invalid image data: {str(e)}") | 47 | + image_stream = io.BytesIO(image_data) |
| 48 | + image = Image.open(image_stream) | ||
| 49 | + image.verify() | ||
| 50 | + image_stream.seek(0) | ||
| 51 | + image = Image.open(image_stream) | ||
| 52 | + if image.mode != "RGB": | ||
| 53 | + image = image.convert("RGB") | ||
| 54 | + return image | ||
| 74 | 55 | ||
| 75 | def download_image(self, url: str, timeout: int = 10) -> bytes: | 56 | def download_image(self, url: str, timeout: int = 10) -> bytes: |
| 76 | - """Download image from URL""" | ||
| 77 | - try: | ||
| 78 | - if url.startswith(('http://', 'https://')): | ||
| 79 | - response = requests.get(url, timeout=timeout) | ||
| 80 | - if response.status_code != 200: | ||
| 81 | - raise ValueError(f"HTTP {response.status_code}") | ||
| 82 | - return response.content | ||
| 83 | - else: | ||
| 84 | - # Local file path | ||
| 85 | - with open(url, 'rb') as f: | ||
| 86 | - return f.read() | ||
| 87 | - except Exception as e: | ||
| 88 | - raise ValueError(f"Failed to download image from {url}: {str(e)}") | 57 | + if url.startswith(("http://", "https://")): |
| 58 | + response = requests.get(url, timeout=timeout) | ||
| 59 | + if response.status_code != 200: | ||
| 60 | + raise ValueError("HTTP %s" % response.status_code) | ||
| 61 | + return response.content | ||
| 62 | + with open(url, "rb") as f: | ||
| 63 | + return f.read() | ||
| 89 | 64 | ||
| 90 | def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image: | 65 | def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image: |
| 91 | - """Preprocess image for CLIP model""" | ||
| 92 | - # Resize if too large | ||
| 93 | if max(image.size) > max_size: | 66 | if max(image.size) > max_size: |
| 94 | - ratio = max_size / max(image.size) | 67 | + ratio = float(max_size) / float(max(image.size)) |
| 95 | new_size = tuple(int(dim * ratio) for dim in image.size) | 68 | new_size = tuple(int(dim * ratio) for dim in image.size) |
| 96 | image = image.resize(new_size, Image.Resampling.LANCZOS) | 69 | image = image.resize(new_size, Image.Resampling.LANCZOS) |
| 97 | return image | 70 | return image |
| 98 | 71 | ||
| 99 | def encode_text(self, text): | 72 | def encode_text(self, text): |
| 100 | - """Encode text to embedding vector using cn_clip""" | ||
| 101 | - text_data = clip.tokenize([text] if type(text) == str else text).to(self.device) | 73 | + text_data = clip.tokenize([text] if isinstance(text, str) else text).to(self.device) |
| 102 | with torch.no_grad(): | 74 | with torch.no_grad(): |
| 103 | text_features = self.model.encode_text(text_data) | 75 | text_features = self.model.encode_text(text_data) |
| 104 | text_features /= text_features.norm(dim=-1, keepdim=True) | 76 | text_features /= text_features.norm(dim=-1, keepdim=True) |
| 105 | return text_features | 77 | return text_features |
| 106 | 78 | ||
| 107 | def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: | 79 | def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: |
| 108 | - """Encode image to embedding vector using cn_clip""" | ||
| 109 | if not isinstance(image, Image.Image): | 80 | if not isinstance(image, Image.Image): |
| 110 | - raise ValueError("CLIPImageEncoder.encode_image Input must be a PIL.Image") | ||
| 111 | - | ||
| 112 | - try: | ||
| 113 | - infer_data = self.preprocess(image).unsqueeze(0).to(self.device) | ||
| 114 | - with torch.no_grad(): | ||
| 115 | - image_features = self.model.encode_image(infer_data) | ||
| 116 | - image_features /= image_features.norm(dim=-1, keepdim=True) | ||
| 117 | - return image_features.cpu().numpy().astype('float32')[0] | ||
| 118 | - except Exception as e: | ||
| 119 | - print(f"Failed to process image. Reason: {str(e)}") | ||
| 120 | - return None | 81 | + raise ValueError("ClipImageModel.encode_image input must be a PIL.Image") |
| 82 | + infer_data = self.preprocess(image).unsqueeze(0).to(self.device) | ||
| 83 | + with torch.no_grad(): | ||
| 84 | + image_features = self.model.encode_image(infer_data) | ||
| 85 | + image_features /= image_features.norm(dim=-1, keepdim=True) | ||
| 86 | + return image_features.cpu().numpy().astype("float32")[0] | ||
| 121 | 87 | ||
| 122 | def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: | 88 | def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: |
| 123 | - """Complete pipeline: download, validate, preprocess and encode image from URL""" | ||
| 124 | - try: | ||
| 125 | - # Download image | ||
| 126 | - image_data = self.download_image(url) | ||
| 127 | - | ||
| 128 | - # Validate image | ||
| 129 | - image = self.validate_image(image_data) | ||
| 130 | - | ||
| 131 | - # Preprocess image | ||
| 132 | - image = self.preprocess_image(image) | ||
| 133 | - | ||
| 134 | - # Encode image | ||
| 135 | - embedding = self.encode_image(image) | ||
| 136 | - | ||
| 137 | - return embedding | ||
| 138 | - | ||
| 139 | - except Exception as e: | ||
| 140 | - print(f"Error processing image from URL {url}: {str(e)}") | ||
| 141 | - return None | ||
| 142 | - | ||
| 143 | - def encode_batch( | ||
| 144 | - self, | ||
| 145 | - images: List[Union[str, Image.Image]], | ||
| 146 | - batch_size: int = 8 | ||
| 147 | - ) -> List[Optional[np.ndarray]]: | ||
| 148 | - """ | ||
| 149 | - Encode a batch of images efficiently. | ||
| 150 | - | ||
| 151 | - Args: | ||
| 152 | - images: List of image URLs or PIL Images | ||
| 153 | - batch_size: Batch size for processing | ||
| 154 | - | ||
| 155 | - Returns: | ||
| 156 | - List of embeddings (or None for failed images) | ||
| 157 | - """ | ||
| 158 | - results = [] | 89 | + image_data = self.download_image(url) |
| 90 | + image = self.validate_image(image_data) | ||
| 91 | + image = self.preprocess_image(image) | ||
| 92 | + return self.encode_image(image) | ||
| 159 | 93 | ||
| 94 | + def encode_batch(self, images: List[Union[str, Image.Image]], batch_size: int = 8) -> List[Optional[np.ndarray]]: | ||
| 95 | + results: List[Optional[np.ndarray]] = [] | ||
| 160 | for i in range(0, len(images), batch_size): | 96 | for i in range(0, len(images), batch_size): |
| 161 | - batch = images[i:i + batch_size] | ||
| 162 | - batch_embeddings = [] | ||
| 163 | - | 97 | + batch = images[i : i + batch_size] |
| 164 | for img in batch: | 98 | for img in batch: |
| 165 | if isinstance(img, str): | 99 | if isinstance(img, str): |
| 166 | - # URL or file path | ||
| 167 | - emb = self.encode_image_from_url(img) | 100 | + results.append(self.encode_image_from_url(img)) |
| 168 | elif isinstance(img, Image.Image): | 101 | elif isinstance(img, Image.Image): |
| 169 | - # PIL Image | ||
| 170 | - emb = self.encode_image(img) | 102 | + results.append(self.encode_image(img)) |
| 171 | else: | 103 | else: |
| 172 | - emb = None | ||
| 173 | - | ||
| 174 | - batch_embeddings.append(emb) | 104 | + results.append(None) |
| 105 | + return results | ||
| 175 | 106 | ||
| 176 | - results.extend(batch_embeddings) | ||
| 177 | 107 | ||
| 178 | - return results |
| @@ -0,0 +1,33 @@ | @@ -0,0 +1,33 @@ | ||
| 1 | +""" | ||
| 2 | +Embedding module configuration. | ||
| 3 | + | ||
| 4 | +This module is intentionally a plain Python file (no env var parsing, no extra deps). | ||
| 5 | +Edit values here to configure: | ||
| 6 | +- server host/port | ||
| 7 | +- local model settings (paths/devices/batch sizes) | ||
| 8 | +""" | ||
| 9 | + | ||
| 10 | +from typing import Optional | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +class EmbeddingConfig(object): | ||
| 14 | + # Server | ||
| 15 | + HOST = "0.0.0.0" | ||
| 16 | + PORT = 6005 | ||
| 17 | + | ||
| 18 | + # Text embeddings (BGE-M3) | ||
| 19 | + TEXT_MODEL_DIR = "Xorbits/bge-m3" | ||
| 20 | + TEXT_DEVICE = "cuda" # "cuda" or "cpu" (model may fall back to CPU if needed) | ||
| 21 | + TEXT_BATCH_SIZE = 32 | ||
| 22 | + | ||
| 23 | + # Image embeddings (CN-CLIP) | ||
| 24 | + IMAGE_MODEL_NAME = "ViT-H-14" | ||
| 25 | + IMAGE_DEVICE = None # type: Optional[str] # "cuda" / "cpu" / None(auto) | ||
| 26 | + | ||
| 27 | + # Service behavior | ||
| 28 | + IMAGE_BATCH_SIZE = 8 | ||
| 29 | + | ||
| 30 | + | ||
| 31 | +CONFIG = EmbeddingConfig() | ||
| 32 | + | ||
| 33 | + |
embeddings/image_encoder.py
| 1 | """ | 1 | """ |
| 2 | Image embedding encoder using network service. | 2 | Image embedding encoder using network service. |
| 3 | 3 | ||
| 4 | -Generates embeddings via HTTP API service running on localhost:5001. | 4 | +Generates embeddings via HTTP API service (default localhost:6005). |
| 5 | """ | 5 | """ |
| 6 | 6 | ||
| 7 | import sys | 7 | import sys |
| @@ -26,24 +26,25 @@ class CLIPImageEncoder: | @@ -26,24 +26,25 @@ class CLIPImageEncoder: | ||
| 26 | _instance = None | 26 | _instance = None |
| 27 | _lock = threading.Lock() | 27 | _lock = threading.Lock() |
| 28 | 28 | ||
| 29 | - def __new__(cls, service_url='http://localhost:5001'): | 29 | + def __new__(cls, service_url: Optional[str] = None): |
| 30 | with cls._lock: | 30 | with cls._lock: |
| 31 | if cls._instance is None: | 31 | if cls._instance is None: |
| 32 | cls._instance = super(CLIPImageEncoder, cls).__new__(cls) | 32 | cls._instance = super(CLIPImageEncoder, cls).__new__(cls) |
| 33 | - logger.info(f"Creating CLIPImageEncoder instance with service URL: {service_url}") | ||
| 34 | - cls._instance.service_url = service_url | ||
| 35 | - cls._instance.endpoint = f"{service_url}/embedding/generate_image_embeddings" | 33 | + resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL", "http://localhost:6005") |
| 34 | + logger.info(f"Creating CLIPImageEncoder instance with service URL: {resolved_url}") | ||
| 35 | + cls._instance.service_url = resolved_url | ||
| 36 | + cls._instance.endpoint = f"{resolved_url}/embed/image" | ||
| 36 | return cls._instance | 37 | return cls._instance |
| 37 | 38 | ||
| 38 | - def _call_service(self, request_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | 39 | + def _call_service(self, request_data: List[str]) -> List[Any]: |
| 39 | """ | 40 | """ |
| 40 | Call the embedding service API. | 41 | Call the embedding service API. |
| 41 | 42 | ||
| 42 | Args: | 43 | Args: |
| 43 | - request_data: List of dictionaries with id and pic_url fields | 44 | + request_data: List of image URLs / local file paths |
| 44 | 45 | ||
| 45 | Returns: | 46 | Returns: |
| 46 | - List of dictionaries with id, pic_url, embedding and error fields | 47 | + List of embeddings (list[float]) or nulls (None), aligned to input order |
| 47 | """ | 48 | """ |
| 48 | try: | 49 | try: |
| 49 | response = requests.post( | 50 | response = requests.post( |
| @@ -77,26 +78,11 @@ class CLIPImageEncoder: | @@ -77,26 +78,11 @@ class CLIPImageEncoder: | ||
| 77 | Embedding vector or None if failed | 78 | Embedding vector or None if failed |
| 78 | """ | 79 | """ |
| 79 | try: | 80 | try: |
| 80 | - # Prepare request data | ||
| 81 | - request_data = [{ | ||
| 82 | - "id": "image_0", | ||
| 83 | - "pic_url": url | ||
| 84 | - }] | ||
| 85 | - | ||
| 86 | - # Call service | ||
| 87 | - response_data = self._call_service(request_data) | ||
| 88 | - | ||
| 89 | - # Process response | ||
| 90 | - if response_data and len(response_data) > 0: | ||
| 91 | - response_item = response_data[0] | ||
| 92 | - if response_item.get("embedding"): | ||
| 93 | - return np.array(response_item["embedding"], dtype=np.float32) | ||
| 94 | - else: | ||
| 95 | - logger.warning(f"No embedding for URL {url}, error: {response_item.get('error', 'Unknown error')}") | ||
| 96 | - return None | ||
| 97 | - else: | ||
| 98 | - logger.warning(f"No response for URL {url}") | ||
| 99 | - return None | 81 | + response_data = self._call_service([url]) |
| 82 | + if response_data and len(response_data) > 0 and response_data[0] is not None: | ||
| 83 | + return np.array(response_data[0], dtype=np.float32) | ||
| 84 | + logger.warning(f"No embedding for URL {url}") | ||
| 85 | + return None | ||
| 100 | 86 | ||
| 101 | except Exception as e: | 87 | except Exception as e: |
| 102 | logger.error(f"Failed to process image from URL {url}: {str(e)}", exc_info=True) | 88 | logger.error(f"Failed to process image from URL {url}: {str(e)}", exc_info=True) |
| @@ -137,32 +123,17 @@ class CLIPImageEncoder: | @@ -137,32 +123,17 @@ class CLIPImageEncoder: | ||
| 137 | batch_urls = url_images[i:i + batch_size] | 123 | batch_urls = url_images[i:i + batch_size] |
| 138 | batch_indices = url_indices[i:i + batch_size] | 124 | batch_indices = url_indices[i:i + batch_size] |
| 139 | 125 | ||
| 140 | - # Prepare request data | ||
| 141 | - request_data = [] | ||
| 142 | - for j, url in enumerate(batch_urls): | ||
| 143 | - request_data.append({ | ||
| 144 | - "id": f"image_{j}", | ||
| 145 | - "pic_url": url | ||
| 146 | - }) | ||
| 147 | - | ||
| 148 | try: | 126 | try: |
| 149 | # Call service | 127 | # Call service |
| 150 | - response_data = self._call_service(request_data) | 128 | + response_data = self._call_service(batch_urls) |
| 151 | 129 | ||
| 152 | - # Process response | 130 | + # Process response (aligned list) |
| 153 | batch_results = [] | 131 | batch_results = [] |
| 154 | for j, url in enumerate(batch_urls): | 132 | for j, url in enumerate(batch_urls): |
| 155 | - response_item = None | ||
| 156 | - for item in response_data: | ||
| 157 | - if str(item.get("id")) == f"image_{j}": | ||
| 158 | - response_item = item | ||
| 159 | - break | ||
| 160 | - | ||
| 161 | - if response_item and response_item.get("embedding"): | ||
| 162 | - batch_results.append(np.array(response_item["embedding"], dtype=np.float32)) | 133 | + if response_data and j < len(response_data) and response_data[j] is not None: |
| 134 | + batch_results.append(np.array(response_data[j], dtype=np.float32)) | ||
| 163 | else: | 135 | else: |
| 164 | - error_msg = response_item.get("error", "Unknown error") if response_item else "No response" | ||
| 165 | - logger.warning(f"Failed to encode URL {url}: {error_msg}") | 136 | + logger.warning(f"Failed to encode URL {url}: no embedding") |
| 166 | batch_results.append(None) | 137 | batch_results.append(None) |
| 167 | 138 | ||
| 168 | # Insert results at the correct positions | 139 | # Insert results at the correct positions |
| @@ -0,0 +1,122 @@ | @@ -0,0 +1,122 @@ | ||
| 1 | +""" | ||
| 2 | +Embedding service (FastAPI). | ||
| 3 | + | ||
| 4 | +API (simple list-in, list-out; aligned by index; failures -> null): | ||
| 5 | +- POST /embed/text body: ["text1", "text2", ...] -> [[...], null, ...] | ||
| 6 | +- POST /embed/image body: ["url_or_path1", ...] -> [[...], null, ...] | ||
| 7 | +""" | ||
| 8 | + | ||
| 9 | +import threading | ||
| 10 | +from typing import Any, Dict, List, Optional | ||
| 11 | + | ||
| 12 | +import numpy as np | ||
| 13 | +from fastapi import FastAPI | ||
| 14 | + | ||
| 15 | +from embeddings.config import CONFIG | ||
| 16 | +from embeddings.bge_model import BgeTextModel | ||
| 17 | +from embeddings.clip_model import ClipImageModel | ||
| 18 | + | ||
| 19 | + | ||
| 20 | +app = FastAPI(title="SearchEngine Embedding Service", version="1.0.0") | ||
| 21 | + | ||
| 22 | +_text_model = None | ||
| 23 | +_image_model = None | ||
| 24 | + | ||
| 25 | +_text_init_lock = threading.Lock() | ||
| 26 | +_image_init_lock = threading.Lock() | ||
| 27 | + | ||
| 28 | +_text_encode_lock = threading.Lock() | ||
| 29 | +_image_encode_lock = threading.Lock() | ||
| 30 | + | ||
| 31 | + | ||
| 32 | +def _get_text_model(): | ||
| 33 | + global _text_model | ||
| 34 | + if _text_model is None: | ||
| 35 | + with _text_init_lock: | ||
| 36 | + if _text_model is None: | ||
| 37 | + _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR) | ||
| 38 | + return _text_model | ||
| 39 | + | ||
| 40 | + | ||
| 41 | +def _get_image_model(): | ||
| 42 | + global _image_model | ||
| 43 | + if _image_model is None: | ||
| 44 | + with _image_init_lock: | ||
| 45 | + if _image_model is None: | ||
| 46 | + _image_model = ClipImageModel( | ||
| 47 | + model_name=CONFIG.IMAGE_MODEL_NAME, | ||
| 48 | + device=CONFIG.IMAGE_DEVICE, | ||
| 49 | + ) | ||
| 50 | + return _image_model | ||
| 51 | + | ||
| 52 | + | ||
| 53 | +def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: | ||
| 54 | + if embedding is None: | ||
| 55 | + return None | ||
| 56 | + if not isinstance(embedding, np.ndarray): | ||
| 57 | + embedding = np.array(embedding, dtype=np.float32) | ||
| 58 | + if embedding.ndim != 1: | ||
| 59 | + embedding = embedding.reshape(-1) | ||
| 60 | + return embedding.astype(np.float32).tolist() | ||
| 61 | + | ||
| 62 | + | ||
| 63 | +@app.get("/health") | ||
| 64 | +def health() -> Dict[str, Any]: | ||
| 65 | + return {"status": "ok"} | ||
| 66 | + | ||
| 67 | + | ||
| 68 | +@app.post("/embed/text") | ||
| 69 | +def embed_text(texts: List[str]) -> List[Optional[List[float]]]: | ||
| 70 | + model = _get_text_model() | ||
| 71 | + out: List[Optional[List[float]]] = [None] * len(texts) | ||
| 72 | + | ||
| 73 | + indexed_texts: List[tuple] = [] | ||
| 74 | + for i, t in enumerate(texts): | ||
| 75 | + if t is None: | ||
| 76 | + continue | ||
| 77 | + if not isinstance(t, str): | ||
| 78 | + t = str(t) | ||
| 79 | + t = t.strip() | ||
| 80 | + if not t: | ||
| 81 | + continue | ||
| 82 | + indexed_texts.append((i, t)) | ||
| 83 | + | ||
| 84 | + if not indexed_texts: | ||
| 85 | + return out | ||
| 86 | + | ||
| 87 | + batch_texts = [t for _, t in indexed_texts] | ||
| 88 | + try: | ||
| 89 | + with _text_encode_lock: | ||
| 90 | + embs = model.encode_batch( | ||
| 91 | + batch_texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE | ||
| 92 | + ) | ||
| 93 | + for j, (idx, _t) in enumerate(indexed_texts): | ||
| 94 | + out[idx] = _as_list(embs[j]) | ||
| 95 | + except Exception: | ||
| 96 | + # keep Nones | ||
| 97 | + pass | ||
| 98 | + return out | ||
| 99 | + | ||
| 100 | + | ||
| 101 | +@app.post("/embed/image") | ||
| 102 | +def embed_image(images: List[str]) -> List[Optional[List[float]]]: | ||
| 103 | + model = _get_image_model() | ||
| 104 | + out: List[Optional[List[float]]] = [None] * len(images) | ||
| 105 | + | ||
| 106 | + with _image_encode_lock: | ||
| 107 | + for i, url_or_path in enumerate(images): | ||
| 108 | + try: | ||
| 109 | + if url_or_path is None: | ||
| 110 | + continue | ||
| 111 | + if not isinstance(url_or_path, str): | ||
| 112 | + url_or_path = str(url_or_path) | ||
| 113 | + url_or_path = url_or_path.strip() | ||
| 114 | + if not url_or_path: | ||
| 115 | + continue | ||
| 116 | + emb = model.encode_image_from_url(url_or_path) | ||
| 117 | + out[i] = _as_list(emb) | ||
| 118 | + except Exception: | ||
| 119 | + out[i] = None | ||
| 120 | + return out | ||
| 121 | + | ||
| 122 | + |
embeddings/text_encoder.py
| 1 | """ | 1 | """ |
| 2 | Text embedding encoder using network service. | 2 | Text embedding encoder using network service. |
| 3 | 3 | ||
| 4 | -Generates embeddings via HTTP API service running on localhost:5001. | 4 | +Generates embeddings via HTTP API service (default localhost:6005). |
| 5 | """ | 5 | """ |
| 6 | 6 | ||
| 7 | import sys | 7 | import sys |
| @@ -11,6 +11,7 @@ import threading | @@ -11,6 +11,7 @@ import threading | ||
| 11 | import numpy as np | 11 | import numpy as np |
| 12 | import pickle | 12 | import pickle |
| 13 | import redis | 13 | import redis |
| 14 | +import os | ||
| 14 | from datetime import timedelta | 15 | from datetime import timedelta |
| 15 | from typing import List, Union, Dict, Any, Optional | 16 | from typing import List, Union, Dict, Any, Optional |
| 16 | import logging | 17 | import logging |
| @@ -33,13 +34,14 @@ class BgeEncoder: | @@ -33,13 +34,14 @@ class BgeEncoder: | ||
| 33 | _instance = None | 34 | _instance = None |
| 34 | _lock = threading.Lock() | 35 | _lock = threading.Lock() |
| 35 | 36 | ||
| 36 | - def __new__(cls, service_url='http://localhost:5001'): | 37 | + def __new__(cls, service_url: Optional[str] = None): |
| 37 | with cls._lock: | 38 | with cls._lock: |
| 38 | if cls._instance is None: | 39 | if cls._instance is None: |
| 39 | cls._instance = super(BgeEncoder, cls).__new__(cls) | 40 | cls._instance = super(BgeEncoder, cls).__new__(cls) |
| 40 | - logger.info(f"Creating BgeEncoder instance with service URL: {service_url}") | ||
| 41 | - cls._instance.service_url = service_url | ||
| 42 | - cls._instance.endpoint = f"{service_url}/embedding/generate_embeddings" | 41 | + resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL", "http://localhost:6005") |
| 42 | + logger.info(f"Creating BgeEncoder instance with service URL: {resolved_url}") | ||
| 43 | + cls._instance.service_url = resolved_url | ||
| 44 | + cls._instance.endpoint = f"{resolved_url}/embed/text" | ||
| 43 | 45 | ||
| 44 | # Initialize Redis cache | 46 | # Initialize Redis cache |
| 45 | try: | 47 | try: |
| @@ -62,15 +64,15 @@ class BgeEncoder: | @@ -62,15 +64,15 @@ class BgeEncoder: | ||
| 62 | cls._instance.redis_client = None | 64 | cls._instance.redis_client = None |
| 63 | return cls._instance | 65 | return cls._instance |
| 64 | 66 | ||
| 65 | - def _call_service(self, request_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | 67 | + def _call_service(self, request_data: List[str]) -> List[Any]: |
| 66 | """ | 68 | """ |
| 67 | Call the embedding service API. | 69 | Call the embedding service API. |
| 68 | 70 | ||
| 69 | Args: | 71 | Args: |
| 70 | - request_data: List of dictionaries with id and text fields | 72 | + request_data: List of texts |
| 71 | 73 | ||
| 72 | Returns: | 74 | Returns: |
| 73 | - List of dictionaries with id and embedding fields | 75 | + List of embeddings (list[float]) or nulls (None), aligned to input order |
| 74 | """ | 76 | """ |
| 75 | try: | 77 | try: |
| 76 | response = requests.post( | 78 | response = requests.post( |
| @@ -126,19 +128,7 @@ class BgeEncoder: | @@ -126,19 +128,7 @@ class BgeEncoder: | ||
| 126 | uncached_texts.append(text) | 128 | uncached_texts.append(text) |
| 127 | 129 | ||
| 128 | # Prepare request data for uncached texts (after cache check) | 130 | # Prepare request data for uncached texts (after cache check) |
| 129 | - request_data = [] | ||
| 130 | - for i, text in enumerate(uncached_texts): | ||
| 131 | - request_item = { | ||
| 132 | - "id": str(uncached_indices[i]), | ||
| 133 | - "name_zh": text | ||
| 134 | - } | ||
| 135 | - | ||
| 136 | - # Add English and Russian fields as empty for now | ||
| 137 | - # Could be enhanced with language detection in the future | ||
| 138 | - request_item["name_en"] = None | ||
| 139 | - request_item["name_ru"] = None | ||
| 140 | - | ||
| 141 | - request_data.append(request_item) | 131 | + request_data = list(uncached_texts) |
| 142 | 132 | ||
| 143 | # If there are uncached texts, call service | 133 | # If there are uncached texts, call service |
| 144 | if uncached_texts: | 134 | if uncached_texts: |
| @@ -149,43 +139,27 @@ class BgeEncoder: | @@ -149,43 +139,27 @@ class BgeEncoder: | ||
| 149 | # Process response | 139 | # Process response |
| 150 | for i, text in enumerate(uncached_texts): | 140 | for i, text in enumerate(uncached_texts): |
| 151 | original_idx = uncached_indices[i] | 141 | original_idx = uncached_indices[i] |
| 152 | - # Find corresponding response by ID | ||
| 153 | - response_item = None | ||
| 154 | - for item in response_data: | ||
| 155 | - if str(item.get("id")) == str(original_idx): | ||
| 156 | - response_item = item | ||
| 157 | - break | ||
| 158 | - | ||
| 159 | - if response_item: | ||
| 160 | - # Try Chinese embedding first, then English, then Russian | 142 | + if response_data and i < len(response_data): |
| 143 | + embedding = response_data[i] | ||
| 144 | + else: | ||
| 161 | embedding = None | 145 | embedding = None |
| 162 | - for lang in ["embedding_zh", "embedding_en", "embedding_ru"]: | ||
| 163 | - if lang in response_item and response_item[lang] is not None: | ||
| 164 | - embedding = response_item[lang] | ||
| 165 | - break | ||
| 166 | 146 | ||
| 167 | - if embedding is not None: | ||
| 168 | - embedding_array = np.array(embedding, dtype=np.float32) | ||
| 169 | - # Validate embedding from service - if invalid, treat as no result | ||
| 170 | - if self._is_valid_embedding(embedding_array): | ||
| 171 | - embeddings[original_idx] = embedding_array | ||
| 172 | - # Cache the embedding | ||
| 173 | - self._set_cached_embedding(text, 'en', embedding_array) | ||
| 174 | - else: | ||
| 175 | - logger.warning( | ||
| 176 | - f"Invalid embedding returned from service for text {original_idx} " | ||
| 177 | - f"(contains NaN/Inf or invalid shape), treating as no result. " | ||
| 178 | - f"Text preview: {text[:50]}..." | ||
| 179 | - ) | ||
| 180 | - # 不生成兜底向量,保持为 None | ||
| 181 | - embeddings[original_idx] = None | 147 | + if embedding is not None: |
| 148 | + embedding_array = np.array(embedding, dtype=np.float32) | ||
| 149 | + # Validate embedding from service - if invalid, treat as no result | ||
| 150 | + if self._is_valid_embedding(embedding_array): | ||
| 151 | + embeddings[original_idx] = embedding_array | ||
| 152 | + # Cache the embedding | ||
| 153 | + self._set_cached_embedding(text, 'en', embedding_array) | ||
| 182 | else: | 154 | else: |
| 183 | - logger.warning(f"No embedding found for text {original_idx}: {text[:50]}...") | ||
| 184 | - # 不生成兜底向量,保持为 None | 155 | + logger.warning( |
| 156 | + f"Invalid embedding returned from service for text {original_idx} " | ||
| 157 | + f"(contains NaN/Inf or invalid shape), treating as no result. " | ||
| 158 | + f"Text preview: {text[:50]}..." | ||
| 159 | + ) | ||
| 185 | embeddings[original_idx] = None | 160 | embeddings[original_idx] = None |
| 186 | else: | 161 | else: |
| 187 | - logger.warning(f"No response found for text {original_idx}") | ||
| 188 | - # 不生成兜底向量,保持为 None | 162 | + logger.warning(f"No embedding found for text {original_idx}: {text[:50]}...") |
| 189 | embeddings[original_idx] = None | 163 | embeddings[original_idx] = None |
| 190 | 164 | ||
| 191 | except Exception as e: | 165 | except Exception as e: |
embeddings/text_encoder__local.py deleted
| @@ -1,124 +0,0 @@ | @@ -1,124 +0,0 @@ | ||
| 1 | -""" | ||
| 2 | -Text embedding encoder using BGE-M3 model. | ||
| 3 | - | ||
| 4 | -Generates 1024-dimensional vectors for text using the BGE-M3 multilingual model. | ||
| 5 | -""" | ||
| 6 | - | ||
| 7 | -import sys | ||
| 8 | -import torch | ||
| 9 | -from sentence_transformers import SentenceTransformer | ||
| 10 | -import time | ||
| 11 | -import threading | ||
| 12 | -from modelscope import snapshot_download | ||
| 13 | -from transformers import AutoModel | ||
| 14 | -import os | ||
| 15 | -import numpy as np | ||
| 16 | -from typing import List, Union | ||
| 17 | - | ||
| 18 | - | ||
| 19 | -class BgeEncoder: | ||
| 20 | - """ | ||
| 21 | - Singleton text encoder using BGE-M3 model. | ||
| 22 | - | ||
| 23 | - Thread-safe singleton pattern ensures only one model instance exists. | ||
| 24 | - """ | ||
| 25 | - _instance = None | ||
| 26 | - _lock = threading.Lock() | ||
| 27 | - | ||
| 28 | - def __new__(cls, model_dir='Xorbits/bge-m3'): | ||
| 29 | - with cls._lock: | ||
| 30 | - if cls._instance is None: | ||
| 31 | - cls._instance = super(BgeEncoder, cls).__new__(cls) | ||
| 32 | - print(f"[BgeEncoder] Creating a new instance with model directory: {model_dir}") | ||
| 33 | - cls._instance.model = SentenceTransformer(snapshot_download(model_dir)) | ||
| 34 | - print("[BgeEncoder] New instance has been created") | ||
| 35 | - return cls._instance | ||
| 36 | - | ||
| 37 | - def encode( | ||
| 38 | - self, | ||
| 39 | - sentences: Union[str, List[str]], | ||
| 40 | - normalize_embeddings: bool = True, | ||
| 41 | - device: str = 'cuda', | ||
| 42 | - batch_size: int = 32 | ||
| 43 | - ) -> np.ndarray: | ||
| 44 | - """ | ||
| 45 | - Encode text into embeddings. | ||
| 46 | - | ||
| 47 | - Args: | ||
| 48 | - sentences: Single string or list of strings to encode | ||
| 49 | - normalize_embeddings: Whether to normalize embeddings | ||
| 50 | - device: Device to use ('cuda' or 'cpu') | ||
| 51 | - batch_size: Batch size for encoding | ||
| 52 | - | ||
| 53 | - Returns: | ||
| 54 | - numpy array of shape (n, 1024) containing embeddings | ||
| 55 | - """ | ||
| 56 | - # Move model to specified device | ||
| 57 | - if device == 'gpu': | ||
| 58 | - device = 'cuda' | ||
| 59 | - | ||
| 60 | - # Try requested device, fallback to CPU if CUDA fails | ||
| 61 | - try: | ||
| 62 | - if device == 'cuda': | ||
| 63 | - # Check CUDA memory first | ||
| 64 | - import torch | ||
| 65 | - if torch.cuda.is_available(): | ||
| 66 | - # Check if we have enough memory (at least 1GB free) | ||
| 67 | - free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated() | ||
| 68 | - if free_memory < 1024 * 1024 * 1024: # 1GB | ||
| 69 | - print(f"[BgeEncoder] CUDA memory insufficient ({free_memory/1024/1024:.1f}MB free), falling back to CPU") | ||
| 70 | - device = 'cpu' | ||
| 71 | - else: | ||
| 72 | - print(f"[BgeEncoder] CUDA not available, using CPU") | ||
| 73 | - device = 'cpu' | ||
| 74 | - | ||
| 75 | - self.model = self.model.to(device) | ||
| 76 | - | ||
| 77 | - embeddings = self.model.encode( | ||
| 78 | - sentences, | ||
| 79 | - normalize_embeddings=normalize_embeddings, | ||
| 80 | - device=device, | ||
| 81 | - show_progress_bar=False, | ||
| 82 | - batch_size=batch_size | ||
| 83 | - ) | ||
| 84 | - | ||
| 85 | - return embeddings | ||
| 86 | - | ||
| 87 | - except Exception as e: | ||
| 88 | - print(f"[BgeEncoder] Device {device} failed: {e}") | ||
| 89 | - if device != 'cpu': | ||
| 90 | - print(f"[BgeEncoder] Falling back to CPU") | ||
| 91 | - try: | ||
| 92 | - self.model = self.model.to('cpu') | ||
| 93 | - embeddings = self.model.encode( | ||
| 94 | - sentences, | ||
| 95 | - normalize_embeddings=normalize_embeddings, | ||
| 96 | - device='cpu', | ||
| 97 | - show_progress_bar=False, | ||
| 98 | - batch_size=batch_size | ||
| 99 | - ) | ||
| 100 | - return embeddings | ||
| 101 | - except Exception as e2: | ||
| 102 | - print(f"[BgeEncoder] CPU also failed: {e2}") | ||
| 103 | - raise | ||
| 104 | - else: | ||
| 105 | - raise | ||
| 106 | - | ||
| 107 | - def encode_batch( | ||
| 108 | - self, | ||
| 109 | - texts: List[str], | ||
| 110 | - batch_size: int = 32, | ||
| 111 | - device: str = 'cuda' | ||
| 112 | - ) -> np.ndarray: | ||
| 113 | - """ | ||
| 114 | - Encode a batch of texts efficiently. | ||
| 115 | - | ||
| 116 | - Args: | ||
| 117 | - texts: List of texts to encode | ||
| 118 | - batch_size: Batch size for processing | ||
| 119 | - device: Device to use | ||
| 120 | - | ||
| 121 | - Returns: | ||
| 122 | - numpy array of embeddings | ||
| 123 | - """ | ||
| 124 | - return self.encode(texts, batch_size=batch_size, device=device) |
| @@ -0,0 +1,40 @@ | @@ -0,0 +1,40 @@ | ||
| 1 | +#!/bin/bash | ||
| 2 | +# | ||
| 3 | +# Start Local Embedding Service | ||
| 4 | +# | ||
| 5 | +# This service exposes: | ||
| 6 | +# - POST /embed/text | ||
| 7 | +# - POST /embed/image | ||
| 8 | +# | ||
| 9 | +# Defaults are defined in `embeddings/config.py` | ||
| 10 | +# | ||
| 11 | +set -e | ||
| 12 | + | ||
| 13 | +cd "$(dirname "$0")/.." | ||
| 14 | + | ||
| 15 | +# Load conda env if available (keep consistent with other scripts) | ||
| 16 | +if [ -f "/home/tw/miniconda3/etc/profile.d/conda.sh" ]; then | ||
| 17 | + source /home/tw/miniconda3/etc/profile.d/conda.sh | ||
| 18 | + conda activate searchengine | ||
| 19 | +fi | ||
| 20 | + | ||
| 21 | +EMBEDDING_SERVICE_HOST=$(python -c "from embeddings.config import CONFIG; print(CONFIG.HOST)") | ||
| 22 | +EMBEDDING_SERVICE_PORT=$(python -c "from embeddings.config import CONFIG; print(CONFIG.PORT)") | ||
| 23 | + | ||
| 24 | +echo "========================================" | ||
| 25 | +echo "Starting Local Embedding Service" | ||
| 26 | +echo "========================================" | ||
| 27 | +echo "Host: ${EMBEDDING_SERVICE_HOST}" | ||
| 28 | +echo "Port: ${EMBEDDING_SERVICE_PORT}" | ||
| 29 | +echo | ||
| 30 | +echo "Tips:" | ||
| 31 | +echo " - Use a single worker (GPU models cannot be safely duplicated across workers)." | ||
| 32 | +echo " - Clients can set EMBEDDING_SERVICE_URL=http://localhost:${EMBEDDING_SERVICE_PORT}" | ||
| 33 | +echo | ||
| 34 | + | ||
| 35 | +exec python -m uvicorn embeddings.server:app \ | ||
| 36 | + --host "${EMBEDDING_SERVICE_HOST}" \ | ||
| 37 | + --port "${EMBEDDING_SERVICE_PORT}" \ | ||
| 38 | + --workers 1 | ||
| 39 | + | ||
| 40 | + |