#!/usr/bin/env python
"""
Task 2. Store computed data in SQLite (schema matches ld3_object / ld3_edge).

Reads from:  data/task1_results.pickle
             data/task3_edges.pickle      (optional)
             data/task4_pagerank.pickle   (optional)

Writes to:   ../databases/ld3_recomputed.sqlite3

After running this, reload Django data with:
  LD3_DB_PATH=databases/ld3_recomputed.sqlite3 python manage.py load_data --clear

Run after process.py, graph.py, and pagerank_mpi.py:
  python database.py
"""

import sqlite3, pickle, os, json
from pathlib import Path

BASE_DIR = Path(__file__).resolve().parent
DATA_DIR = BASE_DIR / "data"
DB_PATH = BASE_DIR.parent / "databases" / "ld3_recomputed.sqlite3"

print("Task 2: Storing data in SQLite")
print("=" * 60)


# ========== Load task results from pickle files ==========

task1_path = DATA_DIR / "task1_results.pickle"
task3_path = DATA_DIR / "task3_edges.pickle"
task4_path = DATA_DIR / "task4_pagerank.pickle"

objects = []
wiki_scores = {}
edges = []
pageranks = {}

if task1_path.exists():
    with open(task1_path, "rb") as f:
        task1_data = pickle.load(f)
    objects = task1_data["objects"]
    wiki_scores = task1_data["wiki_scores"]
    print("Loaded %d objects from task1_results.pickle" % len(objects))
else:
    print("WARNING: %s not found. Run process.py first." % task1_path)

if task3_path.exists():
    with open(task3_path, "rb") as f:
        task3_data = pickle.load(f)
    edges = task3_data["edges"]
    print("Loaded %d edges from task3_edges.pickle" % len(edges))
else:
    print("WARNING: %s not found. Run graph.py first (optional)." % task3_path)

if task4_path.exists():
    with open(task4_path, "rb") as f:
        task4_data = pickle.load(f)
    pageranks = task4_data["pageranks"]
    print(
        "Loaded %d pagerank values from task4_pagerank.pickle" % len(pageranks)
    )
else:
    print(
        "WARNING: %s not found. Run pagerank_mpi.py first (optional)." % task4_path
    )


# ========== Create SQLite database ==========

if DB_PATH.exists():
    print("Removing old database: %s" % DB_PATH)
    os.unlink(DB_PATH)

DB_PATH.parent.mkdir(exist_ok=True)

db = sqlite3.connect(DB_PATH)
c = db.cursor()


# ========== Create schema ==========

c.executescript(
    """
    PRAGMA foreign_keys = ON;

    CREATE TABLE IF NOT EXISTS ld3_object (
        internal_id INTEGER PRIMARY KEY,
        place_rowid INTEGER,
        city_name TEXT,
        osm_type TEXT,
        osm_id INTEGER,
        lat REAL,
        lon REAL,
        name TEXT,
        wikidata_qid TEXT,
        historic TEXT,
        shop TEXT,
        tourism TEXT,
        wd_label_en TEXT,
        wd_description_en TEXT,
        wd_instance_of_qid TEXT,
        wiki_relevance REAL NOT NULL DEFAULT 0.0,
        pagerank REAL
    );

    CREATE TABLE IF NOT EXISTS ld3_edge (
        src_id INTEGER NOT NULL,
        dst_id INTEGER NOT NULL,
        edge_type TEXT NOT NULL,
        weight REAL NOT NULL DEFAULT 1.0,
        PRIMARY KEY (src_id, dst_id, edge_type)
    );

    CREATE INDEX IF NOT EXISTS idx_ld3_object_city ON ld3_object(city_name);
    CREATE INDEX IF NOT EXISTS idx_ld3_object_qid ON ld3_object(wikidata_qid);
    CREATE INDEX IF NOT EXISTS idx_ld3_edge_src ON ld3_edge(src_id);
    CREATE INDEX IF NOT EXISTS idx_ld3_edge_dst ON ld3_edge(dst_id);
    CREATE INDEX IF NOT EXISTS idx_ld3_object_pr ON ld3_object(pagerank);
    CREATE INDEX IF NOT EXISTS idx_ld3_object_wr ON ld3_object(wiki_relevance);
"""
)


# ========== Insert objects ==========

for o in objects:
    i = o["internal_id"]
    c.execute(
        """
        INSERT INTO ld3_object (
            internal_id, place_rowid, city_name, osm_type, osm_id,
            lat, lon, name, wikidata_qid, historic, shop, tourism,
            wd_label_en, wd_description_en, wd_instance_of_qid,
            wiki_relevance, pagerank
        ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
        """,
        (
            i,
            o.get("place_rowid"),
            o.get("city_name"),
            o.get("osm_type"),
            o.get("osm_id"),
            o.get("lat"),
            o.get("lon"),
            o.get("name"),
            o.get("wikidata_qid"),
            o.get("historic"),
            o.get("shop"),
            o.get("tourism"),
            o.get("wd_label_en"),
            o.get("wd_description_en"),
            o.get("wd_instance_of_qid"),
            wiki_scores.get(i, 0.0),
            pageranks.get(i),
        ),
    )


# ========== Insert edges ==========

for src_id, dst_id, edge_type, weight in edges:
    c.execute(
        """
        INSERT OR REPLACE INTO ld3_edge (src_id, dst_id, edge_type, weight)
        VALUES (?, ?, ?, ?);
        """,
        (src_id, dst_id, edge_type, weight),
    )

db.commit()


# ========== Print statistics ==========

count_obj = c.execute("SELECT COUNT(*) FROM ld3_object").fetchone()[0]
count_edge = c.execute("SELECT COUNT(*) FROM ld3_edge").fetchone()[0]
count_cities = c.execute(
    "SELECT COUNT(DISTINCT city_name) FROM ld3_object"
).fetchone()[0]

print("=" * 60)
print("Database: %s" % DB_PATH)
print("Objects:  %d" % count_obj)
print("Edges:    %d" % count_edge)
print("Cities:   %d" % count_cities)

for row in c.execute(
    "SELECT city_name, COUNT(*) FROM ld3_object GROUP BY city_name ORDER BY city_name"
):
    print("  %s: %d objects" % (row[0], row[1]))

if pageranks:
    print("\nTop 10 by PageRank:")
    for row in c.execute(
        "SELECT name, city_name, pagerank, wiki_relevance FROM ld3_object "
        "WHERE pagerank IS NOT NULL ORDER BY pagerank DESC LIMIT 10"
    ):
        print(
            "  %s (%s) - PR: %.6f, WR: %.4f" % (row[0], row[1], row[2], row[3])
        )

print("\nTop 10 by Wiki Relevance:")
for row in c.execute(
    "SELECT name, city_name, wiki_relevance, pagerank FROM ld3_object "
    "ORDER BY wiki_relevance DESC LIMIT 10"
):
    pr_str = "%.6f" % row[3] if row[3] is not None else "N/A"
    print("  %s (%s) - WR: %.4f, PR: %s" % (row[0], row[1], row[2], pr_str))

edge_type_counts = c.execute(
    "SELECT edge_type, COUNT(*) FROM ld3_edge GROUP BY edge_type ORDER BY edge_type"
).fetchall()
if edge_type_counts:
    print("\nEdge types:")
    for row in edge_type_counts:
        print("  %s: %d" % (row[0], row[1]))

db.close()

print("=" * 60)
print("Task 2 complete.")
print("\nTo reload Django with the new scores:")
print("  cd .. && LD3_DB_PATH=databases/ld3_recomputed.sqlite3 python manage.py load_data --clear")
