import copy
import json
import logging
import math
import os
import re
from pathlib import Path
import fsspec

from detect_delimiter import detect
from dotenv import load_dotenv
from glom import delete as glom_delete
from glom import glom
from google.genai.types import GenerateContentConfigDict
from google import genai
from rapidfuzz import fuzz

from kyd_dataspec_gen.config import Config
from kyd_dataspec_gen.schemas.datasource_schema_template import schema_default_template
from kyd_dataspec_gen.logging_config import setup_logging
from kyd_dataspec_gen.models import DataSource, DataClassification, DataDictionaryMatch
from kyd_dataspec_gen.setup_ai import generate_response, setup_gemini_client
from kyd_dataspec_gen.primary_key_detection import (
    identify_primary_key,
    review_primary_key,
)
from kyd_dataspec_gen.profile_full_data import check_replace_profile_data
from kyd_dataspec_gen.utils import string_array_to_list
from kyd_dataspec_gen.match_reference_datasets import (
    match_reference_data_to_published_set,
    publish_ref_dataset,
)
from kyd_dataspec_gen.match_data_dictionary import (
    read_data_dictionary,
    match_data_dictionary,
)

load_dotenv()
setup_logging()
logger = logging.getLogger(__name__)
GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash-preview-04-17")
fs = fsspec.filesystem("gcs")

date_format_list = [
    "99/99/9999",
    "99.99.9999",
    "99-99-9999",
    "9999-99-99",
    "99-XXX-9999",
]


def read_profiler_output(file_path: Path) -> dict:
    """
    Reads the profiler output from a JSON file and returns it as a dictionary.

    Args:
        file_path (Path): The path to the JSON file containing the profiler output.

    Returns:
        dict: A dictionary representation of the JSON data from the file.

    """
    if str(file_path).startswith("gs:/"):
        path = str(file_path).replace("gs:/", "gs://")
        fs, fs_path = fsspec.core.url_to_fs(path)
        with fs.open(fs_path, "r", encoding="utf-8") as file:
            data = json.load(file)
    else:
        with open(file_path, "r", encoding="utf-8") as file:
            data = json.load(file)
    return data


def create_data_profile_shape(
    category_count: dict, total_rows: int, config: Config
) -> list[dict]:
    """
    Generates a data profile shape based on the provided category counts and total rows.

    This function processes a dictionary of category counts, standardises the category names
    into patterns, and calculates statistics such as sample size and likelihood for each
    pattern. The resulting data profile shape is returned as a sorted list of dictionaries.

    Args:
        category_count (dict): A dictionary where keys are category names (str) and values
            are their respective counts (int).
        total_rows (int): The total number of rows in the dataset.

    Returns:
        list[dict]: A list of dictionaries representing the data profile shape. Each dictionary
        contains:
            - "value" (str): The standardized pattern of the category name.
            - "statistics" (dict): A dictionary with:
                - "sampleSize" (int): The total count of rows matching the pattern.
                - "likelihood" (float): The proportion of rows matching the pattern, rounded
                  to 4 decimal places.

    Example:
        Input:
            category_count = {"A1": 50, "B2": 30, "C3": 20}
            total_rows = 100
        Output:
            [
                {"value": "X9", "statistics": {"sampleSize": 100, "likelihood": 1.0}}
            ]
    """
    if len(category_count) == 0:
        return []
    profile_shape_dict = {}
    profile_shape_list = []
    # Convert the category name to a standard format to identify patterns
    for k, v in category_count.items():
        k = str(k)
        converted_k = re.sub(r"[a-zA-Z]", "X", k)
        converted_k = re.sub(r"[0-9]", "9", converted_k)
        if converted_k in profile_shape_dict:
            profile_shape_dict[converted_k]["statistics"]["sampleSize"] += v
            continue
        profile_shape_dict[converted_k] = {
            "value": converted_k,
            "statistics": {
                "sampleSize": v,
            },
        }
    # Calculate the likelihood of each category pattern
    for v in profile_shape_dict.values():
        v["statistics"]["likelihood"] = round(
            v["statistics"]["sampleSize"] / total_rows, 4
        )
        profile_shape_list.append(v)
    profile_shape_list = sorted(
        profile_shape_list, key=lambda x: x["statistics"]["sampleSize"], reverse=True
    )
    if len(profile_shape_list) > config.categorical_limit:
        profile_shape_list = []
    return profile_shape_list


def create_sampled_categories(category_count: dict) -> list[dict]:
    """
    Creates a list of sampled categories based on the provided category count.

    This function processes a dictionary of categories and their respective counts,
    and returns a list of dictionaries containing the category values and their
    associated statistics. If the number of categories exceeds 250, it selects
    the top 20 categories with the highest counts and the bottom 10 categories
    with the lowest counts. Otherwise, it includes all categories.

    Args:
        category_count (dict): A dictionary where keys are category names (str)
            and values are their respective counts (int).

    Returns:
        list[dict]: A list of dictionaries, where each dictionary represents a
            category with the following structure:
            {
                "value": <category_name>,
                "statistics": {
                    "sampleSize": <category_count>
                }
            }
            If the input dictionary is empty, an empty list is returned.
    """
    if len(category_count) == 0:
        return []
    sampled_categories = []
    sorted_categories = sorted(category_count.items(), key=lambda x: x[1], reverse=True)
    #  Get the top 20 and bottom 10 categories if the number of categories is greater than 250
    if len(sorted_categories) > 250:
        count = 1
        for k, v in sorted_categories:
            if count <= 20 or count > len(sorted_categories) - 10:
                # Encode all non-printable characters in 0x format
                encoded_value = "".join(
                    f"0x{ord(c):02x}" if not c.isprintable() else c for c in str(k)
                )
                sampled_categories.append(
                    {"value": encoded_value, "statistics": {"sampleSize": v}}
                )
            count += 1
    else:
        for k, v in sorted_categories:
            encoded_value = "".join(
                f"0x{ord(c):02x}" if not c.isprintable() else c for c in str(k)
            )
            sampled_categories.append(
                {"value": encoded_value, "statistics": {"sampleSize": v}}
            )
    return sampled_categories


def determine_format(
    col_name: str,
    data_type: str,
    data_profile_shape: list,
    samples: list,
    config: Config,
) -> dict:
    """
    Determines the format of the data in a column based on its type, profile, and sample values.

    Args:
        col_name (str): The name of the column being analyzed.
        data_type (str): The data type of the column (e.g., "date").
        data_profile_shape (list): A list describing the shape/profile of the data.
        samples (list): A list of sample values from the column.
        config (Config): The configuration object containing settings for the data specification generation.

    Returns:
        dict: A dictionary containing the inferred format details. Possible keys include:
            - "dateFormat": A dictionary mapping date format patterns (e.g., "MM/DD/YYYY")
              to their occurrence counts.
            - "separator": A character identified as a separator in categorical data.

    Notes:
        - For date columns, the function attempts to identify the date format based on
          common patterns and separators (e.g., "/", "-", ".").
        - For non-date columns, the function checks for the presence of separators
          (e.g., commas, pipes) and ensures they are not part of sentences by comparing
          against a list of common words.
        - If no valid format or separator is found, the returned dictionary will be empty.
    """
    col_format = {}
    # Use common words below to check if the column is a sentence
    # as values might contain a separator
    # and we want to avoid splitting the content
    common_words = ["is", "are", "a", "on", "in", "to", "and", "this", "that"]
    if len(samples) == 0:
        logger.warning(f"Samples are empty for column: {col_name}")
        return col_format
    # check date format if the column is a date
    date_format = {}
    key = ""
    if len(data_profile_shape) > 0:
        if data_type == "date":
            for sample in samples:
                if not sample:
                    continue
                match = re.search(r"[/.\-]", sample)
                if not match:
                    logger.debug(f"No valid separator found in sample: {sample}")
                    continue
                separator = match.group()
                data_split = sample.split(separator)
                if re.match(r"^\d{2}[/.\-]\d{2}[/.\-]\d{4}$", sample):
                    if int(data_split[0]) > 12:
                        key = f"DD{separator}MM{separator}YYYY"
                    else:
                        key = f"MM{separator}DD{separator}YYYY"
                elif re.match(r"^\d{4}[/.\-]\d{2}[/.\-]\d{2}$", sample):
                    key = f"YYYY{separator}MM{separator}DD"
                elif re.match(r"^\d{2}[- ]?[A-Z]{3}[- ]?\d{4}$", sample):
                    key = f"DD{separator}MMM{separator}YYYY"
                if key != "":
                    date_format[key] = date_format.setdefault(key, 0) + 1
            col_format["dateFormat"] = date_format
        # check if the column value consists of separators, no need to check if it's an id or an address column
        # as these columns are not supposed to be categorical
        elif (
            config.identification_column_keyword not in col_name.lower()
            and config.address_column_keyword not in col_name.lower()
        ):
            logger.debug(
                f"Column is considered as categorical: {col_name}, looking for separators"
            )
            config_separator = config.separator
            if config_separator:
                for sample in samples:
                    sample = str(sample)
                    for match in re.finditer(f"[{config.separator}]", sample, re.S):
                        separator_found = match.group()
                        # Check if the value is a sentence
                        sample_parts = sample.lower().split(separator_found)
                        # Splitting the sample by space and check if any word is matching the common words
                        # If no common words are found, add the separator to the list
                        if (
                            sum(
                                [
                                    1
                                    for w in common_words
                                    for sample_part in sample_parts
                                    if w in sample_part.split()
                                ]
                            )
                            == 0
                        ):
                            col_format["separator"] = separator_found
                            logger.debug(
                                f"Separator {separator_found} found in column: {col_name}"
                            )
                            break
            if "separator" not in col_format:
                logger.debug(f"No valid separator found in column: {col_name}")
    return col_format


def get_times(times_obj: dict) -> list:
    """
    Extracts and returns a list of times from the given dictionary.

    Args:
        times_obj (dict): A dictionary where the values represent times it took to generate this sample's statistics in milliseconds.

    Returns:
        list: A list of times extracted from the dictionary. If the dictionary
        is empty, an empty list is returned.
    """
    if len(times_obj) == 0:
        return []
    return [num for num in times_obj.values()]


def get_unLikeAbility(categorical_count: dict) -> int:
    """
    Calculate the unLikeAbility of the data based on categorical counts.

    The unLikeAbility is determined by summing up the values of the provided
    categorical count dictionary. If the dictionary is empty, the unLikeAbility
    is considered to be 0.

    Args:
        categorical_count (dict): A dictionary where keys represent categories
            and values represent their respective counts.

    Returns:
        int: The unLikeAbility value, which is the sum of all counts in the
            categorical_count dictionary, or 0 if the dictionary is empty.
    """
    if len(categorical_count) == 0:
        return 0
    return sum(categorical_count.values())


def get_data_type(col_name: str, data_profile_shape: list, data_type: str) -> str:
    """
    Determines the data type of a column based on its name and profile shape.

    Args:
        col_name (str): The name of the column to analyze.
        data_profile_shape (list): A list of profile patterns for the column, where each pattern
            contains statistics and a value. The first pattern is used for determining the type.
        data_type (str): The initial data type to be refined based on the analysis.

    Returns:
        str: The determined data type of the column. If the column name or profile shape suggests
        a date format, the data type is set to "date".
    """
    # Detect if we can determine the type, specifically looking for dates
    if "date" in col_name.lower():
        data_type = "date"
    if len(data_profile_shape) > 0:
        top_pattern = data_profile_shape[0]
        if top_pattern["statistics"]["likelihood"] > 0.75:
            if top_pattern["value"] in date_format_list:
                data_type = "date"
    return data_type


def save_reference_data_table(
    reference_data_sets: list, name: str, dataset: list[dict]
) -> str | None:
    """
    Saves a reference dataset to the reference data sets list. If a reference dataset
    with the same name already exists, it updates the existing dataset by appending
    new unique values.

    Args:
        reference_data_sets (list): A list of existing reference datasets, where each
            dataset is represented as a dictionary.
        name (str): The name of the reference dataset to be saved or updated.
        dataset (list[dict]): A list of dictionaries representing the values to be
            added to the reference dataset. Each dictionary should have a "value" key.

    Returns:
        str: The reference dataset code (refDataCode) for the saved or updated dataset.

    Raises:
        None

    Notes:
        - If the reference dataset already exists, a warning is logged, and only new
          unique values are appended to the existing dataset.
        - If the reference dataset does not exist, it is created and added to the
          reference_data_sets list, and a debug message is logged.
    """
    ref_name = "REF-" + name
    reference_definition = {
        "refDataCode": ref_name,
        "description": f"List of {name}",
        "dataMapping": {"dataMappingName": None},
        "values": dataset,
    }
    # Check if the reference dataset already exists
    # If it does, update the existing reference dataset with the new values
    # to expand the reference dataset
    existing_refs = [ref["refDataCode"] for ref in reference_data_sets]
    if ref_name in existing_refs:
        for ref in reference_data_sets:
            if ref["refDataCode"] == ref_name:
                logger.warning(f"Reference dataset {ref_name} already exists.")
                existing_values = [v["value"] for v in ref["values"]]
                # Update the existing reference dataset with the new values
                for value in dataset:
                    if value["value"] and value["value"].strip() not in existing_values:
                        ref["values"].append(value)
                return ref_name
    else:
        reference_data_sets.append(reference_definition)
        logger.debug(f"column '{name}' is categorical, saved as {ref_name}")
        return ref_name


def update_category_list(key: str, value: int, category_list: list, total_rows: int):
    """
    Updates a list of categories with statistical information for a given key.

    If the key does not exist in the category list, it adds a new entry with the
    provided value and calculates its likelihood based on the total number of rows.
    If the key already exists, it updates the sample size and recalculates the likelihood.

    Args:
        key (str): The category key to update or add.
        value (int): The sample size associated with the key.
        category_list (list): A list of dictionaries representing categories. Each dictionary
            contains a "value" key for the category name and a "statistics" key for its data.
        total_rows (int): The total number of rows used to calculate the likelihood.

    Returns:
        None: The function modifies the `category_list` in place.
    """
    category_list_keys = [k["value"] for k in category_list]
    if key not in category_list_keys:
        category_list.append(
            {
                "value": key,
                "statistics": {
                    "sampleSize": value,
                    "likelihood": round(value / total_rows, 4),
                },
            }
        )
    else:
        # If the key already exists, update its sample size
        index = category_list_keys.index(key)
        category_list[index]["statistics"]["sampleSize"] += value
        category_list[index]["statistics"]["likelihood"] = round(
            category_list[index]["statistics"]["sampleSize"] / total_rows, 4
        )


def dict_to_category_list_count(
    col_format: dict, category_count: dict, total_rows: int
) -> list:
    """
    Converts a dictionary of categorical counts into a list of categories,
    splitting categories based on a specified separator, if applicable.

    Args:
        col_format (dict): A dictionary containing formatting options for the column.
            Expected to have a key "separator" which specifies the delimiter used
            to split category names, if applicable.
        category_count (dict): A dictionary where keys are category names and values
            are the counts of occurrences for each category.
        total_rows (int): The total number of rows in the dataset, used for
            calculating proportions or other metrics.

    Returns:
        list: A list of processed categories.
    """
    category_list = []
    separator = col_format.get("separator")
    for category, category_data_count in category_count.items():
        if isinstance(category, str):
            if separator and separator in category:
                # If separator is found in the category name, split it into parts
                category_parts = category.split(separator)
                for part in category_parts:
                    part = part.strip()
                    update_category_list(
                        part, category_data_count, category_list, total_rows
                    )
                continue

            encoded_category = "".join(
                f"0x{ord(c):02x}" if not c.isprintable() else c for c in category
            )
            update_category_list(
                encoded_category, category_data_count, category_list, total_rows
            )
            continue
        update_category_list(
            str(category), category_data_count, category_list, total_rows
        )
    return category_list


def save_reference_datasets(
    reference_data_sets: list,
    col_name: str,
    category_count: dict,
    total_rows: int,
    col_format: dict,
    config: Config,
) -> str | None:
    """
    Saves a reference dataset based on the provided category count and column format.

    This function generates a list of category counts from the given category count
    dictionary and column format, and saves it as a reference dataset if the number
    of categories falls within a specified limit.

    Args:
        reference_data_sets (list): A list to store reference datasets.
        col_name (str): The name of the column for which the reference dataset is generated.
        category_count (dict): A dictionary containing category counts for the column.
        total_rows (int): The total number of rows in the dataset.
        col_format (dict): A dictionary specifying the format of the column.
        config (Config): The configuration object containing settings for the data specification generation.

    Returns:
        str: The name of the saved reference dataset if successful, otherwise None.
    """
    category_count_list = dict_to_category_list_count(
        col_format, category_count, total_rows
    )
    logger.debug(
        f"col_name: {col_name}, category list count: {len(category_count_list)}"
    )
    if (
        len(category_count_list) > 0
        and len(category_count_list) < config.categorical_limit
    ):
        ref_name = save_reference_data_table(
            reference_data_sets, col_name, category_count_list
        )
        return ref_name
    return None


def generate_schema_information(
    global_schema: dict, content: str, ai_client: genai.Client
) -> DataSource:
    """
    Generates a detailed description of the data source and its datasets based on the provided global schema.

    This function processes the provided `global_schema` to create a structured description
    of the data source, including dataset relationships, column-level details, summarised
    statistics, and location coverage. The output is formatted as a `DataSource` object.

    Args:
        global_schema (dict): A dictionary representing the global schema of the data source,
            including datasets and their respective columns.
        content (str): The content to be used in the AI prompt for generating the description.
        ai_client (genai.Client): The AI client used for generating the description. If None,
            the function will not generate a description and will return an empty `DataSource`.

    Returns:
        DataSource: A structured description of the data source, including dataset relationships,
        column descriptions, summarised statistics, and location coverage.

    Notes:
        - The function removes reference datasets and data profile shapes from the schema to reduce
          the number of tokens for processing.
        - For each dataset, the function generates:
            - A short description of the dataset.
            - Column-level descriptions, including the meaning of values, summarized statistics,
              and top 3 frequent values (if applicable).
            - Date format information for date columns.
        - The output is generated in the same order as the input data.
        - Includes a list of all countries covered in the data source in alphabetical order,
          excluding duplicates and ensuring only valid countries are listed.
    """
    ai_config: GenerateContentConfigDict = {
        "response_mime_type": "application/json",
        "response_schema": DataSource,
    }
    description_response = DataSource(
        name="", description="", data_sets=[], relationships=[], location_coverage=[]
    )
    schema_for_description = copy.deepcopy(global_schema)
    # remove the reference datasets to reduce the number of tokens, keep it within limit
    glom_delete(schema_for_description, "referenceDatasets", ignore_missing=True)
    for dataset in schema_for_description["datasets"]:
        for column in dataset["columns"]:
            # remove the data profile shape to reduce the number of tokens
            glom_delete(
                column, "dataProfile.statistics.dataProfileShape", ignore_missing=True
            )
            glom_delete(
                column, "dataProfile.statistics.categories", ignore_missing=True
            )
            column["dataProfile"]["statistics"]["sampledCategories"] = [
                cat["value"]
                for cat in column["dataProfile"]["statistics"]["sampledCategories"]
            ]
    if ai_client:
        logger.debug("Generating column description using Gemini AI")
        description_response: DataSource = generate_response(
            ai_client,
            content + str(schema_for_description),
            ai_config,
        )
    return description_response


def clean_column_stats(schema_column: dict) -> None:
    """
    Cleans the column statistics in the provided schema column dictionary by removing
    empty or irrelevant values and updating the column type if necessary.

    Args:
        schema_column (dict): A dictionary representing the schema column, which includes
                              data profile statistics and other metadata.

    Returns:
        None: The function modifies the `schema_column` in place, removing unnecessary
              keys and updating the column type if needed.
    """
    sampled_rows_with_values = schema_column["dataProfile"]["statistics"].get(
        "sampleSize", 0
    ) - schema_column["dataProfile"]["statistics"].get("nullCount", 0)
    # If we found a row with no values need to clean it up to be sensible
    if sampled_rows_with_values == 0:
        schema_column["type"] = "empty"
        glom_delete(schema_column, "dataProfile.order", ignore_missing=True)
        glom_delete(schema_column, "dataProfile.categorical", ignore_missing=True)
        glom_delete(
            schema_column,
            "dataProfile.statistics.unalikeability",
            ignore_missing=True,
        )
        glom_delete(
            schema_column,
            "dataProfile.statistics.giniImpurity",
            ignore_missing=True,
        )
        glom_delete(
            schema_column, "dataProfile.statistics.uniqueCount", ignore_missing=True
        )
        glom_delete(
            schema_column, "dataProfile.statistics.uniqueRatio", ignore_missing=True
        )
        glom_delete(schema_column, "dataProfile.statistics.min", ignore_missing=True)
        glom_delete(schema_column, "dataProfile.statistics.max", ignore_missing=True)
        glom_delete(schema_column, "dataProfile.statistics.times", ignore_missing=True)
        glom_delete(
            schema_column, "dataProfile.statistics.histogram", ignore_missing=True
        )
        glom_delete(
            schema_column, "dataProfile.statistics.nullTypes", ignore_missing=True
        )
        glom_delete(
            schema_column, "dataProfile.statistics.vocabulary", ignore_missing=True
        )
        glom_delete(
            schema_column, "dataProfile.statistics.categories", ignore_missing=True
        )

    # If there are no nulls remove the placeholder
    if not schema_column["nullOrEmpty"]:
        glom_delete(schema_column, "nullValues")

    # If the values didn't exist remove them from the dictionary
    if math.isnan(glom(schema_column, "dataProfile.statistics.skewness", default=1)):
        glom_delete(schema_column, "dataProfile.statistics.skewness")


def compare_column_names(col_1: str, col_2: str, separator: str) -> bool:
    """
    Compare two column names to determine if they are similar based on a fuzzy matching ratio.

    Args:
        col_1 (str): The first column name to compare.
        col_2 (str): The second column name to compare.
        separator (str): The character used as a separator in the column names,
                         which will be replaced with a space for comparison.

    Returns:
        bool: True if the similarity ratio between the two column names is
              greater than or equal to 80, otherwise False.
    """
    col_1 = col_1.replace(separator, " ")
    col_2 = col_2.replace(separator, " ")
    fuzz_ratio = fuzz.partial_ratio(col_1, col_2)
    if fuzz_ratio >= 80:
        return True
    return False


def compare_categories(cat_1: list, cat_2: list) -> bool:
    """
    Compare two categories to determine if they are similar based on their sample sizes.

    This function takes two lists of category dictionaries and checks if they are
    similar by comparing the sample sizes of their respective elements. The comparison
    is performed after ensuring that both lists have the same length.

    Args:
        cat_1 (list): The first list of category dictionaries. Each dictionary is
            expected to have a "statistics" key containing a "sampleSize" field.
        cat_2 (list): The second list of category dictionaries. Each dictionary is
            expected to have a "statistics" key containing a "sampleSize" field.

    Returns:
        bool: True if the two categories are similar (i.e., they have the same length
            and corresponding sample sizes match), otherwise False.
    """
    if len(cat_1) != len(cat_2):
        return False
    # compare sorted sampled categories by sample size
    # if all the sample sizes match, the columns can be considered correlated
    for i in range(len(cat_1)):
        if cat_1[i]["statistics"]["sampleSize"] != cat_2[i]["statistics"]["sampleSize"]:
            return False
    return True


def check_correlated_columns(column_list: list) -> None:
    """
    Identifies and marks correlated columns in a list of column metadata.

    This function checks for correlations between columns based on their names
    and sampled categories. If two columns are determined to be correlated,
    their names are added to each other's "correlatedColumns" list in the
    provided column metadata.

    Args:
        column_list (list): A list of dictionaries, where each dictionary
            represents column metadata. Each dictionary is expected to have
            the following structure:
            - "name" (str): The name of the column.
            - "dataProfile" (dict): A dictionary containing the column's data
              profile, which includes:
                - "statistics" (dict): A dictionary containing statistical
                  information, including:
                    - "sampledCategories" (list): A list of sampled categories
                      for the column.

    Returns:
        None: The function modifies the input `column_list` in place by adding
        a "correlatedColumns" key to correlated columns.
    """
    separators = set([detect(c["name"]) for c in column_list if detect(c["name"])])
    separator = list(separators)[0] if len(separators) > 0 else ""
    for i, col_1 in enumerate(column_list):
        for j in range(i + 1, len(column_list)):
            check_1 = compare_column_names(
                col_1["name"], column_list[j]["name"], separator or ""
            )
            check_2 = compare_categories(
                col_1["dataProfile"]["statistics"]["sampledCategories"],
                column_list[j]["dataProfile"]["statistics"]["sampledCategories"],
            )
            if check_1 and check_2:
                column_list[i].setdefault("correlatedColumns", []).append(
                    column_list[j]["name"]
                )
                column_list[j].setdefault("correlatedColumns", []).append(col_1["name"])


def insert_generated_information(
    global_schema: dict,
    generated_info: DataSource,
    enable_anonymisation: bool = False,  # noqa: FBT001, FBT002 - Passing enable_anonymisation to anonymise_samples function and allowing the caller to easily enable or disable anonymisation without requiring additional configuration or complex parameter structures
) -> list[str]:
    """
    This function updates the provided global schema by adding descriptions for
    the overall schema, individual datasets, and their respective columns and locations
    covered in the data source as well as identifying primary and foreign keys. It ensures
    that datasets and their descriptions are sorted by name for consistency.

    Args:
        global_schema (dict): The global schema dictionary to be updated. It contains
            information about datasets and their columns.
        generated_info (DataSource): A data source object containing descriptions for
            datasets and columns, including additional metadata such as comments, primary
            and foreign keys.
        enable_anonymisation (bool): A flag indicating whether to enable anonymisation
            of sensitive data. If set to True, the function will anonymise sensitive data
            in the columns of the datasets.

    Returns:
        None: This function modifies the `global_schema` in place.
    """
    logger.debug("Adding AI-generated information to the schema")
    new_missing_elements = []
    # To keep existing descriptions if ai generates a new one accidentally
    # Only update the description if it is the default one
    global_schema["description"] = (
        generated_info["description"]
        if global_schema["description"]
        == f"Data specification for {global_schema['name']}"
        else global_schema["description"]
    )
    # Add/Update the location coverage to the global schema, check existing fields to avoid overwriting
    global_schema["locationCoverage"] = (
        generated_info["location_coverage"]
        if len(global_schema["locationCoverage"])
        != len(generated_info["location_coverage"])
        else global_schema["locationCoverage"]
    )
    # Add the relationships between datasets to the global schema
    global_schema["relationships"] = generated_info["relationships"]
    # Sort the datasets by name
    global_schema["datasets"] = sorted(
        global_schema["datasets"], key=lambda x: x["name"]
    )
    generated_info["data_sets"] = sorted(
        generated_info["data_sets"], key=lambda x: x["data_set_name"]
    )
    # Add the description to each dataset
    if len(global_schema["datasets"]) == len(generated_info["data_sets"]):
        for i, dataset in enumerate(global_schema["datasets"]):
            if dataset["name"] == generated_info["data_sets"][i]["data_set_name"]:
                dataset["description"] = (
                    generated_info["data_sets"][i]["description"]
                    if not dataset["description"]
                    else dataset["description"]
                )
                # Add the description to each column
                if len(dataset["columns"]) == len(
                    generated_info["data_sets"][i]["columns"]
                ):
                    for j, column in enumerate(dataset["columns"]):
                        generated_column = generated_info["data_sets"][i]["columns"][j]
                        if column["name"] == generated_column["col_name"]:
                            # check existing fields to avoid overwriting
                            column["description"] = (
                                generated_column["description"]
                                if not column["description"]
                                else column["description"]
                            )
                            column["comment"] = (
                                generated_column["comment"]
                                if not column["comment"]
                                else column["comment"]
                            )
                            column["foreignKey"] = (
                                column["foreignKey"]
                                if column.get("foreignKey")
                                else generated_column["foreign_key"]
                            )
                            column["dataClassification"] = (
                                generated_column["data_classification"].value
                                if not column["dataClassification"]
                                else column["dataClassification"]
                            )
                            column["dataDictionaryMatching"] = {
                                "result": generated_column[
                                    "data_dictionary_match"
                                ].value
                            }
                            if (
                                generated_column["data_dictionary_match"]
                                == DataDictionaryMatch.new_missing
                            ):
                                new_missing_elements.append(
                                    f"{dataset['name']}.{column['name']}"
                                )
                                column["dataDictionaryMatching"][
                                    "potentialElementMatch"
                                ] = generated_column["proposed_dd_match"]
                            if enable_anonymisation and (
                                generated_column["data_classification"]
                                == DataClassification.sensitive_data
                            ):
                                logger.debug(
                                    f"Anonymising column {column['name']} in dataset {dataset['name']}"
                                )
                                column["samples"] = (
                                    generated_column["anonymised_samples"]
                                    if generated_column["anonymised_samples"]
                                    else column.get("samples", [])
                                )
        # Update descriptions to the reference datasets with generated column descriptions
        for ref in global_schema["referenceDatasets"]:
            for dataset in generated_info["data_sets"]:
                for col in dataset["columns"]:
                    if col["col_name"] == ref["refDataCode"][4:]:
                        ref["description"] = col["description"]
    return new_missing_elements


def categorical_check(
    inferred_categorical: bool,  # noqa: FBT001 - inferred_categorical is a boolean
    col_data_type: str,
    col_format: dict,
    data_profile_shape: list,
    unique_count: int,
    total_rows: int,
) -> bool:
    """
    Determines whether a column should be considered categorical based on various
    conditions such as data type, format, profile shape, and unique value count.

    Args:
        inferred_categorical (bool): Initial inference of whether the column is categorical.
        col_data_type (str): The data type of the column (e.g., "text").
        col_format (dict): A dictionary containing format information for the column
            (e.g., presence of "dateFormat").
        data_profile_shape (list): A list of dictionaries representing the profile shape
            of the data, where each dictionary contains information such as "value" and
            "statistics" (e.g., likelihood).
        unique_count (int): The number of unique values in the column.
        total_rows (int): The total number of rows in the dataset.

    Returns:
        bool: True if the column is determined to be categorical, False otherwise.

    Notes:
        - If the column has a "dateFormat" in its format or its data type is "text",
          it is not considered categorical.
        - If all values in the column have more than 10 words, it is considered "text"
          and not categorical.
        - If the unique value count exceeds 50% of total rows, it is not categorical.
        - Otherwise, the function returns the value of `inferred_categorical`.
    """
    if "dateFormat" in col_format or col_data_type == "text":
        return False
    if len(data_profile_shape) > 0:
        # if the word count of the value is more than 10, it's considered as "text" and non-categorical
        # or if the unique count is more than 50% of total rows
        # which means the values in columns are not consistent
        if all(len(s["value"].split(" ")) > 10 for s in data_profile_shape) or (
            unique_count > total_rows * 0.5
        ):
            return False
    return inferred_categorical


def reuse_current_generated_information(
    global_schema: dict, existing_schema: dict
) -> None:
    """
    Updates the global schema with descriptions and metadata from an existing schema.

    This function copies descriptions, comments, and key information (e.g., primaryKey, foreignKey)
    from an existing schema to the global schema for datasets and their columns.

    Args:
        global_schema (dict): The schema being generated, which will be updated with additional information.
        existing_schema (dict): The existing schema containing descriptions and metadata to be reused.

    Returns:
        None: The function modifies the `global_schema` in place.
    """
    if (
        len(global_schema) > 0
        and len(existing_schema) > 0
        and global_schema.get("name") == existing_schema.get("name")
    ):
        global_schema["description"] = existing_schema["description"]
        for ds in global_schema["datasets"]:
            for existing_ds in existing_schema["datasets"]:
                if ds["name"] == existing_ds["name"]:
                    ds["description"] = existing_ds["description"]
                    for col in ds["columns"]:
                        for existing_col in existing_ds["columns"]:
                            if col["name"] == existing_col["name"]:
                                col["description"] = existing_col["description"]
                                col["comment"] = existing_col["comment"]
                                col["primaryKey"] = existing_col.get(
                                    "primaryKey", False
                                )
                                col["foreignKey"] = existing_col.get(
                                    "foreignKey", False
                                )
                                # Reduce changes due to ordering of samples
                                if col.get("samples", []) and set(
                                    col.get("samples", [])
                                ) == set(existing_col.get("samples", [])):
                                    col["samples"] = existing_col["samples"]
    else:
        logger.warning(
            "Global schema name does not match existing schema name, skipping reuse of information."
        )


def generate_data_spec(
    data_source: str,
    profile_dir: str,
    config: Config,
    output_dir: str,
    schema_path: str = "",
    existing_dataspec_path: str = "",
    verify_primary_key: bool = False,  # noqa: FBT001, FBT002 - Passing verify_primary_key to review_primary_key function and allowing the caller to easily enable or disable verification logic without requiring additional configuration or complex parameter structures
    full_data_profiling: bool = False,  # noqa: FBT001, FBT002 - Passing full_data_profiling to check_replace_profile_data function and allowing the caller to easily enable or disable full data profiling without requiring additional configuration or complex parameter structures
    raw_data_folder: str = "",
    enable_anonymisation: bool = False,  # noqa: FBT001, FBT002 - Passing enable_anonymisation to anonymise_samples function and allowing the caller to easily enable or disable anonymisation without requiring additional configuration or complex parameter structures
    published_ref_dataset_path: str = "",
    publish_new_ref_dataset: list | None = None,
    data_dictionary_path: str = "",
) -> dict:
    """
    Generates a data specification JSON file based on profiler output files.

    This function processes profiler output files located in the specified directory
    and generates a data specification JSON file. The generated specification includes
    metadata, dataset-level statistics, and column-level statistics.

    Args:
        data_source (str): The name of the data source for which the specification is generated.
        profile_dir (str): The directory containing profiler output files (e.g., "p_*.json").
        config (Config): Configuration object containing settings for the data specification generation.
        output_dir (str): The directory where the generated data specification JSON file will be saved.
        schema_path (str, optional): Path to a custom schema template file. If not provided,
            a default schema template is used.
        existing_dataspec_path (str, optional): Path to an existing data specification file.
            If provided, the function will attempt to reuse information from this file.
        verify_primary_key (bool, optional): If True, verifies the primary key in the generated data specification.
            Defaults to False.
        full_data_profiling (bool, optional): If True, performs full data profiling by checking and replacing
            profile data with raw data. Defaults to False.
        raw_data_folder (str, optional): The folder containing raw data files. Required if `full_data_profiling` is True.
        enable_anonymisation (bool, optional): If True, anonymises sensitive samples in the data specification. Defaults to False.
        published_ref_dataset_path (str, optional): Path to a published reference dataset file.
        publish_new_ref_dataset (list, optional): List of new reference datasets to be published.

    Returns:
        dict or None: The generated data specification as a dictionary if successful,
        or None if no profiler output files are found.
    """

    schema_template = schema_default_template
    # if schema_path is provided, load the schema template from the file
    if schema_path != "":
        with open(schema_path, "r", encoding="utf-8") as file:
            schema_template = json.load(file)

    # get profiler output files
    data_profile_paths = list(Path(profile_dir).glob("p_*.json"))
    if "gs://" in profile_dir:
        fs, fs_path = fsspec.core.url_to_fs(profile_dir)
        matches = fs.glob(fs_path.rstrip("/") + "/p_*.json")
        data_profile_paths = [Path("gs://" + m) for m in matches]
    if len(data_profile_paths) == 0:
        logger.error(f"No data profile files found in {profile_dir}")
        msg = f"No data profile files found in {profile_dir}. Please run the profiler first."
        raise ValueError(msg)

    logger.debug(f"Generating data specification for {data_source}")

    ai_client = setup_gemini_client()

    global_schema = {
        "$schema": schema_path
        if schema_path != ""
        else "https://schema.kyd.ai/v1/datasource.schema.json",
        "name": data_source,
        "description": f"Data specification for {data_source}",
        "licensing": "Open Data Commons",
        "locationCoverage": ["uk"],
        "referenceDatasets": [],
        "format": set(),  # use set here in case there are multiple formats within one data source
        "encoding": "",
        "datasets": [],
        "relationships": [],
    }

    # iterate through the profiler output files
    for p in data_profile_paths:
        logger.debug(f"Processing file: {p}")

        profile = read_profiler_output(p)
        file_name = p.stem[p.stem.index("p_") + 2 :]
        file_type = profile["global_stats"]["file_type"]
        if full_data_profiling:
            profile = check_replace_profile_data(
                Path(f"{raw_data_folder}/{file_name}.{file_type}"), profile
            )
        # Get the data source level statistics
        global_schema["format"].add(file_type)
        global_schema["encoding"] = profile["global_stats"]["encoding"]
        profile_global_stats = profile["global_stats"]
        schema_template["format"] = profile_global_stats["file_type"]
        schema_template["encoding"] = profile_global_stats["encoding"]
        primary_key_list = []

        # Get the file level statistics
        data_set = {
            "name": file_name,
            "type": file_type,
            "description": "",
            "code": file_name,
            "statistics": {
                "columnCount": profile_global_stats["column_count"],
                "rowCount": profile_global_stats["row_count"],
                "samplesUsed": profile_global_stats["samples_used"],
                "uniqueRowRatio": profile_global_stats["unique_row_ratio"],
                "rowHasNullRatio": profile_global_stats["row_has_null_ratio"],
                "duplicateRowCount": profile_global_stats["duplicate_row_count"],
                "rowIsNullRatio": profile_global_stats["row_is_null_ratio"],
            },
            "completeDataSpec": profile_global_stats["row_count"]
            == profile_global_stats["samples_used"],
            "columns": [],
        }
        for col in profile["data_stats"]:
            description = ""
            comment = ""
            col_stats = col.get("statistics", {})
            data_profile_shape = create_data_profile_shape(
                col_stats.get("categorical_count", {}),
                col_stats.get("sample_size", 0) - col_stats.get("null_count", 0),
                config,
            )

            col_data_type = get_data_type(
                col["column_name"], data_profile_shape, col.get("data_type", "")
            )
            col_format = determine_format(
                col["column_name"],
                col_data_type,
                data_profile_shape,
                string_array_to_list(col.get("samples", [])),
                config,
            )
            categorical = categorical_check(
                col["categorical"],
                col["data_type"],
                col_format,
                data_profile_shape,
                col_stats.get("unique_count", 0),
                col_stats.get("sample_size", 0),
            )
            unique = col_stats.get("unique_ratio", 0) == 1.0
            primary_key = identify_primary_key(
                col_stats, col["column_name"], col_data_type, config, primary_key_list
            )

            column = {
                # Name of the attribute
                "name": col["column_name"].strip(),
                # Data type of the attribute
                "type": col_data_type,
                # Description of the attribute
                "description": description,
                "dataClassification": None,
                "comment": comment,
                # If the column is required in the dataset, determined by the column name and the unique count
                "required": config.identification_column_keyword
                in col["column_name"].lower()
                and unique,
                # If the column is unique, determined by the unique count and sample size
                "unique": col_stats.get("unique_count", 0)
                == col_stats.get("sample_size", 0),
                # List of examples values from the column
                "samples": list(set(string_array_to_list(col["samples"]))),
                "dataElementMapping": None,
                "dataDictionaryMatching": {"result": "Not Matched"},
                # If the column can be null or empty
                "nullOrEmpty": col_stats.get("null_count", 0) > 0,
                "nullValues": string_array_to_list(col_stats.get("null_types", [])),
                # Format of the attribute (date or separator)
                "format": col_format,
                # Create a reference dataset for the column if it is categorical
                "refDataCode": save_reference_datasets(
                    global_schema["referenceDatasets"],
                    col["column_name"],
                    col_stats.get("categorical_count", {}),
                    col_stats.get("sample_size", 0) - col_stats.get("null_count", 0),
                    col_format,
                    config,
                )
                if "dateFormat" not in col_format
                and col["data_type"] != "text"
                and col["column_name"] not in primary_key_list
                else None,
                "primaryKey": primary_key,
                "alternateKey": False,
                "foreignKey": False,
                # Column data profile
                "dataProfile": {
                    "inferredDataType": col["data_type"],
                    "dataLabel": col_stats.get("data_label", ""),
                    "categorical": categorical,
                    "order": col["order"],
                    "statistics": {
                        # Number of input data samples used to generate this profile
                        "sampleSize": col_stats.get("sample_size", 0),
                        # Number of null values in the sample
                        "nullCount": col_stats.get("null_count", 0),
                        # A dict containing each null type and a respective list of the indices that it is present within this sample
                        "nullTypesIndex": {
                            k: string_array_to_list(v)
                            for k, v in col_stats.get("null_types_index", {}).items()
                        },
                        # The percentage of samples used identifying as each data_type
                        "dataTypeRepresentation": col.get(
                            "data_type_representation", {}
                        ),
                        # Minimum and maximum values of the column
                        "min": col_stats.get("min", 0),
                        "max": col_stats.get("max", 0),
                        # Mode of the entries in the sample
                        "mode": string_array_to_list(col_stats.get("mode", [])),
                        # Median of the entries in the sample
                        "median": col_stats.get("median", 0),
                        # Median absolute deviation of the entries in the sample
                        "medianAbsDeviation": col_stats.get(
                            "median_absolute_deviation", 0
                        ),
                        # The probability of incorrectly classifying a randomly chosen element in the dataset
                        # if it were randomly labelled according to the class distribution in the dataset.
                        "giniImpurity": col_stats.get("gini_impurity", 0),
                        # How often observations differ from one another
                        "unalikeability": col_stats.get("unalikeability", 0),
                        # The total of all sampled values from the column
                        "sum": col_stats.get("sum", 0),
                        # The average of all entries in the sample
                        "mean": col_stats.get("mean", 0),
                        # The variance of all entries in the sample
                        "variance": col_stats.get("variance", 0),
                        # The standard deviation of all entries in the sample
                        "stddev": col_stats.get("stddev", 0),
                        # The statistical skewness of all entries in the sample
                        "skewness": col_stats.get("skewness", 0),
                        # The statistical kurtosis of all entries in the sample
                        "kurtosis": col_stats.get("kurtosis", 0),
                        # The number of entries in this sample that have the value 0
                        "numZeros": col_stats.get("num_zeros", 0),
                        # The number of entries in this sample that have a value less than 0
                        "numNegatives": col_stats.get("num_negatives", 0),
                        # Contains histogram relevant information
                        # The number of entries within each bin and the thresholds of each bin
                        "histogram": {
                            k: string_array_to_list(v)
                            for k, v in col_stats.get("histogram", {}).items()
                        },
                        # The value at each percentile in the order they are listed based on the entries in the sample
                        "quantiles": col_stats.get("quantiles", {}),
                        # A list of the characters used within the entries in this sample
                        "vocabulary": string_array_to_list(col_stats.get("vocab", [])),
                        # Average of the data label prediction confidences across all data points sampled
                        "avgPredictions": col_stats.get("avg_predictions", {}),
                        # How the labels assigned to the data are formatted
                        "dataLabelRepresentation": col_stats.get(
                            "data_label_representation", {}
                        ),
                        # A list of each distinct category within the sample if categorial = 'true'
                        "categories": string_array_to_list(
                            col_stats.get("categories", [])
                        )
                        if categorical
                        else [],
                        # The number of distinct entries in the sample
                        "uniqueCount": col_stats.get("unique_count", 0),
                        # The proportion of the number of distinct entries in the sample
                        # to the total number of entries in the sample
                        "uniqueRatio": col_stats.get("unique_ratio", 0),
                        # Each standardised category pattern and the number of times it occurs in the sample
                        "dataProfileShape": data_profile_shape
                        if len(data_profile_shape) < config.data_profile_shapes_limit
                        else [],
                        # Sampled number of entries in each categories
                        # if categorical = 'true', limited to top 20 and bottom 10 if greater than categorical_limit
                        "sampledCategories": create_sampled_categories(
                            col_stats.get("categorical_count", {})
                        )
                        if categorical
                        else [],
                        # Number of entries sampled for each category
                        "unLikeAbility": get_unLikeAbility(
                            col_stats.get("categorical_count", {})
                        ),
                        # A dict of statistics with respect to the number of digits in a number for each sample
                        "precision": col_stats.get("precision", {}),
                        # The duration of time it took to generate this sample's statistics in milliseconds
                        "times": get_times(col_stats.get("times", [])),
                        # list of possible date-time formats
                        "format": col_stats.get("format", []),
                    },
                    # Statistics of data partitioned based on whether column value is null (index 1 of lists referenced by dict keys) or not (index 0)
                    "nullReplicationMetrics": col.get("null_replication_metrics", {}),
                },
            }
            # clean up column stats
            clean_column_stats(column)
            data_set["columns"].append(column)
        check_correlated_columns(data_set["columns"])
        global_schema["datasets"].append(data_set)
    # convert format set to list for json serialisation
    global_schema["format"] = str(list(global_schema["format"]))
    if published_ref_dataset_path:
        global_schema = match_reference_data_to_published_set(
            global_schema, published_ref_dataset_path, config
        )
    # Sort the ref data values in alphabetical and case insensitive order
    # to group similar values together
    # This is done to make it easier to read and understand the data values
    for ref in global_schema["referenceDatasets"]:
        for ref_item, ref_item_value in ref.items():
            if ref_item == "values":
                # sort the values in the reference dataset
                ref["values"] = sorted(
                    ref_item_value,
                    key=lambda x: int(x["value"])
                    if str(x["value"]).isnumeric()
                    else x["value"].upper(),
                )

    core_prompt_content = (
        "Create a concise description of the data source with the given datasets"
        "and describe the relationships between the datasets. Then write a short description for each dataset"
        "and describe what the values in the column mean "
        "in each dataset but no need to state the data type."
        "Summarise the statistics as an additional comment for each column from the dictionary in words"
        "and percentages. Only show top 3 frequent values if applicable. If it's a date column, provide the date format."
        "Check all the columns and list all the countries covered in the data source in alphabetical order with no duplicates."
        "Only include valid countries."
        "Classify the data for each column. Examples for 'names' are first name, last name, business name."
        "Examples for 'addresses' are personal address, business address, etc."
        "Examples for 'individual_identifier_data' are DOB, email addresses, phone numbers, bank account numbers, credit card numbers, IP address, MAC address, etc."
        "Examples for 'sensitive_data' are racial or ethnic origin, political opinions, SSN(Social Security Number), health data, Political Opinions, Unique Person Identifier, Gender, etc."
        "Examples for 'non_classified' are data that does not fall into any of the above categories."
        "ONLY anonymise the data classified as 'sensitive_data' in the 'samples' field in a similar format, but not be based off real life information."
        "Check all the columns in all datasets and which column is a foreign key."
        "Determine relationships between ALL datasets by matching their primary and foreign keys. Do NOT do self linking."
        "Describe the relationship as relationship type and write cardinality in the following format, e.g. '1:M' as one-to-many, '1:0..M' as one-to-zero-or-many, '1..M:1..M', '0..M:0..M', 'M:1', 'M:M'."
        "Output response in the same order as the input data: "
    )
    prompt_content = core_prompt_content
    if existing_dataspec_path:
        with open(existing_dataspec_path, "r", encoding="utf-8") as file:
            existing_schema = json.load(file)
        reuse_current_generated_information(global_schema, existing_schema)
        prompt_content = (
            "Check the given schema and fill in the documentation with the following instruction at all levels "
            "that have missing descriptions and comments. If both description and comment are missing, "
            "review the countries, relationships, primary key and foreign key fields: "
        ) + core_prompt_content

    if ai_client:
        generated_info = generate_schema_information(
            global_schema, prompt_content, ai_client
        )
        if data_dictionary_path:
            data_dictionary = read_data_dictionary(data_dictionary_path)
            generated_info = match_data_dictionary(
                generated_info, ai_client, data_dictionary
            )
        logger.debug(f"Generated schema information: {generated_info}")
        if generated_info:
            new_missing_elements = insert_generated_information(
                global_schema, generated_info, enable_anonymisation
            )
            logger.info(
                f"New/Missing elements identified from data dictionary: {new_missing_elements}"
            )
        logger.debug("Reviewing primary keys")
        review_primary_key(
            global_schema, ai_client, config, verify_primary_key, raw_data_folder
        )

    if publish_new_ref_dataset:
        publish_ref_dataset(
            global_schema,
            publish_new_ref_dataset[0],
            publish_new_ref_dataset[1],
            config,
        )

    logger.debug(f"Writing data specification to {data_source}_data_spec.json")
    # write the data spec to the output file
    output_path = Path(f"{output_dir}/{data_source}_data_spec.json")
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w", encoding="utf-8") as file:
        json.dump(global_schema, file, indent=4)
    return global_schema
