i2i_content_similar.py 11.4 KB
"""
i2i - 基于ES向量的内容相似索引
从Elasticsearch获取商品向量,计算两种相似度:
1. 基于名称文本向量的相似度
2. 基于图片向量的相似度
"""
import json
import os
import argparse
import pandas as pd
from datetime import datetime, timedelta
from elasticsearch import Elasticsearch
from db_service import create_db_connection
from config.offline_config import DB_CONFIG, OUTPUT_DIR
from scripts.debug_utils import (
    setup_debug_logger, log_processing_step, 
    save_readable_index, fetch_name_mappings
)

# ES配置
ES_CONFIG = {
    'host': 'http://localhost:9200',
    'index_name': 'spu',
    'username': 'essa',
    'password': '4hOaLaf41y2VuI8y'
}

# 算法参数
TOP_N = 50          # 每个商品返回的相似商品数量
KNN_K = 100         # knn查询返回的候选数
KNN_CANDIDATES = 200  # knn查询的候选池大小


def get_active_items(engine):
    """
    获取最近1年有过行为的item列表
    """
    one_year_ago = (datetime.now() - timedelta(days=365)).strftime('%Y-%m-%d')
    
    sql_query = f"""
    SELECT DISTINCT
        se.item_id
    FROM 
        sensors_events se
    WHERE 
        se.event IN ('click', 'contactFactory', 'addToPool', 'addToCart', 'purchase')
        AND se.create_time >= '{one_year_ago}'
        AND se.item_id IS NOT NULL
    """
    
    df = pd.read_sql(sql_query, engine)
    return df['item_id'].tolist()


def connect_es():
    """连接到Elasticsearch"""
    es = Elasticsearch(
        [ES_CONFIG['host']],
        basic_auth=(ES_CONFIG['username'], ES_CONFIG['password']),
        verify_certs=False,
        request_timeout=30
    )
    return es


def get_item_vectors(es, item_id):
    """
    从ES获取商品的向量数据
    
    Returns:
        dict with keys: _id, name_zh, embedding_name_zh, embedding_pic_h14, on_sell_days_boost
        或 None if not found
    """
    try:
        response = es.search(
            index=ES_CONFIG['index_name'],
            body={
                "query": {
                    "term": {
                        "_id": str(item_id)
                    }
                },
                "_source": {
                    "includes": ["_id", "name_zh", "embedding_name_zh", "embedding_pic_h14", "on_sell_days_boost"]
                }
            }
        )
        
        if response['hits']['hits']:
            hit = response['hits']['hits'][0]
            return {
                '_id': hit['_id'],
                'name_zh': hit['_source'].get('name_zh', ''),
                'embedding_name_zh': hit['_source'].get('embedding_name_zh'),
                'embedding_pic_h14': hit['_source'].get('embedding_pic_h14'),
                'on_sell_days_boost': hit['_source'].get('on_sell_days_boost', 1.0)
            }
        return None
    except Exception as e:
        return None


def find_similar_by_vector(es, vector, field_name, k=KNN_K, num_candidates=KNN_CANDIDATES):
    """
    使用knn查询找到相似的items
    
    Args:
        es: Elasticsearch客户端
        vector: 查询向量
        field_name: 向量字段名 (embedding_name_zh 或 embedding_pic_h14.vector)
        k: 返回的结果数
        num_candidates: 候选池大小
    
    Returns:
        List of (item_id, boosted_score, name_zh) tuples
    """
    try:
        response = es.search(
            index=ES_CONFIG['index_name'],
            body={
                "knn": {
                    "field": field_name,
                    "query_vector": vector,
                    "k": k,
                    "num_candidates": num_candidates
                },
                "_source": ["_id", "name_zh", "on_sell_days_boost"],
                "size": k
            }
        )
        
        results = []
        for hit in response['hits']['hits']:
            # 获取基础分数
            base_score = hit['_score']
            
            # 获取on_sell_days_boost提权值,默认为1.0(不提权)
            boost = hit['_source'].get('on_sell_days_boost', 1.0)
            
            # 确保boost在合理范围内
            if boost is None or boost < 0.9 or boost > 1.1:
                boost = 1.0
            
            # 应用提权
            boosted_score = base_score * boost
            
            results.append((
                hit['_id'],
                boosted_score,
                hit['_source'].get('name_zh', '')
            ))
        return results
    except Exception as e:
        return []


def generate_similarity_index(es, active_items, vector_field, field_name, logger, top_n=50):
    """
    生成一种向量的相似度索引
    
    Args:
        es: Elasticsearch客户端
        active_items: 活跃商品ID列表
        vector_field: 向量字段名 (embedding_name_zh 或 embedding_pic_h14)
        field_name: 字段简称 (name 或 pic)
        logger: 日志记录器
        top_n: 返回的相似商品数量
    
    Returns:
        dict: {item_id: [(similar_id, score, name), ...]}
    """
    result = {}
    total = len(active_items)
    
    for idx, item_id in enumerate(active_items):
        if (idx + 1) % 100 == 0:
            logger.info(f"处理进度: {idx + 1}/{total} ({(idx + 1) / total * 100:.1f}%)")
        
        # 获取该商品的向量
        item_data = get_item_vectors(es, item_id)
        if not item_data:
            continue
        
        # 提取向量
        if vector_field == 'embedding_name_zh':
            query_vector = item_data.get('embedding_name_zh')
        elif vector_field == 'embedding_pic_h14':
            pic_data = item_data.get('embedding_pic_h14')
            if pic_data and isinstance(pic_data, list) and len(pic_data) > 0:
                query_vector = pic_data[0].get('vector') if isinstance(pic_data[0], dict) else None
            else:
                query_vector = None
        else:
            query_vector = None
        
        if not query_vector:
            continue
        
        # 使用knn查询相似items(需要排除自己)
        knn_field = f"{vector_field}.vector" if vector_field == 'embedding_pic_h14' else vector_field
        similar_items = find_similar_by_vector(es, query_vector, knn_field)
        
        # 过滤掉自己,只保留top N
        # 注意:分数已经在find_similar_by_vector中应用了on_sell_days_boost提权
        filtered_items = []
        for sim_id, boosted_score, name in similar_items:
            if sim_id != str(item_id):
                filtered_items.append((sim_id, boosted_score, name))
            if len(filtered_items) >= top_n:
                break
        
        if filtered_items:
            result[item_id] = filtered_items
    
    return result


def save_index_file(result, es, output_file, logger):
    """
    保存索引文件
    
    格式: item_id \t item_name \t similar_id1:score1,similar_id2:score2,...
    """
    logger.info(f"保存索引到: {output_file}")
    
    with open(output_file, 'w', encoding='utf-8') as f:
        for item_id, similar_items in result.items():
            if not similar_items:
                continue
            
            # 获取当前商品的名称
            item_data = get_item_vectors(es, item_id)
            item_name = item_data.get('name_zh', 'Unknown') if item_data else 'Unknown'
            
            # 格式化相似商品列表
            sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score, _ in similar_items])
            f.write(f'{item_id}\t{item_name}\t{sim_str}\n')
    
    logger.info(f"索引保存完成,共 {len(result)} 个商品")


def main():
    """主函数"""
    # 解析命令行参数
    parser = argparse.ArgumentParser(description='Generate content-based similarity using ES vectors')
    parser.add_argument('--debug', action='store_true', help='Enable debug mode with readable output')
    parser.add_argument('--top_n', type=int, default=50, help='Number of similar items per item (default: 50)')
    args = parser.parse_args()
    
    # 使用参数中的top_n值
    top_n = args.top_n
    
    # 设置logger
    logger = setup_debug_logger('i2i_content_similar', debug=args.debug)
    
    logger.info("="*80)
    logger.info("开始生成基于ES向量的内容相似索引")
    logger.info(f"ES地址: {ES_CONFIG['host']}")
    logger.info(f"索引名: {ES_CONFIG['index_name']}")
    logger.info(f"Top N: {top_n}")
    logger.info("="*80)
    
    # 创建数据库连接
    log_processing_step(logger, "连接数据库")
    engine = create_db_connection(
        DB_CONFIG['host'],
        DB_CONFIG['port'],
        DB_CONFIG['database'],
        DB_CONFIG['username'],
        DB_CONFIG['password']
    )
    
    # 获取活跃商品
    log_processing_step(logger, "获取最近1年有过行为的商品")
    active_items = get_active_items(engine)
    logger.info(f"找到 {len(active_items)} 个活跃商品")
    
    # 连接ES
    log_processing_step(logger, "连接Elasticsearch")
    es = connect_es()
    logger.info("ES连接成功")
    
    # 生成两份相似度索引
    date_str = datetime.now().strftime("%Y%m%d")
    
    # 获取name mappings用于debug模式
    name_mappings = {}
    if args.debug:
        log_processing_step(logger, "获取物品名称映射")
        name_mappings = fetch_name_mappings(engine, debug=True)
    
    # 1. 基于名称文本向量
    log_processing_step(logger, "生成基于名称文本向量的相似索引")
    name_result = generate_similarity_index(
        es, active_items, 'embedding_name_zh', 'name', logger, top_n=top_n
    )
    name_output = os.path.join(OUTPUT_DIR, f'i2i_content_name_{date_str}.txt')
    save_index_file(name_result, es, name_output, logger)
    
    # 如果启用debug模式,保存可读格式
    if args.debug and name_result:
        log_processing_step(logger, "保存i2i_content_name可读格式")
        # 转换数据格式为 {item_id: [(sim_id, score), ...]}
        readable_data = {}
        for item_id, similar_items in name_result.items():
            readable_data[f"i2i:content_name:{item_id}"] = [
                (sim_id, score) for sim_id, score, _ in similar_items
            ]
        save_readable_index(
            name_output,
            readable_data,
            name_mappings,
            description='i2i:content_name'
        )
    
    # 2. 基于图片向量
    log_processing_step(logger, "生成基于图片向量的相似索引")
    pic_result = generate_similarity_index(
        es, active_items, 'embedding_pic_h14', 'pic', logger, top_n=top_n
    )
    pic_output = os.path.join(OUTPUT_DIR, f'i2i_content_pic_{date_str}.txt')
    save_index_file(pic_result, es, pic_output, logger)
    
    # 如果启用debug模式,保存可读格式
    if args.debug and pic_result:
        log_processing_step(logger, "保存i2i_content_pic可读格式")
        # 转换数据格式为 {item_id: [(sim_id, score), ...]}
        readable_data = {}
        for item_id, similar_items in pic_result.items():
            readable_data[f"i2i:content_pic:{item_id}"] = [
                (sim_id, score) for sim_id, score, _ in similar_items
            ]
        save_readable_index(
            pic_output,
            readable_data,
            name_mappings,
            description='i2i:content_pic'
        )
    
    logger.info("="*80)
    logger.info("完成!生成了两份内容相似索引:")
    logger.info(f"  1. 名称向量索引: {name_output} ({len(name_result)} 个商品)")
    logger.info(f"  2. 图片向量索引: {pic_output} ({len(pic_result)} 个商品)")
    logger.info("="*80)


if __name__ == '__main__':
    main()