#!/usr/bin/env python
"""
Task 4. Distributed PageRank using MPI collective communication.

Reads from:  data/task1_results.pickle, data/task3_edges.pickle
Writes to:   data/task4_pagerank.pickle

Run after process.py (Task 1) and graph.py (Task 3):
  mpirun --oversubscribe -np 4 python pagerank_mpi.py
  python pagerank_mpi.py
"""

from mpi4py import MPI
import sys, pickle, os
import numpy as np
import itertools
from pathlib import Path

comm = MPI.COMM_WORLD
size = MPI.COMM_WORLD.Get_size()
rank = MPI.COMM_WORLD.Get_rank()
name = MPI.Get_processor_name()

BASE_DIR = Path(__file__).resolve().parent
DATA_DIR = BASE_DIR / "data"


# ========== rank 0 loads data ==========

if rank == 0:
    print("Task 4: Distributed PageRank using MPI")
    print("=" * 60)

    task1_path = DATA_DIR / "task1_results.pickle"
    task3_path = DATA_DIR / "task3_edges.pickle"

    if not task1_path.exists():
        print("ERROR: %s not found. Run process.py first." % task1_path)
        sys.exit(1)
    if not task3_path.exists():
        print("ERROR: %s not found. Run graph.py first." % task3_path)
        sys.exit(1)

    with open(task1_path, "rb") as f:
        task1_data = pickle.load(f)
    with open(task3_path, "rb") as f:
        task3_data = pickle.load(f)

    objects = task1_data["objects"]
    wiki_scores = task1_data["wiki_scores"]
    edges = task3_data["edges"]
    graph_f = task3_data["graph_forward"]
    graph_b = task3_data["graph_backward"]
    N_real = task3_data["n_objects"]

    # Pad N to be divisible by size
    N = ((N_real + size - 1) // size) * size

    PR = np.ones(N) / N

    ranges_lower = np.array(range(0, N, N // size))
    ranges_higher = ranges_lower + (N // size)

    print(
        "[rank 0] Objects: %d, Edges: %d, N (padded): %d"
        % (len(objects), len(edges), N)
    )
else:
    N = None
    N_real = None
    ranges_lower = None
    ranges_higher = None
    PR = None
    graph_f = None
    graph_b = None
    wiki_scores = None
    objects = None


# ========== bcast shared data ==========

N = comm.bcast(N, root=0)
N_real = comm.bcast(N_real, root=0)
PR = comm.bcast(PR, root=0)
graph_f = comm.bcast(graph_f, root=0)
graph_b = comm.bcast(graph_b, root=0)

d = 0.85
const1 = (1 - d) / N

if rank == 0:
    print("start articles:", ranges_lower)
    print("end articles:  ", ranges_higher, "not inclusive")
    print("**********")

comm.Barrier()


# ========== scatter ranges ==========

ranges_lower = comm.scatter(ranges_lower, root=0)
ranges_higher = comm.scatter(ranges_higher, root=0)

print(
    "%d/%d (%s) node processing [%d-%d) %d articles"
    % (rank, size, name, ranges_lower, ranges_higher, ranges_higher - ranges_lower)
)
comm.Barrier()


# ========== PageRank iterations ==========

max_iter = 60
tol = 1e-6

for num_iter in range(max_iter):

    PR = comm.allgather(PR[ranges_lower:ranges_higher])
    PR = np.array(list(itertools.chain(*PR)), dtype=np.float64)

    if num_iter > 0:
        local_diff = np.array(
            [
                np.max(
                    np.abs(
                        PR[ranges_lower:ranges_higher]
                        - PR_old[ranges_lower:ranges_higher]
                    )
                )
            ],
            dtype=np.float64,
        )
        max_diff = np.zeros(1, dtype=np.float64)
        comm.Allreduce(local_diff, max_diff, op=MPI.MAX)

        if rank == 0 and num_iter % 10 == 0:
            print("@ iter %d max_diff = %e" % (num_iter, max_diff[0]))

        if max_diff[0] < tol:
            if rank == 0:
                print(
                    "Converged at iter %d, max_diff = %e" % (num_iter, max_diff[0])
                )
            break

    PR_old = PR.copy()

    for i in range(ranges_lower, ranges_higher):
        s = 0
        if i in graph_b:
            for j in graph_b[i]:
                if j in graph_f and len(graph_f[j]) > 0:
                    s = s + (PR_old[j] / len(graph_f[j]))
        else:
            s = s + (1 / N)
        PR[i] = const1 + d * s

    dangling_sum_local = np.array(
        [
            sum(
                PR_old[i]
                for i in range(ranges_lower, ranges_higher)
                if i not in graph_f
            )
        ],
        dtype=np.float64,
    )
    dangling_sum_global = np.zeros(1, dtype=np.float64)
    comm.Allreduce(dangling_sum_local, dangling_sum_global, op=MPI.SUM)
    dangling_add = d * dangling_sum_global[0] / N
    for i in range(ranges_lower, ranges_higher):
        PR[i] += dangling_add

    if rank == 0 and num_iter % 10 == 0:
        print("Iter %d finished" % num_iter)


# ========== gather final PR ==========

PR_final = comm.gather(PR[ranges_lower:ranges_higher], root=0)

if rank == 0:
    PR_final = np.array(list(itertools.chain(*PR_final)), dtype=np.float64)

    pr_array = PR_final[:N_real].astype(np.float64)
    s = pr_array.sum()
    if s > 0:
        pr_array = pr_array / s

    pageranks = {i: float(pr_array[i]) for i in range(N_real)}

    if N_real >= 2:
        w = np.array(
            [wiki_scores.get(i, 0.0) for i in range(N_real)], dtype=np.float64
        )
        p = pr_array[:N_real]
        if np.std(w) > 1e-12 and np.std(p) > 1e-12:
            corr = float(np.corrcoef(w, p)[0, 1])
        else:
            corr = float("nan")
        wr = w.argsort().argsort()
        pr_ranks = p.argsort().argsort()
        mad = float(np.mean(np.abs(wr - pr_ranks)))
    else:
        corr = float("nan")
        mad = float("nan")

    print("=" * 60)
    print("Pearson correlation(wiki_relevance, PageRank) = %.4f" % corr)
    print("Mean |rank order delta| = %.4f" % mad)

    print("\nTop 10 by PageRank:")
    top_pr = sorted(pageranks.items(), key=lambda x: -x[1])[:10]
    for iid, pr_val in top_pr:
        obj = objects[iid]
        wr = wiki_scores.get(iid, 0.0)
        print(
            "  %s (%s) - PR: %.6f, WR: %.4f"
            % (
                obj.get("name") or obj.get("wd_label_en") or "N/A",
                obj.get("city_name"),
                pr_val,
                wr,
            )
        )

    print("\nTop 10 by Wiki Relevance:")
    top_wr = sorted(wiki_scores.items(), key=lambda x: -x[1])[:10]
    for iid, wr_val in top_wr:
        obj = objects[iid]
        pr_val = pageranks.get(iid, 0.0)
        print(
            "  %s (%s) - WR: %.4f, PR: %.6f"
            % (
                obj.get("name") or obj.get("wd_label_en") or "N/A",
                obj.get("city_name"),
                wr_val,
                pr_val,
            )
        )

    DATA_DIR.mkdir(exist_ok=True)
    with open(DATA_DIR / "task4_pagerank.pickle", "wb") as f:
        pickle.dump(
            {"pageranks": pageranks, "correlation": corr, "mad": mad}, f
        )
    print("\nSaved: %s" % (DATA_DIR / "task4_pagerank.pickle"))

comm.Barrier()
if rank == 0:
    print("Task 4 complete.")
