Commit 200fdddf9b8cac307bb67bd9eac4bef1eb9a0e22
1 parent
654f20d1
embed norm
Showing
10 changed files
with
128 additions
and
44 deletions
Show diff stats
embeddings/README.md
| ... | ... | @@ -14,7 +14,7 @@ |
| 14 | 14 | - **clip-as-service 客户端**:`clip_as_service_encoder.py`(图片向量,推荐) |
| 15 | 15 | - **向量化服务(FastAPI)**:`server.py` |
| 16 | 16 | - **统一配置**:`config.py` |
| 17 | -- **接口契约**:`protocols.ImageEncoderProtocol`(图片编码统一为 `encode_image_urls(urls, batch_size)`,本地 CN-CLIP 与 clip-as-service 均实现该接口) | |
| 17 | +- **接口契约**:`protocols.ImageEncoderProtocol`(图片编码统一为 `encode_image_urls(urls, batch_size, normalize_embeddings)`,本地 CN-CLIP 与 clip-as-service 均实现该接口) | |
| 18 | 18 | |
| 19 | 19 | 说明:历史上的云端 embedding 试验实现(DashScope)已从主仓库移除,当前仅维护 6005 这条统一向量服务链路。 |
| 20 | 20 | |
| ... | ... | @@ -29,10 +29,12 @@ |
| 29 | 29 | |
| 30 | 30 | - `POST /embed/text` |
| 31 | 31 | - 入参:`["文本1", "文本2", ...]` |
| 32 | + - 可选 query 参数:`normalize=true|false`(不传则使用服务端默认) | |
| 32 | 33 | - 出参:`[[...], [...], ...]`(与输入按 index 对齐,失败直接报错) |
| 33 | 34 | |
| 34 | 35 | - `POST /embed/image` |
| 35 | 36 | - 入参:`["url或本地路径1", ...]` |
| 37 | + - 可选 query 参数:`normalize=true|false`(不传则使用服务端默认) | |
| 36 | 38 | - 出参:`[[...], [...], ...]`(与输入按 index 对齐,失败直接报错) |
| 37 | 39 | |
| 38 | 40 | ### 图片向量:clip-as-service(推荐) |
| ... | ... | @@ -77,6 +79,7 @@ TEI_USE_GPU=0 ./scripts/start_tei_service.sh |
| 77 | 79 | |
| 78 | 80 | - `PORT`: 服务端口(默认 6005) |
| 79 | 81 | - `TEXT_MODEL_ID`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE`, `TEXT_NORMALIZE_EMBEDDINGS` |
| 82 | +- `IMAGE_NORMALIZE_EMBEDDINGS`(默认 true) | |
| 80 | 83 | - `USE_CLIP_AS_SERVICE`, `CLIP_AS_SERVICE_SERVER`:图片向量(clip-as-service) |
| 81 | 84 | - `IMAGE_MODEL_NAME`, `IMAGE_DEVICE`:本地 CN-CLIP(当 `USE_CLIP_AS_SERVICE=false` 时) |
| 82 | 85 | - TEI 相关:`TEI_USE_GPU`、`TEI_VERSION`、`TEI_MAX_BATCH_TOKENS`、`TEI_MAX_CLIENT_BATCH_SIZE`、`TEI_HEALTH_TIMEOUT_SEC` | ... | ... |
embeddings/clip_as_service_encoder.py
| ... | ... | @@ -79,6 +79,7 @@ class ClipAsServiceImageEncoder: |
| 79 | 79 | self, |
| 80 | 80 | urls: List[str], |
| 81 | 81 | batch_size: Optional[int] = None, |
| 82 | + normalize_embeddings: bool = True, | |
| 82 | 83 | ) -> List[np.ndarray]: |
| 83 | 84 | """ |
| 84 | 85 | Encode a list of image URLs to vectors. |
| ... | ... | @@ -117,10 +118,15 @@ class ClipAsServiceImageEncoder: |
| 117 | 118 | vec = np.asarray(row, dtype=np.float32) |
| 118 | 119 | if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): |
| 119 | 120 | raise RuntimeError("clip-as-service returned invalid embedding vector") |
| 121 | + if normalize_embeddings: | |
| 122 | + norm = float(np.linalg.norm(vec)) | |
| 123 | + if not np.isfinite(norm) or norm <= 0.0: | |
| 124 | + raise RuntimeError("clip-as-service returned zero/invalid norm vector") | |
| 125 | + vec = vec / norm | |
| 120 | 126 | out.append(vec) |
| 121 | 127 | return out |
| 122 | 128 | |
| 123 | - def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: | |
| 129 | + def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> Optional[np.ndarray]: | |
| 124 | 130 | """Encode a single image URL. Returns 1024-dim vector or None.""" |
| 125 | - results = self.encode_image_urls([url], batch_size=1) | |
| 131 | + results = self.encode_image_urls([url], batch_size=1, normalize_embeddings=normalize_embeddings) | |
| 126 | 132 | return results[0] if results else None | ... | ... |
embeddings/clip_model.py
| ... | ... | @@ -76,25 +76,27 @@ class ClipImageModel(object): |
| 76 | 76 | text_features /= text_features.norm(dim=-1, keepdim=True) |
| 77 | 77 | return text_features |
| 78 | 78 | |
| 79 | - def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: | |
| 79 | + def encode_image(self, image: Image.Image, normalize_embeddings: bool = True) -> Optional[np.ndarray]: | |
| 80 | 80 | if not isinstance(image, Image.Image): |
| 81 | 81 | raise ValueError("ClipImageModel.encode_image input must be a PIL.Image") |
| 82 | 82 | infer_data = self.preprocess(image).unsqueeze(0).to(self.device) |
| 83 | 83 | with torch.no_grad(): |
| 84 | 84 | image_features = self.model.encode_image(infer_data) |
| 85 | - image_features /= image_features.norm(dim=-1, keepdim=True) | |
| 85 | + if normalize_embeddings: | |
| 86 | + image_features /= image_features.norm(dim=-1, keepdim=True) | |
| 86 | 87 | return image_features.cpu().numpy().astype("float32")[0] |
| 87 | 88 | |
| 88 | - def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: | |
| 89 | + def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> Optional[np.ndarray]: | |
| 89 | 90 | image_data = self.download_image(url) |
| 90 | 91 | image = self.validate_image(image_data) |
| 91 | 92 | image = self.preprocess_image(image) |
| 92 | - return self.encode_image(image) | |
| 93 | + return self.encode_image(image, normalize_embeddings=normalize_embeddings) | |
| 93 | 94 | |
| 94 | 95 | def encode_image_urls( |
| 95 | 96 | self, |
| 96 | 97 | urls: List[str], |
| 97 | 98 | batch_size: Optional[int] = None, |
| 99 | + normalize_embeddings: bool = True, | |
| 98 | 100 | ) -> List[Optional[np.ndarray]]: |
| 99 | 101 | """ |
| 100 | 102 | Encode a list of image URLs to vectors. Same interface as ClipAsServiceImageEncoder. |
| ... | ... | @@ -106,19 +108,27 @@ class ClipImageModel(object): |
| 106 | 108 | Returns: |
| 107 | 109 | List of vectors (or None for failed items), same length as urls. |
| 108 | 110 | """ |
| 109 | - return self.encode_batch(urls, batch_size=batch_size or 8) | |
| 111 | + return self.encode_batch( | |
| 112 | + urls, | |
| 113 | + batch_size=batch_size or 8, | |
| 114 | + normalize_embeddings=normalize_embeddings, | |
| 115 | + ) | |
| 110 | 116 | |
| 111 | - def encode_batch(self, images: List[Union[str, Image.Image]], batch_size: int = 8) -> List[Optional[np.ndarray]]: | |
| 117 | + def encode_batch( | |
| 118 | + self, | |
| 119 | + images: List[Union[str, Image.Image]], | |
| 120 | + batch_size: int = 8, | |
| 121 | + normalize_embeddings: bool = True, | |
| 122 | + ) -> List[Optional[np.ndarray]]: | |
| 112 | 123 | results: List[Optional[np.ndarray]] = [] |
| 113 | 124 | for i in range(0, len(images), batch_size): |
| 114 | 125 | batch = images[i : i + batch_size] |
| 115 | 126 | for img in batch: |
| 116 | 127 | if isinstance(img, str): |
| 117 | - results.append(self.encode_image_from_url(img)) | |
| 128 | + results.append(self.encode_image_from_url(img, normalize_embeddings=normalize_embeddings)) | |
| 118 | 129 | elif isinstance(img, Image.Image): |
| 119 | - results.append(self.encode_image(img)) | |
| 130 | + results.append(self.encode_image(img, normalize_embeddings=normalize_embeddings)) | |
| 120 | 131 | else: |
| 121 | 132 | results.append(None) |
| 122 | 133 | return results |
| 123 | 134 | |
| 124 | - | ... | ... |
embeddings/config.py
embeddings/image_encoder.py
| ... | ... | @@ -26,7 +26,7 @@ class CLIPImageEncoder: |
| 26 | 26 | self.endpoint = f"{self.service_url}/embed/image" |
| 27 | 27 | logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url) |
| 28 | 28 | |
| 29 | - def _call_service(self, request_data: List[str]) -> List[Any]: | |
| 29 | + def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: | |
| 30 | 30 | """ |
| 31 | 31 | Call the embedding service API. |
| 32 | 32 | |
| ... | ... | @@ -39,6 +39,7 @@ class CLIPImageEncoder: |
| 39 | 39 | try: |
| 40 | 40 | response = requests.post( |
| 41 | 41 | self.endpoint, |
| 42 | + params={"normalize": "true" if normalize_embeddings else "false"}, | |
| 42 | 43 | json=request_data, |
| 43 | 44 | timeout=60 |
| 44 | 45 | ) |
| ... | ... | @@ -56,7 +57,7 @@ class CLIPImageEncoder: |
| 56 | 57 | """ |
| 57 | 58 | raise NotImplementedError("encode_image with PIL Image is not supported by embedding service") |
| 58 | 59 | |
| 59 | - def encode_image_from_url(self, url: str) -> np.ndarray: | |
| 60 | + def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> np.ndarray: | |
| 60 | 61 | """ |
| 61 | 62 | Generate image embedding via network service using URL. |
| 62 | 63 | |
| ... | ... | @@ -66,7 +67,7 @@ class CLIPImageEncoder: |
| 66 | 67 | Returns: |
| 67 | 68 | Embedding vector |
| 68 | 69 | """ |
| 69 | - response_data = self._call_service([url]) | |
| 70 | + response_data = self._call_service([url], normalize_embeddings=normalize_embeddings) | |
| 70 | 71 | if not response_data or len(response_data) != 1 or response_data[0] is None: |
| 71 | 72 | raise RuntimeError(f"No image embedding returned for URL: {url}") |
| 72 | 73 | vec = np.array(response_data[0], dtype=np.float32) |
| ... | ... | @@ -77,7 +78,8 @@ class CLIPImageEncoder: |
| 77 | 78 | def encode_batch( |
| 78 | 79 | self, |
| 79 | 80 | images: List[Union[str, Image.Image]], |
| 80 | - batch_size: int = 8 | |
| 81 | + batch_size: int = 8, | |
| 82 | + normalize_embeddings: bool = True, | |
| 81 | 83 | ) -> List[np.ndarray]: |
| 82 | 84 | """ |
| 83 | 85 | Encode a batch of images efficiently via network service. |
| ... | ... | @@ -98,7 +100,7 @@ class CLIPImageEncoder: |
| 98 | 100 | results: List[np.ndarray] = [] |
| 99 | 101 | for i in range(0, len(images), batch_size): |
| 100 | 102 | batch_urls = [str(u).strip() for u in images[i:i + batch_size]] |
| 101 | - response_data = self._call_service(batch_urls) | |
| 103 | + response_data = self._call_service(batch_urls, normalize_embeddings=normalize_embeddings) | |
| 102 | 104 | if not response_data or len(response_data) != len(batch_urls): |
| 103 | 105 | raise RuntimeError( |
| 104 | 106 | f"Image embedding response length mismatch: expected {len(batch_urls)}, " |
| ... | ... | @@ -119,6 +121,7 @@ class CLIPImageEncoder: |
| 119 | 121 | self, |
| 120 | 122 | urls: List[str], |
| 121 | 123 | batch_size: Optional[int] = None, |
| 124 | + normalize_embeddings: bool = True, | |
| 122 | 125 | ) -> List[np.ndarray]: |
| 123 | 126 | """ |
| 124 | 127 | 与 ClipImageModel / ClipAsServiceImageEncoder 一致的接口,供索引器 document_transformer 调用。 |
| ... | ... | @@ -130,4 +133,8 @@ class CLIPImageEncoder: |
| 130 | 133 | Returns: |
| 131 | 134 | 与 urls 等长的向量列表 |
| 132 | 135 | """ |
| 133 | - return self.encode_batch(urls, batch_size=batch_size or 8) | |
| 136 | + return self.encode_batch( | |
| 137 | + urls, | |
| 138 | + batch_size=batch_size or 8, | |
| 139 | + normalize_embeddings=normalize_embeddings, | |
| 140 | + ) | ... | ... |
embeddings/protocols.py
embeddings/server.py
| ... | ... | @@ -112,14 +112,24 @@ def load_models(): |
| 112 | 112 | logger.info("All embedding models loaded successfully, service ready") |
| 113 | 113 | |
| 114 | 114 | |
| 115 | -def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: | |
| 115 | +def _normalize_vector(vec: np.ndarray) -> np.ndarray: | |
| 116 | + norm = float(np.linalg.norm(vec)) | |
| 117 | + if not np.isfinite(norm) or norm <= 0.0: | |
| 118 | + raise RuntimeError("Embedding vector has invalid norm (must be > 0)") | |
| 119 | + return vec / norm | |
| 120 | + | |
| 121 | + | |
| 122 | +def _as_list(embedding: Optional[np.ndarray], normalize: bool = False) -> Optional[List[float]]: | |
| 116 | 123 | if embedding is None: |
| 117 | 124 | return None |
| 118 | 125 | if not isinstance(embedding, np.ndarray): |
| 119 | 126 | embedding = np.array(embedding, dtype=np.float32) |
| 120 | 127 | if embedding.ndim != 1: |
| 121 | 128 | embedding = embedding.reshape(-1) |
| 122 | - return embedding.astype(np.float32).tolist() | |
| 129 | + embedding = embedding.astype(np.float32, copy=False) | |
| 130 | + if normalize: | |
| 131 | + embedding = _normalize_vector(embedding).astype(np.float32, copy=False) | |
| 132 | + return embedding.tolist() | |
| 123 | 133 | |
| 124 | 134 | |
| 125 | 135 | @app.get("/health") |
| ... | ... | @@ -134,9 +144,10 @@ def health() -> Dict[str, Any]: |
| 134 | 144 | |
| 135 | 145 | |
| 136 | 146 | @app.post("/embed/text") |
| 137 | -def embed_text(texts: List[str]) -> List[Optional[List[float]]]: | |
| 147 | +def embed_text(texts: List[str], normalize: Optional[bool] = None) -> List[Optional[List[float]]]: | |
| 138 | 148 | if _text_model is None: |
| 139 | 149 | raise RuntimeError("Text model not loaded") |
| 150 | + effective_normalize = bool(CONFIG.TEXT_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) | |
| 140 | 151 | normalized: List[str] = [] |
| 141 | 152 | for i, t in enumerate(texts): |
| 142 | 153 | if not isinstance(t, str): |
| ... | ... | @@ -152,7 +163,7 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: |
| 152 | 163 | normalized, |
| 153 | 164 | batch_size=int(CONFIG.TEXT_BATCH_SIZE), |
| 154 | 165 | device=CONFIG.TEXT_DEVICE, |
| 155 | - normalize_embeddings=bool(CONFIG.TEXT_NORMALIZE_EMBEDDINGS), | |
| 166 | + normalize_embeddings=effective_normalize, | |
| 156 | 167 | ) |
| 157 | 168 | except Exception as e: |
| 158 | 169 | logger.error("Text embedding backend failure: %s", e, exc_info=True) |
| ... | ... | @@ -167,7 +178,7 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: |
| 167 | 178 | ) |
| 168 | 179 | out: List[Optional[List[float]]] = [] |
| 169 | 180 | for i, emb in enumerate(embs): |
| 170 | - vec = _as_list(emb) | |
| 181 | + vec = _as_list(emb, normalize=effective_normalize) | |
| 171 | 182 | if vec is None: |
| 172 | 183 | raise RuntimeError(f"Text model returned empty embedding for index {i}") |
| 173 | 184 | out.append(vec) |
| ... | ... | @@ -175,9 +186,10 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: |
| 175 | 186 | |
| 176 | 187 | |
| 177 | 188 | @app.post("/embed/image") |
| 178 | -def embed_image(images: List[str]) -> List[Optional[List[float]]]: | |
| 189 | +def embed_image(images: List[str], normalize: Optional[bool] = None) -> List[Optional[List[float]]]: | |
| 179 | 190 | if _image_model is None: |
| 180 | 191 | raise RuntimeError("Image model not loaded") |
| 192 | + effective_normalize = bool(CONFIG.IMAGE_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) | |
| 181 | 193 | urls: List[str] = [] |
| 182 | 194 | for i, url_or_path in enumerate(images): |
| 183 | 195 | if not isinstance(url_or_path, str): |
| ... | ... | @@ -188,7 +200,11 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]: |
| 188 | 200 | urls.append(s) |
| 189 | 201 | |
| 190 | 202 | with _image_encode_lock: |
| 191 | - vectors = _image_model.encode_image_urls(urls, batch_size=CONFIG.IMAGE_BATCH_SIZE) | |
| 203 | + vectors = _image_model.encode_image_urls( | |
| 204 | + urls, | |
| 205 | + batch_size=CONFIG.IMAGE_BATCH_SIZE, | |
| 206 | + normalize_embeddings=effective_normalize, | |
| 207 | + ) | |
| 192 | 208 | if vectors is None or len(vectors) != len(urls): |
| 193 | 209 | raise RuntimeError( |
| 194 | 210 | f"Image model response length mismatch: expected {len(urls)}, " |
| ... | ... | @@ -196,7 +212,7 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]: |
| 196 | 212 | ) |
| 197 | 213 | out: List[Optional[List[float]]] = [] |
| 198 | 214 | for i, vec in enumerate(vectors): |
| 199 | - out_vec = _as_list(vec) | |
| 215 | + out_vec = _as_list(vec, normalize=effective_normalize) | |
| 200 | 216 | if out_vec is None: |
| 201 | 217 | raise RuntimeError(f"Image model returned empty embedding for index {i}") |
| 202 | 218 | out.append(out_vec) | ... | ... |
embeddings/text_encoder.py
| ... | ... | @@ -50,7 +50,7 @@ class TextEmbeddingEncoder: |
| 50 | 50 | logger.warning("Failed to initialize Redis cache for embeddings: %s, continuing without cache", e) |
| 51 | 51 | self.redis_client = None |
| 52 | 52 | |
| 53 | - def _call_service(self, request_data: List[str]) -> List[Any]: | |
| 53 | + def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: | |
| 54 | 54 | """ |
| 55 | 55 | Call the embedding service API. |
| 56 | 56 | |
| ... | ... | @@ -63,6 +63,7 @@ class TextEmbeddingEncoder: |
| 63 | 63 | try: |
| 64 | 64 | response = requests.post( |
| 65 | 65 | self.endpoint, |
| 66 | + params={"normalize": "true" if normalize_embeddings else "false"}, | |
| 66 | 67 | json=request_data, |
| 67 | 68 | timeout=60 |
| 68 | 69 | ) |
| ... | ... | @@ -84,7 +85,7 @@ class TextEmbeddingEncoder: |
| 84 | 85 | |
| 85 | 86 | Args: |
| 86 | 87 | sentences: Single string or list of strings to encode |
| 87 | - normalize_embeddings: Whether to normalize embeddings (ignored for service) | |
| 88 | + normalize_embeddings: Whether to request normalized embeddings from service | |
| 88 | 89 | device: Device parameter ignored for service compatibility |
| 89 | 90 | batch_size: Batch size for processing (used for service requests) |
| 90 | 91 | |
| ... | ... | @@ -103,7 +104,7 @@ class TextEmbeddingEncoder: |
| 103 | 104 | embeddings: List[Optional[np.ndarray]] = [None] * len(sentences) |
| 104 | 105 | |
| 105 | 106 | for i, text in enumerate(sentences): |
| 106 | - cached = self._get_cached_embedding(text, "generic") | |
| 107 | + cached = self._get_cached_embedding(text, "generic", normalize_embeddings) | |
| 107 | 108 | if cached is not None: |
| 108 | 109 | embeddings[i] = cached |
| 109 | 110 | else: |
| ... | ... | @@ -115,7 +116,7 @@ class TextEmbeddingEncoder: |
| 115 | 116 | |
| 116 | 117 | # If there are uncached texts, call service |
| 117 | 118 | if uncached_texts: |
| 118 | - response_data = self._call_service(request_data) | |
| 119 | + response_data = self._call_service(request_data, normalize_embeddings=normalize_embeddings) | |
| 119 | 120 | |
| 120 | 121 | # Process response |
| 121 | 122 | for i, text in enumerate(uncached_texts): |
| ... | ... | @@ -129,7 +130,7 @@ class TextEmbeddingEncoder: |
| 129 | 130 | embedding_array = np.array(embedding, dtype=np.float32) |
| 130 | 131 | if self._is_valid_embedding(embedding_array): |
| 131 | 132 | embeddings[original_idx] = embedding_array |
| 132 | - self._set_cached_embedding(text, "generic", embedding_array) | |
| 133 | + self._set_cached_embedding(text, "generic", embedding_array, normalize_embeddings) | |
| 133 | 134 | else: |
| 134 | 135 | raise ValueError( |
| 135 | 136 | f"Invalid embedding returned from service for text index {original_idx}" |
| ... | ... | @@ -144,7 +145,8 @@ class TextEmbeddingEncoder: |
| 144 | 145 | self, |
| 145 | 146 | texts: List[str], |
| 146 | 147 | batch_size: int = 32, |
| 147 | - device: str = 'cpu' | |
| 148 | + device: str = 'cpu', | |
| 149 | + normalize_embeddings: bool = True, | |
| 148 | 150 | ) -> np.ndarray: |
| 149 | 151 | """ |
| 150 | 152 | Encode a batch of texts efficiently via network service. |
| ... | ... | @@ -157,11 +159,17 @@ class TextEmbeddingEncoder: |
| 157 | 159 | Returns: |
| 158 | 160 | numpy array of embeddings |
| 159 | 161 | """ |
| 160 | - return self.encode(texts, batch_size=batch_size, device=device) | |
| 162 | + return self.encode( | |
| 163 | + texts, | |
| 164 | + batch_size=batch_size, | |
| 165 | + device=device, | |
| 166 | + normalize_embeddings=normalize_embeddings, | |
| 167 | + ) | |
| 161 | 168 | |
| 162 | - def _get_cache_key(self, query: str, language: str) -> str: | |
| 169 | + def _get_cache_key(self, query: str, language: str, normalize_embeddings: bool = True) -> str: | |
| 163 | 170 | """Generate a cache key for the query""" |
| 164 | - return f"embedding:{language}:{query}" | |
| 171 | + norm_flag = "norm1" if normalize_embeddings else "norm0" | |
| 172 | + return f"embedding:{language}:{norm_flag}:{query}" | |
| 165 | 173 | |
| 166 | 174 | def _is_valid_embedding(self, embedding: np.ndarray) -> bool: |
| 167 | 175 | """ |
| ... | ... | @@ -184,13 +192,18 @@ class TextEmbeddingEncoder: |
| 184 | 192 | return False |
| 185 | 193 | return True |
| 186 | 194 | |
| 187 | - def _get_cached_embedding(self, query: str, language: str) -> Optional[np.ndarray]: | |
| 195 | + def _get_cached_embedding( | |
| 196 | + self, | |
| 197 | + query: str, | |
| 198 | + language: str, | |
| 199 | + normalize_embeddings: bool = True, | |
| 200 | + ) -> Optional[np.ndarray]: | |
| 188 | 201 | """Get embedding from cache if exists (with sliding expiration)""" |
| 189 | 202 | if not self.redis_client: |
| 190 | 203 | return None |
| 191 | 204 | |
| 192 | 205 | try: |
| 193 | - cache_key = self._get_cache_key(query, language) | |
| 206 | + cache_key = self._get_cache_key(query, language, normalize_embeddings) | |
| 194 | 207 | cached_data = self.redis_client.get(cache_key) |
| 195 | 208 | if cached_data: |
| 196 | 209 | embedding = pickle.loads(cached_data) |
| ... | ... | @@ -216,13 +229,19 @@ class TextEmbeddingEncoder: |
| 216 | 229 | logger.error(f"Error retrieving embedding from cache: {e}") |
| 217 | 230 | return None |
| 218 | 231 | |
| 219 | - def _set_cached_embedding(self, query: str, language: str, embedding: np.ndarray) -> bool: | |
| 232 | + def _set_cached_embedding( | |
| 233 | + self, | |
| 234 | + query: str, | |
| 235 | + language: str, | |
| 236 | + embedding: np.ndarray, | |
| 237 | + normalize_embeddings: bool = True, | |
| 238 | + ) -> bool: | |
| 220 | 239 | """Store embedding in cache""" |
| 221 | 240 | if not self.redis_client: |
| 222 | 241 | return False |
| 223 | 242 | |
| 224 | 243 | try: |
| 225 | - cache_key = self._get_cache_key(query, language) | |
| 244 | + cache_key = self._get_cache_key(query, language, normalize_embeddings) | |
| 226 | 245 | serialized_data = pickle.dumps(embedding) |
| 227 | 246 | self.redis_client.setex( |
| 228 | 247 | cache_key, | ... | ... |
restart.sh
scripts/service_ctl.sh
| ... | ... | @@ -190,7 +190,14 @@ start_one() { |
| 190 | 190 | nohup bash -lc "${cmd}" > "${lf}" 2>&1 & |
| 191 | 191 | local pid=$! |
| 192 | 192 | echo "${pid}" > "${pf}" |
| 193 | - if wait_for_health "${service}"; then | |
| 193 | + # Some services (notably reranker with vLLM backend) can take longer | |
| 194 | + # to load models / compile graphs on first start. Give them a longer | |
| 195 | + # health check window to avoid false negatives. | |
| 196 | + local retries=30 | |
| 197 | + if [ "${service}" = "reranker" ]; then | |
| 198 | + retries=90 | |
| 199 | + fi | |
| 200 | + if wait_for_health "${service}" "${retries}"; then | |
| 194 | 201 | echo "[ok] ${service} healthy (pid=${pid}, log=${lf})" |
| 195 | 202 | else |
| 196 | 203 | echo "[error] ${service} health check timeout, inspect ${lf}" >&2 |
| ... | ... | @@ -320,9 +327,14 @@ resolve_targets() { |
| 320 | 327 | if [ "${START_RERANKER:-0}" = "1" ]; then targets+=("reranker"); fi |
| 321 | 328 | echo "${targets[@]}" |
| 322 | 329 | ;; |
| 323 | - stop|restart|status) | |
| 330 | + stop|status) | |
| 324 | 331 | echo "$(all_services)" |
| 325 | 332 | ;; |
| 333 | + restart) | |
| 334 | + # Restart with no explicit services should preserve start-order dependency | |
| 335 | + # behavior (e.g. tei/cnclip before embedding). | |
| 336 | + echo "$(resolve_targets start)" | |
| 337 | + ;; | |
| 326 | 338 | *) |
| 327 | 339 | echo "" |
| 328 | 340 | ;; |
| ... | ... | @@ -340,7 +352,7 @@ Usage: |
| 340 | 352 | Default target set (when no service provided): |
| 341 | 353 | start -> backend indexer frontend (+ optional by env flags) |
| 342 | 354 | stop -> all known services |
| 343 | - restart -> all known services | |
| 355 | + restart -> stop all known services, then start with start targets | |
| 344 | 356 | status -> all known services |
| 345 | 357 | |
| 346 | 358 | Optional startup flags: |
| ... | ... | @@ -361,7 +373,13 @@ main() { |
| 361 | 373 | shift || true |
| 362 | 374 | |
| 363 | 375 | load_env_file "${PROJECT_ROOT}/.env" |
| 376 | + local stop_targets="" | |
| 364 | 377 | local targets |
| 378 | + # For restart without explicit services, stop everything first, then start | |
| 379 | + # with dependency-aware start targets. | |
| 380 | + if [ "${action}" = "restart" ] && [ "$#" -eq 0 ]; then | |
| 381 | + stop_targets="$(resolve_targets stop)" | |
| 382 | + fi | |
| 365 | 383 | targets="$(resolve_targets "${action}" "$@")" |
| 366 | 384 | if [ -z "${targets}" ]; then |
| 367 | 385 | usage |
| ... | ... | @@ -380,7 +398,7 @@ main() { |
| 380 | 398 | done |
| 381 | 399 | ;; |
| 382 | 400 | restart) |
| 383 | - for svc in ${targets}; do | |
| 401 | + for svc in ${stop_targets:-${targets}}; do | |
| 384 | 402 | stop_one "${svc}" |
| 385 | 403 | done |
| 386 | 404 | for svc in ${targets}; do | ... | ... |