Commit 0a3764c437d6bddab876e9bba9923aa3012dfd91
1 parent
7bfb9946
优化embedding模型加载
Showing
2 changed files
with
44 additions
and
32 deletions
Show diff stats
embeddings/README.md
| @@ -33,8 +33,3 @@ | @@ -33,8 +33,3 @@ | ||
| 33 | - `TEXT_MODEL_DIR`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE` | 33 | - `TEXT_MODEL_DIR`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE` |
| 34 | - `IMAGE_MODEL_NAME`, `IMAGE_DEVICE` | 34 | - `IMAGE_MODEL_NAME`, `IMAGE_DEVICE` |
| 35 | 35 | ||
| 36 | -### 目录说明(旧文件) | ||
| 37 | - | ||
| 38 | -旧的 `vector_service/` 目录与 `*_encoder__local.py` 文件已经废弃,统一由本目录实现与维护。 | ||
| 39 | - | ||
| 40 | - |
embeddings/server.py
| @@ -6,6 +6,7 @@ API (simple list-in, list-out; aligned by index; failures -> null): | @@ -6,6 +6,7 @@ API (simple list-in, list-out; aligned by index; failures -> null): | ||
| 6 | - POST /embed/image body: ["url_or_path1", ...] -> [[...], null, ...] | 6 | - POST /embed/image body: ["url_or_path1", ...] -> [[...], null, ...] |
| 7 | """ | 7 | """ |
| 8 | 8 | ||
| 9 | +import logging | ||
| 9 | import threading | 10 | import threading |
| 10 | from typing import Any, Dict, List, Optional | 11 | from typing import Any, Dict, List, Optional |
| 11 | 12 | ||
| @@ -16,38 +17,47 @@ from embeddings.config import CONFIG | @@ -16,38 +17,47 @@ from embeddings.config import CONFIG | ||
| 16 | from embeddings.bge_model import BgeTextModel | 17 | from embeddings.bge_model import BgeTextModel |
| 17 | from embeddings.clip_model import ClipImageModel | 18 | from embeddings.clip_model import ClipImageModel |
| 18 | 19 | ||
| 20 | +logger = logging.getLogger(__name__) | ||
| 19 | 21 | ||
| 20 | app = FastAPI(title="SearchEngine Embedding Service", version="1.0.0") | 22 | app = FastAPI(title="SearchEngine Embedding Service", version="1.0.0") |
| 21 | 23 | ||
| 22 | -_text_model = None | ||
| 23 | -_image_model = None | ||
| 24 | - | ||
| 25 | -_text_init_lock = threading.Lock() | ||
| 26 | -_image_init_lock = threading.Lock() | 24 | +# Models are loaded at startup, not lazily |
| 25 | +_text_model: Optional[BgeTextModel] = None | ||
| 26 | +_image_model: Optional[ClipImageModel] = None | ||
| 27 | 27 | ||
| 28 | _text_encode_lock = threading.Lock() | 28 | _text_encode_lock = threading.Lock() |
| 29 | _image_encode_lock = threading.Lock() | 29 | _image_encode_lock = threading.Lock() |
| 30 | 30 | ||
| 31 | 31 | ||
| 32 | -def _get_text_model(): | ||
| 33 | - global _text_model | ||
| 34 | - if _text_model is None: | ||
| 35 | - with _text_init_lock: | ||
| 36 | - if _text_model is None: | ||
| 37 | - _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR) | ||
| 38 | - return _text_model | 32 | +@app.on_event("startup") |
| 33 | +def load_models(): | ||
| 34 | + """Load models at service startup to avoid first-request latency.""" | ||
| 35 | + global _text_model, _image_model | ||
| 39 | 36 | ||
| 37 | + logger.info("Loading embedding models at startup...") | ||
| 40 | 38 | ||
| 41 | -def _get_image_model(): | ||
| 42 | - global _image_model | ||
| 43 | - if _image_model is None: | ||
| 44 | - with _image_init_lock: | ||
| 45 | - if _image_model is None: | ||
| 46 | - _image_model = ClipImageModel( | ||
| 47 | - model_name=CONFIG.IMAGE_MODEL_NAME, | ||
| 48 | - device=CONFIG.IMAGE_DEVICE, | ||
| 49 | - ) | ||
| 50 | - return _image_model | 39 | + # Load text model |
| 40 | + try: | ||
| 41 | + logger.info(f"Loading text model: {CONFIG.TEXT_MODEL_DIR}") | ||
| 42 | + _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR) | ||
| 43 | + logger.info("Text model loaded successfully") | ||
| 44 | + except Exception as e: | ||
| 45 | + logger.error(f"Failed to load text model: {e}", exc_info=True) | ||
| 46 | + raise | ||
| 47 | + | ||
| 48 | + # Load image model | ||
| 49 | + try: | ||
| 50 | + logger.info(f"Loading image model: {CONFIG.IMAGE_MODEL_NAME} (device: {CONFIG.IMAGE_DEVICE})") | ||
| 51 | + _image_model = ClipImageModel( | ||
| 52 | + model_name=CONFIG.IMAGE_MODEL_NAME, | ||
| 53 | + device=CONFIG.IMAGE_DEVICE, | ||
| 54 | + ) | ||
| 55 | + logger.info("Image model loaded successfully") | ||
| 56 | + except Exception as e: | ||
| 57 | + logger.error(f"Failed to load image model: {e}", exc_info=True) | ||
| 58 | + raise | ||
| 59 | + | ||
| 60 | + logger.info("All embedding models loaded successfully, service ready") | ||
| 51 | 61 | ||
| 52 | 62 | ||
| 53 | def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: | 63 | def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: |
| @@ -62,12 +72,18 @@ def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: | @@ -62,12 +72,18 @@ def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: | ||
| 62 | 72 | ||
| 63 | @app.get("/health") | 73 | @app.get("/health") |
| 64 | def health() -> Dict[str, Any]: | 74 | def health() -> Dict[str, Any]: |
| 65 | - return {"status": "ok"} | 75 | + """Health check endpoint. Returns status and model loading state.""" |
| 76 | + return { | ||
| 77 | + "status": "ok", | ||
| 78 | + "text_model_loaded": _text_model is not None, | ||
| 79 | + "image_model_loaded": _image_model is not None, | ||
| 80 | + } | ||
| 66 | 81 | ||
| 67 | 82 | ||
| 68 | @app.post("/embed/text") | 83 | @app.post("/embed/text") |
| 69 | def embed_text(texts: List[str]) -> List[Optional[List[float]]]: | 84 | def embed_text(texts: List[str]) -> List[Optional[List[float]]]: |
| 70 | - model = _get_text_model() | 85 | + if _text_model is None: |
| 86 | + raise RuntimeError("Text model not loaded") | ||
| 71 | out: List[Optional[List[float]]] = [None] * len(texts) | 87 | out: List[Optional[List[float]]] = [None] * len(texts) |
| 72 | 88 | ||
| 73 | indexed_texts: List[tuple] = [] | 89 | indexed_texts: List[tuple] = [] |
| @@ -87,7 +103,7 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: | @@ -87,7 +103,7 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: | ||
| 87 | batch_texts = [t for _, t in indexed_texts] | 103 | batch_texts = [t for _, t in indexed_texts] |
| 88 | try: | 104 | try: |
| 89 | with _text_encode_lock: | 105 | with _text_encode_lock: |
| 90 | - embs = model.encode_batch( | 106 | + embs = _text_model.encode_batch( |
| 91 | batch_texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE | 107 | batch_texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE |
| 92 | ) | 108 | ) |
| 93 | for j, (idx, _t) in enumerate(indexed_texts): | 109 | for j, (idx, _t) in enumerate(indexed_texts): |
| @@ -100,7 +116,8 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: | @@ -100,7 +116,8 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: | ||
| 100 | 116 | ||
| 101 | @app.post("/embed/image") | 117 | @app.post("/embed/image") |
| 102 | def embed_image(images: List[str]) -> List[Optional[List[float]]]: | 118 | def embed_image(images: List[str]) -> List[Optional[List[float]]]: |
| 103 | - model = _get_image_model() | 119 | + if _image_model is None: |
| 120 | + raise RuntimeError("Image model not loaded") | ||
| 104 | out: List[Optional[List[float]]] = [None] * len(images) | 121 | out: List[Optional[List[float]]] = [None] * len(images) |
| 105 | 122 | ||
| 106 | with _image_encode_lock: | 123 | with _image_encode_lock: |
| @@ -113,7 +130,7 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]: | @@ -113,7 +130,7 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]: | ||
| 113 | url_or_path = url_or_path.strip() | 130 | url_or_path = url_or_path.strip() |
| 114 | if not url_or_path: | 131 | if not url_or_path: |
| 115 | continue | 132 | continue |
| 116 | - emb = model.encode_image_from_url(url_or_path) | 133 | + emb = _image_model.encode_image_from_url(url_or_path) |
| 117 | out[i] = _as_list(emb) | 134 | out[i] = _as_list(emb) |
| 118 | except Exception: | 135 | except Exception: |
| 119 | out[i] = None | 136 | out[i] = None |