#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
from pathlib import Path


DEFAULT_DATASET = "NaruseShiroha/Genshin-Furina-English"
DEFAULT_PARQUET = "data/train-00000-of-00002.parquet"
DEFAULT_ROW_INDEX = 20
DEFAULT_PROMPT_PREFIX = "You are a helpful assistant.<|endofprompt|>"


def build_prompt_text(raw_text: str, prompt_prefix: str) -> str:
    raw_text = raw_text.strip()
    if "<|endofprompt|>" in raw_text:
        return raw_text
    return f"{prompt_prefix}{raw_text}"


def prompt_row_data(row: dict, row_index: int) -> tuple[bytes, str]:
    audio_data = row.get("audio")
    if not isinstance(audio_data, dict) or audio_data.get("bytes") is None:
        raise ValueError(f"Audio data or bytes missing in row {row_index}")

    audio_bytes = audio_data["bytes"]
    if not isinstance(audio_bytes, bytes):
        raise ValueError(f"Audio bytes in row {row_index} are not bytes")

    transcription = row.get("transcription")
    if transcription is None:
        transcription = ""
    elif not isinstance(transcription, str):
        raise ValueError(f"Transcription in row {row_index} is not text")

    return audio_bytes, transcription


def main() -> None:
    parser = argparse.ArgumentParser(description="Extract a Furina prompt voice for CosyVoice3.")
    parser.add_argument("--repo-root", default=".")
    parser.add_argument("--output-dir", default="voices/furina_en")
    parser.add_argument("--speaker-id", default="furina_en")
    parser.add_argument("--dataset-repo", default=DEFAULT_DATASET)
    parser.add_argument("--parquet-file", default=DEFAULT_PARQUET)
    parser.add_argument("--row-index", type=int, default=DEFAULT_ROW_INDEX)
    parser.add_argument("--prompt-prefix", default=DEFAULT_PROMPT_PREFIX)
    args = parser.parse_args()

    repo_root = Path(args.repo_root).resolve()

    import pyarrow.parquet as pq
    from huggingface_hub import hf_hub_download

    output_dir = (repo_root / args.output_dir).resolve()
    output_dir.mkdir(parents=True, exist_ok=True)

    parquet_path = hf_hub_download(args.dataset_repo, args.parquet_file, repo_type="dataset")
    table = pq.read_table(parquet_path)
    if args.row_index < 0 or args.row_index >= table.num_rows:
        raise ValueError(
            f"row-index {args.row_index} is out of bounds for "
            f"{args.dataset_repo}/{args.parquet_file}; table has {table.num_rows} rows"
        )
    row = table.slice(args.row_index, 1).to_pylist()[0]
    prompt_wav = output_dir / "prompt.wav"
    prompt_txt = output_dir / "prompt.txt"
    prompt_cosyvoice3_txt = output_dir / "prompt_cosyvoice3.txt"
    metadata_json = output_dir / "metadata.json"

    audio_bytes, transcription = prompt_row_data(row, args.row_index)

    prompt_wav.write_bytes(audio_bytes)
    prompt_txt.write_text(transcription.strip() + "\n", encoding="utf-8")
    prompt_cosyvoice3_txt.write_text(build_prompt_text(transcription, args.prompt_prefix) + "\n", encoding="utf-8")
    metadata_json.write_text(
        json.dumps(
            {
                "dataset_repo": args.dataset_repo,
                "parquet_file": args.parquet_file,
                "row_index": args.row_index,
                "transcription": transcription,
                "language": row.get("language"),
                "speaker": row.get("speaker"),
                "speaker_type": row.get("speaker_type"),
                "type": row.get("type"),
                "in_game_filename": row.get("inGameFilename"),
            },
            indent=2,
        )
        + "\n",
        encoding="utf-8",
    )

    print(f"prompt_wav={prompt_wav}")
    print(f"prompt_txt={prompt_txt}")
    print(f"prompt_cosyvoice3_txt={prompt_cosyvoice3_txt}")
    print(f"speaker_id={args.speaker_id}")


if __name__ == "__main__":
    main()
