"""Docx Table Functions"""

import logging
from typing import TYPE_CHECKING

from anyascii import anyascii
from docx.oxml.table import CT_Tbl, CT_Tc
from docx.oxml.text.paragraph import CT_P
from lxml import etree as lxml_etree

from .docx_config import DocxConfig
from .docx_constants import OutputEncodeType
from .docx_parts_para import ParaPart
from .docx_utils import clean_html_tags

if TYPE_CHECKING:
    from .kyd_docx2md import Docx2Md

logger = logging.getLogger(__name__)

w_ns = "{http://schemas.openxmlformats.org/wordprocessingml/2006/main}"
w_lev = f"{w_ns}ilvl"


class TableParts:
    """Table processing for Docx2Md."""

    def __init__(self, converter_instance: "Docx2Md", config: DocxConfig) -> None:
        """
        Initialize the TableParts class.

        Args:
            converter_instance (Docx2Md): The Docx2Md instance.
            config (DocxConfig): The configuration settings.

        """
        self.config = config
        self.converter_instance = converter_instance

    def measure_table(self, table: CT_Tbl) -> bool:
        """
        Check if the table has any merged cells.

        Args:
            table: The table to check for merged cells.

        Returns:
            bool: True if merged cells are present, False otherwise.

        """
        self.max_cols = 0
        self.max_rows = 0
        self.merged_cells = False

        for row in table.tr_lst:  # pyright: ignore[reportAttributeAccessIssue]
            self.max_rows += 1
            col_counter = 0
            for cell in row.tc_lst:
                col_counter += 1
                tcPr = cell.tcPr
                if tcPr is not None:
                    vMerge = tcPr.find(f".//{w_ns}vMerge")
                    gridSpan = tcPr.find(f".//{w_ns}gridSpan")
                    if vMerge is not None or gridSpan is not None:
                        self.merged_cells = True
            self.max_cols = max(self.max_cols, col_counter)

        return self.merged_cells

    def table_to_html(self, table: CT_Tbl) -> str:
        """
        Convert a table element from a docx file into HTML format, handling merged cells and header rows.

        Args:
            table (CT_Tbl): The table element to be converted.

        Returns:
            str: The HTML representation of the table.

        """
        html = ["<table>"]
        is_header_row = True  # Assume the first row is the header

        for row_index, row in enumerate(
            table.tr_lst,
        ):  # pyright: ignore[reportAttributeAccessIssue]
            row_html = []
            row_html.append("<tr>")
            for col_index, cell in enumerate(row.tc_lst):
                tcPr = cell.tcPr
                if tcPr is not None:
                    vMerge = tcPr.find(f".//{w_ns}vMerge")
                    if vMerge is not None and vMerge.get(f"{w_ns}val") is None:
                        # This is a "continue" cell for a vertical merge, so we skip it.
                        continue

                rowspan, colspan = self.find_merged_cells(
                    table,
                    row_index,
                    col_index,
                    tcPr,
                )

                if is_header_row and row_index == 0 and colspan != 1:
                    is_header_row = False  # If the first row has merged cells, it's not a header row

                # Extract cell text content
                cell_text = ""
                cell_text = self.extract_cell_text(
                    cell,
                    OutputEncodeType.HTML,
                )

                attrs = []
                if rowspan > 1:
                    attrs.append(f'rowspan="{rowspan}"')
                if colspan > 1:
                    attrs.append(f'colspan="{colspan}"')

                # If there are attributes, insert an empty string at the start
                if len(attrs) > 0:
                    attrs.insert(0, "")
                cell_type = "th" if is_header_row and row_index == 0 else "td"
                row_html.append(
                    f"<{cell_type}{' '.join(attrs)}>{cell_text}</{cell_type}>",
                )
            row_html.append("</tr>")
            if is_header_row:
                is_header_row = False  # Only the first row is the header
            html.append("".join(row_html))

        html.append("</table>")
        return "\n".join(html)

    def find_merged_cells(
        self,
        table: CT_Tbl,
        row_index: int,
        col_index: int,
        tcPr: lxml_etree._Element | None,
    ) -> tuple[int, int]:
        """
        Find merged cell information from table

        Args:
            table: The table element to search for merged cells.
            row_index: The row index of the current cell.
            col_index: The column index of the current cell.
            tcPr: The XML element representing the cell properties.

        Returns:
            tuple[int, int]: The rowspan and colspan of the merged cells.

        """
        rowspan = 1
        colspan = 1

        if tcPr is None:
            return rowspan, colspan

        vMerge = tcPr.find(f".//{w_ns}vMerge")
        gridSpan = tcPr.find(f".//{w_ns}gridSpan")

        if vMerge is not None:
            vMerge_val = vMerge.get(f"{w_ns}val")
            if vMerge_val == "restart":
                # Find how many rows to merge
                for lookahead in range(
                    row_index + 1,
                    len(table.tr_lst),
                ):  # pyright: ignore[reportAttributeAccessIssue]
                    next_row = table.tr_lst[lookahead]  # pyright: ignore[reportAttributeAccessIssue]
                    if col_index < len(next_row.tc_lst):
                        next_cell = next_row.tc_lst[col_index]
                        next_tcPr = next_cell.tcPr
                        if (
                            next_tcPr is not None
                            and next_tcPr.find(f".//{w_ns}vMerge") is not None
                        ):
                            vMerge_next = next_tcPr.find(f".//{w_ns}vMerge")
                            if (
                                vMerge_next is not None
                                and vMerge_next.get(f"{w_ns}val") is None
                            ):  # continue
                                rowspan += 1
                            else:
                                break  # new restart
                        else:
                            break
                    else:
                        # Row has fewer cells than col_index, so merge must end.
                        break
        if gridSpan is not None:
            val = gridSpan.get(f"{w_ns}val")
            colspan = int(val) if val is not None else 1

        return rowspan, colspan

    def convert_table_to_text(self, table: CT_Tbl, depth: int = 0) -> str:
        """
        Convert a table element from a docx file into text format.

        Args:
            table (CT_Tbl): The table element to be converted.
            depth (int): The depth of the table in nested tables. Defaults to 0.

        Returns:
            str: The text representation of the table.

        """
        self.measure_table(table)
        self.depth = depth
        logger.debug(
            f"Table: Rows={self.max_rows} Cols={self.max_cols} Merged={self.merged_cells} depth={self.depth}",
        )

        # Check to see if table has been requested to be stripped
        current_size = (self.max_cols, self.max_rows)

        if current_size in self.config.remove_wrapping_tables:
            return self.remove_wrapping_table(table)

        if self.merged_cells:
            logger.debug("Table has merged cells, converting to HTML.")
            return self.table_to_html(table)
        return self.convert_table_to_md(table)

    def determine_table_cell_padding(
        self,
        table: list[list[str]],
    ) -> tuple[list[int], list[list[int]]]:
        """
        Determine the padding for each cell in a table.

        This version is robust to rows having differing numbers of cells:
        it computes the number of columns as the maximum row length and
        treats missing cells as empty strings.
        """
        if not table:
            return [], []

        num_cols = max(len(row) for row in table)
        col_widths = [0] * num_cols
        max_pad_width = 35

        recover_table_cells: list[list[int]] = []

        # Calculate the best sizes for each col (safe indexing)
        for row in table:
            for idx in range(num_cols):
                cell = row[idx] if idx < len(row) else ""
                col_widths[idx] = min(max_pad_width, max(col_widths[idx], len(cell)))

        # Now determine if a cell breaks the max size try to recover
        for row in table[1:]:
            rec_row: list[int] = []
            recover_table_cells.append(rec_row)

            recover_len = 0
            for idx in range(num_cols):
                cell = row[idx] if idx < len(row) else ""
                cell_len = len(cell)
                if col_widths[idx] - cell_len < 0:
                    recover_len += cell_len - col_widths[idx]
                pad_size = col_widths[idx]

                if recover_len > 0 and pad_size > cell_len - recover_len:
                    pad_size = max(0, pad_size - recover_len)
                    recover_len = 0

                rec_row.append(pad_size)

        return col_widths, recover_table_cells

    def remove_wrapping_table(self, table: CT_Tbl) -> str:
        """
        Convert a table element from a docx file into Markdown format.

        Args:
            table (CT_Tbl): The table element to be converted.

        Returns:
            list: A list of strings representing the table in Markdown format.

        """
        md_lines = []
        in_list = False

        for tr in table.tr_lst:
            for tc in tr.tc_lst:
                for content in tc.inner_content_elements:
                    if isinstance(content, CT_P):
                        para = ParaPart(self.config, content, self.depth)
                        in_list = para.process_para(
                            content,
                            md_lines,
                            in_list,
                        )
                    elif isinstance(content, CT_Tbl):
                        # Add a blank line before nested tables for readability
                        if self.depth == 0:
                            md_lines.append("")
                        sub_table = TableParts(self.converter_instance, self.config)
                        markdown_table = sub_table.convert_table_to_text(
                            content,
                            self.depth + 1,
                        )
                        md_lines.append(markdown_table)
                    else:
                        logger.warning(f"Unknown table content: {type(content)}")
                        md_lines.append(str(content))

        return "\n".join(md_lines)

    def convert_table_to_md(self, table: CT_Tbl) -> str:
        """
        Convert a table element from a docx file into Markdown format.

        Args:
            table (CT_Tbl): The table element to be converted.

        Returns:
            list: A list of strings representing the table in Markdown format.

        """
        table_lines = []
        headers = []
        rows = []

        first_row = 0
        first_col = 0

        quote_start = ">" * self.depth + " " if self.depth > 0 else ""

        for row_counter, tr in enumerate(
            table.tr_lst,
            start=1,
        ):  # pyright: ignore[reportAttributeAccessIssue]
            if tr.trPr is not None:
                for i in tr.trPr:
                    first_row = i.get(f"{w_ns}firstRow", 0)
                    first_col = i.get(f"{w_ns}firstColumn", 0)
                    logger.debug(f"[{row_counter}] FR:{first_row} FC:{first_col}")
            else:
                logger.debug(f"[{row_counter}] FR:{first_row} FC:{first_col}")

            cells = []
            for tc in tr.tc_lst:
                cell_text = self.extract_cell_text(
                    tc,
                    OutputEncodeType.MARKDOWN,
                )
                cells.append(cell_text)
            if not headers:
                headers = cells
            rows.append(cells)
            logger.debug(cells)

        # Calculate max width for each column (up to 20 chars)
        num_cols = len(headers)
        col_widths, recover_table_cells = self.determine_table_cell_padding(rows)

        # Build header line
        header_line = (
            "| "
            + " | ".join(
                cell.ljust(col_widths[idx]) for idx, cell in enumerate(headers)
            )
            + " |"
        )
        table_lines.append(quote_start + header_line)

        # Build separator line
        sep_line = (
            "| "
            + " | ".join(
                "-" * col_widths[idx] if col_widths[idx] > 0 else "---"
                for idx in range(num_cols)
            )
            + " |"
        )
        table_lines.append(quote_start + sep_line)

        for row_idx, row in enumerate(rows[1:]):
            # fall back to global col_widths if recover data missing
            row_col_widths = (
                recover_table_cells[row_idx]
                if row_idx < len(recover_table_cells)
                else col_widths
            )
            line = (
                "| "
                + " | ".join(
                    (
                        row[idx].ljust(row_col_widths[idx])
                        if idx < len(row)
                        else "".ljust(row_col_widths[idx])
                    )
                    for idx in range(num_cols)
                )
                + " |"
            )
            table_lines.append(quote_start + line)
        return "\n".join(table_lines)

    def extract_cell_text(  # noqa: PLR0912, PLR0915
        self,
        tc: CT_Tc,
        encoding: OutputEncodeType,
    ) -> str:
        """
        Extract text from a table cell.

        Args:
            tc (CT_Tc): The table cell element.
            encoding (OutputEncodeType): The encoding type for the output. Defaults to OutputEncodeType.MARKDOWN.

        Returns:
            str: The text content of the table cell.

        """
        cell_text = ""
        in_list = False
        list_level = 0
        para_list_level = 0

        for content in tc.inner_content_elements:
            if isinstance(content, CT_P):
                para = ParaPart(self.config, content)
                para_text = para.extract_paragraph_text(
                    content,
                    encoding,
                )
                para_style = content.style if content.style else ""

                if para_style in self.config.styles["bullets"]:
                    # If the style is in the custom bullets, mark it as a list
                    para_style = "ListParagraph"

                logger.debug(f"TABLE:Para Style [{para_style}] -[{para_text}]")

                if in_list and para_style != "ListParagraph":
                    # If we are in a list and the style is not ListParagraph, close the list
                    cell_text += "</ul>" + para_text + "<br/>"
                    in_list = False
                    list_level = 0
                elif para_style == "ListParagraph":
                    last_list = content

                    if last_list.pPr is not None and last_list.pPr.numPr is not None:
                        # Get the list level from the last list item
                        for i in last_list.pPr.numPr:  # pyright: ignore[reportGeneralTypeIssues]
                            if i.tag == w_lev:
                                para_list_level = i.val
                    elif content.style in self.config.styles["bullets"]:
                        # If the style is in the custom bullets, use the level from the style
                        para_list_level = self.config.styles["bullets"][content.style]
                    else:  # pragma: no cover
                        para_list_level = 0

                    if in_list:
                        net_change = para_list_level - list_level
                        if net_change > 0:
                            # If we are in a list and the level has increased, add a new list
                            cell_text += "<ul>" * net_change
                            list_level = para_list_level
                        elif net_change < 0:
                            # If we are in a list and the level has decreased, close lists
                            cell_text += "</ul>" * abs(net_change)
                            list_level = para_list_level
                            # If we are already in a list, just add the new item
                        cell_text += "<li>" + para_text + "</li>"
                    else:
                        cell_text += "<ul><li>" + para_text + "</li>"
                        in_list = True
                else:
                    cell_text += para_text + "<br/>"
                    # cell_text += content.text + "<br/>"
            elif isinstance(content, CT_Tbl):
                logger.warning("!!WARNING!! Embedded table found")
                self.config.runtime.embedded_tbl_count += 1
                cell_text += "!!WARNING!! Embedded table found"
            else:  # pragma: no cover
                cell_text += f"Unknown: {content} <br/>"

                # If we are in a list and the style is not ListParagraph, close the list
        if in_list:
            logger.debug(f"Closing list: Level {list_level}")
            cell_text += "</ul>" * (list_level + 1) + "<br/>"
            in_list = False
            list_level = 0

        if encoding != OutputEncodeType.HTML:
            # Clean up HTML tags if not in HTML encoding mode
            cell_text = clean_html_tags(cell_text)
        cell_text = clean_html_tags(cell_text)

        if self.config.ascii_only:
            cell_text = anyascii(cell_text)

        # Remove the trailing <br/>
        return cell_text[:-5]
