Blame view

embeddings/text_encoder.py 4.94 KB
be52af70   tangwang   first commit
1
  """
325eec03   tangwang   1. 日志、配置基础设施,使用优化
2
  Text embedding encoder using network service.
be52af70   tangwang   first commit
3
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
4
  Generates embeddings via HTTP API service running on localhost:5001.
be52af70   tangwang   first commit
5
6
7
  """
  
  import sys
325eec03   tangwang   1. 日志、配置基础设施,使用优化
8
  import requests
be52af70   tangwang   first commit
9
10
  import time
  import threading
be52af70   tangwang   first commit
11
  import numpy as np
325eec03   tangwang   1. 日志、配置基础设施,使用优化
12
13
14
15
  import logging
  from typing import List, Union, Dict, Any
  
  logger = logging.getLogger(__name__)
be52af70   tangwang   first commit
16
17
18
19
  
  
  class BgeEncoder:
      """
325eec03   tangwang   1. 日志、配置基础设施,使用优化
20
      Singleton text encoder using network service.
be52af70   tangwang   first commit
21
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
22
      Thread-safe singleton pattern ensures only one instance exists.
be52af70   tangwang   first commit
23
24
25
26
      """
      _instance = None
      _lock = threading.Lock()
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
27
      def __new__(cls, service_url='http://localhost:5001'):
be52af70   tangwang   first commit
28
29
30
          with cls._lock:
              if cls._instance is None:
                  cls._instance = super(BgeEncoder, cls).__new__(cls)
325eec03   tangwang   1. 日志、配置基础设施,使用优化
31
32
33
                  logger.info(f"Creating BgeEncoder instance with service URL: {service_url}")
                  cls._instance.service_url = service_url
                  cls._instance.endpoint = f"{service_url}/embedding/generate_embeddings"
be52af70   tangwang   first commit
34
35
          return cls._instance
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
      def _call_service(self, request_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
          """
          Call the embedding service API.
  
          Args:
              request_data: List of dictionaries with id and text fields
  
          Returns:
              List of dictionaries with id and embedding fields
          """
          try:
              response = requests.post(
                  self.endpoint,
                  json=request_data,
                  timeout=60
              )
              response.raise_for_status()
              return response.json()
          except requests.exceptions.RequestException as e:
              logger.error(f"BgeEncoder service request failed: {e}", exc_info=True)
              raise
  
be52af70   tangwang   first commit
58
59
60
61
      def encode(
          self,
          sentences: Union[str, List[str]],
          normalize_embeddings: bool = True,
325eec03   tangwang   1. 日志、配置基础设施,使用优化
62
          device: str = 'cpu',
be52af70   tangwang   first commit
63
64
65
          batch_size: int = 32
      ) -> np.ndarray:
          """
325eec03   tangwang   1. 日志、配置基础设施,使用优化
66
          Encode text into embeddings via network service.
be52af70   tangwang   first commit
67
68
69
  
          Args:
              sentences: Single string or list of strings to encode
325eec03   tangwang   1. 日志、配置基础设施,使用优化
70
71
72
              normalize_embeddings: Whether to normalize embeddings (ignored for service)
              device: Device parameter ignored for service compatibility
              batch_size: Batch size for processing (used for service requests)
be52af70   tangwang   first commit
73
74
75
76
  
          Returns:
              numpy array of shape (n, 1024) containing embeddings
          """
325eec03   tangwang   1. 日志、配置基础设施,使用优化
77
78
79
          # Convert single string to list
          if isinstance(sentences, str):
              sentences = [sentences]
be52af70   tangwang   first commit
80
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
81
82
83
84
85
86
87
          # Prepare request data
          request_data = []
          for i, text in enumerate(sentences):
              request_item = {
                  "id": str(i),
                  "name_zh": text
              }
be52af70   tangwang   first commit
88
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
89
90
91
92
              # Add English and Russian fields as empty for now
              # Could be enhanced with language detection in the future
              request_item["name_en"] = None
              request_item["name_ru"] = None
be52af70   tangwang   first commit
93
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
              request_data.append(request_item)
  
          try:
              # Call service
              response_data = self._call_service(request_data)
  
              # Process response
              embeddings = []
              for i, text in enumerate(sentences):
                  # Find corresponding response by ID
                  response_item = None
                  for item in response_data:
                      if str(item.get("id")) == str(i):
                          response_item = item
                          break
  
                  if response_item:
                      # Try Chinese embedding first, then English, then Russian
                      embedding = None
                      for lang in ["embedding_zh", "embedding_en", "embedding_ru"]:
                          if lang in response_item and response_item[lang] is not None:
                              embedding = response_item[lang]
                              break
  
                      if embedding is not None:
                          embeddings.append(embedding)
                      else:
                          logger.warning(f"No embedding found for text {i}: {text[:50]}...")
                          embeddings.append([0.0] * 1024)
                  else:
                      logger.warning(f"No response found for text {i}")
                      embeddings.append([0.0] * 1024)
16c42787   tangwang   feat: implement r...
126
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
127
              return np.array(embeddings, dtype=np.float32)
16c42787   tangwang   feat: implement r...
128
129
  
          except Exception as e:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
130
131
132
              logger.error(f"Failed to encode texts: {e}", exc_info=True)
              # Return zero embeddings as fallback
              return np.zeros((len(sentences), 1024), dtype=np.float32)
be52af70   tangwang   first commit
133
134
135
136
137
  
      def encode_batch(
          self,
          texts: List[str],
          batch_size: int = 32,
325eec03   tangwang   1. 日志、配置基础设施,使用优化
138
          device: str = 'cpu'
be52af70   tangwang   first commit
139
140
      ) -> np.ndarray:
          """
325eec03   tangwang   1. 日志、配置基础设施,使用优化
141
          Encode a batch of texts efficiently via network service.
be52af70   tangwang   first commit
142
143
144
145
  
          Args:
              texts: List of texts to encode
              batch_size: Batch size for processing
325eec03   tangwang   1. 日志、配置基础设施,使用优化
146
              device: Device parameter ignored for service compatibility
be52af70   tangwang   first commit
147
148
149
150
151
  
          Returns:
              numpy array of embeddings
          """
          return self.encode(texts, batch_size=batch_size, device=device)