from pathlib import Path
from unittest.mock import MagicMock

import pytest
from mock import patch

from kyd_dataspec_gen.match_data_dictionary import (
    match_data_dictionary,
    read_data_dictionary,
)
from kyd_dataspec_gen.models import (
    Column,
    DataClassification,
    DataDictionaryMatch,
    DataSet,
    DataSource,
)

curr_dir = Path(__file__).parent


def test_read_data_dictionary():
    """Test reading a data dictionary from a CSV file."""
    data_dictionary_path = curr_dir / "test_data" / "test_data_dict.csv"
    data_dictionary = read_data_dictionary(data_dictionary_path)
    assert len(data_dictionary) == 9


generated_info = DataSource(
    name="adventureWorks",
    description="Adventure Works database",
    location_coverage=["UK"],
    data_sets=[
        DataSet(
            data_set_name="Product",
            description="This dataset provides detailed information about various products, including their unique keys, names, standard costs, colors, and categorization.",
            columns=[
                Column(
                    col_name="ProductKey",
                    description="Unique identifier for each product.",
                    comment="Integer values ranging from 210 to 606.",
                    foreign_key=False,
                    data_classification=DataClassification.non_classified,
                    anonymised_samples=[],
                    data_dictionary_match=DataDictionaryMatch.new_missing,
                    proposed_dd_match="PROD_KEY",
                ),
                Column(
                    col_name="SalesOrderNumber",
                    description="Unique identifier for each sales order.",
                    comment="String values, all 7 characters long.",
                    foreign_key=False,
                    data_classification=DataClassification.non_classified,
                    anonymised_samples=[],
                    data_dictionary_match=DataDictionaryMatch.new_missing,
                    proposed_dd_match=None,
                ),
            ],
        ),
    ],
    relationships=[],
)
expected_generated_match = DataSource(
    name="adventureWorks",
    description="Adventure Works database",
    location_coverage=["UK"],
    data_sets=[
        DataSet(
            data_set_name="Product",
            description="This dataset provides detailed information about various products, including their unique keys, names, standard costs, colors, and categorization.",
            columns=[
                Column(
                    col_name="ProductKey",
                    description="Unique identifier for each product.",
                    comment="Integer values ranging from 210 to 606.",
                    foreign_key=False,
                    data_classification=DataClassification.non_classified,
                    anonymised_samples=[],
                    data_dictionary_match=DataDictionaryMatch.matched,
                    proposed_dd_match=None,
                ),
                Column(
                    col_name="SalesOrderNumber",
                    description="Unique identifier for each sales order.",
                    comment="String values, all 7 characters long.",
                    foreign_key=False,
                    data_classification=DataClassification.non_classified,
                    anonymised_samples=[],
                    data_dictionary_match=DataDictionaryMatch.new_missing,
                    proposed_dd_match=None,
                ),
            ],
        ),
    ],
    relationships=[],
)
test_data_dictionary = [
    {
        "dd-id": "PROD",
        "description": "A product sold by the company",
        "is_a": "Entity",
        "part_of": "",
        "examples": [""],
    },
    {
        "dd-id": "PROD_KEY",
        "description": "Unique numerical identifier for a product",
        "is_a": "Identifier",
        "part_of": "Product",
        "examples": "210, 215, 599",
    },
    {
        "dd-id": "PROD_NAME",
        "description": "The full name of the product",
        "is_a": "often including size or specific variant",
        "part_of": "Product",
        "examples": [
            "HL Road Frame - Black, 58, Sport-100 Helmet, Black, AWC Logo Cap, Water Bottle - 30 oz."
        ],
    },
]


@pytest.mark.parametrize(
    "data_source, ai_client, data_dictionary, expected_match",
    [
        pytest.param(
            generated_info,
            MagicMock(),
            test_data_dictionary,
            expected_generated_match,
            id="ai_client_provided",
        ),
        pytest.param(
            generated_info,
            None,
            test_data_dictionary,
            generated_info,
            id="no_ai_client",
        ),
    ],
)
def test_match_data_dictionary(data_source, ai_client, data_dictionary, expected_match):
    """Test matching data elements against a data dictionary."""
    with (
        patch(
            "kyd_dataspec_gen.match_data_dictionary.generate_response"
        ) as mock_generate_response,
    ):
        mock_generate_response.return_value = expected_match
        result = match_data_dictionary(data_source, ai_client, data_dictionary)
        assert result == expected_match
