#include <iostream>
#include <fstream>
#include <utility>
#include <string>
#include <map>
#include <set>
#include <vector>
#include <algorithm>
#include <functional>
#include <string.h>
#include <time.h>
#include <unordered_map>
#include <iterator>
#include <queue>
#include <numeric>
#include "utils.h"
#include "BitMap.h"

int max_sim_list_len = 300;

using namespace std;

typedef unsigned long long item_id_t; // 定义64位无符号整型作为item ID

// 比较函数，用于排序时按item_id_t来比较
bool compare_i2ulist_map_iters2(const unordered_map<item_id_t, vector<int>>::const_iterator &a, 
                               const unordered_map<item_id_t, vector<int>>::const_iterator &b) {
    return a->first < b->first;
}

// 比较函数，用于sim_list排序
bool compare_pairs2(const pair<item_id_t, float> &a, const pair<item_id_t, float> &b) {
    return a.second > b.second;
}

int main(int argc, char *argv[]) {

    float alpha = 0.5;
    float threshold = 0.5;
    int show_progress = 0;

    if (argc < 4) {
        cerr << "usage " << argv[0] << " alpha threshold show_progress(0/1)" << endl;
        return -1;
    }

    alpha = atof(argv[1]);
    threshold = atof(argv[2]);
    show_progress = atoi(argv[3]);

    cerr << currentTimetoStr() << " start... " << endl;
    cerr << " alpha " << alpha << endl;
    cerr << " threshold " << threshold << endl;

    unordered_map<item_id_t, vector<int>> i2u_map;
    i2u_map.reserve(160000);

    string line_buff;
    const string delimiters(",");

    vector<string> field_segs;
    vector<vector<item_id_t>> groups;  // Changed to store item_id_t
    groups.reserve(2000000);
    vector<item_id_t> item_list;

    vector<int> items_intersection_buffer;
    vector<int> users_intersection_buffer;
    users_intersection_buffer.reserve(2000);

    pair<item_id_t, vector<int>> pair_entry;
    pair<unordered_map<item_id_t, vector<int>>::iterator, bool> ins_i2u_ret;

    while (getline(cin, line_buff)) {
        // 格式是一个json，所以要把开头和结尾的括号去掉
        line_buff.erase(0, line_buff.find_first_not_of("{"));
        line_buff.erase(line_buff.find_last_not_of("}") + 1);
        field_segs.clear();
        split(field_segs, line_buff, delimiters);

        item_list.clear();
        for (size_t i = 0; i < field_segs.size(); i++) {
            const char *seg_pos = strchr(field_segs[i].c_str(), ':');
            if (seg_pos == NULL || (seg_pos - field_segs[i].c_str() >= field_segs[i].length())) break;

            float value = atof(seg_pos + 1);
            if (value > threshold) {
                // 开头有一个双引号
                item_id_t item_id = strtoull(field_segs[i].c_str() + 1, NULL, 10);
                item_list.push_back(item_id);
            }
        }

        if (item_list.size() < 2) continue;
        // 排序
        sort(item_list.begin(), item_list.end());

        // append本次的itemlist
        int idx = groups.size();
        groups.push_back(item_list);  // item_list is now of type item_id_t
        // 合入i2u索引
        for (vector<item_id_t>::const_iterator iter = item_list.begin(); iter != item_list.end(); ++iter) {
            pair_entry.first = *iter;
            ins_i2u_ret = i2u_map.insert(pair_entry);
            ins_i2u_ret.first->second.push_back(idx);
        }
    }

    int items_num = i2u_map.size();
    int users_num = groups.size();
    cerr << currentTimetoStr() << " items num: " << i2u_map.size() << endl;
    cerr << currentTimetoStr() << " users num: " << groups.size() << endl;
    cerr << currentTimetoStr() << " sort.." << endl;

    vector<unordered_map<item_id_t, vector<int>>::const_iterator> sorted_i_ulist_pairs;

    for (unordered_map<item_id_t, vector<int>>::iterator iter = i2u_map.begin(); iter != i2u_map.end(); ++iter) {
        sorted_i_ulist_pairs.push_back(iter);
        sort(iter->second.begin(), iter->second.end());
    }
    cerr << currentTimetoStr() << " sort finished" << endl;

    sort(sorted_i_ulist_pairs.begin(), sorted_i_ulist_pairs.end(), compare_i2ulist_map_iters2);

    if (items_num < 2) return -1;

    vector<pair<item_id_t, float>> sim_list_buff;
    unordered_map<item_id_t, vector<pair<item_id_t, float>>> sim_matrix;
    sim_matrix.reserve(items_num);

    int idx = 0;

    BitMap user_bm(users_num);
    bool use_bitmap;
    vector<int> sim_list_len_statis;
    sim_list_len_statis.resize(max_sim_list_len + 1);

    for (int i = 1; i < sorted_i_ulist_pairs.size(); ++i) {
        unordered_map<item_id_t, vector<int>>::const_iterator pair_i = sorted_i_ulist_pairs[i];
        if (show_progress) {
            fprintf(stderr, "\r%d of %d", idx++, items_num);
        }
        sim_list_buff.clear();

        use_bitmap = pair_i->second.size() > 50;

        if (use_bitmap) {
            for (vector<int>::const_iterator iter_pair_i = pair_i->second.begin(); iter_pair_i != pair_i->second.end(); ++iter_pair_i) {
                user_bm.Set(*iter_pair_i);
            }
        }

        for (int j = 0; j < i; ++j) {
            unordered_map<item_id_t, vector<int>>::const_iterator pair_j = sorted_i_ulist_pairs[j];
            users_intersection_buffer.clear();

            if (use_bitmap) {
                for (vector<int>::const_iterator iter_pair_j = pair_j->second.begin(); iter_pair_j != pair_j->second.end(); ++iter_pair_j) {
                    if (user_bm.Existed(*iter_pair_j)) {
                        users_intersection_buffer.push_back(*iter_pair_j);
                    }
                }
            } else {
                set_intersection(pair_i->second.begin(), pair_i->second.end(), pair_j->second.begin(), pair_j->second.end(), back_inserter(users_intersection_buffer));
            }

            if (users_intersection_buffer.size() < 2) continue;

            float sim_of_item_i_j = 0.0;
            for (vector<int>::const_iterator user_i = users_intersection_buffer.begin() + 1;
                 user_i != users_intersection_buffer.end();
                 ++user_i) {

                const vector<item_id_t> &item_list_of_user_i = groups[*user_i];

                for (vector<int>::const_iterator user_j = users_intersection_buffer.begin();
                     user_j != user_i;
                     ++user_j) {

                    const vector<item_id_t> &item_list_of_user_j = groups[*user_j];
                    items_intersection_buffer.clear();
                    set_intersection(item_list_of_user_i.begin(), item_list_of_user_i.end(), item_list_of_user_j.begin(), item_list_of_user_j.end(), back_inserter(items_intersection_buffer));

                    sim_of_item_i_j += 1.0 / (alpha + items_intersection_buffer.size());
                }
            }
            sim_list_buff.push_back(make_pair(pair_j->first, sim_of_item_i_j));
        }

        sim_matrix[pair_i->first] = sim_list_buff;
        for (auto &p : sim_list_buff) {
            sim_matrix[p.first].push_back(make_pair(pair_i->first, p.second));
        }
        if (use_bitmap) {
            for (vector<int>::const_iterator iter_pair_i = pair_i->second.begin(); iter_pair_i != pair_i->second.end(); ++iter_pair_i) {
                user_bm.ResetRoughly(*iter_pair_i);
            }
        }
    }

    for (auto &p : sim_matrix) {
        vector<pair<item_id_t, float>> &sim_list = p.second;
        int sim_list_len = p.second.size();
        if (sim_list_len > 0) {
            sort(sim_list.begin(), sim_list.end(), compare_pairs2);

            cout << p.first << "\t" << sim_list[0].first << ":" << sim_list[0].second;

            if (sim_list_len > max_sim_list_len) {
                sim_list_len = max_sim_list_len;
            }

            sim_list_len_statis[sim_list_len] += 1;

            for (int i = 1; i < sim_list_len; i++) {
                cout << ',' << sim_list[i].first << ':' << sim_list[i].second;
            }
            cout << endl;
        }
    }

    int sum_groups = accumulate(sim_list_len_statis.begin(), sim_list_len_statis.end(), 0);
    cerr << currentTimetoStr() << " write sim matrix finished" << endl;
    cerr << currentTimetoStr() << " print stats info of sim matrix... " << sim_list_len_statis.size() << endl;
    cerr << currentTimetoStr() << " total keys: " << sum_groups << endl;

    int accumulate = 0;
    for (int i = sim_list_len_statis.size() - 1; i >= 0; i--) {
        accumulate += sim_list_len_statis[i];
        fprintf(stderr, "simlist_len %4d, num %4d, accumulate %6d accumulated_rate %5.2f%%\n",
                i, sim_list_len_statis[i], accumulate, 100.0 * accumulate / sum_groups);
    }

    return 0;
}
