#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.11"
# ///

# {{@@ header() @@}}
#
# age encryption / decryption helpers
# based on https://github.com/ryantm/agenix

from __future__ import annotations

import argparse
import re
import subprocess
import sys
from pathlib import Path

USER_BEGIN_MARKER = "#-----BEGIN USER PUBLIC KEYS-----"
USER_END_MARKER = "#------END USER PUBLIC KEYS------"
SYSTEM_BEGIN_MARKER = "#-----BEGIN SYSTEM PUBLIC KEYS-----"
SYSTEM_END_MARKER = "#------END SYSTEM PUBLIC KEYS------"


def normalize_public_key(raw_key: str) -> str:
    parts = raw_key.strip().split()
    if len(parts) < 2:
        raise ValueError(
            'Public key must contain at least "<type> <key>" (e.g. "ssh-ed25519 AAAAC3N...").'
        )
    return f"{parts[0]} {parts[1]}"


def find_marker(lines: list[str], marker: str) -> int:
    for index, line in enumerate(lines):
        if marker in line:
            return index
    raise ValueError(f'Marker "{marker}" was not found.')


def update_keys(
    file_path: Path,
    start_marker: str,
    end_marker: str,
    new_key_line: str,
    list_name: str,
) -> None:
    lines = file_path.read_text(encoding="utf-8").splitlines(keepends=True)

    start_index = find_marker(lines, start_marker)
    end_index = find_marker(lines[start_index + 1 :], end_marker) + start_index + 1

    existing_lines = [line.rstrip("\n") for line in lines[start_index + 1 : end_index]]
    merged = sorted(existing_lines + [new_key_line])

    unique_lines: list[str] = []
    ordered_keys: list[str] = []
    seen_keys: set[str] = set()
    for entry in merged:
        key_name = entry.split("=", 1)[0].strip()
        if key_name not in seen_keys:
            unique_lines.append(entry)
            ordered_keys.append(key_name)
            seen_keys.add(key_name)

    output_lines: list[str] = []
    in_block = False
    list_pattern = re.compile(rf"^\s*{re.escape(list_name)}\s*=\s*\[.*\];\s*$")
    for index, line in enumerate(lines):
        if index == start_index:
            output_lines.append(line)
            for entry in unique_lines:
                output_lines.append(f"{entry}\n")
            in_block = True
            continue

        if in_block:
            if index == end_index:
                output_lines.append(line)
                in_block = False
            continue

        if list_pattern.match(line):
            indent = re.match(r"^\s*", line).group(0)
            key_list = " ".join(ordered_keys)
            output_lines.append(f"{indent}{list_name} = [ {key_list} ];\n")
        else:
            output_lines.append(line)

    file_path.write_text("".join(output_lines), encoding="utf-8")


def parse_ssh_keyscan_output(raw_output: str, requested_types: str) -> str:
    preferred_types = [item.strip() for item in requested_types.split(",") if item.strip()]
    type_rank = {key_type: rank for rank, key_type in enumerate(preferred_types)}

    candidates: list[tuple[int, str, str]] = []
    for index, raw_line in enumerate(raw_output.splitlines()):
        line = raw_line.strip()
        if not line or line.startswith("#"):
            continue

        fields = line.split()
        if len(fields) < 3:
            continue

        key_type, key_material = fields[1], fields[2]
        candidates.append((index, key_type, key_material))

    if not candidates:
        raise ValueError("No valid SSH key lines found in ssh-keyscan output.")

    # Prefer modern Ed25519 keys over RSA whenever both are available.
    if not preferred_types or "ssh-ed25519" in preferred_types:
        ed25519_candidate = next(
            ((key_type, key_material) for _, key_type, key_material in candidates if key_type == "ssh-ed25519"),
            None,
        )
        if ed25519_candidate is not None:
            key_type, key_material = ed25519_candidate
            return f"{key_type} {key_material}"

    if preferred_types:
        ranked_candidates = [
            (type_rank[key_type], output_order, key_type, key_material)
            for output_order, key_type, key_material in candidates
            if key_type in type_rank
        ]
        if ranked_candidates:
            _, _, key_type, key_material = min(ranked_candidates)
            return f"{key_type} {key_material}"

    _, key_type, key_material = candidates[0]
    return f"{key_type} {key_material}"


def gen_user_key(name: str, public_key: str | None, working_directory: Path) -> None:
    input_file = working_directory / "secrets" / "secrets.nix"

    if public_key is None:
        print(f"generating new keys for host {name}")
        private_key_file = Path.home() / ".ssh" / name
        subprocess.run(
            [
                "ssh-keygen",
                "-t",
                "ed25519",
                "-f",
                str(private_key_file),
                "-C",
                f"agenix@{name}",
                "-N",
                "",
            ],
            check=True,
        )

        print(f"getting user public key for user {name}")
        user_public_key = normalize_public_key(private_key_file.with_suffix(".pub").read_text(encoding="utf-8"))
    else:
        user_public_key = normalize_public_key(public_key)

    user_key_line = f'  {name} = "{user_public_key}";'
    update_keys(input_file, USER_BEGIN_MARKER, USER_END_MARKER, user_key_line, "users")


def get_host_key(name: str, target: str, key_type: str, working_directory: Path) -> None:
    input_file = working_directory / "secrets" / "secrets.nix"

    print(f"getting host public key for host {name}")
    keyscan = subprocess.run(
        ["ssh-keyscan", "-t", key_type, target],
        check=True,
        text=True,
        capture_output=True,
    )
    parsed_key = parse_ssh_keyscan_output(keyscan.stdout, key_type)
    host_key_line = f'  {name} = "{parsed_key}";'

    update_keys(input_file, SYSTEM_BEGIN_MARKER, SYSTEM_END_MARKER, host_key_line, "systems")


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        prog=Path(sys.argv[0]).name,
        usage="%(prog)s < gen-user-key [argument ...] | get-host-key [argument ...] >",
        description="age encryption / decryption helpers (based on https://github.com/ryantm/agenix)",
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        "subcommand",
        choices=["gen-user-key", "get-host-key"],
        help=(
            "gen-user-key: generate or add a user key to secrets.nix\n"
            "get-host-key: fetch a host key via ssh-keyscan and add it to secrets.nix"
        ),
    )
    parser.add_argument(
        "-k",
        "--public-key",
        help='provide a public key instead of generating one (format: "ssh-ed25519 AAAAC3N...")',
    )
    parser.add_argument("-n", "--name", help="key name, usually the hostname")
    parser.add_argument(
        "-p",
        "--path",
        help="path to the root directory for nixOS configuration files (defaults to current directory)",
    )
    parser.add_argument("-t", "--target", help="hostname/FQDN/IP to query via ssh-keyscan")
    parser.add_argument(
        "--type",
        default="ssh-ed25519",
        help='ssh-keyscan key type(s), defaults to "ssh-ed25519"',
    )
    return parser


def main(argv: list[str] | None = None) -> int:
    parser = build_parser()
    effective_argv = sys.argv[1:] if argv is None else argv
    if not effective_argv:
        parser.print_help()
        return 0

    args = parser.parse_args(effective_argv)
    working_directory = Path(args.path).expanduser().resolve() if args.path else Path.cwd()

    if args.subcommand == "gen-user-key":
        if not args.name:
            parser.error('Error, missing option "-n/--name"')
        gen_user_key(args.name, args.public_key, working_directory)
        return 0

    if args.subcommand == "get-host-key":
        if not args.name:
            parser.error('Error, missing option "-n/--name"')
        if not args.target:
            parser.error('Error, missing option "-t/--target"')
        get_host_key(args.name, args.target, args.type, working_directory)
        return 0

    parser.error("Wrong sub command, use -h to print the help.")
    return 4


if __name__ == "__main__":
    try:
        raise SystemExit(main())
    except (subprocess.CalledProcessError, ValueError, OSError) as exc:
        print(str(exc), file=sys.stderr)
        raise SystemExit(1)
