# chatterbox_dhivehi.py """ Dhivehi extension for ChatterboxTTS. Requires: chatterbox-tts 0.1.4 (not tested on any other version) Adds: - load_t3_with_vocab(state_dict, device, force_vocab_size): load T3 with a specific vocab size, resizing both the embedding and the projection head, and padding checkpoint weights if needed. - from_dhivehi(...): classmethod for building a ChatterboxTTS from a checkpoint directory, using load_t3_with_vocab under the hood (defaults to vocab=2000). - extend_dhivehi(): attach the above to ChatterboxTTS (idempotent). Usage in app.py: import chatterbox_dhivehi chatterbox_dhivehi.extend_dhivehi() self.model = ChatterboxTTS.from_dhivehi( ckpt_dir=Path(self.checkpoint), device="cuda" if torch.cuda.is_available() else "cpu", force_vocab_size=2000, ) """ from __future__ import annotations import logging from pathlib import Path from typing import Optional, Union import torch import torch.nn as nn from safetensors.torch import load_file # Core chatterbox imports from chatterbox.tts import ChatterboxTTS, Conditionals from chatterbox.models.t3 import T3 from chatterbox.models.s3gen import S3Gen from chatterbox.models.tokenizers import EnTokenizer from chatterbox.models.voice_encoder import VoiceEncoder # Helpers def _expand_or_trim_rows(t: torch.Tensor, new_rows: int, init_std: float = 0.02) -> torch.Tensor: """ Return a tensor with first dimension resized to `new_rows`. If expanding, newly added rows are randomly initialized N(0, init_std). """ old_rows = t.shape[0] if new_rows == old_rows: return t.clone() if new_rows < old_rows: return t[:new_rows].clone() # expand out = t.new_empty((new_rows,) + t.shape[1:]) out[:old_rows] = t out[old_rows:].normal_(mean=0.0, std=init_std) return out def _prepare_resized_state_dict(sd: dict, new_vocab: int, init_std: float = 0.02) -> dict: """ Create a modified copy of `sd` where text_emb/text_head weights (and bias) match `new_vocab`. """ sd = sd.copy() # text embedding: [vocab, dim] if "text_emb.weight" in sd: sd["text_emb.weight"] = _expand_or_trim_rows(sd["text_emb.weight"], new_vocab, init_std) # text projection head: Linear(out=vocab, in=dim) if "text_head.weight" in sd: sd["text_head.weight"] = _expand_or_trim_rows(sd["text_head.weight"], new_vocab, init_std) if "text_head.bias" in sd: bias = sd["text_head.bias"] if bias.ndim == 1: sd["text_head.bias"] = _expand_or_trim_rows(bias.unsqueeze(1), new_vocab, init_std).squeeze(1) return sd def _resize_model_vocab_layers(model: T3, new_vocab: int, dim: Optional[int] = None) -> None: """ Rebuild model.text_emb and model.text_head to match `new_vocab`. Embedding dim is inferred from existing layers if not provided. """ if dim is None: if hasattr(model, "text_emb") and isinstance(model.text_emb, nn.Embedding): dim = model.text_emb.embedding_dim elif hasattr(model, "text_head") and isinstance(model.text_head, nn.Linear): dim = model.text_head.in_features else: raise RuntimeError("Cannot infer text embedding dimension from T3 model.") model.text_emb = nn.Embedding(new_vocab, dim) model.text_head = nn.Linear(dim, new_vocab, bias=True) # Public api def load_t3_with_vocab( t3_state_dict: dict, device: str = "cpu", *, force_vocab_size: Optional[int] = None, init_std: float = 0.02, ) -> T3: """ Load a T3 model with a specified vocabulary size. - Removes a leading "t3." prefix on state_dict keys if present. - Resizes BOTH `text_emb` and `text_head` to `force_vocab_size` (or to the checkpoint vocab if not forced). - Pads checkpoint weights when the target vocab is larger than the checkpoint's. Args: t3_state_dict: state dict loaded from t3_cfg.safetensors (or similar). device: "cpu", "cuda", or "mps". force_vocab_size: desired vocab size (e.g., 2000 for Dhivehi-extended models). init_std: std for random init of padded rows. Returns: T3: model moved to `device` and set to eval(). """ logger = logging.getLogger(__name__) # Strip "t3." prefix if present if any(k.startswith("t3.") for k in t3_state_dict.keys()): t3_state_dict = {k[len("t3."):]: v for k, v in t3_state_dict.items()} # derive checkpoint vocab if available ckpt_vocab_size = None if "text_emb.weight" in t3_state_dict and t3_state_dict["text_emb.weight"].ndim == 2: ckpt_vocab_size = int(t3_state_dict["text_emb.weight"].shape[0]) elif "text_head.weight" in t3_state_dict and t3_state_dict["text_head.weight"].ndim == 2: ckpt_vocab_size = int(t3_state_dict["text_head.weight"].shape[0]) target_vocab = int(force_vocab_size) if force_vocab_size is not None else ckpt_vocab_size if target_vocab is None: raise RuntimeError("Could not determine vocab size. Provide force_vocab_size.") logger.info(f"Loading T3 with vocab={target_vocab} (ckpt_vocab={ckpt_vocab_size})") # Build a base model and resize layers to accept the incoming state dict t3 = T3() _resize_model_vocab_layers(t3, target_vocab) # Patch the checkpoint tensors to the target vocab patched_sd = _prepare_resized_state_dict(t3_state_dict, target_vocab, init_std) # Load (strict=False to tolerate benign extra/missing keys) t3.load_state_dict(patched_sd, strict=False) return t3.to(device).eval() def from_dhivehi( cls, *, ckpt_dir: Union[str, Path], device: str = "cpu", force_vocab_size: int = 2000, ): """ Construct a Dhivehi-extended ChatterboxTTS from a checkpoint directory. Expected files in `ckpt_dir`: - ve.safetensors - t3_cfg.safetensors - s3gen.safetensors - tokenizer.json - conds.pt (optional) """ ckpt_dir = Path(ckpt_dir) # Voice encoder ve = VoiceEncoder() ve.load_state_dict(load_file(ckpt_dir / "ve.safetensors")) ve.to(device).eval() # T3 with Dhivehi vocab extension t3_state = load_file(ckpt_dir / "t3_cfg.safetensors") t3 = load_t3_with_vocab(t3_state, device=device, force_vocab_size=force_vocab_size) # S3Gen s3gen = S3Gen() s3gen.load_state_dict(load_file(ckpt_dir / "s3gen.safetensors"), strict=False) s3gen.to(device).eval() # Tokenizer tokenizer = EnTokenizer(str(ckpt_dir / "tokenizer.json")) # Optional conditionals conds = None conds_path = ckpt_dir / "conds.pt" if conds_path.exists(): # Always safe-load to CPU first; .to(device) later conds = Conditionals.load(conds_path, map_location="cpu").to(device) return cls(t3, s3gen, ve, tokenizer, device, conds=conds) def extend_dhivehi(): """ Attach Dhivehi-specific helpers to ChatterboxTTS (idempotent). - ChatterboxTTS.load_t3_with_vocab (staticmethod) - ChatterboxTTS.from_dhivehi (classmethod) """ if getattr(ChatterboxTTS, "_dhivehi_extended", False): return ChatterboxTTS.load_t3_with_vocab = staticmethod(load_t3_with_vocab) ChatterboxTTS.from_dhivehi = classmethod(from_dhivehi) ChatterboxTTS._dhivehi_extended = True