#!/usr/bin/env python3
"""Resolve a PyTorch XPU torch/torchaudio pair from the public XPU wheel index."""
from __future__ import annotations

import argparse
import html
import re
import urllib.parse
import urllib.request
from pathlib import Path


XPU_INDEX = "https://download.pytorch.org/whl/xpu"
RELEASE_RE = re.compile(r"^(\d+)(?:\.(\d+))?(?:\.(\d+))?")
PRE_RELEASE_RE = re.compile(r"(?:^|[.\-_]|\d)(dev|a|b|rc)(\d*)(?=$|[.\-_]|\d)", re.IGNORECASE)
POST_RELEASE_RE = re.compile(r"(?:^|[.\-_]|\d)post(\d*)(?=$|[.\-_]|\d)", re.IGNORECASE)
PRE_RELEASE_PRIORITY = {"dev": 0, "a": 1, "b": 2, "rc": 3}


def fetch_index(url: str) -> str:
    request = urllib.request.Request(url, headers={"User-Agent": "pip"})
    try:
        with urllib.request.urlopen(request, timeout=15) as response:
            page = response.read().decode("utf-8", errors="replace")
    except Exception as exc:
        raise RuntimeError(f"failed to fetch index from {url}: {exc}") from exc
    return urllib.parse.unquote(html.unescape(page))


def versions_from_page(package: str, decoded_page: str) -> list[str]:
    package_name = re.escape(package.replace("_", "-").lower())
    package_pattern = re.compile(
        rf"(?<![A-Za-z0-9_.-]){package_name}-(\d+\.\d+\.\d+(?:[a-zA-Z0-9_.-]*)\+xpu)",
        re.IGNORECASE,
    )
    versions = package_pattern.findall(decoded_page)
    seen = set()
    ordered = []
    for version in versions:
        if version not in seen:
            seen.add(version)
            ordered.append(version)
    return ordered


def simple_index_versions(package: str, index_url: str) -> list[str]:
    root_url = f"{index_url.rstrip('/')}/"
    try:
        root_page = fetch_index(root_url)
        versions = versions_from_page(package, root_page)
        if versions:
            return versions
    except RuntimeError:
        pass

    package_path = urllib.parse.quote(package.replace("_", "-").lower(), safe="")
    package_url = f"{root_url}{package_path}/"
    return versions_from_page(package, fetch_index(package_url))


def parse_version(version: str) -> tuple[int, ...]:
    core = version.removesuffix("+xpu")
    release_match = RELEASE_RE.match(core)
    if release_match is None:
        return (0, 0, 0, 0, 0)
    release = [int(part) if part is not None else 0 for part in release_match.groups()]

    stage = 4
    stage_number = 0
    if pre_match := PRE_RELEASE_RE.search(core):
        stage = PRE_RELEASE_PRIORITY[pre_match.group(1).lower()]
        stage_number = int(pre_match.group(2) or 0)
    elif post_match := POST_RELEASE_RE.search(core):
        stage = 5
        stage_number = int(post_match.group(1) or 0)

    return tuple(release + [stage, stage_number])


def resolve(index_url: str, require_matching: bool) -> tuple[str, str, str]:
    torch_versions = simple_index_versions("torch", index_url)
    audio_versions = simple_index_versions("torchaudio", index_url)
    if not torch_versions or not audio_versions:
        raise RuntimeError(f"missing XPU wheels: torch={torch_versions}, torchaudio={audio_versions}")

    torch_sorted = sorted(torch_versions, key=parse_version, reverse=True)
    audio_set = set(audio_versions)
    if require_matching:
        for version in torch_sorted:
            if version in audio_set:
                return version, version, "newest matching torch/torchaudio pair"
        raise RuntimeError(f"no matching torch/torchaudio XPU pair: torch={torch_versions}, torchaudio={audio_versions}")

    audio_sorted = sorted(audio_versions, key=parse_version, reverse=True)
    return torch_sorted[0], audio_sorted[0], "newest torch plus newest torchaudio"


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--index-url", default=XPU_INDEX)
    parser.add_argument("--output", required=True)
    parser.add_argument("--allow-mismatched-torchaudio", action="store_true")
    args = parser.parse_args()

    torch_version, audio_version, mode = resolve(args.index_url, not args.allow_mismatched_torchaudio)
    output = Path(args.output)
    output.parent.mkdir(parents=True, exist_ok=True)
    output.write_text(
        "\n".join(
            [
                f"--extra-index-url {args.index_url}",
                f"torch=={torch_version}",
                f"torchaudio=={audio_version}",
                "",
            ]
        ),
        encoding="utf-8",
    )
    print(f"mode={mode}")
    print(f"torch={torch_version}")
    print(f"torchaudio={audio_version}")
    print(f"output={output}")


if __name__ == "__main__":
    main()
