import pandas as pd
import math
import os
from collections import defaultdict
from sqlalchemy import create_engine
from db_service import create_db_connection
import argparse
from datetime import datetime
import json
import logging

def setup_logger():
    """设置logger配置"""
    # 创建logs目录
    logs_dir = 'logs'
    os.makedirs(logs_dir, exist_ok=True)
    
    # 创建logger
    logger = logging.getLogger('tag_category_similar')
    logger.setLevel(logging.INFO)
    
    # 避免重复添加handler
    if logger.handlers:
        return logger
    
    # 创建文件handler
    log_file = os.path.join(logs_dir, f'tag_category_similar_{datetime.now().strftime("%Y%m%d")}.log')
    file_handler = logging.FileHandler(log_file, encoding='utf-8')
    file_handler.setLevel(logging.INFO)
    
    # 创建控制台handler
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    
    # 创建formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)
    
    # 添加handler到logger
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    
    return logger

def clean_text_field(text):
    if pd.isna(text):
        return ''
    # 移除换行符、回车符，并替换其他可能导致CSV格式问题的字符
    return str(text).replace('\r', ' ').replace('\n', ' ').replace('"', '""').strip()

# 解析命令行参数
parser = argparse.ArgumentParser(description='计算基于订单的分类相似度（Tag Similarity）')
parser.add_argument('--lookback_days', type=int, default=180, help='回溯天数，默认180天')
parser.add_argument('--top_n', type=int, default=50, help='每个分类保留的相似分类数量，默认50')
parser.add_argument('--debug', action='store_true', help='开启debug模式')
args = parser.parse_args()

# 初始化logger
logger = setup_logger()

bpms_host = '120.76.244.158'
bpms_port = '3325'
bpms_database = 'bpms'
bpms_username = 'PRD_M1_190311'
bpms_password = 'WTF)xdbqtW!4gwA7'

# 创建数据库连接
engine = create_db_connection(bpms_host, bpms_port, bpms_database, bpms_username, bpms_password)

# SQL 查询
sql_query = f"""
SELECT
  sp.code AS `PO单号`,
  psm.name AS `区域`,
  bb.code AS `客户编码`,
  GROUP_CONCAT(DISTINCT CONCAT(pc_1.id, ':', pc_1.name)) AS `商品信息`,
  MIN(spi.order_time) AS `下单货时间`
FROM sale_po sp
INNER JOIN sale_po_item spi ON sp.id = spi.po_id
LEFT JOIN buy_buyer bb ON bb.id = sp.buyer_id
LEFT JOIN prd_goods pg ON pg.id = spi.spu_id
LEFT JOIN prd_category AS pc_1 ON pc_1.id = SUBSTRING_INDEX(SUBSTRING_INDEX(pg.category_id, '.', 2), '.', -1)
LEFT JOIN pub_sale_market_setting psms ON psms.country_code = bb.countries
LEFT JOIN pub_sale_market psm ON psms.sale_market_id = psm.id
WHERE spi.quantity > 0
  AND spi.is_delete = 0
  AND bb.is_delete = 0
  AND spi.order_time >= DATE_SUB(NOW(), INTERVAL {args.lookback_days} DAY)
GROUP BY sp.code, psm.name, bb.code;
"""

logger.info("="*80)
logger.info("Tag分类相似度计算开始")
logger.info("="*80)
logger.info(f"参数配置: lookback_days={args.lookback_days}, top_n={args.top_n}")

if args.debug:
    logger.debug(f"[DEBUG] 参数配置: lookback_days={args.lookback_days}, top_n={args.top_n}")
    logger.debug(f"[DEBUG] 开始查询数据库...")

# 执行 SQL 查询并将结果加载到 pandas DataFrame
df = pd.read_sql(sql_query, engine)

# 确保ID为整数类型
if 'category_id' in df.columns:
    df['category_id'] = df['category_id'].astype(int)
if 'supplier_id' in df.columns:
    df['supplier_id'] = df['supplier_id'].astype(int)

# 统计信息收集
stats = {
    'start_time': datetime.now(),
    'total_orders': len(df),
    'unique_regions': 0,
    'unique_customers': 0,
    'total_categories': 0,
    'categories_with_similarities': 0,
    'total_similarity_pairs': 0,
    'avg_similarities_per_category': 0,
    'file_stats': {}
}

logger.info(f"数据库查询完成，共 {len(df)} 条订单记录")

if args.debug:
    logger.debug(f"[DEBUG] 查询完成，共 {len(df)} 条订单记录")

# 处理商品信息，分割并去重
# 构建ID到名称的映射
cat_id_to_name = {}
cooccur = defaultdict(lambda: defaultdict(int))
freq = defaultdict(int)

# 统计唯一区域和客户
stats['unique_regions'] = df['区域'].nunique() if '区域' in df.columns else 0
stats['unique_customers'] = df['客户编码'].nunique() if '客户编码' in df.columns else 0

for _, row in df.iterrows():
    # Handle None values in 商品信息
    if pd.isna(row['商品信息']):
        continue
    categories = [cat.strip() for cat in str(row['商品信息']).split(',') if cat.strip()]
    unique_cats = set()
    for cat_pair in categories:
        if ':' in cat_pair:
            cat_id, cat_name = cat_pair.split(':', 1)
            cat_id = cat_id.strip()
            cat_name = cat_name.strip()
            cat_id_to_name[cat_id] = cat_name
            unique_cats.add(cat_id)
    
    for c1 in unique_cats:
        freq[c1] += 1
        for c2 in unique_cats:
            if c1 != c2:
                cooccur[c1][c2] += 1

# 更新统计信息
stats['total_categories'] = len(freq)

# 计算余弦相似度
logger.info("开始计算分类相似度...")

if args.debug:
    logger.debug(f"[DEBUG] 开始计算分类相似度...")

result = {}
total_similarity_pairs = 0
for c1 in cooccur:
    sim_scores = []
    for c2 in cooccur[c1]:
        numerator = cooccur[c1][c2]
        denominator = math.sqrt(freq[c1]) * math.sqrt(freq[c2])
        if denominator != 0:
            score = numerator / denominator
            sim_scores.append((c2, score))
    sim_scores.sort(key=lambda x: -x[1])  # 按分数排序
    # 只保留top_n个相似分类
    result[c1] = sim_scores[:args.top_n]
    total_similarity_pairs += len(sim_scores[:args.top_n])

# 更新统计信息
stats['categories_with_similarities'] = len(result)
stats['total_similarity_pairs'] = total_similarity_pairs
stats['avg_similarities_per_category'] = total_similarity_pairs / len(result) if result else 0

logger.info(f"相似度计算完成，共 {len(result)} 个分类有相似推荐")

if args.debug:
    logger.debug(f"[DEBUG] 相似度计算完成，共 {len(result)} 个分类有相似推荐")
    unique_cats = set()
    for cats in result.values():
        for cat, _ in cats:
            unique_cats.add(cat)
    logger.debug(f"[DEBUG] 唯一分类数: {len(unique_cats)}")

# 准备输出
date_str = datetime.now().strftime('%Y%m%d')
output_dir = 'output'
debug_dir = os.path.join(output_dir, 'debug')
os.makedirs(output_dir, exist_ok=True)
os.makedirs(debug_dir, exist_ok=True)
output_file = os.path.join(output_dir, f'tag_category_similar_{date_str}.txt')
debug_file = os.path.join(debug_dir, f'tag_category_similar_{date_str}_readable.txt')

# 输出相似分类到文件（ID格式）
logger.info(f"开始写入输出文件: {output_file}")

if args.debug:
    logger.debug(f"[DEBUG] 开始写入文件: {output_file}")

with open(output_file, 'w', encoding='utf-8') as f:
    for cat_id, sims in sorted(result.items()):
        # 格式: category_id \t similar_id1:score1,similar_id2:score2,...
        sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims])
        f.write(f'{cat_id}\t{sim_str}\n')

logger.info(f"输出文件写入完成: {output_file}")

# 获取文件统计信息
if os.path.exists(output_file):
    file_size = os.path.getsize(output_file)
    stats['file_stats']['output_file'] = {
        'path': output_file,
        'size_bytes': file_size,
        'size_mb': round(file_size / (1024 * 1024), 2)
    }

# 输出可读版本到debug目录（ID+名称格式）
logger.info(f"开始写入可读文件: {debug_file}")

if args.debug:
    logger.debug(f"[DEBUG] 开始写入可读文件: {debug_file}")

with open(debug_file, 'w', encoding='utf-8') as f:
    # 写入文件头信息
    f.write('='*80 + '\n')
    f.write('明文索引文件\n')
    f.write(f'生成时间: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}\n')
    f.write(f'描述: tag_category_similar (分类相似度)\n')
    f.write(f'总索引数: {len(result)}\n')
    f.write('='*80 + '\n\n')
    
    for cat_id, sims in sorted(result.items()):
        cat_name = cat_id_to_name.get(cat_id, 'Unknown')
        cat_clean = clean_text_field(cat_name)
        # 格式: category_id:category_name \t similar_id1:similar_name1:score1,...
        sim_parts = []
        for sim_id, score in sims:
            sim_name = cat_id_to_name.get(sim_id, 'Unknown')
            sim_clean = clean_text_field(sim_name)
            sim_parts.append(f'{sim_id}:{sim_clean}:{score:.4f}')
        sim_str = ','.join(sim_parts)
        f.write(f'{cat_id}:{cat_clean}\t{sim_str}\n')

logger.info(f"可读文件写入完成: {debug_file}")

# 获取debug文件统计信息
if os.path.exists(debug_file):
    file_size = os.path.getsize(debug_file)
    stats['file_stats']['debug_file'] = {
        'path': debug_file,
        'size_bytes': file_size,
        'size_mb': round(file_size / (1024 * 1024), 2)
    }

# 计算处理时间
stats['end_time'] = datetime.now()
stats['processing_time_seconds'] = (stats['end_time'] - stats['start_time']).total_seconds()

# 输出统计信息
logger.info("="*80)
logger.info("Tag分类相似度计算 - 关键统计信息")
logger.info("="*80)
logger.info(f"📊 数据概览:")
logger.info(f"  - 总订单数: {stats['total_orders']:,}")
logger.info(f"  - 唯一区域数: {stats['unique_regions']:,}")
logger.info(f"  - 唯一客户数: {stats['unique_customers']:,}")
logger.info(f"  - 总分类数: {stats['total_categories']:,}")
logger.info(f"  - 有相似度的分类数: {stats['categories_with_similarities']:,}")
logger.info(f"  - 总相似度对数量: {stats['total_similarity_pairs']:,}")
logger.info(f"  - 平均每分类相似数: {stats['avg_similarities_per_category']:.1f}")

logger.info(f"📁 输出文件:")
for file_type, file_info in stats['file_stats'].items():
    logger.info(f"  - {file_type}: {file_info['path']}")
    logger.info(f"    大小: {file_info['size_mb']} MB ({file_info['size_bytes']:,} bytes)")

logger.info(f"⏱️ 处理时间:")
logger.info(f"  - 开始时间: {stats['start_time'].strftime('%Y-%m-%d %H:%M:%S')}")
logger.info(f"  - 结束时间: {stats['end_time'].strftime('%Y-%m-%d %H:%M:%S')}")
logger.info(f"  - 总耗时: {stats['processing_time_seconds']:.2f} 秒")

logger.info(f"✅ Tag相似度计算完成")
logger.info(f"  - 输出文件: {output_file}")
logger.info(f"  - 可读文件: {debug_file}")
logger.info("="*80)
