Blame view

embeddings/bge_model.py 2.44 KB
7bfb9946   tangwang   向量化模块
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
  """
  BGE-M3 local text embedding implementation.
  
  Internal model implementation used by the embedding service.
  """
  
  import threading
  from typing import List, Union
  
  import numpy as np
  from sentence_transformers import SentenceTransformer
  from modelscope import snapshot_download
  
  
  class BgeTextModel(object):
      """
      Thread-safe singleton text encoder using BGE-M3 model (local inference).
      """
  
      _instance = None
      _lock = threading.Lock()
  
      def __new__(cls, model_dir: str = "Xorbits/bge-m3"):
          with cls._lock:
              if cls._instance is None:
                  cls._instance = super(BgeTextModel, cls).__new__(cls)
                  cls._instance.model = SentenceTransformer(snapshot_download(model_dir))
          return cls._instance
  
      def encode(
          self,
          sentences: Union[str, List[str]],
          normalize_embeddings: bool = True,
          device: str = "cuda",
          batch_size: int = 32,
      ) -> np.ndarray:
          if device == "gpu":
              device = "cuda"
  
          # Try requested device, fallback to CPU if CUDA fails
          try:
              if device == "cuda":
                  import torch
  
                  if torch.cuda.is_available():
                      free_memory = (
                          torch.cuda.get_device_properties(0).total_memory
                          - torch.cuda.memory_allocated()
                      )
                      if free_memory < 1024 * 1024 * 1024:  # 1GB
                          device = "cpu"
                  else:
                      device = "cpu"
  
              self.model = self.model.to(device)
              embeddings = self.model.encode(
                  sentences,
                  normalize_embeddings=normalize_embeddings,
                  device=device,
                  show_progress_bar=False,
                  batch_size=batch_size,
              )
              return embeddings
  
          except Exception:
              if device != "cpu":
                  self.model = self.model.to("cpu")
                  embeddings = self.model.encode(
                      sentences,
                      normalize_embeddings=normalize_embeddings,
                      device="cpu",
                      show_progress_bar=False,
                      batch_size=batch_size,
                  )
                  return embeddings
              raise
  
      def encode_batch(self, texts: List[str], batch_size: int = 32, device: str = "cuda") -> np.ndarray:
          return self.encode(texts, batch_size=batch_size, device=device)