"""
Management command to compute similarity between places using KNN.

Computes three types of similarities:
1. Structural KNN (same city) - based on categorical tags
2. Image KNN (same city) - based on image feature vectors
3. Image KNN (other cities) - based on image feature vectors

Uses multiprocessing for parallel computation.

Usage:
    python manage.py compute_similarities
    python manage.py compute_similarities --method structural
    python manage.py compute_similarities --method image_same_city
    python manage.py compute_similarities --method image_other_city
"""
import os
import json
import numpy as np
from multiprocessing import Pool, cpu_count
from django.core.management.base import BaseCommand
from django.db import connection


# Module-level functions for multiprocessing (must be picklable)

def _process_structural_city(args):
    """Process structural KNN for a single city. Called by multiprocessing."""
    city_id, city_name, records, k = args
    from processing.algorithms.knn import build_onehot_matrix, knn_search, distance_to_score

    cat_cols = ['historic', 'tourism', 'amenity', 'shop', 'leisure', 'man_made', 'memorial', 'artwork_type']

    if len(records) < 2:
        return []

    matrix, feature_names, _ = build_onehot_matrix(records, cat_cols)

    # Filter out records with no features (all zeros)
    row_sums = np.sum(matrix, axis=1)
    valid_mask = row_sums > 0
    valid_indices = np.where(valid_mask)[0]

    if len(valid_indices) < 2:
        return []

    # Maximum possible distance for scoring
    max_dist = float(np.sqrt(matrix.shape[1]))

    results = []
    for i in valid_indices:
        # Only search among valid feature rows
        neighbors = knn_search(i, matrix, k=k, city_mask=valid_mask)
        place_id = records[i]['place_id']
        for neighbor_idx, dist in neighbors:
            if neighbor_idx == i:
                continue
            neighbor_place_id = records[neighbor_idx]['place_id']
            score = distance_to_score(dist, max_distance=max_dist)
            results.append((place_id, neighbor_place_id, score, 'structural'))

    return results


def _process_image_knn(args):
    """Process image KNN for a single image. Called by multiprocessing."""
    img_idx, all_features, all_place_ids, all_city_ids, target_city_id, k, method = args
    from processing.algorithms.knn import distance_to_score

    query_vec = all_features[img_idx]
    query_city = all_city_ids[img_idx]
    query_place = all_place_ids[img_idx]

    if method == 'image_same_city':
        city_mask = np.array([cid == query_city for cid in all_city_ids])
    else:  # image_other_city
        city_mask = np.array([cid != query_city for cid in all_city_ids])

    # Remove self
    city_mask[img_idx] = False

    candidate_indices = np.where(city_mask)[0]
    if len(candidate_indices) == 0:
        return []

    candidate_features = all_features[candidate_indices]
    diff = candidate_features - query_vec
    distances = np.sqrt(np.sum(diff ** 2, axis=1))

    k_actual = min(k, len(candidate_indices))
    top_local = np.argpartition(distances, k_actual - 1)[:k_actual]
    top_local = top_local[np.argsort(distances[top_local])]

    max_dist = float(np.sqrt(all_features.shape[1]))

    results = []
    for local_idx in top_local:
        orig_idx = int(candidate_indices[local_idx])
        neighbor_place_id = all_place_ids[orig_idx]
        if neighbor_place_id is None or neighbor_place_id == query_place:
            continue
        dist = float(distances[local_idx])
        score = distance_to_score(dist, max_distance=max_dist)
        results.append((query_place, neighbor_place_id, score, method))

    return results


class Command(BaseCommand):
    help = 'Compute place similarities using KNN algorithms with multiprocessing'

    def add_arguments(self, parser):
        parser.add_argument(
            '--method',
            type=str,
            default='all',
            choices=['all', 'structural', 'image_same_city', 'image_other_city'],
            help='Which similarity method to compute',
        )
        parser.add_argument(
            '--k-structural',
            type=int,
            default=10,
            help='Number of structural neighbors',
        )
        parser.add_argument(
            '--k-image',
            type=int,
            default=5,
            help='Number of image neighbors',
        )
        parser.add_argument(
            '--clear',
            action='store_true',
            help='Clear existing similarity data before computing',
        )

    def handle(self, *args, **options):
        from places.models import City, Place, SimilarPlace, ImageFeature

        method = options['method']
        k_struct = options['k_structural']
        k_img = options['k_image']

        if options['clear']:
            self.stdout.write('Clearing existing similarity data...')
            if method == 'all':
                SimilarPlace.objects.all().delete()
            else:
                SimilarPlace.objects.filter(method=method).delete()
            self.stdout.write('Cleared.')

        n_workers = max(1, cpu_count())
        self.stdout.write(f'Using {n_workers} worker process(es)')

        if method in ('all', 'structural'):
            self.stdout.write('\n=== Computing Structural KNN ===')
            self._compute_structural(k_struct, n_workers)

        if method in ('all', 'image_same_city'):
            self.stdout.write('\n=== Computing Image KNN (same city) ===')
            self._compute_image_knn('image_same_city', k_img, n_workers)

        if method in ('all', 'image_other_city'):
            self.stdout.write('\n=== Computing Image KNN (other cities) ===')
            self._compute_image_knn('image_other_city', k_img, n_workers)

        self.stdout.write(self.style.SUCCESS('\nSimilarity computation complete!'))

    def _compute_structural(self, k, n_workers):
        from places.models import City, Place, SimilarPlace

        cities = list(City.objects.all())
        city_args = []

        for city in cities:
            places = list(Place.objects.filter(city=city).values(
                'id', 'historic', 'tourism', 'amenity', 'shop',
                'leisure', 'man_made', 'memorial', 'artwork_type'
            ))
            records = []
            for p in places:
                rec = {
                    'place_id': p['id'],
                    'historic': p['historic'],
                    'tourism': p['tourism'],
                    'amenity': p['amenity'],
                    'shop': p['shop'],
                    'leisure': p['leisure'],
                    'man_made': p['man_made'],
                    'memorial': p['memorial'],
                    'artwork_type': p['artwork_type'],
                }
                records.append(rec)

            city_args.append((city.id, city.name, records, k))
            self.stdout.write(f'  Queued {city.name} ({len(records)} places)')

        # Close DB connection before forking
        connection.close()

        self.stdout.write(f'Running structural KNN with {n_workers} workers...')
        with Pool(processes=n_workers) as pool:
            results_list = pool.map(_process_structural_city, city_args)

        # Flatten and save
        all_results = []
        for results in results_list:
            all_results.extend(results)

        self.stdout.write(f'Saving {len(all_results)} structural similarity records...')
        self._save_similarities(all_results, 'structural')
        self.stdout.write(f'  Saved {len(all_results)} structural similarities')

    def _compute_image_knn(self, method, k, n_workers):
        from places.models import SimilarPlace, ImageFeature

        # Load all image features
        features_qs = ImageFeature.objects.filter(
            place__isnull=False
        ).select_related('city', 'place')

        features_list = list(features_qs)

        if not features_list:
            self.stdout.write('  No image features found. Run load_data first.')
            return

        self.stdout.write(f'  Loaded {len(features_list)} image features')

        # Parse feature vectors
        parsed_features = []
        valid_features = []
        for feat in features_list:
            try:
                vec = np.array(json.loads(feat.feature_vector), dtype=np.float32)
                parsed_features.append(vec)
                valid_features.append(feat)
            except Exception:
                continue

        if not parsed_features:
            self.stdout.write('  No valid feature vectors found.')
            return

        n = len(parsed_features)
        d = len(parsed_features[0])
        all_features = np.zeros((n, d), dtype=np.float32)
        for i, vec in enumerate(parsed_features):
            if len(vec) == d:
                all_features[i] = vec

        all_place_ids = [f.place_id for f in valid_features]
        all_city_ids = [f.city_id for f in valid_features]

        # Prepare args for each image
        args_list = [
            (i, all_features, all_place_ids, all_city_ids, all_city_ids[i], k, method)
            for i in range(n)
        ]

        self.stdout.write(f'  Computing {method} KNN for {n} images...')

        # Close DB before forking
        connection.close()

        with Pool(processes=n_workers) as pool:
            results_list = pool.map(_process_image_knn, args_list)

        all_results = []
        for results in results_list:
            all_results.extend(results)

        # Deduplicate: same (main_place, similar_place, method) - keep highest score
        deduped = {}
        for main_id, sim_id, score, m in all_results:
            if main_id is None or sim_id is None or main_id == sim_id:
                continue
            key = (main_id, sim_id, m)
            if key not in deduped or deduped[key] < score:
                deduped[key] = score

        final_results = [(k[0], k[1], v, k[2]) for k, v in deduped.items()]
        self.stdout.write(f'  Saving {len(final_results)} {method} similarity records...')
        self._save_similarities(final_results, method)
        self.stdout.write(f'  Saved {len(final_results)} {method} similarities')

    def _save_similarities(self, results, method):
        from places.models import SimilarPlace, Place

        if not results:
            return

        # Validate place IDs exist
        all_place_ids = set(Place.objects.values_list('id', flat=True))

        to_create = []
        seen = set()
        for main_id, sim_id, score, m in results:
            if main_id not in all_place_ids or sim_id not in all_place_ids:
                continue
            if main_id == sim_id:
                continue
            key = (main_id, sim_id, m)
            if key in seen:
                continue
            seen.add(key)
            to_create.append(SimilarPlace(
                main_place_id=main_id,
                similar_place_id=sim_id,
                score=score,
                method=m,
            ))

        # Bulk create with ignore conflicts
        batch_size = 1000
        created = 0
        for i in range(0, len(to_create), batch_size):
            batch = to_create[i:i + batch_size]
            SimilarPlace.objects.bulk_create(batch, ignore_conflicts=True)
            created += len(batch)

        self.stdout.write(f'  Bulk created {created} records for method={method}')
