#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import subprocess
from pathlib import Path


def main() -> None:
    parser = argparse.ArgumentParser(description="Probe PyTorch CUDA on the Quadro M4000.")
    parser.add_argument("--output", type=Path)
    args = parser.parse_args()

    import torch

    result: dict[str, object] = {
        "torch": torch.__version__,
        "cuda_runtime": torch.version.cuda,
        "cuda_available": torch.cuda.is_available(),
        "device_count": torch.cuda.device_count(),
    }
    if torch.cuda.is_available():
        result.update(
            {
                "device_name": torch.cuda.get_device_name(0),
                "capability": list(torch.cuda.get_device_capability(0)),
                "arch_list": torch.cuda.get_arch_list(),
                "memory_before": torch.cuda.memory_allocated(0),
            }
        )
        x = torch.ones((256, 256), device="cuda")
        y = x @ x
        torch.cuda.synchronize()
        result.update(
            {
                "probe_tensor_device": str(y.device),
                "memory_after": torch.cuda.memory_allocated(0),
                "probe_sum": float(y.cpu().sum()),
            }
        )

    try:
        smi = subprocess.run(
            [
                "nvidia-smi",
                "--query-gpu=index,name,compute_cap,driver_version,memory.total",
                "--format=csv,noheader",
            ],
            check=False,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
        )
        result["nvidia_smi"] = smi.stdout.strip()
    except FileNotFoundError:
        result["nvidia_smi"] = "nvidia-smi not found"

    rendered = json.dumps(result, indent=2)
    print(rendered)
    if args.output:
        args.output.parent.mkdir(parents=True, exist_ok=True)
        args.output.write_text(rendered + "\n", encoding="utf-8")


if __name__ == "__main__":
    main()
