#!/usr/bin/env python3

# Copyright: 2026 Hector CAO <hector.cao@canonical.com>
# SPDX-License-Identifier: GPL-3.0-or-later

"""ubuntu-virt-helper: manage base and hwe stacks."""

import argparse
import subprocess
import sys

# Source packages considered part of the virt stack
_VIRT_SOURCES = {"qemu",
                 "qemu-hwe",
                 "edk2",
                 "edk2-hwe",
                 "seabios",
                 "seabios-hwe",
                 "libvirt",
                 "libvirt-hwe"}

def get_all_virt_installed_packages():
    """Get list of all installed packages with their source and version.
    
    Returns:
        variant: variant package (ubuntu-virt or ubuntu-virt-hwe)
        List of installed virt-related packages.
    """
    variant = None
    virt_packages = []

    mark_manual = get_package_manual()

    result = subprocess.run(
        ["dpkg-query", "-W", "-f=${binary:Package}\t${Version}\t${source:Package}\t${Status}\n"],
        capture_output=True,
        text=True,
        check=False,
    )
    for line in result.stdout.splitlines():
        parts = line.split("\t", 3)
        if len(parts) == 4:
            pkg, version, src, status = parts
        # only consider packages that are installed (status contains "installed")
        if "installed" not in status:
            continue
        # detect variant
        # dpkg-query might return package name with :arch suffix (e.g libvirt0:amd64)
        pkg_name = pkg.split(":")[0]
        if pkg_name == 'ubuntu-virt-hwe' or pkg_name == 'ubuntu-virt':
            if variant:
                print(f"Warning: Multiple ubuntu-virt variants detected")
            else:
                variant = pkg_name
        # collect packages with virt-related sources
        if src in _VIRT_SOURCES:
            virt_packages.append((pkg, version, src, pkg_name in mark_manual))
    return variant, virt_packages

def get_package_manual():
    """Retrieve lists of packages marked as manual.

    Returns:
        List of packages marked as manual.
    """
    manual_packages = set()

    # Get manually installed packages
    result = subprocess.run(
        ["apt-mark", "showmanual"],
        capture_output=True,
        text=True,
        check=False,
    )
    if result.returncode == 0:
        manual_packages = set(result.stdout.strip().splitlines())

    return manual_packages

def cmd_status(args):
    """Show currently installed variant."""
    
    variant, packages = get_all_virt_installed_packages()

    if variant == "ubuntu-virt":
        print("Installed variant: base")
    elif variant == "ubuntu-virt-hwe":
        print("Installed variant: hwe")
    else:
        print("Installed variant: none")
    
    if packages:
        print("\n{} packages:".format(len(packages)))
        for pkg, version, src, manual in packages:
            if args.verbose:
                manual_str = "manual" if manual else "auto"
                print(f"  - {pkg} ({version}, src:{src}, {manual_str})")
            else:
                print(f"  - {pkg} ({version}, src:{src})")
    else:
        print("\nNo packages installed")
    
    return 0


def cmd_switch(args):
    """Switch to the other virt variant."""
    
    variant, packages = get_all_virt_installed_packages()

    if variant == "ubuntu-virt":
        print(f"Switching from base to hwe variant...")
    elif variant == "ubuntu-virt-hwe":
        print(f"Switching from hwe to base variant...")
    else:
        print("No virt variant currently installed. Cannot switch.")
        return 1

    # build the counterpart packages, skip special packages only available in 1 variant
    counterpart_packages = []
    skipped = {'ubuntu-helper-virt-hwe'}
    for pkg, version, src, manual in packages:
        if pkg in skipped:
            continue

        pkg_name = pkg.split(":")[0]
        pkg_arch = pkg.split(":")[1] if ":" in pkg else None

        if variant == "ubuntu-virt":
            pkg_name = pkg_name + '-hwe'
        elif variant == "ubuntu-virt-hwe" and pkg_name.endswith("-hwe"):
            pkg_name = pkg_name[:-4]
        counterpart_pkg = pkg_name + (f":{pkg_arch}" if pkg_arch else "")

        # check of the counterpart package exists in the repositories before adding it to the list
        result = subprocess.run(
            ["apt-cache", "show", counterpart_pkg],
            capture_output=True,
            text=True,
            check=False,        )
        if result.returncode != 0:
            print(f"Warning: counterpart package {counterpart_pkg} not found in repositories, skipping")
            continue
        counterpart_packages.append((counterpart_pkg, manual))

    # Run apt install for the target variant
    apt_install_args = args.apt_install_args or []
    if apt_install_args and apt_install_args[0] == "--":
        apt_install_args = apt_install_args[1:]

    try:
        apt_install_cmd = ["apt", "install"] + apt_install_args + [p[0] for p in counterpart_packages]
        if args.verbose:
            print("Running: {}".format(" ".join(apt_install_cmd)))
        result = subprocess.run(
            apt_install_cmd,
            check=False,
        )
        if result.returncode != 0:
            return result.returncode
    except Exception as e:
        print(f"Error: Failed to run apt install: {e}")
        return 1

    # restore auto marks if not marked as manual in the original variant
    # this is best effort so ignore all errors and just print a warning if it fails
    auto_pkgs = [p[0] for p in counterpart_packages if not p[1]]
    print("Restoring auto marks for packages: {}".format(", ".join(auto_pkgs)))
    try:
        result = subprocess.run(
            ["apt-mark", "auto"] + auto_pkgs,
            check=False,
        )
    except Exception as e:
        print(f"Error: Failed to run apt-mark auto: {e}")

    return 0

def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(
        prog="ubuntu_virt_helper",
        description="Detect ubuntu-virt package variants",
    )
    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        help="Enable verbose output",
    )
    
    subparsers = parser.add_subparsers(dest="command", help="Command")
    
    # Status command
    status = subparsers.add_parser("status", help="Show installed variant")
    
    # Switch command
    switch = subparsers.add_parser("switch", help="Switch to the other virt variant")
    switch.add_argument(
        "apt_install_args",
        nargs=argparse.REMAINDER,
        help="Arguments passed through to apt install (use '--' before them)",
    )
    
    args = parser.parse_args()
    
    if not args.command:
        # Default to status
        args.command = "status"
    
    if args.command == "status":
        return cmd_status(args)
    elif args.command == "switch":
        return cmd_switch(args)
    else:
        parser.print_help()
        return 1


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        pass
    except Exception as e:
        print(f"ubuntu_virt_helper: warning: unexpected error: {e}", file=sys.stderr)
        raise SystemExit(0)
