""" i2i - 基于ES向量的内容相似索引 从Elasticsearch获取商品向量,计算两种相似度: 1. 基于名称文本向量的相似度 2. 基于图片向量的相似度 """ import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) import json import pandas as pd from datetime import datetime, timedelta from elasticsearch import Elasticsearch from db_service import create_db_connection from offline_tasks.config.offline_config import DB_CONFIG, OUTPUT_DIR from offline_tasks.scripts.debug_utils import setup_debug_logger, log_processing_step # ES配置 ES_CONFIG = { 'host': 'http://localhost:9200', 'index_name': 'spu', 'username': 'essa', 'password': '4hOaLaf41y2VuI8y' } # 算法参数 TOP_N = 50 # 每个商品返回的相似商品数量 KNN_K = 100 # knn查询返回的候选数 KNN_CANDIDATES = 200 # knn查询的候选池大小 def get_active_items(engine): """ 获取最近1年有过行为的item列表 """ one_year_ago = (datetime.now() - timedelta(days=365)).strftime('%Y-%m-%d') sql_query = f""" SELECT DISTINCT se.item_id FROM sensors_events se WHERE se.event IN ('click', 'contactFactory', 'addToPool', 'addToCart', 'purchase') AND se.create_time >= '{one_year_ago}' AND se.item_id IS NOT NULL """ df = pd.read_sql(sql_query, engine) return df['item_id'].tolist() def connect_es(): """连接到Elasticsearch""" es = Elasticsearch( [ES_CONFIG['host']], basic_auth=(ES_CONFIG['username'], ES_CONFIG['password']), verify_certs=False, request_timeout=30 ) return es def get_item_vectors(es, item_id): """ 从ES获取商品的向量数据 Returns: dict with keys: _id, name_zh, embedding_name_zh, embedding_pic_h14 或 None if not found """ try: response = es.search( index=ES_CONFIG['index_name'], body={ "query": { "term": { "_id": str(item_id) } }, "_source": { "includes": ["_id", "name_zh", "embedding_name_zh", "embedding_pic_h14"] } } ) if response['hits']['hits']: hit = response['hits']['hits'][0] return { '_id': hit['_id'], 'name_zh': hit['_source'].get('name_zh', ''), 'embedding_name_zh': hit['_source'].get('embedding_name_zh'), 'embedding_pic_h14': hit['_source'].get('embedding_pic_h14') } return None except Exception as e: return None def find_similar_by_vector(es, vector, field_name, k=KNN_K, num_candidates=KNN_CANDIDATES): """ 使用knn查询找到相似的items Args: es: Elasticsearch客户端 vector: 查询向量 field_name: 向量字段名 (embedding_name_zh 或 embedding_pic_h14.vector) k: 返回的结果数 num_candidates: 候选池大小 Returns: List of (item_id, score) tuples """ try: response = es.search( index=ES_CONFIG['index_name'], body={ "knn": { "field": field_name, "query_vector": vector, "k": k, "num_candidates": num_candidates }, "_source": ["_id", "name_zh"], "size": k } ) results = [] for hit in response['hits']['hits']: results.append(( hit['_id'], hit['_score'], hit['_source'].get('name_zh', '') )) return results except Exception as e: return [] def generate_similarity_index(es, active_items, vector_field, field_name, logger): """ 生成一种向量的相似度索引 Args: es: Elasticsearch客户端 active_items: 活跃商品ID列表 vector_field: 向量字段名 (embedding_name_zh 或 embedding_pic_h14) field_name: 字段简称 (name 或 pic) logger: 日志记录器 Returns: dict: {item_id: [(similar_id, score, name), ...]} """ result = {} total = len(active_items) for idx, item_id in enumerate(active_items): if (idx + 1) % 100 == 0: logger.info(f"处理进度: {idx + 1}/{total} ({(idx + 1) / total * 100:.1f}%)") # 获取该商品的向量 item_data = get_item_vectors(es, item_id) if not item_data: continue # 提取向量 if vector_field == 'embedding_name_zh': query_vector = item_data.get('embedding_name_zh') elif vector_field == 'embedding_pic_h14': pic_data = item_data.get('embedding_pic_h14') if pic_data and isinstance(pic_data, list) and len(pic_data) > 0: query_vector = pic_data[0].get('vector') if isinstance(pic_data[0], dict) else None else: query_vector = None else: query_vector = None if not query_vector: continue # 使用knn查询相似items(需要排除自己) knn_field = f"{vector_field}.vector" if vector_field == 'embedding_pic_h14' else vector_field similar_items = find_similar_by_vector(es, query_vector, knn_field) # 过滤掉自己,只保留top N filtered_items = [] for sim_id, score, name in similar_items: if sim_id != str(item_id): filtered_items.append((sim_id, score, name)) if len(filtered_items) >= TOP_N: break if filtered_items: result[item_id] = filtered_items return result def save_index_file(result, es, output_file, logger): """ 保存索引文件 格式: item_id \t item_name \t similar_id1:score1,similar_id2:score2,... """ logger.info(f"保存索引到: {output_file}") with open(output_file, 'w', encoding='utf-8') as f: for item_id, similar_items in result.items(): if not similar_items: continue # 获取当前商品的名称 item_data = get_item_vectors(es, item_id) item_name = item_data.get('name_zh', 'Unknown') if item_data else 'Unknown' # 格式化相似商品列表 sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score, _ in similar_items]) f.write(f'{item_id}\t{item_name}\t{sim_str}\n') logger.info(f"索引保存完成,共 {len(result)} 个商品") def main(): """主函数""" # 设置logger logger = setup_debug_logger('i2i_content_similar', debug=True) logger.info("="*80) logger.info("开始生成基于ES向量的内容相似索引") logger.info(f"ES地址: {ES_CONFIG['host']}") logger.info(f"索引名: {ES_CONFIG['index_name']}") logger.info(f"Top N: {TOP_N}") logger.info("="*80) # 创建数据库连接 log_processing_step(logger, "连接数据库") engine = create_db_connection( DB_CONFIG['host'], DB_CONFIG['port'], DB_CONFIG['database'], DB_CONFIG['username'], DB_CONFIG['password'] ) # 获取活跃商品 log_processing_step(logger, "获取最近1年有过行为的商品") active_items = get_active_items(engine) logger.info(f"找到 {len(active_items)} 个活跃商品") # 连接ES log_processing_step(logger, "连接Elasticsearch") es = connect_es() logger.info("ES连接成功") # 生成两份相似度索引 date_str = datetime.now().strftime("%Y%m%d") # 1. 基于名称文本向量 log_processing_step(logger, "生成基于名称文本向量的相似索引") name_result = generate_similarity_index( es, active_items, 'embedding_name_zh', 'name', logger ) name_output = os.path.join(OUTPUT_DIR, f'i2i_content_name_{date_str}.txt') save_index_file(name_result, es, name_output, logger) # 2. 基于图片向量 log_processing_step(logger, "生成基于图片向量的相似索引") pic_result = generate_similarity_index( es, active_items, 'embedding_pic_h14', 'pic', logger ) pic_output = os.path.join(OUTPUT_DIR, f'i2i_content_pic_{date_str}.txt') save_index_file(pic_result, es, pic_output, logger) logger.info("="*80) logger.info("完成!生成了两份内容相似索引:") logger.info(f" 1. 名称向量索引: {name_output} ({len(name_result)} 个商品)") logger.info(f" 2. 图片向量索引: {pic_output} ({len(pic_result)} 个商品)") logger.info("="*80) if __name__ == '__main__': main()