#!/usr/bin/env python3
import os

os.environ.setdefault("PYTORCH_ENABLE_XPU_FALLBACK", "0")

import torch


def main() -> None:
    print(f"PYTORCH_ENABLE_XPU_FALLBACK={os.environ.get('PYTORCH_ENABLE_XPU_FALLBACK')}")
    print(f"torch={torch.__version__}")
    print(f"cuda_available={torch.cuda.is_available()}")
    has_xpu = hasattr(torch, "xpu")
    print(f"has_xpu={has_xpu}")
    xpu_available = has_xpu and torch.xpu.is_available()
    print(f"xpu_available={xpu_available}")
    if xpu_available:
        print(f"xpu_device_count={torch.xpu.device_count()}")
        for index in range(torch.xpu.device_count()):
            print(f"xpu[{index}]={torch.xpu.get_device_name(index)}")
        before = torch.xpu.memory_allocated(0)
        x = torch.ones((256, 256), device="xpu")
        y = x @ x
        torch.xpu.synchronize()
        print(f"probe_tensor_device={y.device}")
        print(f"xpu_memory_before={before}")
        print(f"xpu_memory_after={torch.xpu.memory_allocated(0)}")
        print(f"probe_sum={float(y.cpu().sum())}")


if __name__ == "__main__":
    main()
