test_es_connection.py 8.36 KB
"""
测试Elasticsearch连接和向量查询
用于验证ES配置和向量字段是否正确
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

from elasticsearch import Elasticsearch
import json

# ES配置
ES_CONFIG = {
    'host': 'http://localhost:9200',
    'index_name': 'spu',
    'username': 'essa',
    'password': '4hOaLaf41y2VuI8y'
}


def test_connection():
    """测试ES连接"""
    print("="*80)
    print("测试Elasticsearch连接")
    print("="*80)
    
    try:
        es = Elasticsearch(
            [ES_CONFIG['host']],
            basic_auth=(ES_CONFIG['username'], ES_CONFIG['password']),
            verify_certs=False,
            request_timeout=30
        )
        
        # 测试连接
        info = es.info()
        print(f"✓ ES连接成功!")
        print(f"  集群名称: {info['cluster_name']}")
        print(f"  版本: {info['version']['number']}")
        
        return es
    except Exception as e:
        print(f"✗ ES连接失败: {e}")
        return None


def test_index_exists(es):
    """测试索引是否存在"""
    print("\n" + "="*80)
    print("测试索引是否存在")
    print("="*80)
    
    try:
        exists = es.indices.exists(index=ES_CONFIG['index_name'])
        if exists:
            print(f"✓ 索引 '{ES_CONFIG['index_name']}' 存在")
            
            # 获取索引统计
            stats = es.count(index=ES_CONFIG['index_name'])
            print(f"  文档数量: {stats['count']}")
        else:
            print(f"✗ 索引 '{ES_CONFIG['index_name']}' 不存在")
            return False
        
        return True
    except Exception as e:
        print(f"✗ 查询索引失败: {e}")
        return False


def test_mapping(es):
    """测试向量字段映射"""
    print("\n" + "="*80)
    print("测试向量字段映射")
    print("="*80)
    
    try:
        mapping = es.indices.get_mapping(index=ES_CONFIG['index_name'])
        properties = mapping[ES_CONFIG['index_name']]['mappings']['properties']
        
        # 检查关键字段
        fields_to_check = ['name_zh', 'embedding_name_zh', 'embedding_pic_h14']
        
        for field in fields_to_check:
            if field in properties:
                field_type = properties[field].get('type', properties[field])
                print(f"✓ 字段 '{field}' 存在")
                if isinstance(field_type, dict):
                    print(f"  类型: {json.dumps(field_type, indent=2)}")
                else:
                    print(f"  类型: {field_type}")
            else:
                print(f"✗ 字段 '{field}' 不存在")
        
        return True
    except Exception as e:
        print(f"✗ 获取mapping失败: {e}")
        return False


def test_query_item(es, item_id="3302275"):
    """测试查询商品向量"""
    print("\n" + "="*80)
    print(f"测试查询商品 {item_id}")
    print("="*80)
    
    try:
        response = es.search(
            index=ES_CONFIG['index_name'],
            body={
                "query": {
                    "term": {
                        "_id": item_id
                    }
                },
                "_source": {
                    "includes": ["_id", "name_zh", "embedding_name_zh", "embedding_pic_h14"]
                }
            }
        )
        
        if response['hits']['hits']:
            hit = response['hits']['hits'][0]
            print(f"✓ 找到商品 {item_id}")
            print(f"  名称: {hit['_source'].get('name_zh', 'N/A')}")
            
            # 检查向量
            name_vector = hit['_source'].get('embedding_name_zh')
            if name_vector:
                print(f"  名称向量维度: {len(name_vector)}")
                print(f"  名称向量示例: {name_vector[:5]}...")
            else:
                print("  ✗ 名称向量不存在")
            
            pic_data = hit['_source'].get('embedding_pic_h14')
            if pic_data and isinstance(pic_data, list) and len(pic_data) > 0:
                pic_vector = pic_data[0].get('vector') if isinstance(pic_data[0], dict) else None
                if pic_vector:
                    print(f"  图片向量维度: {len(pic_vector)}")
                    print(f"  图片向量示例: {pic_vector[:5]}...")
                else:
                    print("  ✗ 图片向量不存在")
            else:
                print("  ✗ 图片数据不存在")
            
            return hit['_source']
        else:
            print(f"✗ 未找到商品 {item_id}")
            return None
    except Exception as e:
        print(f"✗ 查询商品失败: {e}")
        return None


def test_knn_query(es, item_id="3302275"):
    """测试KNN向量查询"""
    print("\n" + "="*80)
    print(f"测试KNN查询(商品 {item_id})")
    print("="*80)
    
    # 先获取该商品的向量
    item_data = test_query_item(es, item_id)
    if not item_data:
        print("无法获取商品向量,跳过KNN测试")
        return False
    
    # 测试名称向量KNN查询
    name_vector = item_data.get('embedding_name_zh')
    if name_vector:
        try:
            print("\n测试名称向量KNN查询...")
            response = es.search(
                index=ES_CONFIG['index_name'],
                body={
                    "knn": {
                        "field": "embedding_name_zh",
                        "query_vector": name_vector,
                        "k": 5,
                        "num_candidates": 10
                    },
                    "_source": ["_id", "name_zh"],
                    "size": 5
                }
            )
            
            print(f"✓ 名称向量KNN查询成功")
            print(f"  找到 {len(response['hits']['hits'])} 个相似商品:")
            for idx, hit in enumerate(response['hits']['hits'], 1):
                print(f"    {idx}. ID: {hit['_id']}, 名称: {hit['_source'].get('name_zh', 'N/A')}, 分数: {hit['_score']:.4f}")
        except Exception as e:
            print(f"✗ 名称向量KNN查询失败: {e}")
    
    # 测试图片向量KNN查询
    pic_data = item_data.get('embedding_pic_h14')
    if pic_data and isinstance(pic_data, list) and len(pic_data) > 0:
        pic_vector = pic_data[0].get('vector') if isinstance(pic_data[0], dict) else None
        if pic_vector:
            try:
                print("\n测试图片向量KNN查询...")
                response = es.search(
                    index=ES_CONFIG['index_name'],
                    body={
                        "knn": {
                            "field": "embedding_pic_h14.vector",
                            "query_vector": pic_vector,
                            "k": 5,
                            "num_candidates": 10
                        },
                        "_source": ["_id", "name_zh"],
                        "size": 5
                    }
                )
                
                print(f"✓ 图片向量KNN查询成功")
                print(f"  找到 {len(response['hits']['hits'])} 个相似商品:")
                for idx, hit in enumerate(response['hits']['hits'], 1):
                    print(f"    {idx}. ID: {hit['_id']}, 名称: {hit['_source'].get('name_zh', 'N/A')}, 分数: {hit['_score']:.4f}")
            except Exception as e:
                print(f"✗ 图片向量KNN查询失败: {e}")
    
    return True


def main():
    """主函数"""
    print("\n" + "="*80)
    print("Elasticsearch向量查询测试")
    print("="*80)
    
    # 1. 测试连接
    es = test_connection()
    if not es:
        return 1
    
    # 2. 测试索引
    if not test_index_exists(es):
        return 1
    
    # 3. 测试mapping
    test_mapping(es)
    
    # 4. 测试查询商品
    # 默认测试ID,如果不存在会失败,用户可以修改为实际的商品ID
    test_item_id = "3302275"
    print(f"\n提示: 如果商品ID {test_item_id} 不存在,请修改 test_item_id 变量为实际的商品ID")
    
    item_data = test_query_item(es, test_item_id)
    
    # 5. 测试KNN查询
    if item_data:
        test_knn_query(es, test_item_id)
    
    print("\n" + "="*80)
    print("测试完成!")
    print("="*80)
    print("\n如果所有测试都通过,可以运行:")
    print("  python scripts/i2i_content_similar.py")
    print("\n")
    
    return 0


if __name__ == '__main__':
    import sys
    sys.exit(main())