From 200fdddf9b8cac307bb67bd9eac4bef1eb9a0e22 Mon Sep 17 00:00:00 2001 From: tangwang Date: Tue, 10 Mar 2026 17:56:28 +0800 Subject: [PATCH] embed norm --- embeddings/README.md | 5 ++++- embeddings/clip_as_service_encoder.py | 10 ++++++++-- embeddings/clip_model.py | 28 +++++++++++++++++++--------- embeddings/config.py | 1 + embeddings/image_encoder.py | 19 +++++++++++++------ embeddings/protocols.py | 1 + embeddings/server.py | 32 ++++++++++++++++++++++++-------- embeddings/text_encoder.py | 45 ++++++++++++++++++++++++++++++++------------- restart.sh | 5 ++++- scripts/service_ctl.sh | 26 ++++++++++++++++++++++---- 10 files changed, 128 insertions(+), 44 deletions(-) diff --git a/embeddings/README.md b/embeddings/README.md index ddb35bb..6c3bb9b 100644 --- a/embeddings/README.md +++ b/embeddings/README.md @@ -14,7 +14,7 @@ - **clip-as-service 客户端**:`clip_as_service_encoder.py`(图片向量,推荐) - **向量化服务(FastAPI)**:`server.py` - **统一配置**:`config.py` -- **接口契约**:`protocols.ImageEncoderProtocol`(图片编码统一为 `encode_image_urls(urls, batch_size)`,本地 CN-CLIP 与 clip-as-service 均实现该接口) +- **接口契约**:`protocols.ImageEncoderProtocol`(图片编码统一为 `encode_image_urls(urls, batch_size, normalize_embeddings)`,本地 CN-CLIP 与 clip-as-service 均实现该接口) 说明:历史上的云端 embedding 试验实现(DashScope)已从主仓库移除,当前仅维护 6005 这条统一向量服务链路。 @@ -29,10 +29,12 @@ - `POST /embed/text` - 入参:`["文本1", "文本2", ...]` + - 可选 query 参数:`normalize=true|false`(不传则使用服务端默认) - 出参:`[[...], [...], ...]`(与输入按 index 对齐,失败直接报错) - `POST /embed/image` - 入参:`["url或本地路径1", ...]` + - 可选 query 参数:`normalize=true|false`(不传则使用服务端默认) - 出参:`[[...], [...], ...]`(与输入按 index 对齐,失败直接报错) ### 图片向量:clip-as-service(推荐) @@ -77,6 +79,7 @@ TEI_USE_GPU=0 ./scripts/start_tei_service.sh - `PORT`: 服务端口(默认 6005) - `TEXT_MODEL_ID`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE`, `TEXT_NORMALIZE_EMBEDDINGS` +- `IMAGE_NORMALIZE_EMBEDDINGS`(默认 true) - `USE_CLIP_AS_SERVICE`, `CLIP_AS_SERVICE_SERVER`:图片向量(clip-as-service) - `IMAGE_MODEL_NAME`, `IMAGE_DEVICE`:本地 CN-CLIP(当 `USE_CLIP_AS_SERVICE=false` 时) - TEI 相关:`TEI_USE_GPU`、`TEI_VERSION`、`TEI_MAX_BATCH_TOKENS`、`TEI_MAX_CLIENT_BATCH_SIZE`、`TEI_HEALTH_TIMEOUT_SEC` diff --git a/embeddings/clip_as_service_encoder.py b/embeddings/clip_as_service_encoder.py index 564783f..2837067 100644 --- a/embeddings/clip_as_service_encoder.py +++ b/embeddings/clip_as_service_encoder.py @@ -79,6 +79,7 @@ class ClipAsServiceImageEncoder: self, urls: List[str], batch_size: Optional[int] = None, + normalize_embeddings: bool = True, ) -> List[np.ndarray]: """ Encode a list of image URLs to vectors. @@ -117,10 +118,15 @@ class ClipAsServiceImageEncoder: vec = np.asarray(row, dtype=np.float32) if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): raise RuntimeError("clip-as-service returned invalid embedding vector") + if normalize_embeddings: + norm = float(np.linalg.norm(vec)) + if not np.isfinite(norm) or norm <= 0.0: + raise RuntimeError("clip-as-service returned zero/invalid norm vector") + vec = vec / norm out.append(vec) return out - def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: + def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> Optional[np.ndarray]: """Encode a single image URL. Returns 1024-dim vector or None.""" - results = self.encode_image_urls([url], batch_size=1) + results = self.encode_image_urls([url], batch_size=1, normalize_embeddings=normalize_embeddings) return results[0] if results else None diff --git a/embeddings/clip_model.py b/embeddings/clip_model.py index 9beb210..835fd14 100644 --- a/embeddings/clip_model.py +++ b/embeddings/clip_model.py @@ -76,25 +76,27 @@ class ClipImageModel(object): text_features /= text_features.norm(dim=-1, keepdim=True) return text_features - def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: + def encode_image(self, image: Image.Image, normalize_embeddings: bool = True) -> Optional[np.ndarray]: if not isinstance(image, Image.Image): raise ValueError("ClipImageModel.encode_image input must be a PIL.Image") infer_data = self.preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(): image_features = self.model.encode_image(infer_data) - image_features /= image_features.norm(dim=-1, keepdim=True) + if normalize_embeddings: + image_features /= image_features.norm(dim=-1, keepdim=True) return image_features.cpu().numpy().astype("float32")[0] - def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: + def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> Optional[np.ndarray]: image_data = self.download_image(url) image = self.validate_image(image_data) image = self.preprocess_image(image) - return self.encode_image(image) + return self.encode_image(image, normalize_embeddings=normalize_embeddings) def encode_image_urls( self, urls: List[str], batch_size: Optional[int] = None, + normalize_embeddings: bool = True, ) -> List[Optional[np.ndarray]]: """ Encode a list of image URLs to vectors. Same interface as ClipAsServiceImageEncoder. @@ -106,19 +108,27 @@ class ClipImageModel(object): Returns: List of vectors (or None for failed items), same length as urls. """ - return self.encode_batch(urls, batch_size=batch_size or 8) + return self.encode_batch( + urls, + batch_size=batch_size or 8, + normalize_embeddings=normalize_embeddings, + ) - def encode_batch(self, images: List[Union[str, Image.Image]], batch_size: int = 8) -> List[Optional[np.ndarray]]: + def encode_batch( + self, + images: List[Union[str, Image.Image]], + batch_size: int = 8, + normalize_embeddings: bool = True, + ) -> List[Optional[np.ndarray]]: results: List[Optional[np.ndarray]] = [] for i in range(0, len(images), batch_size): batch = images[i : i + batch_size] for img in batch: if isinstance(img, str): - results.append(self.encode_image_from_url(img)) + results.append(self.encode_image_from_url(img, normalize_embeddings=normalize_embeddings)) elif isinstance(img, Image.Image): - results.append(self.encode_image(img)) + results.append(self.encode_image(img, normalize_embeddings=normalize_embeddings)) else: results.append(None) return results - diff --git a/embeddings/config.py b/embeddings/config.py index d02f79f..1df70aa 100644 --- a/embeddings/config.py +++ b/embeddings/config.py @@ -37,6 +37,7 @@ class EmbeddingConfig(object): # Service behavior IMAGE_BATCH_SIZE = 8 + IMAGE_NORMALIZE_EMBEDDINGS = os.getenv("IMAGE_NORMALIZE_EMBEDDINGS", "true").lower() in ("1", "true", "yes") CONFIG = EmbeddingConfig() diff --git a/embeddings/image_encoder.py b/embeddings/image_encoder.py index 3e4b6d7..728c184 100644 --- a/embeddings/image_encoder.py +++ b/embeddings/image_encoder.py @@ -26,7 +26,7 @@ class CLIPImageEncoder: self.endpoint = f"{self.service_url}/embed/image" logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url) - def _call_service(self, request_data: List[str]) -> List[Any]: + def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: """ Call the embedding service API. @@ -39,6 +39,7 @@ class CLIPImageEncoder: try: response = requests.post( self.endpoint, + params={"normalize": "true" if normalize_embeddings else "false"}, json=request_data, timeout=60 ) @@ -56,7 +57,7 @@ class CLIPImageEncoder: """ raise NotImplementedError("encode_image with PIL Image is not supported by embedding service") - def encode_image_from_url(self, url: str) -> np.ndarray: + def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> np.ndarray: """ Generate image embedding via network service using URL. @@ -66,7 +67,7 @@ class CLIPImageEncoder: Returns: Embedding vector """ - response_data = self._call_service([url]) + response_data = self._call_service([url], normalize_embeddings=normalize_embeddings) if not response_data or len(response_data) != 1 or response_data[0] is None: raise RuntimeError(f"No image embedding returned for URL: {url}") vec = np.array(response_data[0], dtype=np.float32) @@ -77,7 +78,8 @@ class CLIPImageEncoder: def encode_batch( self, images: List[Union[str, Image.Image]], - batch_size: int = 8 + batch_size: int = 8, + normalize_embeddings: bool = True, ) -> List[np.ndarray]: """ Encode a batch of images efficiently via network service. @@ -98,7 +100,7 @@ class CLIPImageEncoder: results: List[np.ndarray] = [] for i in range(0, len(images), batch_size): batch_urls = [str(u).strip() for u in images[i:i + batch_size]] - response_data = self._call_service(batch_urls) + response_data = self._call_service(batch_urls, normalize_embeddings=normalize_embeddings) if not response_data or len(response_data) != len(batch_urls): raise RuntimeError( f"Image embedding response length mismatch: expected {len(batch_urls)}, " @@ -119,6 +121,7 @@ class CLIPImageEncoder: self, urls: List[str], batch_size: Optional[int] = None, + normalize_embeddings: bool = True, ) -> List[np.ndarray]: """ 与 ClipImageModel / ClipAsServiceImageEncoder 一致的接口,供索引器 document_transformer 调用。 @@ -130,4 +133,8 @@ class CLIPImageEncoder: Returns: 与 urls 等长的向量列表 """ - return self.encode_batch(urls, batch_size=batch_size or 8) + return self.encode_batch( + urls, + batch_size=batch_size or 8, + normalize_embeddings=normalize_embeddings, + ) diff --git a/embeddings/protocols.py b/embeddings/protocols.py index c9071c6..8b7d8ac 100644 --- a/embeddings/protocols.py +++ b/embeddings/protocols.py @@ -17,6 +17,7 @@ class ImageEncoderProtocol(Protocol): self, urls: List[str], batch_size: Optional[int] = None, + normalize_embeddings: bool = True, ) -> List[Optional[np.ndarray]]: """ Encode a list of image URLs to vectors. diff --git a/embeddings/server.py b/embeddings/server.py index a1cdab4..ee16d04 100644 --- a/embeddings/server.py +++ b/embeddings/server.py @@ -112,14 +112,24 @@ def load_models(): logger.info("All embedding models loaded successfully, service ready") -def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: +def _normalize_vector(vec: np.ndarray) -> np.ndarray: + norm = float(np.linalg.norm(vec)) + if not np.isfinite(norm) or norm <= 0.0: + raise RuntimeError("Embedding vector has invalid norm (must be > 0)") + return vec / norm + + +def _as_list(embedding: Optional[np.ndarray], normalize: bool = False) -> Optional[List[float]]: if embedding is None: return None if not isinstance(embedding, np.ndarray): embedding = np.array(embedding, dtype=np.float32) if embedding.ndim != 1: embedding = embedding.reshape(-1) - return embedding.astype(np.float32).tolist() + embedding = embedding.astype(np.float32, copy=False) + if normalize: + embedding = _normalize_vector(embedding).astype(np.float32, copy=False) + return embedding.tolist() @app.get("/health") @@ -134,9 +144,10 @@ def health() -> Dict[str, Any]: @app.post("/embed/text") -def embed_text(texts: List[str]) -> List[Optional[List[float]]]: +def embed_text(texts: List[str], normalize: Optional[bool] = None) -> List[Optional[List[float]]]: if _text_model is None: raise RuntimeError("Text model not loaded") + effective_normalize = bool(CONFIG.TEXT_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) normalized: List[str] = [] for i, t in enumerate(texts): if not isinstance(t, str): @@ -152,7 +163,7 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: normalized, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE, - normalize_embeddings=bool(CONFIG.TEXT_NORMALIZE_EMBEDDINGS), + normalize_embeddings=effective_normalize, ) except Exception as e: logger.error("Text embedding backend failure: %s", e, exc_info=True) @@ -167,7 +178,7 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: ) out: List[Optional[List[float]]] = [] for i, emb in enumerate(embs): - vec = _as_list(emb) + vec = _as_list(emb, normalize=effective_normalize) if vec is None: raise RuntimeError(f"Text model returned empty embedding for index {i}") out.append(vec) @@ -175,9 +186,10 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: @app.post("/embed/image") -def embed_image(images: List[str]) -> List[Optional[List[float]]]: +def embed_image(images: List[str], normalize: Optional[bool] = None) -> List[Optional[List[float]]]: if _image_model is None: raise RuntimeError("Image model not loaded") + effective_normalize = bool(CONFIG.IMAGE_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) urls: List[str] = [] for i, url_or_path in enumerate(images): if not isinstance(url_or_path, str): @@ -188,7 +200,11 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]: urls.append(s) with _image_encode_lock: - vectors = _image_model.encode_image_urls(urls, batch_size=CONFIG.IMAGE_BATCH_SIZE) + vectors = _image_model.encode_image_urls( + urls, + batch_size=CONFIG.IMAGE_BATCH_SIZE, + normalize_embeddings=effective_normalize, + ) if vectors is None or len(vectors) != len(urls): raise RuntimeError( f"Image model response length mismatch: expected {len(urls)}, " @@ -196,7 +212,7 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]: ) out: List[Optional[List[float]]] = [] for i, vec in enumerate(vectors): - out_vec = _as_list(vec) + out_vec = _as_list(vec, normalize=effective_normalize) if out_vec is None: raise RuntimeError(f"Image model returned empty embedding for index {i}") out.append(out_vec) diff --git a/embeddings/text_encoder.py b/embeddings/text_encoder.py index bfcfc52..acf54fc 100644 --- a/embeddings/text_encoder.py +++ b/embeddings/text_encoder.py @@ -50,7 +50,7 @@ class TextEmbeddingEncoder: logger.warning("Failed to initialize Redis cache for embeddings: %s, continuing without cache", e) self.redis_client = None - def _call_service(self, request_data: List[str]) -> List[Any]: + def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: """ Call the embedding service API. @@ -63,6 +63,7 @@ class TextEmbeddingEncoder: try: response = requests.post( self.endpoint, + params={"normalize": "true" if normalize_embeddings else "false"}, json=request_data, timeout=60 ) @@ -84,7 +85,7 @@ class TextEmbeddingEncoder: Args: sentences: Single string or list of strings to encode - normalize_embeddings: Whether to normalize embeddings (ignored for service) + normalize_embeddings: Whether to request normalized embeddings from service device: Device parameter ignored for service compatibility batch_size: Batch size for processing (used for service requests) @@ -103,7 +104,7 @@ class TextEmbeddingEncoder: embeddings: List[Optional[np.ndarray]] = [None] * len(sentences) for i, text in enumerate(sentences): - cached = self._get_cached_embedding(text, "generic") + cached = self._get_cached_embedding(text, "generic", normalize_embeddings) if cached is not None: embeddings[i] = cached else: @@ -115,7 +116,7 @@ class TextEmbeddingEncoder: # If there are uncached texts, call service if uncached_texts: - response_data = self._call_service(request_data) + response_data = self._call_service(request_data, normalize_embeddings=normalize_embeddings) # Process response for i, text in enumerate(uncached_texts): @@ -129,7 +130,7 @@ class TextEmbeddingEncoder: embedding_array = np.array(embedding, dtype=np.float32) if self._is_valid_embedding(embedding_array): embeddings[original_idx] = embedding_array - self._set_cached_embedding(text, "generic", embedding_array) + self._set_cached_embedding(text, "generic", embedding_array, normalize_embeddings) else: raise ValueError( f"Invalid embedding returned from service for text index {original_idx}" @@ -144,7 +145,8 @@ class TextEmbeddingEncoder: self, texts: List[str], batch_size: int = 32, - device: str = 'cpu' + device: str = 'cpu', + normalize_embeddings: bool = True, ) -> np.ndarray: """ Encode a batch of texts efficiently via network service. @@ -157,11 +159,17 @@ class TextEmbeddingEncoder: Returns: numpy array of embeddings """ - return self.encode(texts, batch_size=batch_size, device=device) + return self.encode( + texts, + batch_size=batch_size, + device=device, + normalize_embeddings=normalize_embeddings, + ) - def _get_cache_key(self, query: str, language: str) -> str: + def _get_cache_key(self, query: str, language: str, normalize_embeddings: bool = True) -> str: """Generate a cache key for the query""" - return f"embedding:{language}:{query}" + norm_flag = "norm1" if normalize_embeddings else "norm0" + return f"embedding:{language}:{norm_flag}:{query}" def _is_valid_embedding(self, embedding: np.ndarray) -> bool: """ @@ -184,13 +192,18 @@ class TextEmbeddingEncoder: return False return True - def _get_cached_embedding(self, query: str, language: str) -> Optional[np.ndarray]: + def _get_cached_embedding( + self, + query: str, + language: str, + normalize_embeddings: bool = True, + ) -> Optional[np.ndarray]: """Get embedding from cache if exists (with sliding expiration)""" if not self.redis_client: return None try: - cache_key = self._get_cache_key(query, language) + cache_key = self._get_cache_key(query, language, normalize_embeddings) cached_data = self.redis_client.get(cache_key) if cached_data: embedding = pickle.loads(cached_data) @@ -216,13 +229,19 @@ class TextEmbeddingEncoder: logger.error(f"Error retrieving embedding from cache: {e}") return None - def _set_cached_embedding(self, query: str, language: str, embedding: np.ndarray) -> bool: + def _set_cached_embedding( + self, + query: str, + language: str, + embedding: np.ndarray, + normalize_embeddings: bool = True, + ) -> bool: """Store embedding in cache""" if not self.redis_client: return False try: - cache_key = self._get_cache_key(query, language) + cache_key = self._get_cache_key(query, language, normalize_embeddings) serialized_data = pickle.dumps(embedding) self.redis_client.setex( cache_key, diff --git a/restart.sh b/restart.sh index cf2d62a..d513584 100755 --- a/restart.sh +++ b/restart.sh @@ -4,4 +4,7 @@ cd "$(dirname "$0")" -./scripts/service_ctl.sh restart +START_EMBEDDING=1 START_TRANSLATOR=1 START_RERANKER=1 START_TEI=1 CNCLIP_DEVICE=cuda TEI_USE_GPU=1 ./scripts/service_ctl.sh restart + +# ./scripts/service_ctl.sh restart + diff --git a/scripts/service_ctl.sh b/scripts/service_ctl.sh index ae629d5..f2734e4 100755 --- a/scripts/service_ctl.sh +++ b/scripts/service_ctl.sh @@ -190,7 +190,14 @@ start_one() { nohup bash -lc "${cmd}" > "${lf}" 2>&1 & local pid=$! echo "${pid}" > "${pf}" - if wait_for_health "${service}"; then + # Some services (notably reranker with vLLM backend) can take longer + # to load models / compile graphs on first start. Give them a longer + # health check window to avoid false negatives. + local retries=30 + if [ "${service}" = "reranker" ]; then + retries=90 + fi + if wait_for_health "${service}" "${retries}"; then echo "[ok] ${service} healthy (pid=${pid}, log=${lf})" else echo "[error] ${service} health check timeout, inspect ${lf}" >&2 @@ -320,9 +327,14 @@ resolve_targets() { if [ "${START_RERANKER:-0}" = "1" ]; then targets+=("reranker"); fi echo "${targets[@]}" ;; - stop|restart|status) + stop|status) echo "$(all_services)" ;; + restart) + # Restart with no explicit services should preserve start-order dependency + # behavior (e.g. tei/cnclip before embedding). + echo "$(resolve_targets start)" + ;; *) echo "" ;; @@ -340,7 +352,7 @@ Usage: Default target set (when no service provided): start -> backend indexer frontend (+ optional by env flags) stop -> all known services - restart -> all known services + restart -> stop all known services, then start with start targets status -> all known services Optional startup flags: @@ -361,7 +373,13 @@ main() { shift || true load_env_file "${PROJECT_ROOT}/.env" + local stop_targets="" local targets + # For restart without explicit services, stop everything first, then start + # with dependency-aware start targets. + if [ "${action}" = "restart" ] && [ "$#" -eq 0 ]; then + stop_targets="$(resolve_targets stop)" + fi targets="$(resolve_targets "${action}" "$@")" if [ -z "${targets}" ]; then usage @@ -380,7 +398,7 @@ main() { done ;; restart) - for svc in ${targets}; do + for svc in ${stop_targets:-${targets}}; do stop_one "${svc}" done for svc in ${targets}; do -- libgit2 0.21.2