Commit 0a3764c437d6bddab876e9bba9923aa3012dfd91

Authored by tangwang
1 parent 7bfb9946

优化embedding模型加载

Showing 2 changed files with 44 additions and 32 deletions   Show diff stats
embeddings/README.md
@@ -33,8 +33,3 @@ @@ -33,8 +33,3 @@
33 - `TEXT_MODEL_DIR`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE` 33 - `TEXT_MODEL_DIR`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE`
34 - `IMAGE_MODEL_NAME`, `IMAGE_DEVICE` 34 - `IMAGE_MODEL_NAME`, `IMAGE_DEVICE`
35 35
36 -### 目录说明(旧文件)  
37 -  
38 -旧的 `vector_service/` 目录与 `*_encoder__local.py` 文件已经废弃,统一由本目录实现与维护。  
39 -  
40 -  
embeddings/server.py
@@ -6,6 +6,7 @@ API (simple list-in, list-out; aligned by index; failures -> null): @@ -6,6 +6,7 @@ API (simple list-in, list-out; aligned by index; failures -> null):
6 - POST /embed/image body: ["url_or_path1", ...] -> [[...], null, ...] 6 - POST /embed/image body: ["url_or_path1", ...] -> [[...], null, ...]
7 """ 7 """
8 8
  9 +import logging
9 import threading 10 import threading
10 from typing import Any, Dict, List, Optional 11 from typing import Any, Dict, List, Optional
11 12
@@ -16,38 +17,47 @@ from embeddings.config import CONFIG @@ -16,38 +17,47 @@ from embeddings.config import CONFIG
16 from embeddings.bge_model import BgeTextModel 17 from embeddings.bge_model import BgeTextModel
17 from embeddings.clip_model import ClipImageModel 18 from embeddings.clip_model import ClipImageModel
18 19
  20 +logger = logging.getLogger(__name__)
19 21
20 app = FastAPI(title="SearchEngine Embedding Service", version="1.0.0") 22 app = FastAPI(title="SearchEngine Embedding Service", version="1.0.0")
21 23
22 -_text_model = None  
23 -_image_model = None  
24 -  
25 -_text_init_lock = threading.Lock()  
26 -_image_init_lock = threading.Lock() 24 +# Models are loaded at startup, not lazily
  25 +_text_model: Optional[BgeTextModel] = None
  26 +_image_model: Optional[ClipImageModel] = None
27 27
28 _text_encode_lock = threading.Lock() 28 _text_encode_lock = threading.Lock()
29 _image_encode_lock = threading.Lock() 29 _image_encode_lock = threading.Lock()
30 30
31 31
32 -def _get_text_model():  
33 - global _text_model  
34 - if _text_model is None:  
35 - with _text_init_lock:  
36 - if _text_model is None:  
37 - _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR)  
38 - return _text_model 32 +@app.on_event("startup")
  33 +def load_models():
  34 + """Load models at service startup to avoid first-request latency."""
  35 + global _text_model, _image_model
39 36
  37 + logger.info("Loading embedding models at startup...")
40 38
41 -def _get_image_model():  
42 - global _image_model  
43 - if _image_model is None:  
44 - with _image_init_lock:  
45 - if _image_model is None:  
46 - _image_model = ClipImageModel(  
47 - model_name=CONFIG.IMAGE_MODEL_NAME,  
48 - device=CONFIG.IMAGE_DEVICE,  
49 - )  
50 - return _image_model 39 + # Load text model
  40 + try:
  41 + logger.info(f"Loading text model: {CONFIG.TEXT_MODEL_DIR}")
  42 + _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR)
  43 + logger.info("Text model loaded successfully")
  44 + except Exception as e:
  45 + logger.error(f"Failed to load text model: {e}", exc_info=True)
  46 + raise
  47 +
  48 + # Load image model
  49 + try:
  50 + logger.info(f"Loading image model: {CONFIG.IMAGE_MODEL_NAME} (device: {CONFIG.IMAGE_DEVICE})")
  51 + _image_model = ClipImageModel(
  52 + model_name=CONFIG.IMAGE_MODEL_NAME,
  53 + device=CONFIG.IMAGE_DEVICE,
  54 + )
  55 + logger.info("Image model loaded successfully")
  56 + except Exception as e:
  57 + logger.error(f"Failed to load image model: {e}", exc_info=True)
  58 + raise
  59 +
  60 + logger.info("All embedding models loaded successfully, service ready")
51 61
52 62
53 def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: 63 def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]:
@@ -62,12 +72,18 @@ def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: @@ -62,12 +72,18 @@ def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]:
62 72
63 @app.get("/health") 73 @app.get("/health")
64 def health() -> Dict[str, Any]: 74 def health() -> Dict[str, Any]:
65 - return {"status": "ok"} 75 + """Health check endpoint. Returns status and model loading state."""
  76 + return {
  77 + "status": "ok",
  78 + "text_model_loaded": _text_model is not None,
  79 + "image_model_loaded": _image_model is not None,
  80 + }
66 81
67 82
68 @app.post("/embed/text") 83 @app.post("/embed/text")
69 def embed_text(texts: List[str]) -> List[Optional[List[float]]]: 84 def embed_text(texts: List[str]) -> List[Optional[List[float]]]:
70 - model = _get_text_model() 85 + if _text_model is None:
  86 + raise RuntimeError("Text model not loaded")
71 out: List[Optional[List[float]]] = [None] * len(texts) 87 out: List[Optional[List[float]]] = [None] * len(texts)
72 88
73 indexed_texts: List[tuple] = [] 89 indexed_texts: List[tuple] = []
@@ -87,7 +103,7 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: @@ -87,7 +103,7 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]:
87 batch_texts = [t for _, t in indexed_texts] 103 batch_texts = [t for _, t in indexed_texts]
88 try: 104 try:
89 with _text_encode_lock: 105 with _text_encode_lock:
90 - embs = model.encode_batch( 106 + embs = _text_model.encode_batch(
91 batch_texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE 107 batch_texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE
92 ) 108 )
93 for j, (idx, _t) in enumerate(indexed_texts): 109 for j, (idx, _t) in enumerate(indexed_texts):
@@ -100,7 +116,8 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]: @@ -100,7 +116,8 @@ def embed_text(texts: List[str]) -> List[Optional[List[float]]]:
100 116
101 @app.post("/embed/image") 117 @app.post("/embed/image")
102 def embed_image(images: List[str]) -> List[Optional[List[float]]]: 118 def embed_image(images: List[str]) -> List[Optional[List[float]]]:
103 - model = _get_image_model() 119 + if _image_model is None:
  120 + raise RuntimeError("Image model not loaded")
104 out: List[Optional[List[float]]] = [None] * len(images) 121 out: List[Optional[List[float]]] = [None] * len(images)
105 122
106 with _image_encode_lock: 123 with _image_encode_lock:
@@ -113,7 +130,7 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]: @@ -113,7 +130,7 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]:
113 url_or_path = url_or_path.strip() 130 url_or_path = url_or_path.strip()
114 if not url_or_path: 131 if not url_or_path:
115 continue 132 continue
116 - emb = model.encode_image_from_url(url_or_path) 133 + emb = _image_model.encode_image_from_url(url_or_path)
117 out[i] = _as_list(emb) 134 out[i] = _as_list(emb)
118 except Exception: 135 except Exception:
119 out[i] = None 136 out[i] = None