"""
KMeans algorithm for image color palette extraction.
Manual implementation without sklearn.
"""
import numpy as np
import random


def load_image_pixels(image_path):
    """
    Load image and return pixel array as (N, 3) float32 numpy array.

    Args:
        image_path: absolute path to the image file

    Returns:
        np.ndarray of shape (N, 3) with RGB values in [0, 255], or None on error
    """
    try:
        from PIL import Image
        img = Image.open(image_path).convert('RGB')
        # Resize to max 100x100 for speed
        img.thumbnail((100, 100), Image.Resampling.LANCZOS)
        pixels = np.array(img, dtype=np.float32).reshape(-1, 3)
        return pixels
    except Exception:
        return None


def kmeans(pixels, k=10, max_iter=20, tol=1.0, seed=42):
    """
    Manual KMeans implementation.

    Args:
        pixels: np.ndarray of shape (N, 3)
        k: number of clusters
        max_iter: maximum iterations
        tol: convergence tolerance (centroid shift)
        seed: random seed

    Returns:
        (centroids, assignments)
        - centroids: np.ndarray of shape (k, 3)
        - assignments: np.ndarray of shape (N,) with cluster indices
    """
    n = len(pixels)
    if n == 0:
        return np.zeros((k, 3), dtype=np.float32), np.zeros(n, dtype=np.int32)

    k = min(k, n)
    rng = random.Random(seed)

    # KMeans++ initialization
    centroids = []
    first_idx = rng.randint(0, n - 1)
    centroids.append(pixels[first_idx].copy())

    for _ in range(1, k):
        # Compute distances to nearest centroid
        centroid_arr = np.array(centroids, dtype=np.float32)
        # distances shape: (n, len(centroids))
        diffs = pixels[:, np.newaxis, :] - centroid_arr[np.newaxis, :, :]
        sq_dists = np.sum(diffs ** 2, axis=2)
        min_sq_dists = np.min(sq_dists, axis=1)

        # Probability proportional to squared distance
        total = float(np.sum(min_sq_dists))
        if total == 0:
            idx = rng.randint(0, n - 1)
        else:
            probs = min_sq_dists / total
            # Weighted random selection
            cumprobs = np.cumsum(probs)
            r = rng.random()
            idx = int(np.searchsorted(cumprobs, r))
            idx = min(idx, n - 1)
        centroids.append(pixels[idx].copy())

    centroids = np.array(centroids, dtype=np.float32)

    assignments = np.zeros(n, dtype=np.int32)

    for iteration in range(max_iter):
        # Assign each pixel to nearest centroid
        diffs = pixels[:, np.newaxis, :] - centroids[np.newaxis, :, :]
        sq_dists = np.sum(diffs ** 2, axis=2)
        new_assignments = np.argmin(sq_dists, axis=1).astype(np.int32)

        # Update centroids
        new_centroids = np.zeros_like(centroids)
        for j in range(k):
            mask = new_assignments == j
            if np.any(mask):
                new_centroids[j] = np.mean(pixels[mask], axis=0)
            else:
                # Re-initialize empty cluster to random pixel
                new_centroids[j] = pixels[rng.randint(0, n - 1)]

        # Check convergence
        shift = float(np.max(np.sqrt(np.sum((new_centroids - centroids) ** 2, axis=1))))
        centroids = new_centroids
        assignments = new_assignments

        if shift < tol:
            break

    return centroids, assignments


def build_feature_vector(centroids, assignments, k):
    """
    Sort clusters by size (descending) and return flat RGB feature vector.

    Args:
        centroids: np.ndarray of shape (k, 3)
        assignments: np.ndarray of shape (N,)
        k: number of clusters

    Returns:
        np.ndarray of length k*3 (float32)
    """
    k_actual = len(centroids)
    # Count pixels per cluster
    counts = np.bincount(assignments, minlength=k_actual)
    # Sort by count descending
    order = np.argsort(-counts)
    sorted_centroids = centroids[order]

    # Pad to k clusters if needed
    if k_actual < k:
        pad = np.zeros((k - k_actual, 3), dtype=np.float32)
        sorted_centroids = np.concatenate([sorted_centroids, pad], axis=0)
    elif k_actual > k:
        sorted_centroids = sorted_centroids[:k]

    # Normalize to [0, 1]
    feature_vector = sorted_centroids.flatten() / 255.0
    return feature_vector.astype(np.float32)


def compute_image_feature(image_path, k=10):
    """
    Full pipeline: load image -> run KMeans -> return feature vector.

    Args:
        image_path: absolute path to the image file
        k: number of color clusters

    Returns:
        np.ndarray of length k*3, or None if error
    """
    pixels = load_image_pixels(image_path)
    if pixels is None or len(pixels) == 0:
        return None

    try:
        centroids, assignments = kmeans(pixels, k=k)
        feature_vector = build_feature_vector(centroids, assignments, k)
        return feature_vector
    except Exception:
        return None
