diff --git a/embeddings/README.md b/embeddings/README.md index 3a1deca..b6418f1 100644 --- a/embeddings/README.md +++ b/embeddings/README.md @@ -33,8 +33,3 @@ - `TEXT_MODEL_DIR`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE` - `IMAGE_MODEL_NAME`, `IMAGE_DEVICE` -### 目录说明(旧文件) - -旧的 `vector_service/` 目录与 `*_encoder__local.py` 文件已经废弃,统一由本目录实现与维护。 - - diff --git a/embeddings/server.py b/embeddings/server.py index d26c4c5..32bac6b 100644 --- a/embeddings/server.py +++ b/embeddings/server.py @@ -6,6 +6,7 @@ API (simple list-in, list-out; aligned by index; failures -> null): - POST /embed/image body: ["url_or_path1", ...] -> [[...], null, ...] """ +import logging import threading from typing import Any, Dict, List, Optional @@ -16,38 +17,47 @@ from embeddings.config import CONFIG from embeddings.bge_model import BgeTextModel from embeddings.clip_model import ClipImageModel +logger = logging.getLogger(__name__) app = FastAPI(title="SearchEngine Embedding Service", version="1.0.0") -_text_model = None -_image_model = None - -_text_init_lock = threading.Lock() -_image_init_lock = threading.Lock() +# Models are loaded at startup, not lazily +_text_model: Optional[BgeTextModel] = None +_image_model: Optional[ClipImageModel] = None _text_encode_lock = threading.Lock() _image_encode_lock = threading.Lock() -def _get_text_model(): - global _text_model - if _text_model is None: - with _text_init_lock: - if _text_model is None: - _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR) - return _text_model +@app.on_event("startup") +def load_models(): + """Load models at service startup to avoid first-request latency.""" + global _text_model, _image_model + logger.info("Loading embedding models at startup...") -def _get_image_model(): - global _image_model - if _image_model is None: - with _image_init_lock: - if _image_model is None: - _image_model = ClipImageModel( - model_name=CONFIG.IMAGE_MODEL_NAME, - device=CONFIG.IMAGE_DEVICE, - ) - return _image_model + # Load text model + try: + logger.info(f"Loading text model: {CONFIG.TEXT_MODEL_DIR}") + _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR) + logger.info("Text model loaded successfully") + except Exception as e: + logger.error(f"Failed to load text model: {e}", exc_info=True) + raise + + # Load image model + try: + logger.info(f"Loading image model: {CONFIG.IMAGE_MODEL_NAME} (device: {CONFIG.IMAGE_DEVICE})") + _image_model = ClipImageModel( + model_name=CONFIG.IMAGE_MODEL_NAME, + device=CONFIG.IMAGE_DEVICE, + ) + logger.info("Image model loaded successfully") + except Exception as e: + logger.error(f"Failed to load image model: {e}", exc_info=True) + raise + + logger.info("All embedding models loaded successfully, service ready") def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: @@ -62,12 +72,18 @@ def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: @app.get("/health") def health() -> Dict[str, Any]: - return {"status": "ok"} + """Health check endpoint. Returns status and model loading state.""" + return { + "status": "ok", + "text_model_loaded": _text_model is not None, + "image_model_loaded": _image_model is not None, + } @app.post("/embed/text") def embed_text(texts: List[str]) -> List[Optional[List[float]]]: - model = _get_text_model() + if _text_model is None: + raise RuntimeError("Text model not loaded") out: List[Optional[List[float]]] = [None] * len(texts) indexed_texts: List[tuple] = [] @@ -87,7 +103,7 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: batch_texts = [t for _, t in indexed_texts] try: with _text_encode_lock: - embs = model.encode_batch( + embs = _text_model.encode_batch( batch_texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE ) for j, (idx, _t) in enumerate(indexed_texts): @@ -100,7 +116,8 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: @app.post("/embed/image") def embed_image(images: List[str]) -> List[Optional[List[float]]]: - model = _get_image_model() + if _image_model is None: + raise RuntimeError("Image model not loaded") out: List[Optional[List[float]]] = [None] * len(images) with _image_encode_lock: @@ -113,7 +130,7 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]: url_or_path = url_or_path.strip() if not url_or_path: continue - emb = model.encode_image_from_url(url_or_path) + emb = _image_model.encode_image_from_url(url_or_path) out[i] = _as_list(emb) except Exception: out[i] = None -- libgit2 0.21.2