#!/usr/bin/env python3
from __future__ import annotations

import argparse
import os
import sys
from pathlib import Path


DEFAULT_PROMPT_PREFIX = "You are a helpful assistant.<|endofprompt|>"


def configure_paths(repo_root: Path) -> None:
    sys.path.insert(0, str(repo_root))
    sys.path.insert(0, str(repo_root / "third_party" / "Matcha-TTS"))


def load_text(args: argparse.Namespace) -> str:
    if args.text is not None:
        return args.text.strip()
    if args.text_file is not None:
        return Path(args.text_file).read_text(encoding="utf-8").strip()
    raise ValueError("provide --text or --text-file")


def build_prompt_text(raw_text: str, prompt_prefix: str) -> str:
    raw_text = raw_text.strip()
    if "<|endofprompt|>" in raw_text:
        return raw_text
    return f"{prompt_prefix}{raw_text}"


def first_parameter_device(cosyvoice):
    model = getattr(cosyvoice, "model", None)
    if model is None:
        return None
    for attr in ("llm", "flow", "hift"):
        module = getattr(model, attr, None)
        parameters = getattr(module, "parameters", None)
        if not callable(parameters):
            continue
        try:
            module_parameters = parameters()
        except TypeError:
            continue
        if module_parameters is None:
            continue
        for param in module_parameters:
            return param.device
    return None


def choose_backend(torch, requested: str) -> str:
    if requested != "auto":
        return requested
    if torch.cuda.is_available():
        return "cuda"
    if hasattr(torch, "xpu") and torch.xpu.is_available():
        return "xpu"
    return "cpu"


def xpu_probe(torch) -> None:
    print(f"xpu_available={hasattr(torch, 'xpu') and torch.xpu.is_available()}")
    if not hasattr(torch, "xpu") or not torch.xpu.is_available():
        return
    before = torch.xpu.memory_allocated(0)
    x = torch.ones((512, 512), device="xpu")
    y = x @ x
    torch.xpu.synchronize()
    print(f"xpu_device_count={torch.xpu.device_count()}")
    print(f"xpu_device_name={torch.xpu.get_device_name(0)}")
    print(f"probe_tensor_device={y.device}")
    print(f"xpu_memory_before_probe={before}")
    print(f"xpu_memory_after_probe={torch.xpu.memory_allocated(0)}")
    print(f"probe_sum={float(y.cpu().sum())}")


def cuda_probe(torch) -> None:
    print(f"cuda_available={torch.cuda.is_available()}")
    print(f"cuda_device_count={torch.cuda.device_count()}")
    if not torch.cuda.is_available():
        return
    before = torch.cuda.memory_allocated(0)
    x = torch.ones((512, 512), device="cuda")
    y = x @ x
    torch.cuda.synchronize()
    print(f"cuda_device_name={torch.cuda.get_device_name(0)}")
    print(f"cuda_capability={torch.cuda.get_device_capability(0)}")
    print(f"cuda_arch_list={torch.cuda.get_arch_list()}")
    print(f"probe_tensor_device={y.device}")
    print(f"cuda_memory_before_probe={before}")
    print(f"cuda_memory_after_probe={torch.cuda.memory_allocated(0)}")
    print(f"probe_sum={float(y.cpu().sum())}")


def synthesize(cosyvoice, text: str, args, prompt_wav: Path, prompt_text: str):
    if args.mode == "instruct":
        return cosyvoice.inference_instruct2(
            text,
            args.instruction,
            str(prompt_wav),
            zero_shot_spk_id="",
            stream=False,
            speed=args.speed,
        )
    return cosyvoice.inference_zero_shot(
        text,
        prompt_text,
        str(prompt_wav),
        zero_shot_spk_id="",
        stream=False,
        speed=args.speed,
    )


def main() -> None:
    parser = argparse.ArgumentParser(description="Generate speech with the Furina prompt voice.")
    parser.add_argument("--repo-root", default=".")
    parser.add_argument("--model-dir", default="pretrained_models/Fun-CosyVoice3-0.5B")
    parser.add_argument("--voice-dir", default="voices/furina_en")
    parser.add_argument("--backend", choices=["auto", "xpu", "cuda"], default="auto")
    input_group = parser.add_mutually_exclusive_group(required=True)
    input_group.add_argument("--text")
    input_group.add_argument("--text-file")
    parser.add_argument("--output", default="outputs/furina.wav")
    parser.add_argument("--mode", choices=["zero-shot", "instruct"], default="zero-shot")
    parser.add_argument("--instruction", default="You are a helpful assistant.<|endofprompt|>")
    parser.add_argument("--prompt-prefix", default=DEFAULT_PROMPT_PREFIX)
    parser.add_argument("--speed", type=float, default=0.9)
    parser.add_argument("--fp16", action="store_true", help="Use fp16 autocast on the selected accelerator.")
    parser.add_argument("--allow-cpu", action="store_true", help="Do not fail if the model is not on the selected accelerator.")
    parser.add_argument("--allow-xpu-fallback", action="store_true", help="Allow PyTorch XPU CPU fallback.")
    args = parser.parse_args()
    if args.speed <= 0:
        parser.error("The --speed parameter must be a positive float.")

    os.environ["PYTORCH_ENABLE_XPU_FALLBACK"] = "1" if args.allow_xpu_fallback else "0"
    repo_root = Path(args.repo_root).resolve()
    configure_paths(repo_root)

    import soundfile as sf
    import torch

    from cosyvoice.cli.cosyvoice import AutoModel

    backend = choose_backend(torch, args.backend)
    print(f"backend={backend}")
    print(f"PYTORCH_ENABLE_XPU_FALLBACK={os.environ['PYTORCH_ENABLE_XPU_FALLBACK']}")
    print(f"torch={torch.__version__}")
    if backend == "cuda":
        print(f"cuda_runtime={torch.version.cuda}")
        cuda_probe(torch)
    elif backend == "xpu":
        xpu_probe(torch)
    elif not args.allow_cpu:
        raise RuntimeError("no CUDA or XPU accelerator available; pass --allow-cpu to continue on CPU")

    text = load_text(args)
    model_dir = (repo_root / args.model_dir).resolve()

    cosyvoice = AutoModel(model_dir=str(model_dir), fp16=args.fp16)
    model_device = first_parameter_device(cosyvoice)
    print(f"model_device={model_device}")
    if not args.allow_cpu and (model_device is None or model_device.type != backend):
        raise RuntimeError(f"expected model on {backend}, got {model_device}")

    voice_dir = (repo_root / args.voice_dir).resolve()
    prompt_wav = voice_dir / "prompt.wav"
    prompt_txt_path = voice_dir / "prompt.txt"
    if not prompt_wav.exists():
        raise FileNotFoundError(f"Prompt audio not found at {prompt_wav}")
    if not prompt_txt_path.exists():
        raise FileNotFoundError(f"Prompt text not found at {prompt_txt_path}")
    prompt_text = build_prompt_text(prompt_txt_path.read_text(encoding="utf-8"), args.prompt_prefix)

    with torch.no_grad():
        chunks = [chunk["tts_speech"] for chunk in synthesize(cosyvoice, text, args, prompt_wav, prompt_text)]
    if not chunks:
        raise RuntimeError("CosyVoice synthesis returned no audio chunks")
    audio = torch.cat(chunks, dim=1)

    output_path = (repo_root / args.output).resolve()
    output_path.parent.mkdir(parents=True, exist_ok=True)
    sf.write(str(output_path), audio.transpose(0, 1).cpu().numpy(), cosyvoice.sample_rate)
    print(f"output={output_path}")
    print(f"sample_rate={cosyvoice.sample_rate}")
    print(f"duration_seconds={audio.shape[1] / cosyvoice.sample_rate:.3f}")
    print("speaker_source=prompt")


if __name__ == "__main__":
    main()
