Blame view

examples/clip_rest_api.py 6.15 KB
40f1e391   tangwang   cnclip
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
  #!/usr/bin/env python3
  """
  CN-CLIP REST API 包装器
  
  提供 HTTP 接口,支持 curl 调用
  """
  
  from flask import Flask, request, jsonify
  from flask_cors import CORS
  from clip_client import Client
  import numpy as np
  import traceback
  
  app = Flask(__name__)
  CORS(app)  # 允许跨域请求
  
  # 连接到 CN-CLIP 服务
  try:
      client = Client('grpc://localhost:51000')
      print("✓ 已连接到 CN-CLIP 服务 (grpc://localhost:51000)")
  except Exception as e:
      print(f"✗ 连接失败: {e}")
      print("请先启动 CN-CLIP 服务: ./scripts/start_cnclip_service.sh")
      client = None
  
  
  @app.route('/health', methods=['GET'])
  def health():
      """健康检查"""
      return jsonify({
          'status': 'ok' if client else 'error',
          'service': 'cnclip-rest-api',
          'backend': 'grpc://localhost:51000'
      })
  
  
  @app.route('/encode/text', methods=['POST'])
  def encode_text():
      """
      编码文本
  
      请求体:
      {
          "texts": ["文本1", "文本2"]
      }
  
      返回:
      {
          "count": 2,
          "shape": [2, 1024],
          "embeddings": [[...], [...]]
      }
      """
      if not client:
          return jsonify({'error': 'CN-CLIP 服务未连接'}), 503
  
      try:
          data = request.json
          texts = data.get('texts', [])
  
          if not texts:
              return jsonify({'error': '缺少 texts 参数'}), 400
  
          # 编码
          embeddings = client.encode(texts)
  
          return jsonify({
              'count': len(texts),
              'shape': embeddings.shape.tolist(),
              'embeddings': embeddings.tolist()
          })
  
      except Exception as e:
          print(f"错误: {e}")
          print(traceback.format_exc())
          return jsonify({'error': str(e)}), 500
  
  
  @app.route('/encode/image', methods=['POST'])
  def encode_image():
      """
      编码图像
  
      请求体:
      {
          "images": ["https://example.com/image.jpg", "/path/to/local.jpg"]
      }
  
      返回:
      {
          "count": 2,
          "shape": [2, 1024],
          "embeddings": [[...], [...]]
      }
      """
      if not client:
          return jsonify({'error': 'CN-CLIP 服务未连接'}), 503
  
      try:
          data = request.json
          images = data.get('images', [])
  
          if not images:
              return jsonify({'error': '缺少 images 参数'}), 400
  
          # 编码
          embeddings = client.encode(images)
  
          return jsonify({
              'count': len(images),
              'shape': embeddings.shape.tolist(),
              'embeddings': embeddings.tolist()
          })
  
      except Exception as e:
          print(f"错误: {e}")
          print(traceback.format_exc())
          return jsonify({'error': str(e)}), 500
  
  
  @app.route('/encode/mixed', methods=['POST'])
  def encode_mixed():
      """
      混合编码(文本+图像)
  
      请求体:
      {
          "data": ["文本", "https://example.com/image.jpg"]
      }
  
      返回:
      {
          "count": 2,
          "shape": [2, 1024],
          "embeddings": [[...], [...]]
      }
      """
      if not client:
          return jsonify({'error': 'CN-CLIP 服务未连接'}), 503
  
      try:
          data = request.json
          mixed_data = data.get('data', [])
  
          if not mixed_data:
              return jsonify({'error': '缺少 data 参数'}), 400
  
          # 编码
          embeddings = client.encode(mixed_data)
  
          return jsonify({
              'count': len(mixed_data),
              'shape': embeddings.shape.tolist(),
              'embeddings': embeddings.tolist()
          })
  
      except Exception as e:
          print(f"错误: {e}")
          print(traceback.format_exc())
          return jsonify({'error': str(e)}), 500
  
  
  @app.route('/similarity', methods=['POST'])
  def similarity():
      """
      计算相似度
  
      请求体:
      {
          "text": "查询文本",
          "images": ["url1", "url2"],
          "texts": ["文本1", "文本2"]
      }
  
      返回:
      {
          "image_similarities": [0.8, 0.3],
          "text_similarities": [0.9, 0.2]
      }
      """
      if not client:
          return jsonify({'error': 'CN-CLIP 服务未连接'}), 503
  
      try:
          data = request.json
          query_text = data.get('text', '')
          images = data.get('images', [])
          texts = data.get('texts', [])
  
          if not query_text:
              return jsonify({'error': '缺少 text 参数'}), 400
  
          from sklearn.metrics.pairwise import cosine_similarity
  
          # 编码查询文本
          query_embedding = client.encode([query_text])
  
          result = {}
  
          # 计算与图像的相似度
          if images:
              image_embeddings = client.encode(images)
              similarities = cosine_similarity(query_embedding, image_embeddings)[0]
              result['image_similarities'] = similarities.tolist()
              result['image_urls'] = images
  
          # 计算与文本的相似度
          if texts:
              text_embeddings = client.encode(texts)
              similarities = cosine_similarity(query_embedding, text_embeddings)[0]
              result['text_similarities'] = similarities.tolist()
              result['texts'] = texts
  
          return jsonify(result)
  
      except Exception as e:
          print(f"错误: {e}")
          print(traceback.format_exc())
          return jsonify({'error': str(e)}), 500
  
  
  @app.errorhandler(404)
  def not_found(error):
      return jsonify({'error': '接口不存在'}), 404
  
  
  @app.errorhandler(500)
  def internal_error(error):
      return jsonify({'error': '服务器内部错误'}), 500
  
  
  if __name__ == '__main__':
      print("\n" + "=" * 60)
      print("CN-CLIP REST API 服务")
      print("=" * 60)
      print("\n服务地址: http://localhost:6000")
      print("\n可用接口:")
      print("  POST /health              - 健康检查")
      print("  POST /encode/text         - 编码文本")
      print("  POST /encode/image        - 编码图像")
      print("  POST /encode/mixed        - 混合编码")
      print("  POST /similarity          - 计算相似度")
      print("\n示例:")
      print("  curl http://localhost:6000/health")
      print("  curl -X POST http://localhost:6000/encode/text -H 'Content-Type: application/json' -d '{\"texts\": [\"测试文本\"]}'")
      print("\n" + "=" * 60)
      print()
  
      app.run(
          host='0.0.0.0',
          port=6000,
          debug=True,
          use_reloader=False  # 避免重复启动
      )