Commit 200fdddf9b8cac307bb67bd9eac4bef1eb9a0e22

Authored by tangwang
1 parent 654f20d1

embed norm

embeddings/README.md
@@ -14,7 +14,7 @@ @@ -14,7 +14,7 @@
14 - **clip-as-service 客户端**:`clip_as_service_encoder.py`(图片向量,推荐) 14 - **clip-as-service 客户端**:`clip_as_service_encoder.py`(图片向量,推荐)
15 - **向量化服务(FastAPI)**:`server.py` 15 - **向量化服务(FastAPI)**:`server.py`
16 - **统一配置**:`config.py` 16 - **统一配置**:`config.py`
17 -- **接口契约**:`protocols.ImageEncoderProtocol`(图片编码统一为 `encode_image_urls(urls, batch_size)`,本地 CN-CLIP 与 clip-as-service 均实现该接口) 17 +- **接口契约**:`protocols.ImageEncoderProtocol`(图片编码统一为 `encode_image_urls(urls, batch_size, normalize_embeddings)`,本地 CN-CLIP 与 clip-as-service 均实现该接口)
18 18
19 说明:历史上的云端 embedding 试验实现(DashScope)已从主仓库移除,当前仅维护 6005 这条统一向量服务链路。 19 说明:历史上的云端 embedding 试验实现(DashScope)已从主仓库移除,当前仅维护 6005 这条统一向量服务链路。
20 20
@@ -29,10 +29,12 @@ @@ -29,10 +29,12 @@
29 29
30 - `POST /embed/text` 30 - `POST /embed/text`
31 - 入参:`["文本1", "文本2", ...]` 31 - 入参:`["文本1", "文本2", ...]`
  32 + - 可选 query 参数:`normalize=true|false`(不传则使用服务端默认)
32 - 出参:`[[...], [...], ...]`(与输入按 index 对齐,失败直接报错) 33 - 出参:`[[...], [...], ...]`(与输入按 index 对齐,失败直接报错)
33 34
34 - `POST /embed/image` 35 - `POST /embed/image`
35 - 入参:`["url或本地路径1", ...]` 36 - 入参:`["url或本地路径1", ...]`
  37 + - 可选 query 参数:`normalize=true|false`(不传则使用服务端默认)
36 - 出参:`[[...], [...], ...]`(与输入按 index 对齐,失败直接报错) 38 - 出参:`[[...], [...], ...]`(与输入按 index 对齐,失败直接报错)
37 39
38 ### 图片向量:clip-as-service(推荐) 40 ### 图片向量:clip-as-service(推荐)
@@ -77,6 +79,7 @@ TEI_USE_GPU=0 ./scripts/start_tei_service.sh @@ -77,6 +79,7 @@ TEI_USE_GPU=0 ./scripts/start_tei_service.sh
77 79
78 - `PORT`: 服务端口(默认 6005) 80 - `PORT`: 服务端口(默认 6005)
79 - `TEXT_MODEL_ID`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE`, `TEXT_NORMALIZE_EMBEDDINGS` 81 - `TEXT_MODEL_ID`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE`, `TEXT_NORMALIZE_EMBEDDINGS`
  82 +- `IMAGE_NORMALIZE_EMBEDDINGS`(默认 true)
80 - `USE_CLIP_AS_SERVICE`, `CLIP_AS_SERVICE_SERVER`:图片向量(clip-as-service) 83 - `USE_CLIP_AS_SERVICE`, `CLIP_AS_SERVICE_SERVER`:图片向量(clip-as-service)
81 - `IMAGE_MODEL_NAME`, `IMAGE_DEVICE`:本地 CN-CLIP(当 `USE_CLIP_AS_SERVICE=false` 时) 84 - `IMAGE_MODEL_NAME`, `IMAGE_DEVICE`:本地 CN-CLIP(当 `USE_CLIP_AS_SERVICE=false` 时)
82 - TEI 相关:`TEI_USE_GPU`、`TEI_VERSION`、`TEI_MAX_BATCH_TOKENS`、`TEI_MAX_CLIENT_BATCH_SIZE`、`TEI_HEALTH_TIMEOUT_SEC` 85 - TEI 相关:`TEI_USE_GPU`、`TEI_VERSION`、`TEI_MAX_BATCH_TOKENS`、`TEI_MAX_CLIENT_BATCH_SIZE`、`TEI_HEALTH_TIMEOUT_SEC`
embeddings/clip_as_service_encoder.py
@@ -79,6 +79,7 @@ class ClipAsServiceImageEncoder: @@ -79,6 +79,7 @@ class ClipAsServiceImageEncoder:
79 self, 79 self,
80 urls: List[str], 80 urls: List[str],
81 batch_size: Optional[int] = None, 81 batch_size: Optional[int] = None,
  82 + normalize_embeddings: bool = True,
82 ) -> List[np.ndarray]: 83 ) -> List[np.ndarray]:
83 """ 84 """
84 Encode a list of image URLs to vectors. 85 Encode a list of image URLs to vectors.
@@ -117,10 +118,15 @@ class ClipAsServiceImageEncoder: @@ -117,10 +118,15 @@ class ClipAsServiceImageEncoder:
117 vec = np.asarray(row, dtype=np.float32) 118 vec = np.asarray(row, dtype=np.float32)
118 if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): 119 if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
119 raise RuntimeError("clip-as-service returned invalid embedding vector") 120 raise RuntimeError("clip-as-service returned invalid embedding vector")
  121 + if normalize_embeddings:
  122 + norm = float(np.linalg.norm(vec))
  123 + if not np.isfinite(norm) or norm <= 0.0:
  124 + raise RuntimeError("clip-as-service returned zero/invalid norm vector")
  125 + vec = vec / norm
120 out.append(vec) 126 out.append(vec)
121 return out 127 return out
122 128
123 - def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: 129 + def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> Optional[np.ndarray]:
124 """Encode a single image URL. Returns 1024-dim vector or None.""" 130 """Encode a single image URL. Returns 1024-dim vector or None."""
125 - results = self.encode_image_urls([url], batch_size=1) 131 + results = self.encode_image_urls([url], batch_size=1, normalize_embeddings=normalize_embeddings)
126 return results[0] if results else None 132 return results[0] if results else None
embeddings/clip_model.py
@@ -76,25 +76,27 @@ class ClipImageModel(object): @@ -76,25 +76,27 @@ class ClipImageModel(object):
76 text_features /= text_features.norm(dim=-1, keepdim=True) 76 text_features /= text_features.norm(dim=-1, keepdim=True)
77 return text_features 77 return text_features
78 78
79 - def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: 79 + def encode_image(self, image: Image.Image, normalize_embeddings: bool = True) -> Optional[np.ndarray]:
80 if not isinstance(image, Image.Image): 80 if not isinstance(image, Image.Image):
81 raise ValueError("ClipImageModel.encode_image input must be a PIL.Image") 81 raise ValueError("ClipImageModel.encode_image input must be a PIL.Image")
82 infer_data = self.preprocess(image).unsqueeze(0).to(self.device) 82 infer_data = self.preprocess(image).unsqueeze(0).to(self.device)
83 with torch.no_grad(): 83 with torch.no_grad():
84 image_features = self.model.encode_image(infer_data) 84 image_features = self.model.encode_image(infer_data)
85 - image_features /= image_features.norm(dim=-1, keepdim=True) 85 + if normalize_embeddings:
  86 + image_features /= image_features.norm(dim=-1, keepdim=True)
86 return image_features.cpu().numpy().astype("float32")[0] 87 return image_features.cpu().numpy().astype("float32")[0]
87 88
88 - def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: 89 + def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> Optional[np.ndarray]:
89 image_data = self.download_image(url) 90 image_data = self.download_image(url)
90 image = self.validate_image(image_data) 91 image = self.validate_image(image_data)
91 image = self.preprocess_image(image) 92 image = self.preprocess_image(image)
92 - return self.encode_image(image) 93 + return self.encode_image(image, normalize_embeddings=normalize_embeddings)
93 94
94 def encode_image_urls( 95 def encode_image_urls(
95 self, 96 self,
96 urls: List[str], 97 urls: List[str],
97 batch_size: Optional[int] = None, 98 batch_size: Optional[int] = None,
  99 + normalize_embeddings: bool = True,
98 ) -> List[Optional[np.ndarray]]: 100 ) -> List[Optional[np.ndarray]]:
99 """ 101 """
100 Encode a list of image URLs to vectors. Same interface as ClipAsServiceImageEncoder. 102 Encode a list of image URLs to vectors. Same interface as ClipAsServiceImageEncoder.
@@ -106,19 +108,27 @@ class ClipImageModel(object): @@ -106,19 +108,27 @@ class ClipImageModel(object):
106 Returns: 108 Returns:
107 List of vectors (or None for failed items), same length as urls. 109 List of vectors (or None for failed items), same length as urls.
108 """ 110 """
109 - return self.encode_batch(urls, batch_size=batch_size or 8) 111 + return self.encode_batch(
  112 + urls,
  113 + batch_size=batch_size or 8,
  114 + normalize_embeddings=normalize_embeddings,
  115 + )
110 116
111 - def encode_batch(self, images: List[Union[str, Image.Image]], batch_size: int = 8) -> List[Optional[np.ndarray]]: 117 + def encode_batch(
  118 + self,
  119 + images: List[Union[str, Image.Image]],
  120 + batch_size: int = 8,
  121 + normalize_embeddings: bool = True,
  122 + ) -> List[Optional[np.ndarray]]:
112 results: List[Optional[np.ndarray]] = [] 123 results: List[Optional[np.ndarray]] = []
113 for i in range(0, len(images), batch_size): 124 for i in range(0, len(images), batch_size):
114 batch = images[i : i + batch_size] 125 batch = images[i : i + batch_size]
115 for img in batch: 126 for img in batch:
116 if isinstance(img, str): 127 if isinstance(img, str):
117 - results.append(self.encode_image_from_url(img)) 128 + results.append(self.encode_image_from_url(img, normalize_embeddings=normalize_embeddings))
118 elif isinstance(img, Image.Image): 129 elif isinstance(img, Image.Image):
119 - results.append(self.encode_image(img)) 130 + results.append(self.encode_image(img, normalize_embeddings=normalize_embeddings))
120 else: 131 else:
121 results.append(None) 132 results.append(None)
122 return results 133 return results
123 134
124 -  
embeddings/config.py
@@ -37,6 +37,7 @@ class EmbeddingConfig(object): @@ -37,6 +37,7 @@ class EmbeddingConfig(object):
37 37
38 # Service behavior 38 # Service behavior
39 IMAGE_BATCH_SIZE = 8 39 IMAGE_BATCH_SIZE = 8
  40 + IMAGE_NORMALIZE_EMBEDDINGS = os.getenv("IMAGE_NORMALIZE_EMBEDDINGS", "true").lower() in ("1", "true", "yes")
40 41
41 42
42 CONFIG = EmbeddingConfig() 43 CONFIG = EmbeddingConfig()
embeddings/image_encoder.py
@@ -26,7 +26,7 @@ class CLIPImageEncoder: @@ -26,7 +26,7 @@ class CLIPImageEncoder:
26 self.endpoint = f"{self.service_url}/embed/image" 26 self.endpoint = f"{self.service_url}/embed/image"
27 logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url) 27 logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url)
28 28
29 - def _call_service(self, request_data: List[str]) -> List[Any]: 29 + def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]:
30 """ 30 """
31 Call the embedding service API. 31 Call the embedding service API.
32 32
@@ -39,6 +39,7 @@ class CLIPImageEncoder: @@ -39,6 +39,7 @@ class CLIPImageEncoder:
39 try: 39 try:
40 response = requests.post( 40 response = requests.post(
41 self.endpoint, 41 self.endpoint,
  42 + params={"normalize": "true" if normalize_embeddings else "false"},
42 json=request_data, 43 json=request_data,
43 timeout=60 44 timeout=60
44 ) 45 )
@@ -56,7 +57,7 @@ class CLIPImageEncoder: @@ -56,7 +57,7 @@ class CLIPImageEncoder:
56 """ 57 """
57 raise NotImplementedError("encode_image with PIL Image is not supported by embedding service") 58 raise NotImplementedError("encode_image with PIL Image is not supported by embedding service")
58 59
59 - def encode_image_from_url(self, url: str) -> np.ndarray: 60 + def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> np.ndarray:
60 """ 61 """
61 Generate image embedding via network service using URL. 62 Generate image embedding via network service using URL.
62 63
@@ -66,7 +67,7 @@ class CLIPImageEncoder: @@ -66,7 +67,7 @@ class CLIPImageEncoder:
66 Returns: 67 Returns:
67 Embedding vector 68 Embedding vector
68 """ 69 """
69 - response_data = self._call_service([url]) 70 + response_data = self._call_service([url], normalize_embeddings=normalize_embeddings)
70 if not response_data or len(response_data) != 1 or response_data[0] is None: 71 if not response_data or len(response_data) != 1 or response_data[0] is None:
71 raise RuntimeError(f"No image embedding returned for URL: {url}") 72 raise RuntimeError(f"No image embedding returned for URL: {url}")
72 vec = np.array(response_data[0], dtype=np.float32) 73 vec = np.array(response_data[0], dtype=np.float32)
@@ -77,7 +78,8 @@ class CLIPImageEncoder: @@ -77,7 +78,8 @@ class CLIPImageEncoder:
77 def encode_batch( 78 def encode_batch(
78 self, 79 self,
79 images: List[Union[str, Image.Image]], 80 images: List[Union[str, Image.Image]],
80 - batch_size: int = 8 81 + batch_size: int = 8,
  82 + normalize_embeddings: bool = True,
81 ) -> List[np.ndarray]: 83 ) -> List[np.ndarray]:
82 """ 84 """
83 Encode a batch of images efficiently via network service. 85 Encode a batch of images efficiently via network service.
@@ -98,7 +100,7 @@ class CLIPImageEncoder: @@ -98,7 +100,7 @@ class CLIPImageEncoder:
98 results: List[np.ndarray] = [] 100 results: List[np.ndarray] = []
99 for i in range(0, len(images), batch_size): 101 for i in range(0, len(images), batch_size):
100 batch_urls = [str(u).strip() for u in images[i:i + batch_size]] 102 batch_urls = [str(u).strip() for u in images[i:i + batch_size]]
101 - response_data = self._call_service(batch_urls) 103 + response_data = self._call_service(batch_urls, normalize_embeddings=normalize_embeddings)
102 if not response_data or len(response_data) != len(batch_urls): 104 if not response_data or len(response_data) != len(batch_urls):
103 raise RuntimeError( 105 raise RuntimeError(
104 f"Image embedding response length mismatch: expected {len(batch_urls)}, " 106 f"Image embedding response length mismatch: expected {len(batch_urls)}, "
@@ -119,6 +121,7 @@ class CLIPImageEncoder: @@ -119,6 +121,7 @@ class CLIPImageEncoder:
119 self, 121 self,
120 urls: List[str], 122 urls: List[str],
121 batch_size: Optional[int] = None, 123 batch_size: Optional[int] = None,
  124 + normalize_embeddings: bool = True,
122 ) -> List[np.ndarray]: 125 ) -> List[np.ndarray]:
123 """ 126 """
124 与 ClipImageModel / ClipAsServiceImageEncoder 一致的接口,供索引器 document_transformer 调用。 127 与 ClipImageModel / ClipAsServiceImageEncoder 一致的接口,供索引器 document_transformer 调用。
@@ -130,4 +133,8 @@ class CLIPImageEncoder: @@ -130,4 +133,8 @@ class CLIPImageEncoder:
130 Returns: 133 Returns:
131 与 urls 等长的向量列表 134 与 urls 等长的向量列表
132 """ 135 """
133 - return self.encode_batch(urls, batch_size=batch_size or 8) 136 + return self.encode_batch(
  137 + urls,
  138 + batch_size=batch_size or 8,
  139 + normalize_embeddings=normalize_embeddings,
  140 + )
embeddings/protocols.py
@@ -17,6 +17,7 @@ class ImageEncoderProtocol(Protocol): @@ -17,6 +17,7 @@ class ImageEncoderProtocol(Protocol):
17 self, 17 self,
18 urls: List[str], 18 urls: List[str],
19 batch_size: Optional[int] = None, 19 batch_size: Optional[int] = None,
  20 + normalize_embeddings: bool = True,
20 ) -> List[Optional[np.ndarray]]: 21 ) -> List[Optional[np.ndarray]]:
21 """ 22 """
22 Encode a list of image URLs to vectors. 23 Encode a list of image URLs to vectors.
embeddings/server.py
@@ -112,14 +112,24 @@ def load_models(): @@ -112,14 +112,24 @@ def load_models():
112 logger.info("All embedding models loaded successfully, service ready") 112 logger.info("All embedding models loaded successfully, service ready")
113 113
114 114
115 -def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: 115 +def _normalize_vector(vec: np.ndarray) -> np.ndarray:
  116 + norm = float(np.linalg.norm(vec))
  117 + if not np.isfinite(norm) or norm <= 0.0:
  118 + raise RuntimeError("Embedding vector has invalid norm (must be > 0)")
  119 + return vec / norm
  120 +
  121 +
  122 +def _as_list(embedding: Optional[np.ndarray], normalize: bool = False) -> Optional[List[float]]:
116 if embedding is None: 123 if embedding is None:
117 return None 124 return None
118 if not isinstance(embedding, np.ndarray): 125 if not isinstance(embedding, np.ndarray):
119 embedding = np.array(embedding, dtype=np.float32) 126 embedding = np.array(embedding, dtype=np.float32)
120 if embedding.ndim != 1: 127 if embedding.ndim != 1:
121 embedding = embedding.reshape(-1) 128 embedding = embedding.reshape(-1)
122 - return embedding.astype(np.float32).tolist() 129 + embedding = embedding.astype(np.float32, copy=False)
  130 + if normalize:
  131 + embedding = _normalize_vector(embedding).astype(np.float32, copy=False)
  132 + return embedding.tolist()
123 133
124 134
125 @app.get("/health") 135 @app.get("/health")
@@ -134,9 +144,10 @@ def health() -&gt; Dict[str, Any]: @@ -134,9 +144,10 @@ def health() -&gt; Dict[str, Any]:
134 144
135 145
136 @app.post("/embed/text") 146 @app.post("/embed/text")
137 -def embed_text(texts: List[str]) -> List[Optional[List[float]]]: 147 +def embed_text(texts: List[str], normalize: Optional[bool] = None) -> List[Optional[List[float]]]:
138 if _text_model is None: 148 if _text_model is None:
139 raise RuntimeError("Text model not loaded") 149 raise RuntimeError("Text model not loaded")
  150 + effective_normalize = bool(CONFIG.TEXT_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize)
140 normalized: List[str] = [] 151 normalized: List[str] = []
141 for i, t in enumerate(texts): 152 for i, t in enumerate(texts):
142 if not isinstance(t, str): 153 if not isinstance(t, str):
@@ -152,7 +163,7 @@ def embed_text(texts: List[str]) -&gt; List[Optional[List[float]]]: @@ -152,7 +163,7 @@ def embed_text(texts: List[str]) -&gt; List[Optional[List[float]]]:
152 normalized, 163 normalized,
153 batch_size=int(CONFIG.TEXT_BATCH_SIZE), 164 batch_size=int(CONFIG.TEXT_BATCH_SIZE),
154 device=CONFIG.TEXT_DEVICE, 165 device=CONFIG.TEXT_DEVICE,
155 - normalize_embeddings=bool(CONFIG.TEXT_NORMALIZE_EMBEDDINGS), 166 + normalize_embeddings=effective_normalize,
156 ) 167 )
157 except Exception as e: 168 except Exception as e:
158 logger.error("Text embedding backend failure: %s", e, exc_info=True) 169 logger.error("Text embedding backend failure: %s", e, exc_info=True)
@@ -167,7 +178,7 @@ def embed_text(texts: List[str]) -&gt; List[Optional[List[float]]]: @@ -167,7 +178,7 @@ def embed_text(texts: List[str]) -&gt; List[Optional[List[float]]]:
167 ) 178 )
168 out: List[Optional[List[float]]] = [] 179 out: List[Optional[List[float]]] = []
169 for i, emb in enumerate(embs): 180 for i, emb in enumerate(embs):
170 - vec = _as_list(emb) 181 + vec = _as_list(emb, normalize=effective_normalize)
171 if vec is None: 182 if vec is None:
172 raise RuntimeError(f"Text model returned empty embedding for index {i}") 183 raise RuntimeError(f"Text model returned empty embedding for index {i}")
173 out.append(vec) 184 out.append(vec)
@@ -175,9 +186,10 @@ def embed_text(texts: List[str]) -&gt; List[Optional[List[float]]]: @@ -175,9 +186,10 @@ def embed_text(texts: List[str]) -&gt; List[Optional[List[float]]]:
175 186
176 187
177 @app.post("/embed/image") 188 @app.post("/embed/image")
178 -def embed_image(images: List[str]) -> List[Optional[List[float]]]: 189 +def embed_image(images: List[str], normalize: Optional[bool] = None) -> List[Optional[List[float]]]:
179 if _image_model is None: 190 if _image_model is None:
180 raise RuntimeError("Image model not loaded") 191 raise RuntimeError("Image model not loaded")
  192 + effective_normalize = bool(CONFIG.IMAGE_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize)
181 urls: List[str] = [] 193 urls: List[str] = []
182 for i, url_or_path in enumerate(images): 194 for i, url_or_path in enumerate(images):
183 if not isinstance(url_or_path, str): 195 if not isinstance(url_or_path, str):
@@ -188,7 +200,11 @@ def embed_image(images: List[str]) -&gt; List[Optional[List[float]]]: @@ -188,7 +200,11 @@ def embed_image(images: List[str]) -&gt; List[Optional[List[float]]]:
188 urls.append(s) 200 urls.append(s)
189 201
190 with _image_encode_lock: 202 with _image_encode_lock:
191 - vectors = _image_model.encode_image_urls(urls, batch_size=CONFIG.IMAGE_BATCH_SIZE) 203 + vectors = _image_model.encode_image_urls(
  204 + urls,
  205 + batch_size=CONFIG.IMAGE_BATCH_SIZE,
  206 + normalize_embeddings=effective_normalize,
  207 + )
192 if vectors is None or len(vectors) != len(urls): 208 if vectors is None or len(vectors) != len(urls):
193 raise RuntimeError( 209 raise RuntimeError(
194 f"Image model response length mismatch: expected {len(urls)}, " 210 f"Image model response length mismatch: expected {len(urls)}, "
@@ -196,7 +212,7 @@ def embed_image(images: List[str]) -&gt; List[Optional[List[float]]]: @@ -196,7 +212,7 @@ def embed_image(images: List[str]) -&gt; List[Optional[List[float]]]:
196 ) 212 )
197 out: List[Optional[List[float]]] = [] 213 out: List[Optional[List[float]]] = []
198 for i, vec in enumerate(vectors): 214 for i, vec in enumerate(vectors):
199 - out_vec = _as_list(vec) 215 + out_vec = _as_list(vec, normalize=effective_normalize)
200 if out_vec is None: 216 if out_vec is None:
201 raise RuntimeError(f"Image model returned empty embedding for index {i}") 217 raise RuntimeError(f"Image model returned empty embedding for index {i}")
202 out.append(out_vec) 218 out.append(out_vec)
embeddings/text_encoder.py
@@ -50,7 +50,7 @@ class TextEmbeddingEncoder: @@ -50,7 +50,7 @@ class TextEmbeddingEncoder:
50 logger.warning("Failed to initialize Redis cache for embeddings: %s, continuing without cache", e) 50 logger.warning("Failed to initialize Redis cache for embeddings: %s, continuing without cache", e)
51 self.redis_client = None 51 self.redis_client = None
52 52
53 - def _call_service(self, request_data: List[str]) -> List[Any]: 53 + def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]:
54 """ 54 """
55 Call the embedding service API. 55 Call the embedding service API.
56 56
@@ -63,6 +63,7 @@ class TextEmbeddingEncoder: @@ -63,6 +63,7 @@ class TextEmbeddingEncoder:
63 try: 63 try:
64 response = requests.post( 64 response = requests.post(
65 self.endpoint, 65 self.endpoint,
  66 + params={"normalize": "true" if normalize_embeddings else "false"},
66 json=request_data, 67 json=request_data,
67 timeout=60 68 timeout=60
68 ) 69 )
@@ -84,7 +85,7 @@ class TextEmbeddingEncoder: @@ -84,7 +85,7 @@ class TextEmbeddingEncoder:
84 85
85 Args: 86 Args:
86 sentences: Single string or list of strings to encode 87 sentences: Single string or list of strings to encode
87 - normalize_embeddings: Whether to normalize embeddings (ignored for service) 88 + normalize_embeddings: Whether to request normalized embeddings from service
88 device: Device parameter ignored for service compatibility 89 device: Device parameter ignored for service compatibility
89 batch_size: Batch size for processing (used for service requests) 90 batch_size: Batch size for processing (used for service requests)
90 91
@@ -103,7 +104,7 @@ class TextEmbeddingEncoder: @@ -103,7 +104,7 @@ class TextEmbeddingEncoder:
103 embeddings: List[Optional[np.ndarray]] = [None] * len(sentences) 104 embeddings: List[Optional[np.ndarray]] = [None] * len(sentences)
104 105
105 for i, text in enumerate(sentences): 106 for i, text in enumerate(sentences):
106 - cached = self._get_cached_embedding(text, "generic") 107 + cached = self._get_cached_embedding(text, "generic", normalize_embeddings)
107 if cached is not None: 108 if cached is not None:
108 embeddings[i] = cached 109 embeddings[i] = cached
109 else: 110 else:
@@ -115,7 +116,7 @@ class TextEmbeddingEncoder: @@ -115,7 +116,7 @@ class TextEmbeddingEncoder:
115 116
116 # If there are uncached texts, call service 117 # If there are uncached texts, call service
117 if uncached_texts: 118 if uncached_texts:
118 - response_data = self._call_service(request_data) 119 + response_data = self._call_service(request_data, normalize_embeddings=normalize_embeddings)
119 120
120 # Process response 121 # Process response
121 for i, text in enumerate(uncached_texts): 122 for i, text in enumerate(uncached_texts):
@@ -129,7 +130,7 @@ class TextEmbeddingEncoder: @@ -129,7 +130,7 @@ class TextEmbeddingEncoder:
129 embedding_array = np.array(embedding, dtype=np.float32) 130 embedding_array = np.array(embedding, dtype=np.float32)
130 if self._is_valid_embedding(embedding_array): 131 if self._is_valid_embedding(embedding_array):
131 embeddings[original_idx] = embedding_array 132 embeddings[original_idx] = embedding_array
132 - self._set_cached_embedding(text, "generic", embedding_array) 133 + self._set_cached_embedding(text, "generic", embedding_array, normalize_embeddings)
133 else: 134 else:
134 raise ValueError( 135 raise ValueError(
135 f"Invalid embedding returned from service for text index {original_idx}" 136 f"Invalid embedding returned from service for text index {original_idx}"
@@ -144,7 +145,8 @@ class TextEmbeddingEncoder: @@ -144,7 +145,8 @@ class TextEmbeddingEncoder:
144 self, 145 self,
145 texts: List[str], 146 texts: List[str],
146 batch_size: int = 32, 147 batch_size: int = 32,
147 - device: str = 'cpu' 148 + device: str = 'cpu',
  149 + normalize_embeddings: bool = True,
148 ) -> np.ndarray: 150 ) -> np.ndarray:
149 """ 151 """
150 Encode a batch of texts efficiently via network service. 152 Encode a batch of texts efficiently via network service.
@@ -157,11 +159,17 @@ class TextEmbeddingEncoder: @@ -157,11 +159,17 @@ class TextEmbeddingEncoder:
157 Returns: 159 Returns:
158 numpy array of embeddings 160 numpy array of embeddings
159 """ 161 """
160 - return self.encode(texts, batch_size=batch_size, device=device) 162 + return self.encode(
  163 + texts,
  164 + batch_size=batch_size,
  165 + device=device,
  166 + normalize_embeddings=normalize_embeddings,
  167 + )
161 168
162 - def _get_cache_key(self, query: str, language: str) -> str: 169 + def _get_cache_key(self, query: str, language: str, normalize_embeddings: bool = True) -> str:
163 """Generate a cache key for the query""" 170 """Generate a cache key for the query"""
164 - return f"embedding:{language}:{query}" 171 + norm_flag = "norm1" if normalize_embeddings else "norm0"
  172 + return f"embedding:{language}:{norm_flag}:{query}"
165 173
166 def _is_valid_embedding(self, embedding: np.ndarray) -> bool: 174 def _is_valid_embedding(self, embedding: np.ndarray) -> bool:
167 """ 175 """
@@ -184,13 +192,18 @@ class TextEmbeddingEncoder: @@ -184,13 +192,18 @@ class TextEmbeddingEncoder:
184 return False 192 return False
185 return True 193 return True
186 194
187 - def _get_cached_embedding(self, query: str, language: str) -> Optional[np.ndarray]: 195 + def _get_cached_embedding(
  196 + self,
  197 + query: str,
  198 + language: str,
  199 + normalize_embeddings: bool = True,
  200 + ) -> Optional[np.ndarray]:
188 """Get embedding from cache if exists (with sliding expiration)""" 201 """Get embedding from cache if exists (with sliding expiration)"""
189 if not self.redis_client: 202 if not self.redis_client:
190 return None 203 return None
191 204
192 try: 205 try:
193 - cache_key = self._get_cache_key(query, language) 206 + cache_key = self._get_cache_key(query, language, normalize_embeddings)
194 cached_data = self.redis_client.get(cache_key) 207 cached_data = self.redis_client.get(cache_key)
195 if cached_data: 208 if cached_data:
196 embedding = pickle.loads(cached_data) 209 embedding = pickle.loads(cached_data)
@@ -216,13 +229,19 @@ class TextEmbeddingEncoder: @@ -216,13 +229,19 @@ class TextEmbeddingEncoder:
216 logger.error(f"Error retrieving embedding from cache: {e}") 229 logger.error(f"Error retrieving embedding from cache: {e}")
217 return None 230 return None
218 231
219 - def _set_cached_embedding(self, query: str, language: str, embedding: np.ndarray) -> bool: 232 + def _set_cached_embedding(
  233 + self,
  234 + query: str,
  235 + language: str,
  236 + embedding: np.ndarray,
  237 + normalize_embeddings: bool = True,
  238 + ) -> bool:
220 """Store embedding in cache""" 239 """Store embedding in cache"""
221 if not self.redis_client: 240 if not self.redis_client:
222 return False 241 return False
223 242
224 try: 243 try:
225 - cache_key = self._get_cache_key(query, language) 244 + cache_key = self._get_cache_key(query, language, normalize_embeddings)
226 serialized_data = pickle.dumps(embedding) 245 serialized_data = pickle.dumps(embedding)
227 self.redis_client.setex( 246 self.redis_client.setex(
228 cache_key, 247 cache_key,
@@ -4,4 +4,7 @@ @@ -4,4 +4,7 @@
4 4
5 cd "$(dirname "$0")" 5 cd "$(dirname "$0")"
6 6
7 -./scripts/service_ctl.sh restart 7 +START_EMBEDDING=1 START_TRANSLATOR=1 START_RERANKER=1 START_TEI=1 CNCLIP_DEVICE=cuda TEI_USE_GPU=1 ./scripts/service_ctl.sh restart
  8 +
  9 +# ./scripts/service_ctl.sh restart
  10 +
scripts/service_ctl.sh
@@ -190,7 +190,14 @@ start_one() { @@ -190,7 +190,14 @@ start_one() {
190 nohup bash -lc "${cmd}" > "${lf}" 2>&1 & 190 nohup bash -lc "${cmd}" > "${lf}" 2>&1 &
191 local pid=$! 191 local pid=$!
192 echo "${pid}" > "${pf}" 192 echo "${pid}" > "${pf}"
193 - if wait_for_health "${service}"; then 193 + # Some services (notably reranker with vLLM backend) can take longer
  194 + # to load models / compile graphs on first start. Give them a longer
  195 + # health check window to avoid false negatives.
  196 + local retries=30
  197 + if [ "${service}" = "reranker" ]; then
  198 + retries=90
  199 + fi
  200 + if wait_for_health "${service}" "${retries}"; then
194 echo "[ok] ${service} healthy (pid=${pid}, log=${lf})" 201 echo "[ok] ${service} healthy (pid=${pid}, log=${lf})"
195 else 202 else
196 echo "[error] ${service} health check timeout, inspect ${lf}" >&2 203 echo "[error] ${service} health check timeout, inspect ${lf}" >&2
@@ -320,9 +327,14 @@ resolve_targets() { @@ -320,9 +327,14 @@ resolve_targets() {
320 if [ "${START_RERANKER:-0}" = "1" ]; then targets+=("reranker"); fi 327 if [ "${START_RERANKER:-0}" = "1" ]; then targets+=("reranker"); fi
321 echo "${targets[@]}" 328 echo "${targets[@]}"
322 ;; 329 ;;
323 - stop|restart|status) 330 + stop|status)
324 echo "$(all_services)" 331 echo "$(all_services)"
325 ;; 332 ;;
  333 + restart)
  334 + # Restart with no explicit services should preserve start-order dependency
  335 + # behavior (e.g. tei/cnclip before embedding).
  336 + echo "$(resolve_targets start)"
  337 + ;;
326 *) 338 *)
327 echo "" 339 echo ""
328 ;; 340 ;;
@@ -340,7 +352,7 @@ Usage: @@ -340,7 +352,7 @@ Usage:
340 Default target set (when no service provided): 352 Default target set (when no service provided):
341 start -> backend indexer frontend (+ optional by env flags) 353 start -> backend indexer frontend (+ optional by env flags)
342 stop -> all known services 354 stop -> all known services
343 - restart -> all known services 355 + restart -> stop all known services, then start with start targets
344 status -> all known services 356 status -> all known services
345 357
346 Optional startup flags: 358 Optional startup flags:
@@ -361,7 +373,13 @@ main() { @@ -361,7 +373,13 @@ main() {
361 shift || true 373 shift || true
362 374
363 load_env_file "${PROJECT_ROOT}/.env" 375 load_env_file "${PROJECT_ROOT}/.env"
  376 + local stop_targets=""
364 local targets 377 local targets
  378 + # For restart without explicit services, stop everything first, then start
  379 + # with dependency-aware start targets.
  380 + if [ "${action}" = "restart" ] && [ "$#" -eq 0 ]; then
  381 + stop_targets="$(resolve_targets stop)"
  382 + fi
365 targets="$(resolve_targets "${action}" "$@")" 383 targets="$(resolve_targets "${action}" "$@")"
366 if [ -z "${targets}" ]; then 384 if [ -z "${targets}" ]; then
367 usage 385 usage
@@ -380,7 +398,7 @@ main() { @@ -380,7 +398,7 @@ main() {
380 done 398 done
381 ;; 399 ;;
382 restart) 400 restart)
383 - for svc in ${targets}; do 401 + for svc in ${stop_targets:-${targets}}; do
384 stop_one "${svc}" 402 stop_one "${svc}"
385 done 403 done
386 for svc in ${targets}; do 404 for svc in ${targets}; do