"""create_pdfs."""

from __future__ import annotations

import logging
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from multiprocessing import cpu_count
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar

from PIL import Image

if TYPE_CHECKING:
    from collections.abc import Callable, Mapping, Sequence

R = TypeVar("R")

modes = Literal["normal", "early_error"]


@dataclass
class ExecutorResults(Generic[R]):
    """Dataclass to store the results and exceptions of the parallel execution."""

    results: list[R]
    exceptions: list[BaseException]

    def __repr__(self) -> str:
        """Return a string representation of the object."""
        return f"results={self.results} exceptions={self.exceptions}"


def _parallelize_base(
    executor_type: type[ProcessPoolExecutor],
    func: Callable[..., R],
    kwargs_list: Sequence[Mapping[str, Any]],
    max_workers: int | None,
    progress_tracker: int | None,
    mode: modes,
) -> ExecutorResults:
    total_work = len(kwargs_list)

    with executor_type(max_workers=max_workers) as executor:
        futures = [executor.submit(func, **kwarg) for kwarg in kwargs_list]

    results = []
    exceptions = []
    for index, future in enumerate(futures, 1):
        if exception := future.exception():
            logging.error(f"{future} raised {exception.__class__.__name__}")
            exceptions.append(exception)
            if mode == "early_error":
                executor.shutdown(wait=False)
                raise exception
            continue

        results.append(future.result())

        if progress_tracker and index % progress_tracker == 0:
            logging.info(f"Progress: {index}/{total_work}")

    return ExecutorResults(results, exceptions)


def process_executor_unchecked(
    func: Callable[..., R],
    kwargs_list: Sequence[Mapping[str, Any]],
    max_workers: int | None,
    progress_tracker: int | None,
    mode: modes = "normal",
) -> ExecutorResults:
    """Generic function to run a function with multiple arguments in parallel.

    Note: this function does not check if the number of workers is greater than the number of CPUs.
    This can cause the system to become unresponsive.

    Args:
        func (Callable[..., R]): Function to run in parallel.
        kwargs_list (Sequence[Mapping[str, Any]]): List of dictionaries with the arguments for the function.
        max_workers (int, optional): Number of workers to use. Defaults to 8.
        progress_tracker (int, optional): Number of tasks to complete before logging progress.
        mode (modes, optional): Mode to use. Defaults to "normal".

    Returns:
        tuple[list[R], list[Exception]]: List with the results and a list with the exceptions.
    """
    return _parallelize_base(
        executor_type=ProcessPoolExecutor,
        func=func,
        kwargs_list=kwargs_list,
        max_workers=max_workers,
        progress_tracker=progress_tracker,
        mode=mode,
    )


def create_pdf(input_dir: Path, output_pdf: Path) -> None:
    """Create a PDF from the input directory.

    Args:
    input_dir (Path): The input directory.
    output_pdf (Path): The output PDF file.
    """
    if output_pdf.exists():
        error = f"{output_pdf} already exists"
        raise FileExistsError(error)

    pnm_files = sorted(input_dir.glob("*.pnm"))

    if not pnm_files:
        error = "No PNM files found"
        raise ValueError(error)

    first_image = Image.open(str(pnm_files[0])).convert("RGB")

    first_image.save(
        output_pdf,
        "PDF",
        save_all=True,
        append_images=(Image.open(str(pnm_file)).convert("RGB") for pnm_file in pnm_files[1:]),
        quality=20,
        optimize=True,
    )


def main() -> None:
    print("start")
    scans_dir = Path("/zfs/media/share/chours_music/titled")

    output_dir = Path("./pdfs")

    scan_dirs = [scan_dir for scan_dir in scans_dir.iterdir() if scan_dir.is_dir()]

    print(f"scan_dirs={scan_dirs}")

    process_executor_unchecked(
        func=create_pdf,
        kwargs_list=[
            {"input_dir": scan_dir, "output_pdf": output_dir / f"{scan_dir.name}.pdf"} for scan_dir in scan_dirs
        ],
        max_workers=cpu_count(),
        progress_tracker=100,
    )

    print("done")


if __name__ == "__main__":
    main()
