clip_rest_api.py 6.15 KB
#!/usr/bin/env python3
"""
CN-CLIP REST API 包装器

提供 HTTP 接口,支持 curl 调用
"""

from flask import Flask, request, jsonify
from flask_cors import CORS
from clip_client import Client
import numpy as np
import traceback

app = Flask(__name__)
CORS(app)  # 允许跨域请求

# 连接到 CN-CLIP 服务
try:
    client = Client('grpc://localhost:51000')
    print("✓ 已连接到 CN-CLIP 服务 (grpc://localhost:51000)")
except Exception as e:
    print(f"✗ 连接失败: {e}")
    print("请先启动 CN-CLIP 服务: ./scripts/start_cnclip_service.sh")
    client = None


@app.route('/health', methods=['GET'])
def health():
    """健康检查"""
    return jsonify({
        'status': 'ok' if client else 'error',
        'service': 'cnclip-rest-api',
        'backend': 'grpc://localhost:51000'
    })


@app.route('/encode/text', methods=['POST'])
def encode_text():
    """
    编码文本

    请求体:
    {
        "texts": ["文本1", "文本2"]
    }

    返回:
    {
        "count": 2,
        "shape": [2, 1024],
        "embeddings": [[...], [...]]
    }
    """
    if not client:
        return jsonify({'error': 'CN-CLIP 服务未连接'}), 503

    try:
        data = request.json
        texts = data.get('texts', [])

        if not texts:
            return jsonify({'error': '缺少 texts 参数'}), 400

        # 编码
        embeddings = client.encode(texts)

        return jsonify({
            'count': len(texts),
            'shape': embeddings.shape.tolist(),
            'embeddings': embeddings.tolist()
        })

    except Exception as e:
        print(f"错误: {e}")
        print(traceback.format_exc())
        return jsonify({'error': str(e)}), 500


@app.route('/encode/image', methods=['POST'])
def encode_image():
    """
    编码图像

    请求体:
    {
        "images": ["https://example.com/image.jpg", "/path/to/local.jpg"]
    }

    返回:
    {
        "count": 2,
        "shape": [2, 1024],
        "embeddings": [[...], [...]]
    }
    """
    if not client:
        return jsonify({'error': 'CN-CLIP 服务未连接'}), 503

    try:
        data = request.json
        images = data.get('images', [])

        if not images:
            return jsonify({'error': '缺少 images 参数'}), 400

        # 编码
        embeddings = client.encode(images)

        return jsonify({
            'count': len(images),
            'shape': embeddings.shape.tolist(),
            'embeddings': embeddings.tolist()
        })

    except Exception as e:
        print(f"错误: {e}")
        print(traceback.format_exc())
        return jsonify({'error': str(e)}), 500


@app.route('/encode/mixed', methods=['POST'])
def encode_mixed():
    """
    混合编码(文本+图像)

    请求体:
    {
        "data": ["文本", "https://example.com/image.jpg"]
    }

    返回:
    {
        "count": 2,
        "shape": [2, 1024],
        "embeddings": [[...], [...]]
    }
    """
    if not client:
        return jsonify({'error': 'CN-CLIP 服务未连接'}), 503

    try:
        data = request.json
        mixed_data = data.get('data', [])

        if not mixed_data:
            return jsonify({'error': '缺少 data 参数'}), 400

        # 编码
        embeddings = client.encode(mixed_data)

        return jsonify({
            'count': len(mixed_data),
            'shape': embeddings.shape.tolist(),
            'embeddings': embeddings.tolist()
        })

    except Exception as e:
        print(f"错误: {e}")
        print(traceback.format_exc())
        return jsonify({'error': str(e)}), 500


@app.route('/similarity', methods=['POST'])
def similarity():
    """
    计算相似度

    请求体:
    {
        "text": "查询文本",
        "images": ["url1", "url2"],
        "texts": ["文本1", "文本2"]
    }

    返回:
    {
        "image_similarities": [0.8, 0.3],
        "text_similarities": [0.9, 0.2]
    }
    """
    if not client:
        return jsonify({'error': 'CN-CLIP 服务未连接'}), 503

    try:
        data = request.json
        query_text = data.get('text', '')
        images = data.get('images', [])
        texts = data.get('texts', [])

        if not query_text:
            return jsonify({'error': '缺少 text 参数'}), 400

        from sklearn.metrics.pairwise import cosine_similarity

        # 编码查询文本
        query_embedding = client.encode([query_text])

        result = {}

        # 计算与图像的相似度
        if images:
            image_embeddings = client.encode(images)
            similarities = cosine_similarity(query_embedding, image_embeddings)[0]
            result['image_similarities'] = similarities.tolist()
            result['image_urls'] = images

        # 计算与文本的相似度
        if texts:
            text_embeddings = client.encode(texts)
            similarities = cosine_similarity(query_embedding, text_embeddings)[0]
            result['text_similarities'] = similarities.tolist()
            result['texts'] = texts

        return jsonify(result)

    except Exception as e:
        print(f"错误: {e}")
        print(traceback.format_exc())
        return jsonify({'error': str(e)}), 500


@app.errorhandler(404)
def not_found(error):
    return jsonify({'error': '接口不存在'}), 404


@app.errorhandler(500)
def internal_error(error):
    return jsonify({'error': '服务器内部错误'}), 500


if __name__ == '__main__':
    print("\n" + "=" * 60)
    print("CN-CLIP REST API 服务")
    print("=" * 60)
    print("\n服务地址: http://localhost:6000")
    print("\n可用接口:")
    print("  POST /health              - 健康检查")
    print("  POST /encode/text         - 编码文本")
    print("  POST /encode/image        - 编码图像")
    print("  POST /encode/mixed        - 混合编码")
    print("  POST /similarity          - 计算相似度")
    print("\n示例:")
    print("  curl http://localhost:6000/health")
    print("  curl -X POST http://localhost:6000/encode/text -H 'Content-Type: application/json' -d '{\"texts\": [\"测试文本\"]}'")
    print("\n" + "=" * 60)
    print()

    app.run(
        host='0.0.0.0',
        port=6000,
        debug=True,
        use_reloader=False  # 避免重复启动
    )