Commit 0a3764c437d6bddab876e9bba9923aa3012dfd91
1 parent
7bfb9946
优化embedding模型加载
Showing
2 changed files
with
44 additions
and
32 deletions
Show diff stats
embeddings/README.md
embeddings/server.py
| ... | ... | @@ -6,6 +6,7 @@ API (simple list-in, list-out; aligned by index; failures -> null): |
| 6 | 6 | - POST /embed/image body: ["url_or_path1", ...] -> [[...], null, ...] |
| 7 | 7 | """ |
| 8 | 8 | |
| 9 | +import logging | |
| 9 | 10 | import threading |
| 10 | 11 | from typing import Any, Dict, List, Optional |
| 11 | 12 | |
| ... | ... | @@ -16,38 +17,47 @@ from embeddings.config import CONFIG |
| 16 | 17 | from embeddings.bge_model import BgeTextModel |
| 17 | 18 | from embeddings.clip_model import ClipImageModel |
| 18 | 19 | |
| 20 | +logger = logging.getLogger(__name__) | |
| 19 | 21 | |
| 20 | 22 | app = FastAPI(title="SearchEngine Embedding Service", version="1.0.0") |
| 21 | 23 | |
| 22 | -_text_model = None | |
| 23 | -_image_model = None | |
| 24 | - | |
| 25 | -_text_init_lock = threading.Lock() | |
| 26 | -_image_init_lock = threading.Lock() | |
| 24 | +# Models are loaded at startup, not lazily | |
| 25 | +_text_model: Optional[BgeTextModel] = None | |
| 26 | +_image_model: Optional[ClipImageModel] = None | |
| 27 | 27 | |
| 28 | 28 | _text_encode_lock = threading.Lock() |
| 29 | 29 | _image_encode_lock = threading.Lock() |
| 30 | 30 | |
| 31 | 31 | |
| 32 | -def _get_text_model(): | |
| 33 | - global _text_model | |
| 34 | - if _text_model is None: | |
| 35 | - with _text_init_lock: | |
| 36 | - if _text_model is None: | |
| 37 | - _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR) | |
| 38 | - return _text_model | |
| 32 | +@app.on_event("startup") | |
| 33 | +def load_models(): | |
| 34 | + """Load models at service startup to avoid first-request latency.""" | |
| 35 | + global _text_model, _image_model | |
| 39 | 36 | |
| 37 | + logger.info("Loading embedding models at startup...") | |
| 40 | 38 | |
| 41 | -def _get_image_model(): | |
| 42 | - global _image_model | |
| 43 | - if _image_model is None: | |
| 44 | - with _image_init_lock: | |
| 45 | - if _image_model is None: | |
| 46 | - _image_model = ClipImageModel( | |
| 47 | - model_name=CONFIG.IMAGE_MODEL_NAME, | |
| 48 | - device=CONFIG.IMAGE_DEVICE, | |
| 49 | - ) | |
| 50 | - return _image_model | |
| 39 | + # Load text model | |
| 40 | + try: | |
| 41 | + logger.info(f"Loading text model: {CONFIG.TEXT_MODEL_DIR}") | |
| 42 | + _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR) | |
| 43 | + logger.info("Text model loaded successfully") | |
| 44 | + except Exception as e: | |
| 45 | + logger.error(f"Failed to load text model: {e}", exc_info=True) | |
| 46 | + raise | |
| 47 | + | |
| 48 | + # Load image model | |
| 49 | + try: | |
| 50 | + logger.info(f"Loading image model: {CONFIG.IMAGE_MODEL_NAME} (device: {CONFIG.IMAGE_DEVICE})") | |
| 51 | + _image_model = ClipImageModel( | |
| 52 | + model_name=CONFIG.IMAGE_MODEL_NAME, | |
| 53 | + device=CONFIG.IMAGE_DEVICE, | |
| 54 | + ) | |
| 55 | + logger.info("Image model loaded successfully") | |
| 56 | + except Exception as e: | |
| 57 | + logger.error(f"Failed to load image model: {e}", exc_info=True) | |
| 58 | + raise | |
| 59 | + | |
| 60 | + logger.info("All embedding models loaded successfully, service ready") | |
| 51 | 61 | |
| 52 | 62 | |
| 53 | 63 | def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: |
| ... | ... | @@ -62,12 +72,18 @@ def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: |
| 62 | 72 | |
| 63 | 73 | @app.get("/health") |
| 64 | 74 | def health() -> Dict[str, Any]: |
| 65 | - return {"status": "ok"} | |
| 75 | + """Health check endpoint. Returns status and model loading state.""" | |
| 76 | + return { | |
| 77 | + "status": "ok", | |
| 78 | + "text_model_loaded": _text_model is not None, | |
| 79 | + "image_model_loaded": _image_model is not None, | |
| 80 | + } | |
| 66 | 81 | |
| 67 | 82 | |
| 68 | 83 | @app.post("/embed/text") |
| 69 | 84 | def embed_text(texts: List[str]) -> List[Optional[List[float]]]: |
| 70 | - model = _get_text_model() | |
| 85 | + if _text_model is None: | |
| 86 | + raise RuntimeError("Text model not loaded") | |
| 71 | 87 | out: List[Optional[List[float]]] = [None] * len(texts) |
| 72 | 88 | |
| 73 | 89 | indexed_texts: List[tuple] = [] |
| ... | ... | @@ -87,7 +103,7 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: |
| 87 | 103 | batch_texts = [t for _, t in indexed_texts] |
| 88 | 104 | try: |
| 89 | 105 | with _text_encode_lock: |
| 90 | - embs = model.encode_batch( | |
| 106 | + embs = _text_model.encode_batch( | |
| 91 | 107 | batch_texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE |
| 92 | 108 | ) |
| 93 | 109 | for j, (idx, _t) in enumerate(indexed_texts): |
| ... | ... | @@ -100,7 +116,8 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: |
| 100 | 116 | |
| 101 | 117 | @app.post("/embed/image") |
| 102 | 118 | def embed_image(images: List[str]) -> List[Optional[List[float]]]: |
| 103 | - model = _get_image_model() | |
| 119 | + if _image_model is None: | |
| 120 | + raise RuntimeError("Image model not loaded") | |
| 104 | 121 | out: List[Optional[List[float]]] = [None] * len(images) |
| 105 | 122 | |
| 106 | 123 | with _image_encode_lock: |
| ... | ... | @@ -113,7 +130,7 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]: |
| 113 | 130 | url_or_path = url_or_path.strip() |
| 114 | 131 | if not url_or_path: |
| 115 | 132 | continue |
| 116 | - emb = model.encode_image_from_url(url_or_path) | |
| 133 | + emb = _image_model.encode_image_from_url(url_or_path) | |
| 117 | 134 | out[i] = _as_list(emb) |
| 118 | 135 | except Exception: |
| 119 | 136 | out[i] = None | ... | ... |