#!/usr/bin/env python
"""
Task 3. Create connection graph between objects.

Uses internal urls (wikidata claim references), classifiers (instance-of, OSM tags),
and lexical similarity (Jaccard on text tokens).

Reads from:  data/task1_results.pickle
Writes to:   data/task3_edges.pickle

Run after process.py:
  python graph.py
"""

import json, pickle, re, os
from collections import defaultdict
from pathlib import Path

BASE_DIR = Path(__file__).resolve().parent
DATA_DIR = BASE_DIR / "data"

print("Task 3: Building connection graph between objects")
print("=" * 60)


# ========== Load task 1 results ==========

task1_path = DATA_DIR / "task1_results.pickle"
if not task1_path.exists():
    print("ERROR: %s not found. Run process.py first." % task1_path)
    exit(1)

with open(task1_path, "rb") as f:
    task1_data = pickle.load(f)

objects = task1_data["objects"]
print("Loaded %d objects" % len(objects))


# ========== tokenize text ==========


def _tokenize(text):
    return set(re.findall(r"[a-z]{4,}", (text or "").lower()))


# ========== extract Q-ids from wikidata claims ==========


def _qids_in_claims(ent):
    if not ent or not isinstance(ent, dict):
        return set()
    out = set()
    for _prop, stmts in (ent.get("claims") or {}).items():
        for st in stmts or []:
            if not isinstance(st, dict):
                continue
            snak = st.get("mainsnak") or {}
            dv = snak.get("datavalue") or {}
            val = dv.get("value")
            if isinstance(val, dict):
                qid = val.get("id")
                if isinstance(qid, str) and qid.startswith("Q"):
                    out.add(qid)
    return out


# ========== Build graph ==========

n = len(objects)

qid_to_ids = defaultdict(list)
for o in objects:
    q = o.get("wikidata_qid")
    if q:
        qid_to_ids[q].append(o["internal_id"])

edges = []
seen = set()


def add_edge(a, b, et, w):
    if a == b:
        return
    key = (a, b, et)
    if key in seen:
        return
    seen.add(key)
    edges.append((a, b, et, w))


# 1) Wikidata structural: A -> B if B's Q-id appears in A's claim object values
print("Building wikidata claim reference edges...")
for o in objects:
    i = o["internal_id"]
    ent = o.get("wikidata_entity")
    for q in _qids_in_claims(ent):
        for j in qid_to_ids.get(q, []):
            if j != i:
                add_edge(i, j, "wikidata_claim_ref", 1.0)
print("  After wikidata_claim_ref: %d edges" % len(edges))


# 2) Same instance-of classifier
print("Building shared instance-of edges...")
inst_groups = defaultdict(list)
for o in objects:
    inst = o.get("wd_instance_of_qid")
    if inst:
        inst_groups[inst].append(o["internal_id"])
for _inst, ids in inst_groups.items():
    ids = sorted(ids)
    for a in ids:
        for b in ids:
            if a < b:
                add_edge(a, b, "shared_instance_of", 0.25)
print("  After shared_instance_of: %d edges" % len(edges))


# 3) Lexical k-NN on English label + description
print("Building lexical Jaccard k-NN edges...")
lexical_k = 5
texts = {}
for o in objects:
    i = o["internal_id"]
    parts = []
    ent = o.get("wikidata_entity")
    if isinstance(ent, dict):
        lab = (ent.get("labels") or {}).get("en")
        if isinstance(lab, dict):
            parts.append(lab.get("value") or "")
        desc = (ent.get("descriptions") or {}).get("en")
        if isinstance(desc, dict):
            parts.append(desc.get("value") or "")
    parts.append(o.get("wd_label_en") or "")
    parts.append(o.get("wd_description_en") or "")
    parts.append(o.get("name") or "")
    texts[i] = _tokenize(" ".join(parts))

for o in objects:
    i = o["internal_id"]
    ti = texts[i]
    if not ti:
        continue
    scored = []
    for o2 in objects:
        j = o2["internal_id"]
        if j == i:
            continue
        tj = texts[j]
        if not tj:
            continue
        inter = len(ti & tj)
        union = len(ti | tj) or 1
        jaccard = inter / union
        if jaccard > 0:
            scored.append((jaccard, j))
    scored.sort(key=lambda t: (-t[0], t[1]))
    for jaccard, j in scored[:lexical_k]:
        add_edge(i, j, "lexical_jaccard", float(jaccard))
print("  After lexical_jaccard: %d edges" % len(edges))


# 4) OSM classifiers: chain within same historic=* / shop=* bucket
print("Building OSM classifier chain edges...")
hist_groups = defaultdict(list)
shop_groups = defaultdict(list)
for o in objects:
    i = o["internal_id"]
    if o.get("historic"):
        hist_groups[str(o["historic"])].append(i)
    if o.get("shop"):
        shop_groups[str(o["shop"])].append(i)
for grp in (hist_groups, shop_groups):
    for _k, ids in grp.items():
        ids = sorted(ids)
        if len(ids) < 2:
            continue
        for k in range(len(ids) - 1):
            add_edge(ids[k], ids[k + 1], "shared_osm_classifier", 0.15)
print("  After shared_osm_classifier: %d edges" % len(edges))


# ========== Cap total out-degree per node ==========

max_edges_per_node = 32
by_src = defaultdict(list)
for a, b, et, w in edges:
    by_src[a].append((b, et, w))
trimmed = []
for a, outs in by_src.items():
    outs.sort(key=lambda t: (-t[2], t[1], t[0]))
    for b, et, w in outs[:max_edges_per_node]:
        trimmed.append((a, b, et, w))
edges = trimmed


# ========== Build forward/backward graph ==========

graph_f = {}
graph_b = {}
for src, dst, et, w in edges:
    if src not in graph_f:
        graph_f[src] = []
    if dst not in graph_b:
        graph_b[dst] = []
    graph_f[src].append(dst)
    graph_b[dst].append(src)


# ========== Save results ==========

DATA_DIR.mkdir(exist_ok=True)
with open(DATA_DIR / "task3_edges.pickle", "wb") as f:
    pickle.dump(
        {
            "edges": edges,
            "graph_forward": graph_f,
            "graph_backward": graph_b,
            "n_objects": n,
        },
        f,
    )

print("=" * 60)
print("Total edges: %d" % len(edges))
print("Nodes with outgoing edges: %d" % len(graph_f))
print("Nodes with incoming edges: %d" % len(graph_b))

edge_types = defaultdict(int)
for _, _, et, _ in edges:
    edge_types[et] += 1
for et, cnt in sorted(edge_types.items()):
    print("  %s: %d" % (et, cnt))

print("Saved: %s" % (DATA_DIR / "task3_edges.pickle"))
print("Task 3 complete.")
