#!/usr/bin/env python3 """ CN-CLIP REST API 包装器 提供 HTTP 接口,支持 curl 调用 """ from flask import Flask, request, jsonify from flask_cors import CORS from clip_client import Client import numpy as np import traceback app = Flask(__name__) CORS(app) # 允许跨域请求 # 连接到 CN-CLIP 服务 try: client = Client('grpc://localhost:51000') print("✓ 已连接到 CN-CLIP 服务 (grpc://localhost:51000)") except Exception as e: print(f"✗ 连接失败: {e}") print("请先启动 CN-CLIP 服务: ./scripts/start_cnclip_service.sh") client = None @app.route('/health', methods=['GET']) def health(): """健康检查""" return jsonify({ 'status': 'ok' if client else 'error', 'service': 'cnclip-rest-api', 'backend': 'grpc://localhost:51000' }) @app.route('/encode/text', methods=['POST']) def encode_text(): """ 编码文本 请求体: { "texts": ["文本1", "文本2"] } 返回: { "count": 2, "shape": [2, 1024], "embeddings": [[...], [...]] } """ if not client: return jsonify({'error': 'CN-CLIP 服务未连接'}), 503 try: data = request.json texts = data.get('texts', []) if not texts: return jsonify({'error': '缺少 texts 参数'}), 400 # 编码 embeddings = client.encode(texts) return jsonify({ 'count': len(texts), 'shape': embeddings.shape.tolist(), 'embeddings': embeddings.tolist() }) except Exception as e: print(f"错误: {e}") print(traceback.format_exc()) return jsonify({'error': str(e)}), 500 @app.route('/encode/image', methods=['POST']) def encode_image(): """ 编码图像 请求体: { "images": ["https://example.com/image.jpg", "/path/to/local.jpg"] } 返回: { "count": 2, "shape": [2, 1024], "embeddings": [[...], [...]] } """ if not client: return jsonify({'error': 'CN-CLIP 服务未连接'}), 503 try: data = request.json images = data.get('images', []) if not images: return jsonify({'error': '缺少 images 参数'}), 400 # 编码 embeddings = client.encode(images) return jsonify({ 'count': len(images), 'shape': embeddings.shape.tolist(), 'embeddings': embeddings.tolist() }) except Exception as e: print(f"错误: {e}") print(traceback.format_exc()) return jsonify({'error': str(e)}), 500 @app.route('/encode/mixed', methods=['POST']) def encode_mixed(): """ 混合编码(文本+图像) 请求体: { "data": ["文本", "https://example.com/image.jpg"] } 返回: { "count": 2, "shape": [2, 1024], "embeddings": [[...], [...]] } """ if not client: return jsonify({'error': 'CN-CLIP 服务未连接'}), 503 try: data = request.json mixed_data = data.get('data', []) if not mixed_data: return jsonify({'error': '缺少 data 参数'}), 400 # 编码 embeddings = client.encode(mixed_data) return jsonify({ 'count': len(mixed_data), 'shape': embeddings.shape.tolist(), 'embeddings': embeddings.tolist() }) except Exception as e: print(f"错误: {e}") print(traceback.format_exc()) return jsonify({'error': str(e)}), 500 @app.route('/similarity', methods=['POST']) def similarity(): """ 计算相似度 请求体: { "text": "查询文本", "images": ["url1", "url2"], "texts": ["文本1", "文本2"] } 返回: { "image_similarities": [0.8, 0.3], "text_similarities": [0.9, 0.2] } """ if not client: return jsonify({'error': 'CN-CLIP 服务未连接'}), 503 try: data = request.json query_text = data.get('text', '') images = data.get('images', []) texts = data.get('texts', []) if not query_text: return jsonify({'error': '缺少 text 参数'}), 400 from sklearn.metrics.pairwise import cosine_similarity # 编码查询文本 query_embedding = client.encode([query_text]) result = {} # 计算与图像的相似度 if images: image_embeddings = client.encode(images) similarities = cosine_similarity(query_embedding, image_embeddings)[0] result['image_similarities'] = similarities.tolist() result['image_urls'] = images # 计算与文本的相似度 if texts: text_embeddings = client.encode(texts) similarities = cosine_similarity(query_embedding, text_embeddings)[0] result['text_similarities'] = similarities.tolist() result['texts'] = texts return jsonify(result) except Exception as e: print(f"错误: {e}") print(traceback.format_exc()) return jsonify({'error': str(e)}), 500 @app.errorhandler(404) def not_found(error): return jsonify({'error': '接口不存在'}), 404 @app.errorhandler(500) def internal_error(error): return jsonify({'error': '服务器内部错误'}), 500 if __name__ == '__main__': print("\n" + "=" * 60) print("CN-CLIP REST API 服务") print("=" * 60) print("\n服务地址: http://localhost:6000") print("\n可用接口:") print(" POST /health - 健康检查") print(" POST /encode/text - 编码文本") print(" POST /encode/image - 编码图像") print(" POST /encode/mixed - 混合编码") print(" POST /similarity - 计算相似度") print("\n示例:") print(" curl http://localhost:6000/health") print(" curl -X POST http://localhost:6000/encode/text -H 'Content-Type: application/json' -d '{\"texts\": [\"测试文本\"]}'") print("\n" + "=" * 60) print() app.run( host='0.0.0.0', port=6000, debug=True, use_reloader=False # 避免重复启动 )