from pathlib import Path
from kyd_dataspec_gen.utils import read_dataset, string_array_to_list
import polars as pl
from scipy.stats import median_abs_deviation

import logging
import math

logger = logging.getLogger(__name__)


def gini_impurity(categories: dict, sample_size: int) -> float | None:
    """
    Calculates the Gini impurity for a set of categorical data.
    Taking gini impurity calculation from data profiler as reference.
    Gini impurity is a measure of how often a randomly chosen element from the set would be incorrectly labeled if it was randomly labeled according to the distribution of labels in the set.
    Args:
        categories (dict): A dictionary containing a "count" key, whose value is a list of counts for each category.
        sample_size (int): The total number of samples.
    Returns:
        float | None: The Gini impurity value rounded to 4 decimal places, or None if sample_size is 0 or categories is empty.
    """
    if sample_size == 0 or len(categories) == 0:
        return None
    gini_sum: float = 0
    for i in categories["count"]:
        gini_sum += (i / sample_size) * (1 - (i / sample_size))
    return round(gini_sum, 4)


def unalikeability(categories: dict, sample_size: int) -> float:
    """
    Calculates the unalikeability measure for a categorical distribution.
    Taking unalikability calculation from data profiler as reference.

    Unalikeability quantifies the probability that two randomly selected items from a sample belong to different categories.
    A higher value indicates greater diversity among categories.

    Args:
        categories (dict): A dictionary containing a key "count" whose value is a list of counts for each category.
        sample_size (int): The total number of samples (sum of all category counts).

    Returns:
        float: The unalikeability value, rounded to 4 decimal places.
    """
    unalike_sum: int = 0
    for count in categories["count"]:
        unalike_sum += (sample_size - count) * count
    unalike: float = unalike_sum / (sample_size**2 - sample_size)
    return round(unalike, 4)


def get_stats(df: pl.DataFrame, col_name: str) -> dict:
    """
    Compute various statistical metrics for a specified column in a Polars DataFrame.
    Args:
        df (pl.DataFrame): The input Polars DataFrame containing the data.
        col_name (str): The name of the column for which statistics are computed.
    Returns:
        dict: A dictionary containing the following statistics for the specified column:
            - null_count (float): Number of null values in the column.
            - histogram: Histogram of the column values.
            - median (float): Median value of the column.
            - variance (float): Variance of the column values.
            - skewness (float): Skewness of the column values.
            - kurtosis (float): Kurtosis of the column values.
            - median_abs_deviation (float): Median absolute deviation of the column values.
            - min (float): Minimum value in the column.
            - max (float): Maximum value in the column.
            - stdv (float): Standard deviation of the column values.
            - q_0 (float): 25th quantile (first quartile) of the column.
            - q_1 (float): 50th quantile (second quartile/median) of the column.
            - q_2 (float): 75th quantile (third quartile) of the column.
            - num_zeros (int): Number of zero values in the column.
            - num_negatives (int): Number of negative values in the column.
    Notes:
        - NaN values are handled and replaced with 0.0 in the output statistics.
        - The histogram is generated using the Polars Series `hist()` method.
    """

    stats = df.describe()

    def _none_if_nan(val):
        try:
            if isinstance(val, float) and math.isnan(val):
                return 0.0
        except Exception:
            pass
        return val

    return {
        "null_count": _none_if_nan(stats.item(1, col_name)),
        "histogram": df.select(col_name).to_series().hist(),
        "median": _none_if_nan(df.select(pl.col(col_name).median()).item()),
        "variance": _none_if_nan(round(df.select(pl.col(col_name).var()).item(), 4)),
        "skewness": _none_if_nan(round(df.select(pl.col(col_name).skew()).item(), 4)),
        "kurtosis": _none_if_nan(
            round(df.select(pl.col(col_name).kurtosis()).item(), 4)
        ),
        "median_abs_deviation": _none_if_nan(
            median_abs_deviation(df.get_column(col_name).fill_null(0).to_list())
        ),
        "min": _none_if_nan(stats.item(4, col_name)),
        "max": _none_if_nan(stats.item(8, col_name)),
        "stdv": _none_if_nan(round(stats.item(3, col_name), 4)),
        "q_0": _none_if_nan(stats.item(5, col_name)),
        "q_1": _none_if_nan(stats.item(6, col_name)),
        "q_2": _none_if_nan(stats.item(7, col_name)),
        "num_zeros": df.filter(pl.col(col_name) == 0).height,
        "num_negatives": df.filter(pl.col(col_name) < 0).height,
    }


def check_replace_profile_data(raw_dataset_path: Path, profile: dict) -> dict:
    """
    Checks and updates the data profile for a given dataset to ensure it reflects the full raw data.
    This function compares the current profile's sample usage with the total row count of the dataset.
    If the profile does not already use the full dataset, it recalculates and updates various global and
    per-column statistics using the complete dataset. These statistics include counts, ratios, null values,
    unique values, and other descriptive statistics. For categorical columns, additional metrics such as
    Gini impurity and unalikeability are computed.
    Args:
        raw_dataset_path (Path): The file path to the raw dataset.
        profile (dict): The existing data profile dictionary to be checked and potentially updated.
    Returns:
        dict: The updated data profile dictionary reflecting statistics computed from the full dataset.
    """
    logger.info(
        f"Checking and replacing data profile for dataset at {raw_dataset_path}"
    )
    dataset_profile = profile["global_stats"]
    if dataset_profile["row_count"] == dataset_profile["samples_used"]:
        logger.debug(
            "Current data spec uses full raw data. Skipping data profile check."
        )
        return profile
    df = read_dataset(raw_dataset_path)
    dataset_profile["samples_used"] = dataset_profile["row_count"]
    dataset_profile["unique_row_ratio"] = df.filter(df.is_unique()).height / df.height
    dataset_profile["duplicate_row_count"] = df.filter(df.is_duplicated()).height
    dataset_profile["row_is_null_ratio"] = (
        df.filter(pl.all_horizontal(pl.all().is_null())).height / df.height
    )
    dataset_profile["row_has_null_ratio"] = (
        df.filter(pl.any_horizontal(pl.all().is_null())).height / df.height
    )
    for i, col in enumerate(profile["data_stats"]):
        logger.debug(f"Processing column: {col['column_name']}")
        col_stats = col["statistics"]
        col_stats["sample_size"] = df.height
        filtered_column = df.select(col["column_name"]).filter(
            pl.col(col["column_name"]).is_not_null()
        )
        if filtered_column.height == 0:
            logger.debug(
                f"Column {col['column_name']} has no non-null values. Updating 'null types index'."
            )
            col_stats["null_count"] = df.height
            null_types = string_array_to_list(col_stats["null_types"])
            if isinstance(null_types, list) and len(null_types) == len(
                col_stats["null_types_index"]
            ):
                for null_type in null_types:
                    if null_type in col_stats["null_types_index"]:
                        col_stats["null_types_index"][null_type] = [
                            i for i in range(df.select(col["column_name"]).height)
                        ]
            continue
        col_sum = (
            df.select(col["column_name"]).sum().item()
            if col["data_type"] == "int" or col["data_type"] == "float"
            else sum(
                [
                    len(s) if s is not None else 0
                    for s in df.select(col["column_name"]).to_series().to_list()
                ]  # Handle string columns
            )
        )
        mode = df.select(pl.col(col["column_name"]).mode()).to_series().to_list()
        if df.schema[col["column_name"]] == pl.String:
            converted_column = df.select(col["column_name"]).with_columns(
                pl.col(col["column_name"]).str.len_chars().alias(col["column_name"])
            )  # Replace non-numerical column with character count
            new_stats = get_stats(converted_column, col["column_name"])
        else:
            new_stats = get_stats(df, col["column_name"])
        col_stats["null_count"] = int(new_stats["null_count"])
        col_stats["min"] = new_stats["min"]
        col_stats["max"] = new_stats["max"]
        col_stats["mode"] = sorted(mode) if len(mode) != df.height else []
        col_stats["median"] = new_stats["median"]
        col_stats["sum"] = col_sum
        col_stats["mean"] = round(col_sum / df.height, 4)
        col_stats["variance"] = new_stats["variance"]
        col_stats["stddev"] = new_stats["stdv"]
        col_stats["skewness"] = new_stats["skewness"]
        col_stats["kurtosis"] = new_stats["kurtosis"]
        col_stats["histogram"]["bin_edges"] = (
            new_stats["histogram"].select("breakpoint").to_series().to_list()
        )
        col_stats["histogram"]["bin_counts"] = (
            new_stats["histogram"].select("count").to_series().to_list()
        )
        col_stats["quantiles"]["0"] = new_stats["q_0"]
        col_stats["quantiles"]["1"] = new_stats["q_1"]
        col_stats["quantiles"]["2"] = new_stats["q_2"]
        col_stats["median_abs_deviation"] = round(new_stats["median_abs_deviation"], 4)
        # Times no longer apply here and assigned as an empty dict as it's the duration of time it took to generate the global statistics with the data profiler
        col_stats["times"] = {}
        col_stats["unique_count"] = df.select(
            pl.approx_n_unique(col["column_name"])
        ).item()
        col_stats["unique_ratio"] = round(
            col_stats["unique_count"] / col_stats["sample_size"], 4
        )
        col_stats["num_zeros"] = new_stats["num_zeros"]
        col_stats["num_negatives"] = new_stats["num_negatives"]
        if col["categorical"]:
            pl_categorical_count = (
                df.get_column(col["column_name"])
                .value_counts()
                .to_dict(as_series=False)
            )
            categorical_count = {
                val: pl_categorical_count["count"][ind]
                for ind, val in enumerate(pl_categorical_count[col["column_name"]])
            }
            col_stats["categories"] = list(categorical_count.keys())
            col_stats["gini_impurity"] = gini_impurity(
                pl_categorical_count, col_stats["sample_size"]
            )
            col_stats["unalikeability"] = unalikeability(
                pl_categorical_count, col_stats["sample_size"]
            )
            col_stats["categorical_count"] = categorical_count
        logger.debug(f"Updated statistics for column: {profile['data_stats'][i]}")
    return profile
