chatterbox-tts-dhivehi / cbox_test.py
alakxender's picture
t
d735744
from pathlib import Path
import os
try:
from huggingface_hub import snapshot_download
_target = Path.home() / ".chatterbox-tts-dhivehi"
if not (_target.exists() and any(_target.rglob("*"))):
snapshot_download(
repo_id="alakxender/chatterbox-tts-dhivehi",
local_dir=str(_target),
local_dir_use_symlinks=False,
resume_download=True
)
except Exception as _e:
pass
from chatterbox.tts import ChatterboxTTS
import chatterbox_dhivehi
import torchaudio
import torch
import numpy as np
import random
# ---- User settings (edit these) ----
CKPT_DIR = f"{_target}/kn_cbox" # path to your finetuned checkpoint dir
REF_WAV = f"{_target}/samples/reference_audio.wav" # optional 3–10s clean reference; "" to disable
#REF_WAV = ""
TEXT = "މި ރިޕޯޓާ ގުޅޭ ގޮތުން އެނިމަލް ވެލްފެއާ މިނިސްޓްރީން އަދި ވާހަކައެއް ނުދައްކާ" # sample Dhivehi text
TEXT = f"{TEXT}, The Animal Welfare Ministry has not yet commented on the report"
EXAGGERATION = 0.4
TEMPERATURE = 0.3
CFG_WEIGHT = 0.7
SEED = 42
SAMPLE_RATE = 24000
OUT_PATH = "out.wav"
# ------------------------------------
# Extend Dhivehi support from local file
chatterbox_dhivehi.extend_dhivehi()
# Seed for reproducibility
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
random.seed(SEED)
np.random.seed(SEED)
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading ChatterboxTTS from: {CKPT_DIR} on {device}")
model = ChatterboxTTS.from_dhivehi(ckpt_dir=Path(CKPT_DIR), device=device)
print("Model loaded.")
# Generate (reference audio optional)
print(f"Generating audio... ref={'yes' if REF_WAV else 'no'}")
gen_kwargs = dict(
text=TEXT,
exaggeration=EXAGGERATION,
temperature=TEMPERATURE,
cfg_weight=CFG_WEIGHT,
)
try:
if REF_WAV:
gen_kwargs["audio_prompt_path"] = REF_WAV
audio = model.generate(**gen_kwargs)
else:
# Try without reference first; if backend requires audio_prompt_path, fall back to ""
try:
audio = model.generate(**gen_kwargs)
except TypeError:
gen_kwargs["audio_prompt_path"] = ""
audio = model.generate(**gen_kwargs)
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
# Save
torchaudio.save(OUT_PATH, audio, SAMPLE_RATE)
dur = audio.shape[1] / SAMPLE_RATE
print(f"Saved {OUT_PATH} ({dur:.2f}s)")