Commit 200fdddf9b8cac307bb67bd9eac4bef1eb9a0e22
1 parent
654f20d1
embed norm
Showing
10 changed files
with
128 additions
and
44 deletions
Show diff stats
embeddings/README.md
| @@ -14,7 +14,7 @@ | @@ -14,7 +14,7 @@ | ||
| 14 | - **clip-as-service 客户端**:`clip_as_service_encoder.py`(图片向量,推荐) | 14 | - **clip-as-service 客户端**:`clip_as_service_encoder.py`(图片向量,推荐) |
| 15 | - **向量化服务(FastAPI)**:`server.py` | 15 | - **向量化服务(FastAPI)**:`server.py` |
| 16 | - **统一配置**:`config.py` | 16 | - **统一配置**:`config.py` |
| 17 | -- **接口契约**:`protocols.ImageEncoderProtocol`(图片编码统一为 `encode_image_urls(urls, batch_size)`,本地 CN-CLIP 与 clip-as-service 均实现该接口) | 17 | +- **接口契约**:`protocols.ImageEncoderProtocol`(图片编码统一为 `encode_image_urls(urls, batch_size, normalize_embeddings)`,本地 CN-CLIP 与 clip-as-service 均实现该接口) |
| 18 | 18 | ||
| 19 | 说明:历史上的云端 embedding 试验实现(DashScope)已从主仓库移除,当前仅维护 6005 这条统一向量服务链路。 | 19 | 说明:历史上的云端 embedding 试验实现(DashScope)已从主仓库移除,当前仅维护 6005 这条统一向量服务链路。 |
| 20 | 20 | ||
| @@ -29,10 +29,12 @@ | @@ -29,10 +29,12 @@ | ||
| 29 | 29 | ||
| 30 | - `POST /embed/text` | 30 | - `POST /embed/text` |
| 31 | - 入参:`["文本1", "文本2", ...]` | 31 | - 入参:`["文本1", "文本2", ...]` |
| 32 | + - 可选 query 参数:`normalize=true|false`(不传则使用服务端默认) | ||
| 32 | - 出参:`[[...], [...], ...]`(与输入按 index 对齐,失败直接报错) | 33 | - 出参:`[[...], [...], ...]`(与输入按 index 对齐,失败直接报错) |
| 33 | 34 | ||
| 34 | - `POST /embed/image` | 35 | - `POST /embed/image` |
| 35 | - 入参:`["url或本地路径1", ...]` | 36 | - 入参:`["url或本地路径1", ...]` |
| 37 | + - 可选 query 参数:`normalize=true|false`(不传则使用服务端默认) | ||
| 36 | - 出参:`[[...], [...], ...]`(与输入按 index 对齐,失败直接报错) | 38 | - 出参:`[[...], [...], ...]`(与输入按 index 对齐,失败直接报错) |
| 37 | 39 | ||
| 38 | ### 图片向量:clip-as-service(推荐) | 40 | ### 图片向量:clip-as-service(推荐) |
| @@ -77,6 +79,7 @@ TEI_USE_GPU=0 ./scripts/start_tei_service.sh | @@ -77,6 +79,7 @@ TEI_USE_GPU=0 ./scripts/start_tei_service.sh | ||
| 77 | 79 | ||
| 78 | - `PORT`: 服务端口(默认 6005) | 80 | - `PORT`: 服务端口(默认 6005) |
| 79 | - `TEXT_MODEL_ID`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE`, `TEXT_NORMALIZE_EMBEDDINGS` | 81 | - `TEXT_MODEL_ID`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE`, `TEXT_NORMALIZE_EMBEDDINGS` |
| 82 | +- `IMAGE_NORMALIZE_EMBEDDINGS`(默认 true) | ||
| 80 | - `USE_CLIP_AS_SERVICE`, `CLIP_AS_SERVICE_SERVER`:图片向量(clip-as-service) | 83 | - `USE_CLIP_AS_SERVICE`, `CLIP_AS_SERVICE_SERVER`:图片向量(clip-as-service) |
| 81 | - `IMAGE_MODEL_NAME`, `IMAGE_DEVICE`:本地 CN-CLIP(当 `USE_CLIP_AS_SERVICE=false` 时) | 84 | - `IMAGE_MODEL_NAME`, `IMAGE_DEVICE`:本地 CN-CLIP(当 `USE_CLIP_AS_SERVICE=false` 时) |
| 82 | - TEI 相关:`TEI_USE_GPU`、`TEI_VERSION`、`TEI_MAX_BATCH_TOKENS`、`TEI_MAX_CLIENT_BATCH_SIZE`、`TEI_HEALTH_TIMEOUT_SEC` | 85 | - TEI 相关:`TEI_USE_GPU`、`TEI_VERSION`、`TEI_MAX_BATCH_TOKENS`、`TEI_MAX_CLIENT_BATCH_SIZE`、`TEI_HEALTH_TIMEOUT_SEC` |
embeddings/clip_as_service_encoder.py
| @@ -79,6 +79,7 @@ class ClipAsServiceImageEncoder: | @@ -79,6 +79,7 @@ class ClipAsServiceImageEncoder: | ||
| 79 | self, | 79 | self, |
| 80 | urls: List[str], | 80 | urls: List[str], |
| 81 | batch_size: Optional[int] = None, | 81 | batch_size: Optional[int] = None, |
| 82 | + normalize_embeddings: bool = True, | ||
| 82 | ) -> List[np.ndarray]: | 83 | ) -> List[np.ndarray]: |
| 83 | """ | 84 | """ |
| 84 | Encode a list of image URLs to vectors. | 85 | Encode a list of image URLs to vectors. |
| @@ -117,10 +118,15 @@ class ClipAsServiceImageEncoder: | @@ -117,10 +118,15 @@ class ClipAsServiceImageEncoder: | ||
| 117 | vec = np.asarray(row, dtype=np.float32) | 118 | vec = np.asarray(row, dtype=np.float32) |
| 118 | if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): | 119 | if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): |
| 119 | raise RuntimeError("clip-as-service returned invalid embedding vector") | 120 | raise RuntimeError("clip-as-service returned invalid embedding vector") |
| 121 | + if normalize_embeddings: | ||
| 122 | + norm = float(np.linalg.norm(vec)) | ||
| 123 | + if not np.isfinite(norm) or norm <= 0.0: | ||
| 124 | + raise RuntimeError("clip-as-service returned zero/invalid norm vector") | ||
| 125 | + vec = vec / norm | ||
| 120 | out.append(vec) | 126 | out.append(vec) |
| 121 | return out | 127 | return out |
| 122 | 128 | ||
| 123 | - def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: | 129 | + def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> Optional[np.ndarray]: |
| 124 | """Encode a single image URL. Returns 1024-dim vector or None.""" | 130 | """Encode a single image URL. Returns 1024-dim vector or None.""" |
| 125 | - results = self.encode_image_urls([url], batch_size=1) | 131 | + results = self.encode_image_urls([url], batch_size=1, normalize_embeddings=normalize_embeddings) |
| 126 | return results[0] if results else None | 132 | return results[0] if results else None |
embeddings/clip_model.py
| @@ -76,25 +76,27 @@ class ClipImageModel(object): | @@ -76,25 +76,27 @@ class ClipImageModel(object): | ||
| 76 | text_features /= text_features.norm(dim=-1, keepdim=True) | 76 | text_features /= text_features.norm(dim=-1, keepdim=True) |
| 77 | return text_features | 77 | return text_features |
| 78 | 78 | ||
| 79 | - def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: | 79 | + def encode_image(self, image: Image.Image, normalize_embeddings: bool = True) -> Optional[np.ndarray]: |
| 80 | if not isinstance(image, Image.Image): | 80 | if not isinstance(image, Image.Image): |
| 81 | raise ValueError("ClipImageModel.encode_image input must be a PIL.Image") | 81 | raise ValueError("ClipImageModel.encode_image input must be a PIL.Image") |
| 82 | infer_data = self.preprocess(image).unsqueeze(0).to(self.device) | 82 | infer_data = self.preprocess(image).unsqueeze(0).to(self.device) |
| 83 | with torch.no_grad(): | 83 | with torch.no_grad(): |
| 84 | image_features = self.model.encode_image(infer_data) | 84 | image_features = self.model.encode_image(infer_data) |
| 85 | - image_features /= image_features.norm(dim=-1, keepdim=True) | 85 | + if normalize_embeddings: |
| 86 | + image_features /= image_features.norm(dim=-1, keepdim=True) | ||
| 86 | return image_features.cpu().numpy().astype("float32")[0] | 87 | return image_features.cpu().numpy().astype("float32")[0] |
| 87 | 88 | ||
| 88 | - def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: | 89 | + def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> Optional[np.ndarray]: |
| 89 | image_data = self.download_image(url) | 90 | image_data = self.download_image(url) |
| 90 | image = self.validate_image(image_data) | 91 | image = self.validate_image(image_data) |
| 91 | image = self.preprocess_image(image) | 92 | image = self.preprocess_image(image) |
| 92 | - return self.encode_image(image) | 93 | + return self.encode_image(image, normalize_embeddings=normalize_embeddings) |
| 93 | 94 | ||
| 94 | def encode_image_urls( | 95 | def encode_image_urls( |
| 95 | self, | 96 | self, |
| 96 | urls: List[str], | 97 | urls: List[str], |
| 97 | batch_size: Optional[int] = None, | 98 | batch_size: Optional[int] = None, |
| 99 | + normalize_embeddings: bool = True, | ||
| 98 | ) -> List[Optional[np.ndarray]]: | 100 | ) -> List[Optional[np.ndarray]]: |
| 99 | """ | 101 | """ |
| 100 | Encode a list of image URLs to vectors. Same interface as ClipAsServiceImageEncoder. | 102 | Encode a list of image URLs to vectors. Same interface as ClipAsServiceImageEncoder. |
| @@ -106,19 +108,27 @@ class ClipImageModel(object): | @@ -106,19 +108,27 @@ class ClipImageModel(object): | ||
| 106 | Returns: | 108 | Returns: |
| 107 | List of vectors (or None for failed items), same length as urls. | 109 | List of vectors (or None for failed items), same length as urls. |
| 108 | """ | 110 | """ |
| 109 | - return self.encode_batch(urls, batch_size=batch_size or 8) | 111 | + return self.encode_batch( |
| 112 | + urls, | ||
| 113 | + batch_size=batch_size or 8, | ||
| 114 | + normalize_embeddings=normalize_embeddings, | ||
| 115 | + ) | ||
| 110 | 116 | ||
| 111 | - def encode_batch(self, images: List[Union[str, Image.Image]], batch_size: int = 8) -> List[Optional[np.ndarray]]: | 117 | + def encode_batch( |
| 118 | + self, | ||
| 119 | + images: List[Union[str, Image.Image]], | ||
| 120 | + batch_size: int = 8, | ||
| 121 | + normalize_embeddings: bool = True, | ||
| 122 | + ) -> List[Optional[np.ndarray]]: | ||
| 112 | results: List[Optional[np.ndarray]] = [] | 123 | results: List[Optional[np.ndarray]] = [] |
| 113 | for i in range(0, len(images), batch_size): | 124 | for i in range(0, len(images), batch_size): |
| 114 | batch = images[i : i + batch_size] | 125 | batch = images[i : i + batch_size] |
| 115 | for img in batch: | 126 | for img in batch: |
| 116 | if isinstance(img, str): | 127 | if isinstance(img, str): |
| 117 | - results.append(self.encode_image_from_url(img)) | 128 | + results.append(self.encode_image_from_url(img, normalize_embeddings=normalize_embeddings)) |
| 118 | elif isinstance(img, Image.Image): | 129 | elif isinstance(img, Image.Image): |
| 119 | - results.append(self.encode_image(img)) | 130 | + results.append(self.encode_image(img, normalize_embeddings=normalize_embeddings)) |
| 120 | else: | 131 | else: |
| 121 | results.append(None) | 132 | results.append(None) |
| 122 | return results | 133 | return results |
| 123 | 134 | ||
| 124 | - |
embeddings/config.py
| @@ -37,6 +37,7 @@ class EmbeddingConfig(object): | @@ -37,6 +37,7 @@ class EmbeddingConfig(object): | ||
| 37 | 37 | ||
| 38 | # Service behavior | 38 | # Service behavior |
| 39 | IMAGE_BATCH_SIZE = 8 | 39 | IMAGE_BATCH_SIZE = 8 |
| 40 | + IMAGE_NORMALIZE_EMBEDDINGS = os.getenv("IMAGE_NORMALIZE_EMBEDDINGS", "true").lower() in ("1", "true", "yes") | ||
| 40 | 41 | ||
| 41 | 42 | ||
| 42 | CONFIG = EmbeddingConfig() | 43 | CONFIG = EmbeddingConfig() |
embeddings/image_encoder.py
| @@ -26,7 +26,7 @@ class CLIPImageEncoder: | @@ -26,7 +26,7 @@ class CLIPImageEncoder: | ||
| 26 | self.endpoint = f"{self.service_url}/embed/image" | 26 | self.endpoint = f"{self.service_url}/embed/image" |
| 27 | logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url) | 27 | logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url) |
| 28 | 28 | ||
| 29 | - def _call_service(self, request_data: List[str]) -> List[Any]: | 29 | + def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: |
| 30 | """ | 30 | """ |
| 31 | Call the embedding service API. | 31 | Call the embedding service API. |
| 32 | 32 | ||
| @@ -39,6 +39,7 @@ class CLIPImageEncoder: | @@ -39,6 +39,7 @@ class CLIPImageEncoder: | ||
| 39 | try: | 39 | try: |
| 40 | response = requests.post( | 40 | response = requests.post( |
| 41 | self.endpoint, | 41 | self.endpoint, |
| 42 | + params={"normalize": "true" if normalize_embeddings else "false"}, | ||
| 42 | json=request_data, | 43 | json=request_data, |
| 43 | timeout=60 | 44 | timeout=60 |
| 44 | ) | 45 | ) |
| @@ -56,7 +57,7 @@ class CLIPImageEncoder: | @@ -56,7 +57,7 @@ class CLIPImageEncoder: | ||
| 56 | """ | 57 | """ |
| 57 | raise NotImplementedError("encode_image with PIL Image is not supported by embedding service") | 58 | raise NotImplementedError("encode_image with PIL Image is not supported by embedding service") |
| 58 | 59 | ||
| 59 | - def encode_image_from_url(self, url: str) -> np.ndarray: | 60 | + def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> np.ndarray: |
| 60 | """ | 61 | """ |
| 61 | Generate image embedding via network service using URL. | 62 | Generate image embedding via network service using URL. |
| 62 | 63 | ||
| @@ -66,7 +67,7 @@ class CLIPImageEncoder: | @@ -66,7 +67,7 @@ class CLIPImageEncoder: | ||
| 66 | Returns: | 67 | Returns: |
| 67 | Embedding vector | 68 | Embedding vector |
| 68 | """ | 69 | """ |
| 69 | - response_data = self._call_service([url]) | 70 | + response_data = self._call_service([url], normalize_embeddings=normalize_embeddings) |
| 70 | if not response_data or len(response_data) != 1 or response_data[0] is None: | 71 | if not response_data or len(response_data) != 1 or response_data[0] is None: |
| 71 | raise RuntimeError(f"No image embedding returned for URL: {url}") | 72 | raise RuntimeError(f"No image embedding returned for URL: {url}") |
| 72 | vec = np.array(response_data[0], dtype=np.float32) | 73 | vec = np.array(response_data[0], dtype=np.float32) |
| @@ -77,7 +78,8 @@ class CLIPImageEncoder: | @@ -77,7 +78,8 @@ class CLIPImageEncoder: | ||
| 77 | def encode_batch( | 78 | def encode_batch( |
| 78 | self, | 79 | self, |
| 79 | images: List[Union[str, Image.Image]], | 80 | images: List[Union[str, Image.Image]], |
| 80 | - batch_size: int = 8 | 81 | + batch_size: int = 8, |
| 82 | + normalize_embeddings: bool = True, | ||
| 81 | ) -> List[np.ndarray]: | 83 | ) -> List[np.ndarray]: |
| 82 | """ | 84 | """ |
| 83 | Encode a batch of images efficiently via network service. | 85 | Encode a batch of images efficiently via network service. |
| @@ -98,7 +100,7 @@ class CLIPImageEncoder: | @@ -98,7 +100,7 @@ class CLIPImageEncoder: | ||
| 98 | results: List[np.ndarray] = [] | 100 | results: List[np.ndarray] = [] |
| 99 | for i in range(0, len(images), batch_size): | 101 | for i in range(0, len(images), batch_size): |
| 100 | batch_urls = [str(u).strip() for u in images[i:i + batch_size]] | 102 | batch_urls = [str(u).strip() for u in images[i:i + batch_size]] |
| 101 | - response_data = self._call_service(batch_urls) | 103 | + response_data = self._call_service(batch_urls, normalize_embeddings=normalize_embeddings) |
| 102 | if not response_data or len(response_data) != len(batch_urls): | 104 | if not response_data or len(response_data) != len(batch_urls): |
| 103 | raise RuntimeError( | 105 | raise RuntimeError( |
| 104 | f"Image embedding response length mismatch: expected {len(batch_urls)}, " | 106 | f"Image embedding response length mismatch: expected {len(batch_urls)}, " |
| @@ -119,6 +121,7 @@ class CLIPImageEncoder: | @@ -119,6 +121,7 @@ class CLIPImageEncoder: | ||
| 119 | self, | 121 | self, |
| 120 | urls: List[str], | 122 | urls: List[str], |
| 121 | batch_size: Optional[int] = None, | 123 | batch_size: Optional[int] = None, |
| 124 | + normalize_embeddings: bool = True, | ||
| 122 | ) -> List[np.ndarray]: | 125 | ) -> List[np.ndarray]: |
| 123 | """ | 126 | """ |
| 124 | 与 ClipImageModel / ClipAsServiceImageEncoder 一致的接口,供索引器 document_transformer 调用。 | 127 | 与 ClipImageModel / ClipAsServiceImageEncoder 一致的接口,供索引器 document_transformer 调用。 |
| @@ -130,4 +133,8 @@ class CLIPImageEncoder: | @@ -130,4 +133,8 @@ class CLIPImageEncoder: | ||
| 130 | Returns: | 133 | Returns: |
| 131 | 与 urls 等长的向量列表 | 134 | 与 urls 等长的向量列表 |
| 132 | """ | 135 | """ |
| 133 | - return self.encode_batch(urls, batch_size=batch_size or 8) | 136 | + return self.encode_batch( |
| 137 | + urls, | ||
| 138 | + batch_size=batch_size or 8, | ||
| 139 | + normalize_embeddings=normalize_embeddings, | ||
| 140 | + ) |
embeddings/protocols.py
| @@ -17,6 +17,7 @@ class ImageEncoderProtocol(Protocol): | @@ -17,6 +17,7 @@ class ImageEncoderProtocol(Protocol): | ||
| 17 | self, | 17 | self, |
| 18 | urls: List[str], | 18 | urls: List[str], |
| 19 | batch_size: Optional[int] = None, | 19 | batch_size: Optional[int] = None, |
| 20 | + normalize_embeddings: bool = True, | ||
| 20 | ) -> List[Optional[np.ndarray]]: | 21 | ) -> List[Optional[np.ndarray]]: |
| 21 | """ | 22 | """ |
| 22 | Encode a list of image URLs to vectors. | 23 | Encode a list of image URLs to vectors. |
embeddings/server.py
| @@ -112,14 +112,24 @@ def load_models(): | @@ -112,14 +112,24 @@ def load_models(): | ||
| 112 | logger.info("All embedding models loaded successfully, service ready") | 112 | logger.info("All embedding models loaded successfully, service ready") |
| 113 | 113 | ||
| 114 | 114 | ||
| 115 | -def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: | 115 | +def _normalize_vector(vec: np.ndarray) -> np.ndarray: |
| 116 | + norm = float(np.linalg.norm(vec)) | ||
| 117 | + if not np.isfinite(norm) or norm <= 0.0: | ||
| 118 | + raise RuntimeError("Embedding vector has invalid norm (must be > 0)") | ||
| 119 | + return vec / norm | ||
| 120 | + | ||
| 121 | + | ||
| 122 | +def _as_list(embedding: Optional[np.ndarray], normalize: bool = False) -> Optional[List[float]]: | ||
| 116 | if embedding is None: | 123 | if embedding is None: |
| 117 | return None | 124 | return None |
| 118 | if not isinstance(embedding, np.ndarray): | 125 | if not isinstance(embedding, np.ndarray): |
| 119 | embedding = np.array(embedding, dtype=np.float32) | 126 | embedding = np.array(embedding, dtype=np.float32) |
| 120 | if embedding.ndim != 1: | 127 | if embedding.ndim != 1: |
| 121 | embedding = embedding.reshape(-1) | 128 | embedding = embedding.reshape(-1) |
| 122 | - return embedding.astype(np.float32).tolist() | 129 | + embedding = embedding.astype(np.float32, copy=False) |
| 130 | + if normalize: | ||
| 131 | + embedding = _normalize_vector(embedding).astype(np.float32, copy=False) | ||
| 132 | + return embedding.tolist() | ||
| 123 | 133 | ||
| 124 | 134 | ||
| 125 | @app.get("/health") | 135 | @app.get("/health") |
| @@ -134,9 +144,10 @@ def health() -> Dict[str, Any]: | @@ -134,9 +144,10 @@ def health() -> Dict[str, Any]: | ||
| 134 | 144 | ||
| 135 | 145 | ||
| 136 | @app.post("/embed/text") | 146 | @app.post("/embed/text") |
| 137 | -def embed_text(texts: List[str]) -> List[Optional[List[float]]]: | 147 | +def embed_text(texts: List[str], normalize: Optional[bool] = None) -> List[Optional[List[float]]]: |
| 138 | if _text_model is None: | 148 | if _text_model is None: |
| 139 | raise RuntimeError("Text model not loaded") | 149 | raise RuntimeError("Text model not loaded") |
| 150 | + effective_normalize = bool(CONFIG.TEXT_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) | ||
| 140 | normalized: List[str] = [] | 151 | normalized: List[str] = [] |
| 141 | for i, t in enumerate(texts): | 152 | for i, t in enumerate(texts): |
| 142 | if not isinstance(t, str): | 153 | if not isinstance(t, str): |
| @@ -152,7 +163,7 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: | @@ -152,7 +163,7 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: | ||
| 152 | normalized, | 163 | normalized, |
| 153 | batch_size=int(CONFIG.TEXT_BATCH_SIZE), | 164 | batch_size=int(CONFIG.TEXT_BATCH_SIZE), |
| 154 | device=CONFIG.TEXT_DEVICE, | 165 | device=CONFIG.TEXT_DEVICE, |
| 155 | - normalize_embeddings=bool(CONFIG.TEXT_NORMALIZE_EMBEDDINGS), | 166 | + normalize_embeddings=effective_normalize, |
| 156 | ) | 167 | ) |
| 157 | except Exception as e: | 168 | except Exception as e: |
| 158 | logger.error("Text embedding backend failure: %s", e, exc_info=True) | 169 | logger.error("Text embedding backend failure: %s", e, exc_info=True) |
| @@ -167,7 +178,7 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: | @@ -167,7 +178,7 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: | ||
| 167 | ) | 178 | ) |
| 168 | out: List[Optional[List[float]]] = [] | 179 | out: List[Optional[List[float]]] = [] |
| 169 | for i, emb in enumerate(embs): | 180 | for i, emb in enumerate(embs): |
| 170 | - vec = _as_list(emb) | 181 | + vec = _as_list(emb, normalize=effective_normalize) |
| 171 | if vec is None: | 182 | if vec is None: |
| 172 | raise RuntimeError(f"Text model returned empty embedding for index {i}") | 183 | raise RuntimeError(f"Text model returned empty embedding for index {i}") |
| 173 | out.append(vec) | 184 | out.append(vec) |
| @@ -175,9 +186,10 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: | @@ -175,9 +186,10 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: | ||
| 175 | 186 | ||
| 176 | 187 | ||
| 177 | @app.post("/embed/image") | 188 | @app.post("/embed/image") |
| 178 | -def embed_image(images: List[str]) -> List[Optional[List[float]]]: | 189 | +def embed_image(images: List[str], normalize: Optional[bool] = None) -> List[Optional[List[float]]]: |
| 179 | if _image_model is None: | 190 | if _image_model is None: |
| 180 | raise RuntimeError("Image model not loaded") | 191 | raise RuntimeError("Image model not loaded") |
| 192 | + effective_normalize = bool(CONFIG.IMAGE_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) | ||
| 181 | urls: List[str] = [] | 193 | urls: List[str] = [] |
| 182 | for i, url_or_path in enumerate(images): | 194 | for i, url_or_path in enumerate(images): |
| 183 | if not isinstance(url_or_path, str): | 195 | if not isinstance(url_or_path, str): |
| @@ -188,7 +200,11 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]: | @@ -188,7 +200,11 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]: | ||
| 188 | urls.append(s) | 200 | urls.append(s) |
| 189 | 201 | ||
| 190 | with _image_encode_lock: | 202 | with _image_encode_lock: |
| 191 | - vectors = _image_model.encode_image_urls(urls, batch_size=CONFIG.IMAGE_BATCH_SIZE) | 203 | + vectors = _image_model.encode_image_urls( |
| 204 | + urls, | ||
| 205 | + batch_size=CONFIG.IMAGE_BATCH_SIZE, | ||
| 206 | + normalize_embeddings=effective_normalize, | ||
| 207 | + ) | ||
| 192 | if vectors is None or len(vectors) != len(urls): | 208 | if vectors is None or len(vectors) != len(urls): |
| 193 | raise RuntimeError( | 209 | raise RuntimeError( |
| 194 | f"Image model response length mismatch: expected {len(urls)}, " | 210 | f"Image model response length mismatch: expected {len(urls)}, " |
| @@ -196,7 +212,7 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]: | @@ -196,7 +212,7 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]: | ||
| 196 | ) | 212 | ) |
| 197 | out: List[Optional[List[float]]] = [] | 213 | out: List[Optional[List[float]]] = [] |
| 198 | for i, vec in enumerate(vectors): | 214 | for i, vec in enumerate(vectors): |
| 199 | - out_vec = _as_list(vec) | 215 | + out_vec = _as_list(vec, normalize=effective_normalize) |
| 200 | if out_vec is None: | 216 | if out_vec is None: |
| 201 | raise RuntimeError(f"Image model returned empty embedding for index {i}") | 217 | raise RuntimeError(f"Image model returned empty embedding for index {i}") |
| 202 | out.append(out_vec) | 218 | out.append(out_vec) |
embeddings/text_encoder.py
| @@ -50,7 +50,7 @@ class TextEmbeddingEncoder: | @@ -50,7 +50,7 @@ class TextEmbeddingEncoder: | ||
| 50 | logger.warning("Failed to initialize Redis cache for embeddings: %s, continuing without cache", e) | 50 | logger.warning("Failed to initialize Redis cache for embeddings: %s, continuing without cache", e) |
| 51 | self.redis_client = None | 51 | self.redis_client = None |
| 52 | 52 | ||
| 53 | - def _call_service(self, request_data: List[str]) -> List[Any]: | 53 | + def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: |
| 54 | """ | 54 | """ |
| 55 | Call the embedding service API. | 55 | Call the embedding service API. |
| 56 | 56 | ||
| @@ -63,6 +63,7 @@ class TextEmbeddingEncoder: | @@ -63,6 +63,7 @@ class TextEmbeddingEncoder: | ||
| 63 | try: | 63 | try: |
| 64 | response = requests.post( | 64 | response = requests.post( |
| 65 | self.endpoint, | 65 | self.endpoint, |
| 66 | + params={"normalize": "true" if normalize_embeddings else "false"}, | ||
| 66 | json=request_data, | 67 | json=request_data, |
| 67 | timeout=60 | 68 | timeout=60 |
| 68 | ) | 69 | ) |
| @@ -84,7 +85,7 @@ class TextEmbeddingEncoder: | @@ -84,7 +85,7 @@ class TextEmbeddingEncoder: | ||
| 84 | 85 | ||
| 85 | Args: | 86 | Args: |
| 86 | sentences: Single string or list of strings to encode | 87 | sentences: Single string or list of strings to encode |
| 87 | - normalize_embeddings: Whether to normalize embeddings (ignored for service) | 88 | + normalize_embeddings: Whether to request normalized embeddings from service |
| 88 | device: Device parameter ignored for service compatibility | 89 | device: Device parameter ignored for service compatibility |
| 89 | batch_size: Batch size for processing (used for service requests) | 90 | batch_size: Batch size for processing (used for service requests) |
| 90 | 91 | ||
| @@ -103,7 +104,7 @@ class TextEmbeddingEncoder: | @@ -103,7 +104,7 @@ class TextEmbeddingEncoder: | ||
| 103 | embeddings: List[Optional[np.ndarray]] = [None] * len(sentences) | 104 | embeddings: List[Optional[np.ndarray]] = [None] * len(sentences) |
| 104 | 105 | ||
| 105 | for i, text in enumerate(sentences): | 106 | for i, text in enumerate(sentences): |
| 106 | - cached = self._get_cached_embedding(text, "generic") | 107 | + cached = self._get_cached_embedding(text, "generic", normalize_embeddings) |
| 107 | if cached is not None: | 108 | if cached is not None: |
| 108 | embeddings[i] = cached | 109 | embeddings[i] = cached |
| 109 | else: | 110 | else: |
| @@ -115,7 +116,7 @@ class TextEmbeddingEncoder: | @@ -115,7 +116,7 @@ class TextEmbeddingEncoder: | ||
| 115 | 116 | ||
| 116 | # If there are uncached texts, call service | 117 | # If there are uncached texts, call service |
| 117 | if uncached_texts: | 118 | if uncached_texts: |
| 118 | - response_data = self._call_service(request_data) | 119 | + response_data = self._call_service(request_data, normalize_embeddings=normalize_embeddings) |
| 119 | 120 | ||
| 120 | # Process response | 121 | # Process response |
| 121 | for i, text in enumerate(uncached_texts): | 122 | for i, text in enumerate(uncached_texts): |
| @@ -129,7 +130,7 @@ class TextEmbeddingEncoder: | @@ -129,7 +130,7 @@ class TextEmbeddingEncoder: | ||
| 129 | embedding_array = np.array(embedding, dtype=np.float32) | 130 | embedding_array = np.array(embedding, dtype=np.float32) |
| 130 | if self._is_valid_embedding(embedding_array): | 131 | if self._is_valid_embedding(embedding_array): |
| 131 | embeddings[original_idx] = embedding_array | 132 | embeddings[original_idx] = embedding_array |
| 132 | - self._set_cached_embedding(text, "generic", embedding_array) | 133 | + self._set_cached_embedding(text, "generic", embedding_array, normalize_embeddings) |
| 133 | else: | 134 | else: |
| 134 | raise ValueError( | 135 | raise ValueError( |
| 135 | f"Invalid embedding returned from service for text index {original_idx}" | 136 | f"Invalid embedding returned from service for text index {original_idx}" |
| @@ -144,7 +145,8 @@ class TextEmbeddingEncoder: | @@ -144,7 +145,8 @@ class TextEmbeddingEncoder: | ||
| 144 | self, | 145 | self, |
| 145 | texts: List[str], | 146 | texts: List[str], |
| 146 | batch_size: int = 32, | 147 | batch_size: int = 32, |
| 147 | - device: str = 'cpu' | 148 | + device: str = 'cpu', |
| 149 | + normalize_embeddings: bool = True, | ||
| 148 | ) -> np.ndarray: | 150 | ) -> np.ndarray: |
| 149 | """ | 151 | """ |
| 150 | Encode a batch of texts efficiently via network service. | 152 | Encode a batch of texts efficiently via network service. |
| @@ -157,11 +159,17 @@ class TextEmbeddingEncoder: | @@ -157,11 +159,17 @@ class TextEmbeddingEncoder: | ||
| 157 | Returns: | 159 | Returns: |
| 158 | numpy array of embeddings | 160 | numpy array of embeddings |
| 159 | """ | 161 | """ |
| 160 | - return self.encode(texts, batch_size=batch_size, device=device) | 162 | + return self.encode( |
| 163 | + texts, | ||
| 164 | + batch_size=batch_size, | ||
| 165 | + device=device, | ||
| 166 | + normalize_embeddings=normalize_embeddings, | ||
| 167 | + ) | ||
| 161 | 168 | ||
| 162 | - def _get_cache_key(self, query: str, language: str) -> str: | 169 | + def _get_cache_key(self, query: str, language: str, normalize_embeddings: bool = True) -> str: |
| 163 | """Generate a cache key for the query""" | 170 | """Generate a cache key for the query""" |
| 164 | - return f"embedding:{language}:{query}" | 171 | + norm_flag = "norm1" if normalize_embeddings else "norm0" |
| 172 | + return f"embedding:{language}:{norm_flag}:{query}" | ||
| 165 | 173 | ||
| 166 | def _is_valid_embedding(self, embedding: np.ndarray) -> bool: | 174 | def _is_valid_embedding(self, embedding: np.ndarray) -> bool: |
| 167 | """ | 175 | """ |
| @@ -184,13 +192,18 @@ class TextEmbeddingEncoder: | @@ -184,13 +192,18 @@ class TextEmbeddingEncoder: | ||
| 184 | return False | 192 | return False |
| 185 | return True | 193 | return True |
| 186 | 194 | ||
| 187 | - def _get_cached_embedding(self, query: str, language: str) -> Optional[np.ndarray]: | 195 | + def _get_cached_embedding( |
| 196 | + self, | ||
| 197 | + query: str, | ||
| 198 | + language: str, | ||
| 199 | + normalize_embeddings: bool = True, | ||
| 200 | + ) -> Optional[np.ndarray]: | ||
| 188 | """Get embedding from cache if exists (with sliding expiration)""" | 201 | """Get embedding from cache if exists (with sliding expiration)""" |
| 189 | if not self.redis_client: | 202 | if not self.redis_client: |
| 190 | return None | 203 | return None |
| 191 | 204 | ||
| 192 | try: | 205 | try: |
| 193 | - cache_key = self._get_cache_key(query, language) | 206 | + cache_key = self._get_cache_key(query, language, normalize_embeddings) |
| 194 | cached_data = self.redis_client.get(cache_key) | 207 | cached_data = self.redis_client.get(cache_key) |
| 195 | if cached_data: | 208 | if cached_data: |
| 196 | embedding = pickle.loads(cached_data) | 209 | embedding = pickle.loads(cached_data) |
| @@ -216,13 +229,19 @@ class TextEmbeddingEncoder: | @@ -216,13 +229,19 @@ class TextEmbeddingEncoder: | ||
| 216 | logger.error(f"Error retrieving embedding from cache: {e}") | 229 | logger.error(f"Error retrieving embedding from cache: {e}") |
| 217 | return None | 230 | return None |
| 218 | 231 | ||
| 219 | - def _set_cached_embedding(self, query: str, language: str, embedding: np.ndarray) -> bool: | 232 | + def _set_cached_embedding( |
| 233 | + self, | ||
| 234 | + query: str, | ||
| 235 | + language: str, | ||
| 236 | + embedding: np.ndarray, | ||
| 237 | + normalize_embeddings: bool = True, | ||
| 238 | + ) -> bool: | ||
| 220 | """Store embedding in cache""" | 239 | """Store embedding in cache""" |
| 221 | if not self.redis_client: | 240 | if not self.redis_client: |
| 222 | return False | 241 | return False |
| 223 | 242 | ||
| 224 | try: | 243 | try: |
| 225 | - cache_key = self._get_cache_key(query, language) | 244 | + cache_key = self._get_cache_key(query, language, normalize_embeddings) |
| 226 | serialized_data = pickle.dumps(embedding) | 245 | serialized_data = pickle.dumps(embedding) |
| 227 | self.redis_client.setex( | 246 | self.redis_client.setex( |
| 228 | cache_key, | 247 | cache_key, |
restart.sh
| @@ -4,4 +4,7 @@ | @@ -4,4 +4,7 @@ | ||
| 4 | 4 | ||
| 5 | cd "$(dirname "$0")" | 5 | cd "$(dirname "$0")" |
| 6 | 6 | ||
| 7 | -./scripts/service_ctl.sh restart | 7 | +START_EMBEDDING=1 START_TRANSLATOR=1 START_RERANKER=1 START_TEI=1 CNCLIP_DEVICE=cuda TEI_USE_GPU=1 ./scripts/service_ctl.sh restart |
| 8 | + | ||
| 9 | +# ./scripts/service_ctl.sh restart | ||
| 10 | + |
scripts/service_ctl.sh
| @@ -190,7 +190,14 @@ start_one() { | @@ -190,7 +190,14 @@ start_one() { | ||
| 190 | nohup bash -lc "${cmd}" > "${lf}" 2>&1 & | 190 | nohup bash -lc "${cmd}" > "${lf}" 2>&1 & |
| 191 | local pid=$! | 191 | local pid=$! |
| 192 | echo "${pid}" > "${pf}" | 192 | echo "${pid}" > "${pf}" |
| 193 | - if wait_for_health "${service}"; then | 193 | + # Some services (notably reranker with vLLM backend) can take longer |
| 194 | + # to load models / compile graphs on first start. Give them a longer | ||
| 195 | + # health check window to avoid false negatives. | ||
| 196 | + local retries=30 | ||
| 197 | + if [ "${service}" = "reranker" ]; then | ||
| 198 | + retries=90 | ||
| 199 | + fi | ||
| 200 | + if wait_for_health "${service}" "${retries}"; then | ||
| 194 | echo "[ok] ${service} healthy (pid=${pid}, log=${lf})" | 201 | echo "[ok] ${service} healthy (pid=${pid}, log=${lf})" |
| 195 | else | 202 | else |
| 196 | echo "[error] ${service} health check timeout, inspect ${lf}" >&2 | 203 | echo "[error] ${service} health check timeout, inspect ${lf}" >&2 |
| @@ -320,9 +327,14 @@ resolve_targets() { | @@ -320,9 +327,14 @@ resolve_targets() { | ||
| 320 | if [ "${START_RERANKER:-0}" = "1" ]; then targets+=("reranker"); fi | 327 | if [ "${START_RERANKER:-0}" = "1" ]; then targets+=("reranker"); fi |
| 321 | echo "${targets[@]}" | 328 | echo "${targets[@]}" |
| 322 | ;; | 329 | ;; |
| 323 | - stop|restart|status) | 330 | + stop|status) |
| 324 | echo "$(all_services)" | 331 | echo "$(all_services)" |
| 325 | ;; | 332 | ;; |
| 333 | + restart) | ||
| 334 | + # Restart with no explicit services should preserve start-order dependency | ||
| 335 | + # behavior (e.g. tei/cnclip before embedding). | ||
| 336 | + echo "$(resolve_targets start)" | ||
| 337 | + ;; | ||
| 326 | *) | 338 | *) |
| 327 | echo "" | 339 | echo "" |
| 328 | ;; | 340 | ;; |
| @@ -340,7 +352,7 @@ Usage: | @@ -340,7 +352,7 @@ Usage: | ||
| 340 | Default target set (when no service provided): | 352 | Default target set (when no service provided): |
| 341 | start -> backend indexer frontend (+ optional by env flags) | 353 | start -> backend indexer frontend (+ optional by env flags) |
| 342 | stop -> all known services | 354 | stop -> all known services |
| 343 | - restart -> all known services | 355 | + restart -> stop all known services, then start with start targets |
| 344 | status -> all known services | 356 | status -> all known services |
| 345 | 357 | ||
| 346 | Optional startup flags: | 358 | Optional startup flags: |
| @@ -361,7 +373,13 @@ main() { | @@ -361,7 +373,13 @@ main() { | ||
| 361 | shift || true | 373 | shift || true |
| 362 | 374 | ||
| 363 | load_env_file "${PROJECT_ROOT}/.env" | 375 | load_env_file "${PROJECT_ROOT}/.env" |
| 376 | + local stop_targets="" | ||
| 364 | local targets | 377 | local targets |
| 378 | + # For restart without explicit services, stop everything first, then start | ||
| 379 | + # with dependency-aware start targets. | ||
| 380 | + if [ "${action}" = "restart" ] && [ "$#" -eq 0 ]; then | ||
| 381 | + stop_targets="$(resolve_targets stop)" | ||
| 382 | + fi | ||
| 365 | targets="$(resolve_targets "${action}" "$@")" | 383 | targets="$(resolve_targets "${action}" "$@")" |
| 366 | if [ -z "${targets}" ]; then | 384 | if [ -z "${targets}" ]; then |
| 367 | usage | 385 | usage |
| @@ -380,7 +398,7 @@ main() { | @@ -380,7 +398,7 @@ main() { | ||
| 380 | done | 398 | done |
| 381 | ;; | 399 | ;; |
| 382 | restart) | 400 | restart) |
| 383 | - for svc in ${targets}; do | 401 | + for svc in ${stop_targets:-${targets}}; do |
| 384 | stop_one "${svc}" | 402 | stop_one "${svc}" |
| 385 | done | 403 | done |
| 386 | for svc in ${targets}; do | 404 | for svc in ${targets}; do |