Commit 0a3764c437d6bddab876e9bba9923aa3012dfd91

Authored by tangwang
1 parent 7bfb9946

优化embedding模型加载

Showing 2 changed files with 44 additions and 32 deletions   Show diff stats
embeddings/README.md
... ... @@ -33,8 +33,3 @@
33 33 - `TEXT_MODEL_DIR`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE`
34 34 - `IMAGE_MODEL_NAME`, `IMAGE_DEVICE`
35 35  
36   -### 目录说明(旧文件)
37   -
38   -旧的 `vector_service/` 目录与 `*_encoder__local.py` 文件已经废弃,统一由本目录实现与维护。
39   -
40   -
... ...
embeddings/server.py
... ... @@ -6,6 +6,7 @@ API (simple list-in, list-out; aligned by index; failures -> null):
6 6 - POST /embed/image body: ["url_or_path1", ...] -> [[...], null, ...]
7 7 """
8 8  
  9 +import logging
9 10 import threading
10 11 from typing import Any, Dict, List, Optional
11 12  
... ... @@ -16,38 +17,47 @@ from embeddings.config import CONFIG
16 17 from embeddings.bge_model import BgeTextModel
17 18 from embeddings.clip_model import ClipImageModel
18 19  
  20 +logger = logging.getLogger(__name__)
19 21  
20 22 app = FastAPI(title="SearchEngine Embedding Service", version="1.0.0")
21 23  
22   -_text_model = None
23   -_image_model = None
24   -
25   -_text_init_lock = threading.Lock()
26   -_image_init_lock = threading.Lock()
  24 +# Models are loaded at startup, not lazily
  25 +_text_model: Optional[BgeTextModel] = None
  26 +_image_model: Optional[ClipImageModel] = None
27 27  
28 28 _text_encode_lock = threading.Lock()
29 29 _image_encode_lock = threading.Lock()
30 30  
31 31  
32   -def _get_text_model():
33   - global _text_model
34   - if _text_model is None:
35   - with _text_init_lock:
36   - if _text_model is None:
37   - _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR)
38   - return _text_model
  32 +@app.on_event("startup")
  33 +def load_models():
  34 + """Load models at service startup to avoid first-request latency."""
  35 + global _text_model, _image_model
39 36  
  37 + logger.info("Loading embedding models at startup...")
40 38  
41   -def _get_image_model():
42   - global _image_model
43   - if _image_model is None:
44   - with _image_init_lock:
45   - if _image_model is None:
46   - _image_model = ClipImageModel(
47   - model_name=CONFIG.IMAGE_MODEL_NAME,
48   - device=CONFIG.IMAGE_DEVICE,
49   - )
50   - return _image_model
  39 + # Load text model
  40 + try:
  41 + logger.info(f"Loading text model: {CONFIG.TEXT_MODEL_DIR}")
  42 + _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR)
  43 + logger.info("Text model loaded successfully")
  44 + except Exception as e:
  45 + logger.error(f"Failed to load text model: {e}", exc_info=True)
  46 + raise
  47 +
  48 + # Load image model
  49 + try:
  50 + logger.info(f"Loading image model: {CONFIG.IMAGE_MODEL_NAME} (device: {CONFIG.IMAGE_DEVICE})")
  51 + _image_model = ClipImageModel(
  52 + model_name=CONFIG.IMAGE_MODEL_NAME,
  53 + device=CONFIG.IMAGE_DEVICE,
  54 + )
  55 + logger.info("Image model loaded successfully")
  56 + except Exception as e:
  57 + logger.error(f"Failed to load image model: {e}", exc_info=True)
  58 + raise
  59 +
  60 + logger.info("All embedding models loaded successfully, service ready")
51 61  
52 62  
53 63 def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]:
... ... @@ -62,12 +72,18 @@ def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]:
62 72  
63 73 @app.get("/health")
64 74 def health() -> Dict[str, Any]:
65   - return {"status": "ok"}
  75 + """Health check endpoint. Returns status and model loading state."""
  76 + return {
  77 + "status": "ok",
  78 + "text_model_loaded": _text_model is not None,
  79 + "image_model_loaded": _image_model is not None,
  80 + }
66 81  
67 82  
68 83 @app.post("/embed/text")
69 84 def embed_text(texts: List[str]) -> List[Optional[List[float]]]:
70   - model = _get_text_model()
  85 + if _text_model is None:
  86 + raise RuntimeError("Text model not loaded")
71 87 out: List[Optional[List[float]]] = [None] * len(texts)
72 88  
73 89 indexed_texts: List[tuple] = []
... ... @@ -87,7 +103,7 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]:
87 103 batch_texts = [t for _, t in indexed_texts]
88 104 try:
89 105 with _text_encode_lock:
90   - embs = model.encode_batch(
  106 + embs = _text_model.encode_batch(
91 107 batch_texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE
92 108 )
93 109 for j, (idx, _t) in enumerate(indexed_texts):
... ... @@ -100,7 +116,8 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]:
100 116  
101 117 @app.post("/embed/image")
102 118 def embed_image(images: List[str]) -> List[Optional[List[float]]]:
103   - model = _get_image_model()
  119 + if _image_model is None:
  120 + raise RuntimeError("Image model not loaded")
104 121 out: List[Optional[List[float]]] = [None] * len(images)
105 122  
106 123 with _image_encode_lock:
... ... @@ -113,7 +130,7 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]:
113 130 url_or_path = url_or_path.strip()
114 131 if not url_or_path:
115 132 continue
116   - emb = model.encode_image_from_url(url_or_path)
  133 + emb = _image_model.encode_image_from_url(url_or_path)
117 134 out[i] = _as_list(emb)
118 135 except Exception:
119 136 out[i] = None
... ...