#!/usr/bin/env python3
import argparse
import json
import logging
from pathlib import Path

import pystache
import semver

logger = logging.getLogger(__name__)


def replace_templates(package_dir: Path) -> None:
    possible_manifests = list((package_dir / "debian").glob("redistrib_*.json"))
    logger.debug(f"Possible manifests in debian/: {[p.name for p in possible_manifests]}")
    if not possible_manifests:
        possible_manifests = list(package_dir.glob("redistrib_*.json"))
        logger.debug(f"Possible manifests in package root: {[p.name for p in possible_manifests]}")
        if not possible_manifests:
            raise FileNotFoundError("No redistrib_*.json manifest file found")
    # find the manifest file with the highest version number
    manifest_path = max(possible_manifests, key=lambda p: semver.VersionInfo.parse(p.stem.split("_")[1]))
    with manifest_path.open("r") as f:
        manifest_data = json.load(f)

    cuda_version = semver.VersionInfo.parse(manifest_data["release_label"])
    template_context: dict[str, str] = {}

    for key, value in manifest_data.items():
        # Skip top-level metadata
        if key in ("release_date", "release_label", "release_product"):
            continue

        # Skip any non-dict values (shouldn't happen but defensive
        # programming)
        if not isinstance(value, dict):
            logger.warning(f"Skipping non-dict component: {key}")
            continue

        # Check if this looks like a component (has required fields)
        if not all(field in value for field in ("name", "license", "version")):
            logger.warning(f"Skipping malformed component: {key}")
            continue

        template_context[f"{key}_version"] = value["version"]

    nvidia_driver_version = template_context["nvidia_driver_version"]
    assert nvidia_driver_version is not None, "Nvidia Driver component not found in manifest"
    # nvidia_driver_version will be something like 580.95.05, nvidia_driver_series will be just 580
    template_context["nvidia_driver_series"] = nvidia_driver_version.split(".")[0]
    template_context["cuda_version"] = str(cuda_version)
    template_context["cuda_major"] = str(cuda_version.major)
    template_context["cuda_minor"] = str(cuda_version.minor)
    template_context["cuda_major_dot_minor"] = f"{cuda_version.major}.{cuda_version.minor}"
    template_context["cuda_major_dash_minor"] = f"{cuda_version.major}-{cuda_version.minor}"
    template_context["cuda_priority"] = str(cuda_version.major * 10 + cuda_version.minor)
    nsight_compute_version = template_context.get("nsight_compute_version")
    nsight_systems_version = template_context.get("nsight_systems_version")
    assert nsight_compute_version is not None, "Nsight Compute component not found in manifest"
    assert nsight_systems_version is not None, "Nsight Systems component not found in manifest"
    template_context["nsight_compute_version"] = ".".join(nsight_compute_version.split(".")[:3])
    template_context["nsight_systems_version"] = ".".join(nsight_systems_version.split(".")[:3])

    logger.debug(f"Template context: {template_context}")

    # Define filename replacements (ordered: process CUDA_MAJOR_DASH_MINOR before CUDA_MAJOR)
    filename_replacements = [
        ("CUDA_MAJOR_DASH_MINOR", f"{cuda_version.major}-{cuda_version.minor}"),
        ("CUDA_MAJOR_DOT_MINOR", f"{cuda_version.major}.{cuda_version.minor}"),
        ("CUDA_MAJOR", str(cuda_version.major)),
        ("CUDA_MINOR", str(cuda_version.minor)),
        ("NSIGHT_COMPUTE_VERSION", nsight_compute_version),
        ("NSIGHT_SYSTEMS_VERSION", nsight_systems_version),
    ]

    logger.debug(f"Filename replacements: {filename_replacements}")

    # Find and process all .template files in the entire package directory
    for root, _, files in (package_dir / "debian").walk():
        for filename in files:
            if filename.endswith(".template"):
                template_path = root / filename
                output_filename = filename[:-9]  # Remove .template extension

                # Apply filename replacements in order
                for placeholder, replacement in filename_replacements:
                    output_filename = output_filename.replace(placeholder, replacement)

                output_path = root / output_filename

                logger.debug(f"Processing template: {template_path} -> {output_path}")

                # Read template content
                with template_path.open("r") as f:
                    template_content = f.read()

                strict_renderer = pystache.Renderer(missing_tags="strict")
                # Render template with context
                rendered_content = strict_renderer.render(template_content, template_context)

                # Write rendered content to output file
                with output_path.open("w") as f:
                    f.write(rendered_content)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Replace templates in package directory.")
    parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.")

    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)

    package_dir = Path(__file__).parent.parent.parent.resolve()

    replace_templates(package_dir)
