#!/usr/bin/env python
"""
Task 1. Process places of interest and compute wiki relevance scores.

Reads from: ../databases/bigdata_ld2_database.sqlite3
Writes to:  data/task1_results.pickle

Run:
  mpirun --oversubscribe -np 4 python process.py
  python process.py
"""

from mpi4py import MPI
import sys, json, os, math, pickle
import numpy as np
import sqlite3
from pathlib import Path

comm = MPI.COMM_WORLD
size = MPI.COMM_WORLD.Get_size()
rank = MPI.COMM_WORLD.Get_rank()
name = MPI.Get_processor_name()

BASE_DIR = Path(__file__).resolve().parent
DATA_DIR = BASE_DIR / "data"
LD2_DB = BASE_DIR.parent / "databases" / "bigdata_ld2_database.sqlite3"


def load_place_objects_from_ld2(db_path):
    conn = sqlite3.connect(db_path)
    conn.row_factory = sqlite3.Row
    cur = conn.execute(
        """
        SELECT ff.id AS place_rowid, ff.city_name, ff.osm_type, ff.osm_id,
               ff.lat, ff.lon, ff.name, ff.wikidata_qid, ff.historic, ff.shop,
               ff.tourism, ff.wd_label_en, ff.wd_description_en, ff.wd_instance_of_qid
        FROM final_features ff
        ORDER BY ff.city_name, ff.osm_type, ff.osm_id;
        """
    )
    rows = [dict(r) for r in cur.fetchall()]

    wiki_bodies = {}
    seen_pairs = set()
    for r in rows:
        if r.get("wikidata_qid") and r.get("city_name"):
            seen_pairs.add((r["city_name"], r["wikidata_qid"]))
    for city_name, qid in seen_pairs:
        cur2 = conn.execute(
            """
            SELECT body FROM raw_response_data
            WHERE layer = 'wikidata_entity' AND ref_key = ? AND city_name = ?
            LIMIT 1;
            """,
            (qid, city_name),
        )
        one = cur2.fetchone()
        if one:
            try:
                wiki_bodies[(city_name, qid)] = json.loads(one[0])
            except json.JSONDecodeError:
                wiki_bodies[(city_name, qid)] = {}
    conn.close()

    out = []
    for i, r in enumerate(rows):
        qid = r.get("wikidata_qid")
        cn = r.get("city_name")
        ent = wiki_bodies.get((cn, qid)) if qid and cn else None
        out.append(
            {
                "internal_id": i,
                "place_rowid": r["place_rowid"],
                "city_name": r["city_name"],
                "osm_type": r["osm_type"],
                "osm_id": r["osm_id"],
                "lat": r["lat"],
                "lon": r["lon"],
                "name": r["name"],
                "wikidata_qid": qid,
                "historic": r["historic"],
                "shop": r["shop"],
                "tourism": r["tourism"],
                "wd_label_en": r["wd_label_en"],
                "wd_description_en": r["wd_description_en"],
                "wd_instance_of_qid": r["wd_instance_of_qid"],
                "wikidata_entity": ent,
            }
        )
    return out


def wiki_relevance_score(obj):
    ent = obj.get("wikidata_entity")
    if isinstance(ent, dict) and ent:
        claims = ent.get("claims") or {}
        n_claims = len(claims)
        sitelinks = ent.get("sitelinks") or {}
        n_sl = len(sitelinks) if isinstance(sitelinks, dict) else 0
        labels = ent.get("labels") or {}
        en_lab = ""
        if isinstance(labels.get("en"), dict):
            en_lab = labels["en"].get("value") or ""
        descs = ent.get("descriptions") or {}
        en_desc = ""
        if isinstance(descs.get("en"), dict):
            en_desc = descs["en"].get("value") or ""
        base = (
            math.log1p(n_claims) * math.log1p(n_sl)
            + 0.02 * len(en_desc)
            + 0.01 * len(en_lab)
        )
        return float(base)

    label = obj.get("wd_label_en") or ""
    desc = obj.get("wd_description_en") or ""
    bonus = 0.5 if obj.get("wikidata_qid") else 0.0
    return bonus + 0.01 * len(str(label)) + 0.005 * len(str(desc))


def list_split(a, n):
    k, m = divmod(len(a), n)
    return [a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)]


# ========== rank 0 loads data ==========

if rank == 0:
    print("Task 1: Processing places of interest using MPI collective communication")
    print("=" * 60)

    if LD2_DB.is_file():
        objects = load_place_objects_from_ld2(LD2_DB)
        print("[rank 0] Loaded %d objects from %s" % (len(objects), LD2_DB))
    else:
        print("ERROR: LD2 database not found at %s" % LD2_DB)
        sys.exit(1)

    lim_raw = os.environ.get("LD3_MAX_OBJECTS", "6000")
    if lim_raw != "0" and lim_raw.lower() != "all":
        lim = int(lim_raw)
        if len(objects) > lim:
            objects = objects[:lim]
            for i, o in enumerate(objects):
                o["internal_id"] = i
            print(
                "[rank 0] Using first %d objects (set LD3_MAX_OBJECTS=0 for all)" % lim
            )

    the_chunks = list_split(objects, size)
else:
    the_chunks = None
    objects = None


# ========== scatter objects to all ranks ==========

local_objs = comm.scatter(the_chunks, root=0)

print(
    "%03d/%03d (%s) received %d objects via scatter"
    % (rank, size, name, len(local_objs))
)


# ========== each rank computes wiki relevance scores ==========

local_pairs = [(o["internal_id"], wiki_relevance_score(o)) for o in local_objs]


# ========== gather score pairs back to rank 0 ==========

gathered = comm.gather(local_pairs, root=0)


# ========== allgather chunk sizes ==========

chunk_sizes = comm.allgather(len(local_objs))
print("%03d/%03d Allgather chunk sizes: %s" % (rank, size, chunk_sizes))


# ========== Allreduce total score sum ==========

local_score_sum = np.array(
    [sum(sc for _, sc in local_pairs)], dtype=np.float64
)
total_score_sum = np.zeros(1, dtype=np.float64)
comm.Allreduce(local_score_sum, total_score_sum, op=MPI.SUM)

local_count = np.array([len(local_pairs)], dtype=np.int64)
total_count = np.zeros(1, dtype=np.int64)
comm.Allreduce(local_count, total_count, op=MPI.SUM)

print(
    "%03d/%03d Allreduce totals: %d objects, score sum %.2f"
    % (rank, size, int(total_count[0]), total_score_sum[0])
)


# ========== Allreduce max score ==========

local_max = np.array(
    [max((sc for _, sc in local_pairs), default=0.0)], dtype=np.float64
)
global_max = np.zeros(1, dtype=np.float64)
comm.Allreduce(local_max, global_max, op=MPI.MAX)


# ========== reduce total score to rank 0 only ==========

local_scores_list = [sc for _, sc in local_pairs]
all_scores = comm.reduce(local_scores_list, MPI.SUM, root=0)


# ========== assemble results on rank 0 ==========

wiki_scores = {}
if rank == 0:
    for part in gathered:
        for internal_id, sc in part:
            wiki_scores[internal_id] = sc

    for o in objects:
        o["relevance_score"] = wiki_scores.get(o["internal_id"], 0.0)

    print("=" * 60)
    print("[rank 0] Total objects with relevance: %d" % len(wiki_scores))
    print(
        "[rank 0] Average relevance: %.4f"
        % (total_score_sum[0] / max(total_count[0], 1))
    )
    print("[rank 0] Max relevance (Allreduce MAX): %.4f" % global_max[0])

    top10 = sorted(wiki_scores.items(), key=lambda x: -x[1])[:10]
    print("[rank 0] Top 10 by wiki relevance:")
    for iid, sc in top10:
        obj = objects[iid]
        print(
            "  %s (%s) - score: %.4f"
            % (
                obj.get("name") or obj.get("wd_label_en") or "N/A",
                obj.get("city_name"),
                sc,
            )
        )

    DATA_DIR.mkdir(exist_ok=True)
    with open(DATA_DIR / "task1_results.pickle", "wb") as f:
        pickle.dump({"objects": objects, "wiki_scores": wiki_scores}, f)
    print("[rank 0] Saved: %s" % (DATA_DIR / "task1_results.pickle"))

comm.Barrier()
if rank == 0:
    print("Task 1 complete.")
