"""
Fetch accurate per-POI images from the Openverse API.

For every place in the database it builds a targeted query from the place name
(plus city for disambiguation), searches Openverse, ranks the results by how well
their title/tags match the place name, and downloads only confident matches.
Images are saved under MEDIA_ROOT/poi/<city_slug>/ and linked via PlaceImage.

Usage:
    python manage.py fetch_poi_images                      # top 50 places per city
    python manage.py fetch_poi_images --city berlin --limit 100
    python manage.py fetch_poi_images --per-place 3 --compute-features
    python manage.py fetch_poi_images --client-id ID --client-secret SECRET
    python manage.py fetch_poi_images --dry-run            # preview matches only
    python manage.py fetch_poi_images --clear              # remove old images first
"""
import os
import re
import time
import json
import threading
import unicodedata
import requests
from urllib.parse import urlparse
from concurrent.futures import ThreadPoolExecutor, as_completed
from django.core.management.base import BaseCommand
from django.conf import settings


API_BASE = "https://api.openverse.org/v1"

BROWSER_HEADERS = {
    "User-Agent": (
        "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 "
        "(KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
    ),
    "Accept": "application/json, text/plain, */*",
    "Accept-Language": "en-US,en;q=0.9",
}

VALID_EXT = {"jpg", "jpeg", "png", "gif", "webp"}

# Country context per city slug — sharpens Openverse full-text retrieval.
CITY_COUNTRY = {
    'london': 'United Kingdom',
    'berlin': 'Germany',
    'new_york': 'United States',
    'paris': 'France',
}

# Tokens too generic to prove a result actually matches the place.
STOPWORDS = {
    "the", "of", "and", "de", "la", "le", "el", "von", "der", "die", "das",
    "a", "an", "at", "in", "on", "to", "for", "monument", "memorial",
    "statue", "museum", "church", "park", "square", "house", "building",
}


def strip_accents(text):
    return ''.join(
        c for c in unicodedata.normalize('NFKD', text)
        if not unicodedata.combining(c)
    )


def tokenize(text):
    text = strip_accents((text or '').lower())
    return [t for t in re.split(r'[^a-z0-9]+', text) if len(t) > 1]


def name_tokens(name):
    """Significant tokens of a place name, excluding generic stopwords."""
    toks = [t for t in tokenize(name) if t not in STOPWORDS]
    return toks or tokenize(name)  # fall back to all tokens if name was all-generic


class Command(BaseCommand):
    help = 'Fetch accurate per-POI images from Openverse and link them to places'

    def add_arguments(self, parser):
        parser.add_argument('--city', type=str, default=None,
                            help='Limit to a single city slug')
        parser.add_argument('--limit', type=int, default=50,
                            help='Max places to process per city (highest score first). Use 0 for ALL places.')
        parser.add_argument('--workers', type=int, default=8,
                            help='Concurrent network workers. Higher = faster but more likely to hit rate limits.')
        parser.add_argument('--per-place', type=int, default=2,
                            help='Number of images to download per place')
        parser.add_argument('--min-match', type=float, default=0.5,
                            help='Minimum fraction of name tokens that must appear in a result (0-1)')
        parser.add_argument('--client-id', type=str, default=os.environ.get('OPENVERSE_CLIENT_ID', ''),
                            help='Openverse OAuth client id (optional, higher rate limits)')
        parser.add_argument('--client-secret', type=str, default=os.environ.get('OPENVERSE_CLIENT_SECRET', ''),
                            help='Openverse OAuth client secret (optional)')
        parser.add_argument('--compute-features', action='store_true',
                            help='Compute KMeans colour feature vectors for downloaded images')
        parser.add_argument('--features-only', action='store_true',
                            help='Skip all network work — only compute missing KMeans features for already-downloaded images')
        parser.add_argument('--clear', action='store_true',
                            help='Delete existing PlaceImage/ImageFeature for processed places first')
        parser.add_argument('--only-new', action='store_true',
                            help='Skip places that already have at least one image saved')
        parser.add_argument('--dry-run', action='store_true',
                            help='Show matched results without downloading')
        parser.add_argument('--sleep', type=float, default=0.0,
                            help='Seconds to wait between API searches (default 0)')
        parser.add_argument('--exclude-wikimedia', action='store_true',
                            help='Skip images hosted on wikimedia.org or wikipedia.org')

    def handle(self, *args, **opt):
        from places.models import City, Place

        self._local = threading.local()
        self.token = None
        self.media_root = settings.MEDIA_ROOT

        cities = City.objects.all()
        if opt['city']:
            cities = cities.filter(slug=opt['city'])
            if not cities.exists():
                self.stderr.write(f"No city with slug '{opt['city']}'")
                return

        if opt['features_only']:
            self._run_features_only(cities, opt)
            return

        if opt['client_id'] and opt['client_secret']:
            self.token = self._get_token(opt['client_id'], opt['client_secret'])

        workers = max(1, opt['workers'])
        total_downloaded = 0

        for city in cities:
            qs = Place.objects.filter(city=city).order_by('-interest_score')
            if opt['limit'] and opt['limit'] > 0:
                qs = qs[:opt['limit']]
            places = list(qs)

            if opt['only_new']:
                from places.models import PlaceImage
                already_done = set(
                    PlaceImage.objects.filter(place__city=city)
                    .values_list('place_id', flat=True).distinct()
                )
                places = [p for p in places if p.id not in already_done]

            self.stdout.write(self.style.MIGRATE_HEADING(
                f"\n=== {city.name}: processing {len(places)} places "
                f"({workers} workers) ==="))

            out_dir = os.path.join(self.media_root, 'poi', city.slug)
            if not opt['dry_run']:
                os.makedirs(out_dir, exist_ok=True)

            done = 0
            # Network search + download run in parallel; DB writes stay on the
            # main thread because SQLite allows only one writer.
            with ThreadPoolExecutor(max_workers=workers) as ex:
                futures = {
                    ex.submit(self._plan_place, place, city, out_dir, opt): place
                    for place in places
                }
                for fut in as_completed(futures):
                    place = futures[fut]
                    plan = fut.result()
                    done += 1
                    total_downloaded += self._commit_plan(place, plan, opt)
                    if done % 50 == 0:
                        self.stdout.write(f"  ...{done}/{len(places)} places, "
                                          f"{total_downloaded} images so far")

        self.stdout.write(self.style.SUCCESS(
            f"\nDone. Downloaded {total_downloaded} images for places."))

    def _run_features_only(self, cities, opt):
        """Compute KMeans feature vectors for all downloaded images that lack one."""
        from places.models import PlaceImage, ImageFeature
        from processing.algorithms.kmeans import compute_image_feature

        for city in cities:
            images = (PlaceImage.objects
                      .filter(place__city=city, image_path__startswith='poi/')
                      .exclude(image_path__in=ImageFeature.objects.values('image_path'))
                      .select_related('place'))

            total = images.count()
            self.stdout.write(self.style.MIGRATE_HEADING(
                f"\n=== {city.name}: computing features for {total} images ==="))

            done = computed = 0
            for pi in images.iterator():
                abs_path = os.path.join(self.media_root, pi.image_path)
                done += 1
                if not os.path.exists(abs_path):
                    self.stdout.write(self.style.WARNING(
                        f"  [missing file] {pi.image_path}"))
                    continue
                vec = compute_image_feature(abs_path, k=10)
                if vec is None:
                    self.stdout.write(self.style.WARNING(
                        f"  [feature failed] {pi.image_path}"))
                    continue
                ImageFeature.objects.update_or_create(
                    image_path=pi.image_path,
                    defaults={
                        'city': city,
                        'category_name': pi.place.get_primary_category() if pi.place else 'general',
                        'feature_vector': json.dumps(vec.tolist()),
                        'place': pi.place,
                    },
                )
                computed += 1
                if done % 50 == 0:
                    self.stdout.write(f"  ...{done}/{total} images, {computed} features computed")

            self.stdout.write(self.style.SUCCESS(
                f"  Done: {computed}/{total} feature vectors computed for {city.name}"))

    def _session(self):
        s = getattr(self._local, 'session', None)
        if s is None:
            s = requests.Session()
            s.headers.update(BROWSER_HEADERS)
            self._local.session = s
        return s

    # ---- Openverse OAuth (optional) ----
    def _get_token(self, client_id, client_secret):
        try:
            r = self._session().post(f"{API_BASE}/auth_tokens/token/", data={
                'grant_type': 'client_credentials',
                'client_id': client_id,
                'client_secret': client_secret,
            }, timeout=30)
            if r.status_code == 200:
                self.stdout.write(self.style.SUCCESS("Openverse OAuth token acquired"))
                return r.json().get('access_token')
            self.stderr.write(f"Token request failed: {r.status_code}")
        except requests.RequestException as e:
            self.stderr.write(f"Token error: {e}")
        return None

    # ---- per-place pipeline ----
    # _plan_place runs in worker threads: network search + download + feature
    # math only (NO database access). It returns a plan that the main thread
    # commits via _commit_plan.
    def _queries(self, place, city):
        """Ordered, deduped query strings — richest (most disambiguating) first."""
        name = place.name
        country = CITY_COUNTRY.get(city.slug, '')
        cat = place.get_primary_category()
        cat = '' if cat == 'Place' else cat

        candidates = [
            ' '.join(p for p in [name, cat, city.name, country] if p),
            ' '.join(p for p in [name, city.name, country] if p),
            ' '.join(p for p in [name, city.name] if p),
            name,
        ]
        seen, ordered = set(), []
        for q in candidates:
            if q and q not in seen:
                seen.add(q)
                ordered.append(q)
        return ordered

    def _plan_place(self, place, city, out_dir, opt):
        wanted = name_tokens(place.name)

        # Search Openverse across all its sources, richest query first,
        # falling back to simpler queries if no confident match is found.
        queries = self._queries(place, city)
        ranked = []
        for q in queries:
            results = self._search(q)
            ranked = self._rank(results, wanted, opt['min_match'], opt['exclude_wikimedia'])
            if ranked:
                break

        if not ranked:
            return {'status': 'no_match', 'queries': queries}

        if opt['dry_run']:
            return {'status': 'dry_run', 'best': ranked[0]}

        downloads = []
        for i, res in enumerate(ranked[:opt['per_place']]):
            ext = self._ext_from_url(res['url'])
            fname = f"{place.id}_{i}.{ext}"
            abs_path = os.path.join(out_dir, fname)
            rel_path = os.path.join('poi', city.slug, fname)

            domain = urlparse(res['url']).netloc or res['url']
            err = self._download(res['url'], abs_path)
            if err:
                downloads.append({'failed': True, 'reason': err, 'url': res['url'], 'title': res['title'], 'domain': domain})
                continue

            vec = None
            if opt['compute_features']:
                from processing.algorithms.kmeans import compute_image_feature
                v = compute_image_feature(abs_path, k=10)
                if v is not None:
                    vec = json.dumps(v.tolist())

            downloads.append({
                'rel_path': rel_path,
                'caption': res['title'][:500],
                'category': place.get_primary_category(),
                'feature_vector': vec,
                'domain': domain,
            })

        if opt['sleep']:
            time.sleep(opt['sleep'])
        return {'status': 'ok', 'downloads': downloads}

    def _commit_plan(self, place, plan, opt):
        """Write a worker's results to the DB (main thread only)."""
        from places.models import PlaceImage, ImageFeature

        if plan['status'] == 'no_match':
            tried = ' | '.join(dict.fromkeys(plan['queries']))  # deduped
            self.stdout.write(
                self.style.WARNING(f"  [no match] {place.name}")
                + f"\n    Tried: {tried}"
            )
            return 0
        if plan['status'] == 'dry_run':
            best = plan['best']
            self.stdout.write(
                f"  [match {best['score']:.2f}] {place.name}  ->  {best['title']!r}")
            return 0

        if opt['clear']:
            ImageFeature.objects.filter(place=place).delete()
            PlaceImage.objects.filter(place=place).delete()

        downloaded = 0
        for d in plan['downloads']:
            domain = d.get('domain', '')
            label = f"{place.name}  ({domain})" if domain else place.name
            if d.get('failed'):
                self.stdout.write(
                    self.style.WARNING(f"  [download failed] {label} | {d['reason']}")
                )
                continue
            PlaceImage.objects.update_or_create(
                image_path=d['rel_path'],
                defaults={'place': place, 'caption': d['caption']},
            )
            if d['feature_vector'] is not None:
                ImageFeature.objects.update_or_create(
                    image_path=d['rel_path'],
                    defaults={
                        'city': place.city,
                        'category_name': d['category'],
                        'feature_vector': d['feature_vector'],
                        'place': place,
                    },
                )
            self.stdout.write(self.style.SUCCESS(f"  [saved] {label}"))
            downloaded += 1

        if not plan['downloads']:
            self.stdout.write(f"  [no images saved] {place.name}")
        return downloaded

    def _search(self, query, page_size=20):
        headers = {}
        if self.token:
            headers['Authorization'] = f"Bearer {self.token}"
            page_size = 50
        params = {'q': query, 'page_size': page_size, 'mature': 'false'}
        try:
            r = self._session().get(f"{API_BASE}/images/", headers=headers,
                                    params=params, timeout=30)
            if r.status_code == 429:
                time.sleep(30)
                return self._search(query, page_size=page_size)
            if r.status_code != 200:
                return []
            return r.json().get('results', [])
        except requests.RequestException:
            return []

    _WIKIMEDIA_DOMAINS = ('wikimedia.org', 'wikipedia.org')

    def _rank(self, results, wanted, min_match, exclude_wikimedia=False):
        """Score each result by how many place-name tokens it contains."""
        if not wanted:
            return []
        scored = []
        for item in results:
            url = item.get('url')
            if not url:
                continue
            if exclude_wikimedia:
                host = urlparse(url).netloc.lower()
                if any(host == d or host.endswith('.' + d) for d in self._WIKIMEDIA_DOMAINS):
                    continue
            haystack = ' '.join(filter(None, [
                item.get('title', ''),
                ' '.join(t.get('name', '') for t in (item.get('tags') or [])),
                item.get('creator', ''),
                item.get('source', ''),
            ]))
            htoks = set(tokenize(haystack))
            hits = sum(1 for w in wanted if w in htoks)
            score = hits / len(wanted)
            if score >= min_match:
                scored.append({
                    'url': url,
                    'title': item.get('title') or 'Untitled',
                    'score': score,
                })
        scored.sort(key=lambda x: x['score'], reverse=True)
        return scored

    @staticmethod
    def _ext_from_url(url):
        tail = url.split('?')[0].rsplit('.', 1)
        if len(tail) == 2 and tail[1].lower() in VALID_EXT:
            return tail[1].lower()
        return 'jpg'

    def _download(self, url, abs_path):
        """Download url to abs_path. Returns None on success, error string on failure."""
        try:
            r = self._session().get(url, timeout=30, stream=True)
            if r.status_code != 200:
                return f"HTTP {r.status_code}"
            with open(abs_path, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)
            size = os.path.getsize(abs_path)
            if size == 0:
                os.remove(abs_path)
                return "empty file (0 bytes)"
            return None  # success
        except requests.exceptions.ConnectionError as e:
            return f"connection error: {e}"
        except requests.exceptions.Timeout:
            return "timed out after 30s"
        except requests.exceptions.RequestException as e:
            return f"request error: {e}"
        except OSError as e:
            return f"file write error: {e}"
