Blame view

embeddings/server.py 3.33 KB
7bfb9946   tangwang   向量化模块
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
  """
  Embedding service (FastAPI).
  
  API (simple list-in, list-out; aligned by index; failures -> null):
  - POST /embed/text   body: ["text1", "text2", ...] -> [[...], null, ...]
  - POST /embed/image  body: ["url_or_path1", ...]  -> [[...], null, ...]
  """
  
  import threading
  from typing import Any, Dict, List, Optional
  
  import numpy as np
  from fastapi import FastAPI
  
  from embeddings.config import CONFIG
  from embeddings.bge_model import BgeTextModel
  from embeddings.clip_model import ClipImageModel
  
  
  app = FastAPI(title="SearchEngine Embedding Service", version="1.0.0")
  
  _text_model = None
  _image_model = None
  
  _text_init_lock = threading.Lock()
  _image_init_lock = threading.Lock()
  
  _text_encode_lock = threading.Lock()
  _image_encode_lock = threading.Lock()
  
  
  def _get_text_model():
      global _text_model
      if _text_model is None:
          with _text_init_lock:
              if _text_model is None:
                  _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR)
      return _text_model
  
  
  def _get_image_model():
      global _image_model
      if _image_model is None:
          with _image_init_lock:
              if _image_model is None:
                  _image_model = ClipImageModel(
                      model_name=CONFIG.IMAGE_MODEL_NAME,
                      device=CONFIG.IMAGE_DEVICE,
                  )
      return _image_model
  
  
  def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]:
      if embedding is None:
          return None
      if not isinstance(embedding, np.ndarray):
          embedding = np.array(embedding, dtype=np.float32)
      if embedding.ndim != 1:
          embedding = embedding.reshape(-1)
      return embedding.astype(np.float32).tolist()
  
  
  @app.get("/health")
  def health() -> Dict[str, Any]:
      return {"status": "ok"}
  
  
  @app.post("/embed/text")
  def embed_text(texts: List[str]) -> List[Optional[List[float]]]:
      model = _get_text_model()
      out: List[Optional[List[float]]] = [None] * len(texts)
  
      indexed_texts: List[tuple] = []
      for i, t in enumerate(texts):
          if t is None:
              continue
          if not isinstance(t, str):
              t = str(t)
          t = t.strip()
          if not t:
              continue
          indexed_texts.append((i, t))
  
      if not indexed_texts:
          return out
  
      batch_texts = [t for _, t in indexed_texts]
      try:
          with _text_encode_lock:
              embs = model.encode_batch(
                  batch_texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE
              )
          for j, (idx, _t) in enumerate(indexed_texts):
              out[idx] = _as_list(embs[j])
      except Exception:
          # keep Nones
          pass
      return out
  
  
  @app.post("/embed/image")
  def embed_image(images: List[str]) -> List[Optional[List[float]]]:
      model = _get_image_model()
      out: List[Optional[List[float]]] = [None] * len(images)
  
      with _image_encode_lock:
          for i, url_or_path in enumerate(images):
              try:
                  if url_or_path is None:
                      continue
                  if not isinstance(url_or_path, str):
                      url_or_path = str(url_or_path)
                  url_or_path = url_or_path.strip()
                  if not url_or_path:
                      continue
                  emb = model.encode_image_from_url(url_or_path)
                  out[i] = _as_list(emb)
              except Exception:
                  out[i] = None
      return out