alakxender commited on
Commit
d735744
·
1 Parent(s): 723c802
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +449 -4
  3. cbox_test.py +79 -0
  4. chatterbox_dhivehi py +210 -0
  5. requirements.txt +1 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Chatterbox Tts Dhivehi
3
  emoji: 📉
4
  colorFrom: red
5
  colorTo: blue
 
1
  ---
2
+ title: Chatterbox TTS Dhivehi
3
  emoji: 📉
4
  colorFrom: red
5
  colorTo: blue
app.py CHANGED
@@ -1,7 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ try:
4
+ from huggingface_hub import snapshot_download
5
+ _target = Path.home() / ".chatterbox-tts-dhivehi"
6
+ if not (_target.exists() and any(_target.rglob("*"))):
7
+ snapshot_download(
8
+ repo_id="alakxender/chatterbox-tts-dhivehi",
9
+ local_dir=str(_target),
10
+ local_dir_use_symlinks=False,
11
+ resume_download=True
12
+ )
13
+ except Exception as _e:
14
+ pass
15
+
16
+ from chatterbox.tts import ChatterboxTTS
17
+ import torchaudio
18
+ import torch
19
+ import random
20
+ import numpy as np
21
  import gradio as gr
22
+ import tempfile
23
+ import os
24
+ import chatterbox_dhivehi
25
+ import warnings
26
+
27
+ warnings.filterwarnings("ignore")
28
+
29
+ chatterbox_dhivehi.extend_dhivehi()
30
+
31
+ class TTSApp:
32
+ def __init__(self, checkpoint=f"{_target}/kn_cbox"):
33
+ self.checkpoint = checkpoint
34
+ self.model = None
35
+ self.load_model()
36
+
37
+ def load_model(self):
38
+ """Load the TTS model"""
39
+ try:
40
+ print(f"Loading model with checkpoint: {self.checkpoint}")
41
+ self.model = ChatterboxTTS.from_dhivehi(
42
+ ckpt_dir=Path(self.checkpoint),
43
+ device="cuda" if torch.cuda.is_available() else "cpu"
44
+ )
45
+ print("Model loaded successfully!")
46
+ except Exception as e:
47
+ print(f"Error loading model: {e}")
48
+ raise e
49
+
50
+ def set_seed(self, seed: int):
51
+ """Set random seed for reproducibility"""
52
+ torch.manual_seed(seed)
53
+ if torch.cuda.is_available():
54
+ torch.cuda.manual_seed(seed)
55
+ torch.cuda.manual_seed_all(seed)
56
+ random.seed(seed)
57
+ np.random.seed(seed)
58
+
59
+ def generate_speech(self,
60
+ text,
61
+ reference_audio,
62
+ exaggeration=0.5,
63
+ temperature=0.1,
64
+ cfg_weight=0.5,
65
+ seed=42):
66
+ """Generate speech from text using voice cloning"""
67
+
68
+ # Clean the input text
69
+ text = self.clean_text(text)
70
+
71
+ if not text:
72
+ return None, "Please enter some text to generate speech."
73
+
74
+ if self.model is None:
75
+ return None, "Model not loaded. Please check your model paths."
76
+
77
+ try:
78
+ # Set seed for reproducibility
79
+ self.set_seed(seed)
80
+
81
+ # Handle reference audio - make it optional
82
+ audio_prompt_path = reference_audio
83
+
84
+ print(f"Generating audio for: {text[:50]}...")
85
+ if audio_prompt_path:
86
+ print(f"Using reference audio: {audio_prompt_path}")
87
+ else:
88
+ print("Generating without reference audio")
89
+
90
+ # Generate audio - handle optional reference audio
91
+ if audio_prompt_path:
92
+ audio = self.model.generate(
93
+ text=text,
94
+ audio_prompt_path=audio_prompt_path,
95
+ exaggeration=exaggeration,
96
+ temperature=temperature,
97
+ cfg_weight=cfg_weight,
98
+ )
99
+ else:
100
+ # Try without reference audio
101
+ try:
102
+ audio = self.model.generate(
103
+ text=text,
104
+ exaggeration=exaggeration,
105
+ temperature=temperature,
106
+ cfg_weight=cfg_weight,
107
+ )
108
+ except TypeError:
109
+ # If the model requires audio_prompt_path, try with empty string
110
+ audio = self.model.generate(
111
+ text=text,
112
+ audio_prompt_path="",
113
+ exaggeration=exaggeration,
114
+ temperature=temperature,
115
+ cfg_weight=cfg_weight,
116
+ )
117
+
118
+ # Save to temporary file
119
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
120
+ output_path = tmp_file.name
121
+
122
+ torchaudio.save(output_path, audio, 24000)
123
+
124
+ return output_path, f"Successfully generated speech! Audio length: {audio.shape[1]/24000:.2f} seconds"
125
+
126
+ except Exception as e:
127
+ error_msg = f"Error generating speech: {str(e)}"
128
+ print(error_msg)
129
+ return None, error_msg
130
+
131
+ def clean_text(self, text):
132
+ """Clean text by removing newlines at start/end, double spaces, and extra whitespace"""
133
+ import re
134
+
135
+ # Remove newlines at start and end
136
+ text = text.strip('\n\r')
137
+
138
+ # Replace multiple spaces with single space
139
+ text = re.sub(r'\s+', ' ', text)
140
+
141
+ # Strip leading and trailing spaces
142
+ text = text.strip()
143
+
144
+ return text
145
+
146
+ def split_sentences(self, text):
147
+ """Split text into sentences based on periods, ensuring each sentence is at least 150 characters"""
148
+ # Clean the input text first
149
+ text = self.clean_text(text)
150
+
151
+ # First, split by periods normally
152
+ initial_sentences = []
153
+ current_sentence = ""
154
+
155
+ for char in text:
156
+ current_sentence += char
157
+ if char == '.':
158
+ # Add sentence if it's not empty after stripping spaces from both sides
159
+ stripped_sentence = current_sentence.strip()
160
+ if stripped_sentence:
161
+ initial_sentences.append(stripped_sentence)
162
+ current_sentence = ""
163
+
164
+ # Add remaining text if any (without period), stripped of spaces from both sides
165
+ stripped_remaining = current_sentence.strip()
166
+ if stripped_remaining:
167
+ initial_sentences.append(stripped_remaining)
168
+
169
+ # If we only have one sentence, return it
170
+ if len(initial_sentences) <= 1:
171
+ return initial_sentences
172
+
173
+ # Now combine sentences until each is at least 150 characters
174
+ final_sentences = []
175
+ combined_sentence = ""
176
+
177
+ for sentence in initial_sentences:
178
+ if combined_sentence:
179
+ combined_sentence += " " + sentence
180
+ else:
181
+ combined_sentence = sentence
182
+
183
+ # If combined sentence is >= 150 chars, add it to final list
184
+ if len(combined_sentence) >= 150:
185
+ final_sentences.append(combined_sentence.strip())
186
+ combined_sentence = ""
187
+
188
+ # Add any remaining combined sentence (even if < 150 chars)
189
+ if combined_sentence.strip():
190
+ final_sentences.append(combined_sentence.strip())
191
+
192
+ return final_sentences
193
+
194
+ def generate_speech_multi_sentence(self,
195
+ text,
196
+ reference_audio,
197
+ exaggeration=0.5,
198
+ temperature=0.1,
199
+ cfg_weight=0.5,
200
+ seed=42):
201
+ """Generate speech from text with multi-sentence support and progress tracking"""
202
+
203
+ # Clean the input text
204
+ text = self.clean_text(text)
205
+
206
+ if not text:
207
+ yield None, "Please enter some text to generate speech."
208
+ return
209
+
210
+ if self.model is None:
211
+ yield None, "Model not loaded. Please check your model paths."
212
+ return
213
+
214
+ # Split text into sentences
215
+ sentences = self.split_sentences(text)
216
+
217
+ # If only one sentence or no periods, use regular method
218
+ if len(sentences) <= 1:
219
+ yield None, "🎵 Generating single sentence..."
220
+ result_audio, result_status = self.generate_speech(text, reference_audio, exaggeration, temperature, cfg_weight, seed)
221
+ yield result_audio, result_status
222
+ return
223
+
224
+ try:
225
+ # Set seed for reproducibility
226
+ self.set_seed(seed)
227
+
228
+ # Handle reference audio - make it optional
229
+ audio_prompt_path = reference_audio
230
+
231
+ yield None, f"🚀 Starting generation for {len(sentences)} sentences..."
232
+ print(f"Processing {len(sentences)} sentences...")
233
+
234
+ all_audio_segments = []
235
+ total_duration = 0
236
+
237
+ for i, sentence in enumerate(sentences):
238
+ # Calculate progress percentage
239
+ progress_percent = int((i / len(sentences)) * 90) # Reserve last 10% for combining
240
+ yield None, f"🎵 Generating sentence {i+1}/{len(sentences)} ({progress_percent}%): {sentence[:50]}..."
241
+
242
+ print(f"Generating audio for sentence {i+1}/{len(sentences)}: {sentence[:50]}...")
243
+
244
+ # Generate audio for this sentence
245
+ try:
246
+ if audio_prompt_path:
247
+ audio = self.model.generate(
248
+ text=sentence,
249
+ audio_prompt_path=audio_prompt_path,
250
+ exaggeration=exaggeration,
251
+ temperature=temperature,
252
+ cfg_weight=cfg_weight,
253
+ )
254
+ else:
255
+ # Try without reference audio
256
+ try:
257
+ audio = self.model.generate(
258
+ text=sentence,
259
+ exaggeration=exaggeration,
260
+ temperature=temperature,
261
+ cfg_weight=cfg_weight,
262
+ )
263
+ except TypeError:
264
+ # If the model requires audio_prompt_path, try with empty string
265
+ audio = self.model.generate(
266
+ text=sentence,
267
+ audio_prompt_path="",
268
+ exaggeration=exaggeration,
269
+ temperature=temperature,
270
+ cfg_weight=cfg_weight,
271
+ )
272
+ except Exception as model_error:
273
+ # If the model fails due to missing reference audio, try with default behavior
274
+ if "reference_voice.wav not found" in str(model_error) or "No reference audio provided" in str(model_error):
275
+ print("Attempting generation without reference audio...")
276
+ # Try different approaches for models that don't support None reference audio
277
+ try:
278
+ # Some models might accept an empty string
279
+ audio = self.model.generate(
280
+ text=sentence,
281
+ audio_prompt_path="",
282
+ exaggeration=exaggeration,
283
+ temperature=temperature,
284
+ cfg_weight=cfg_weight,
285
+ )
286
+ except:
287
+ # If that fails, try without the audio_prompt_path parameter entirely
288
+ audio = self.model.generate(
289
+ text=sentence,
290
+ exaggeration=exaggeration,
291
+ temperature=temperature,
292
+ cfg_weight=cfg_weight,
293
+ )
294
+ else:
295
+ raise model_error
296
+
297
+ all_audio_segments.append(audio)
298
+ total_duration += audio.shape[1] / 24000
299
+
300
+ # Concatenate all audio segments
301
+ yield None, "🔧 Combining audio segments (95%)..."
302
+ print("Combining audio segments...")
303
+ combined_audio = torch.cat(all_audio_segments, dim=1)
304
+
305
+ # Save to temporary file
306
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
307
+ output_path = tmp_file.name
308
+
309
+ torchaudio.save(output_path, combined_audio, 24000)
310
+ print("Multi-sentence processing complete!")
311
+
312
+ yield output_path, f"✅ Successfully generated speech from {len(sentences)} sentences! Total audio length: {total_duration:.2f} seconds"
313
+
314
+ except Exception as e:
315
+ error_msg = f"❌ Error generating multi-sentence speech: {str(e)}"
316
+ print(error_msg)
317
+ yield None, error_msg
318
 
319
+ def get_cbox_dv():
320
+ """Create the Gradio interface"""
321
+
322
+ # Initialize the TTS app
323
+ tts_app = TTSApp()
324
+
325
+ # Sample texts in Dhivehi
326
+ sample_texts = [
327
+ "ކާޑު ނުލައި ފައިސާ ދެއްކޭ ނެޝަނަލް ކިއުއާރް ކޯޑް އެމްއެމްއޭ އިން ތައާރަފްކުރަނީ",
328
+ """ފުޓްބޯޅަ ސްކޫލްގެ ބިމާއި ގުދަންބަރި ބިމުގައި އިމާރާތް ކުރުމުގެ މަސައްކަތް ހުއްޓާލަން އަންގައިފި...
329
+ Construction work on football school land and warehouse land has been ordered to stop""",
330
+ "ސިވިލް ސާވިސްގެ ހިދުމަތުގެ މުއްދަތު ގުނުމުގައި ކުންފުނިތަކާއި އިދާރާތަކަށް ހިދުމަތްކުރި މުއްދަތު ހިމަނަނީ",
331
+ """އެ ރަށުގެ ބިން ހިއްކުމާއި ބަނދަރުގެ ނެރު ބަދަލުކުރުމާއި ގޮނޑުދޮށް ހިމާޔަތް ކުރުމުގެ މަސައްކަތް އެމްޓީސީސީއާ މިނިސްޓްރީން ހަވާލުކުރީ މިދިޔަ މަހު ރައީސް އެ ރަށަށް ކުރެއްވި ދަތުރުފުޅުގައި.
332
+ The ministry handed over the land reclamation, replacement of the port canal and beach protection to MTCC during the President's visit to the village last month"""
333
+ ]
334
+
335
+ with gr.Tab("🎤 ChatterboxTTS"):
336
+ gr.Markdown("# 🎤 ChatterboxTTS - Dhivehi Text-to-Speech with Voice Cloning")
337
+ gr.Markdown("Generate natural-sounding Dhivehi speech with voice cloning capabilities.")
338
+
339
+ # Row 1: Text input and Reference audio
340
+ with gr.Row():
341
+ text_input = gr.Textbox(
342
+ label="Text to Convert",
343
+ placeholder="Enter Dhivehi text here...",
344
+ lines=6,
345
+ value=sample_texts[0],
346
+ rtl=True,
347
+ elem_classes=["textbox1"]
348
+ )
349
+ reference_audio = gr.Audio(
350
+ label="Reference Voice Audio (optional - for voice cloning)",
351
+ type="filepath",
352
+ sources=["upload", "microphone"],
353
+ )
354
+
355
+ # Row 2: Example buttons
356
+ gr.Markdown("**Quick Examples:**")
357
+ with gr.Row():
358
+ sample_btn1 = gr.Button("Sample 1", size="sm")
359
+ sample_btn2 = gr.Button("Sample 2", size="sm")
360
+ sample_btn3 = gr.Button("Sample 3", size="sm")
361
+ sample_btn4 = gr.Button("Sample 4", size="sm")
362
 
363
+ # Row 3: Advanced settings
364
+ with gr.Accordion("Advanced Settings", open=False):
365
+ with gr.Row():
366
+ exaggeration = gr.Slider(
367
+ minimum=0.0,
368
+ maximum=2.0,
369
+ value=0.5,
370
+ step=0.1,
371
+ label="Exaggeration",
372
+ info="Controls expressiveness"
373
+ )
374
+ temperature = gr.Slider(
375
+ minimum=0.01,
376
+ maximum=1.0,
377
+ value=0.35,
378
+ step=0.01,
379
+ label="Temperature",
380
+ info="Controls randomness"
381
+ )
382
+ cfg_weight = gr.Slider(
383
+ minimum=0.0,
384
+ maximum=2.0,
385
+ value=0.3,
386
+ step=0.1,
387
+ label="CFG Weight",
388
+ info="Classifier-free guidance weight"
389
+ )
390
+ seed = gr.Slider(
391
+ minimum=0,
392
+ maximum=9999,
393
+ value=42,
394
+ step=1,
395
+ label="Seed",
396
+ info="For reproducible results"
397
+ )
398
+
399
+ # Row 4: Generate button
400
+ generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
401
+
402
+ # Row 5: Output section
403
+ with gr.Row():
404
+ with gr.Column():
405
+ output_audio = gr.Audio(label="Generated Speech", type="filepath")
406
+ status_message = gr.Textbox(label="Status", interactive=False)
407
+
408
+ # Event handlers
409
+ def set_sample_text(sample_idx):
410
+ return sample_texts[sample_idx]
411
+
412
+ sample_btn1.click(lambda: set_sample_text(0), outputs=[text_input])
413
+ sample_btn2.click(lambda: set_sample_text(1), outputs=[text_input])
414
+ sample_btn3.click(lambda: set_sample_text(2), outputs=[text_input])
415
+ sample_btn4.click(lambda: set_sample_text(3), outputs=[text_input])
416
+
417
+ def generate_with_progress(text, reference_audio, exaggeration, temperature, cfg_weight, seed):
418
+ """Generate speech with streaming progress updates"""
419
+ # Use the streaming generator from the TTS app
420
+ for result_audio, result_status in tts_app.generate_speech_multi_sentence(
421
+ text, reference_audio, exaggeration, temperature, cfg_weight, seed
422
+ ):
423
+ yield result_audio, result_status
424
+
425
+ generate_btn.click(
426
+ fn=generate_with_progress,
427
+ inputs=[text_input, reference_audio, exaggeration, temperature, cfg_weight, seed],
428
+ outputs=[output_audio, status_message]
429
+ )
430
+
431
+ # Instructions
432
+ with gr.Accordion("Tips", open=False):
433
+ gr.Markdown("""
434
+ ### General Use (TTS and Voice Agents):
435
+ - The default settings (exaggeration=0.5, cfg=0.5) work well for most prompts.
436
+ - If the reference speaker has a fast speaking style, lowering cfg to around 0.3 can improve pacing.
437
+
438
+ ### Expressive or Dramatic Speech:
439
+ - Try lower cfg values (e.g. ~0.3) and increase exaggeration to around 0.7 or higher.
440
+ - Higher exaggeration tends to speed up speech; reducing cfg helps compensate with slower, more deliberate pacing.
441
+
442
+ ### Language Transfer Notes:
443
+ - Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip's language.
444
+ - To mitigate this, set the CFG weight to 0.
445
+
446
+ ### Additional Tips:
447
+ - For best voice cloning results, use clear audio with minimal background noise
448
+ - The reference audio should be 3-10 seconds long
449
+ - Use the same seed value for reproducible results
450
+ """)
451
+
452
+ return app
cbox_test.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ try:
4
+ from huggingface_hub import snapshot_download
5
+ _target = Path.home() / ".chatterbox-tts-dhivehi"
6
+ if not (_target.exists() and any(_target.rglob("*"))):
7
+ snapshot_download(
8
+ repo_id="alakxender/chatterbox-tts-dhivehi",
9
+ local_dir=str(_target),
10
+ local_dir_use_symlinks=False,
11
+ resume_download=True
12
+ )
13
+ except Exception as _e:
14
+ pass
15
+
16
+ from chatterbox.tts import ChatterboxTTS
17
+ import chatterbox_dhivehi
18
+ import torchaudio
19
+ import torch
20
+ import numpy as np
21
+ import random
22
+ # ---- User settings (edit these) ----
23
+ CKPT_DIR = f"{_target}/kn_cbox" # path to your finetuned checkpoint dir
24
+ REF_WAV = f"{_target}/samples/reference_audio.wav" # optional 3–10s clean reference; "" to disable
25
+ #REF_WAV = ""
26
+ TEXT = "މި ރިޕޯޓާ ގުޅޭ ގޮތުން އެނިމަލް ވެލްފެއާ މިނިސްޓްރީން އަދި ވާހަކައެއް ނުދައްކާ" # sample Dhivehi text
27
+ TEXT = f"{TEXT}, The Animal Welfare Ministry has not yet commented on the report"
28
+ EXAGGERATION = 0.4
29
+ TEMPERATURE = 0.3
30
+ CFG_WEIGHT = 0.7
31
+ SEED = 42
32
+ SAMPLE_RATE = 24000
33
+ OUT_PATH = "out.wav"
34
+ # ------------------------------------
35
+
36
+ # Extend Dhivehi support from local file
37
+ chatterbox_dhivehi.extend_dhivehi()
38
+
39
+ # Seed for reproducibility
40
+ torch.manual_seed(SEED)
41
+ if torch.cuda.is_available():
42
+ torch.cuda.manual_seed(SEED)
43
+ torch.cuda.manual_seed_all(SEED)
44
+ random.seed(SEED)
45
+ np.random.seed(SEED)
46
+
47
+ # Load model
48
+ device = "cuda" if torch.cuda.is_available() else "cpu"
49
+ print(f"Loading ChatterboxTTS from: {CKPT_DIR} on {device}")
50
+ model = ChatterboxTTS.from_dhivehi(ckpt_dir=Path(CKPT_DIR), device=device)
51
+ print("Model loaded.")
52
+
53
+ # Generate (reference audio optional)
54
+ print(f"Generating audio... ref={'yes' if REF_WAV else 'no'}")
55
+ gen_kwargs = dict(
56
+ text=TEXT,
57
+ exaggeration=EXAGGERATION,
58
+ temperature=TEMPERATURE,
59
+ cfg_weight=CFG_WEIGHT,
60
+ )
61
+
62
+ try:
63
+ if REF_WAV:
64
+ gen_kwargs["audio_prompt_path"] = REF_WAV
65
+ audio = model.generate(**gen_kwargs)
66
+ else:
67
+ # Try without reference first; if backend requires audio_prompt_path, fall back to ""
68
+ try:
69
+ audio = model.generate(**gen_kwargs)
70
+ except TypeError:
71
+ gen_kwargs["audio_prompt_path"] = ""
72
+ audio = model.generate(**gen_kwargs)
73
+ except Exception as e:
74
+ raise RuntimeError(f"Generation failed: {e}")
75
+
76
+ # Save
77
+ torchaudio.save(OUT_PATH, audio, SAMPLE_RATE)
78
+ dur = audio.shape[1] / SAMPLE_RATE
79
+ print(f"Saved {OUT_PATH} ({dur:.2f}s)")
chatterbox_dhivehi py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chatterbox_dhivehi.py
2
+ """
3
+ Dhivehi extension for ChatterboxTTS.
4
+
5
+ Requires: chatterbox-tts 0.1.4 (not tested on any other version)
6
+
7
+ Adds:
8
+ - load_t3_with_vocab(state_dict, device, force_vocab_size): load T3 with a specific vocab size,
9
+ resizing both the embedding and the projection head, and padding checkpoint weights if needed.
10
+ - from_dhivehi(...): classmethod for building a ChatterboxTTS from a checkpoint directory,
11
+ using load_t3_with_vocab under the hood (defaults to vocab=2000).
12
+ - extend_dhivehi(): attach the above to ChatterboxTTS (idempotent).
13
+
14
+ Usage in app.py:
15
+ import chatterbox_dhivehi
16
+ chatterbox_dhivehi.extend_dhivehi()
17
+
18
+ self.model = ChatterboxTTS.from_dhivehi(
19
+ ckpt_dir=Path(self.checkpoint),
20
+ device="cuda" if torch.cuda.is_available() else "cpu",
21
+ force_vocab_size=2000,
22
+ )
23
+ """
24
+
25
+ from __future__ import annotations
26
+ import logging
27
+ from pathlib import Path
28
+ from typing import Optional, Union
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ from safetensors.torch import load_file
33
+
34
+ # Core chatterbox imports
35
+ from chatterbox.tts import ChatterboxTTS, Conditionals
36
+ from chatterbox.models.t3 import T3
37
+ from chatterbox.models.s3gen import S3Gen
38
+ from chatterbox.models.tokenizers import EnTokenizer
39
+ from chatterbox.models.voice_encoder import VoiceEncoder
40
+
41
+
42
+ # Helpers
43
+
44
+ def _expand_or_trim_rows(t: torch.Tensor, new_rows: int, init_std: float = 0.02) -> torch.Tensor:
45
+ """
46
+ Return a tensor with first dimension resized to `new_rows`.
47
+ If expanding, newly added rows are randomly initialized N(0, init_std).
48
+ """
49
+ old_rows = t.shape[0]
50
+ if new_rows == old_rows:
51
+ return t.clone()
52
+ if new_rows < old_rows:
53
+ return t[:new_rows].clone()
54
+ # expand
55
+ out = t.new_empty((new_rows,) + t.shape[1:])
56
+ out[:old_rows] = t
57
+ out[old_rows:].normal_(mean=0.0, std=init_std)
58
+ return out
59
+
60
+
61
+ def _prepare_resized_state_dict(sd: dict, new_vocab: int, init_std: float = 0.02) -> dict:
62
+ """
63
+ Create a modified copy of `sd` where text_emb/text_head weights (and bias) match `new_vocab`.
64
+ """
65
+ sd = sd.copy()
66
+
67
+ # text embedding: [vocab, dim]
68
+ if "text_emb.weight" in sd:
69
+ sd["text_emb.weight"] = _expand_or_trim_rows(sd["text_emb.weight"], new_vocab, init_std)
70
+
71
+ # text projection head: Linear(out=vocab, in=dim)
72
+ if "text_head.weight" in sd:
73
+ sd["text_head.weight"] = _expand_or_trim_rows(sd["text_head.weight"], new_vocab, init_std)
74
+ if "text_head.bias" in sd:
75
+ bias = sd["text_head.bias"]
76
+ if bias.ndim == 1:
77
+ sd["text_head.bias"] = _expand_or_trim_rows(bias.unsqueeze(1), new_vocab, init_std).squeeze(1)
78
+
79
+ return sd
80
+
81
+
82
+ def _resize_model_vocab_layers(model: T3, new_vocab: int, dim: Optional[int] = None) -> None:
83
+ """
84
+ Rebuild model.text_emb and model.text_head to match `new_vocab`.
85
+ Embedding dim is inferred from existing layers if not provided.
86
+ """
87
+ if dim is None:
88
+ if hasattr(model, "text_emb") and isinstance(model.text_emb, nn.Embedding):
89
+ dim = model.text_emb.embedding_dim
90
+ elif hasattr(model, "text_head") and isinstance(model.text_head, nn.Linear):
91
+ dim = model.text_head.in_features
92
+ else:
93
+ raise RuntimeError("Cannot infer text embedding dimension from T3 model.")
94
+ model.text_emb = nn.Embedding(new_vocab, dim)
95
+ model.text_head = nn.Linear(dim, new_vocab, bias=True)
96
+
97
+
98
+ # Public api
99
+
100
+ def load_t3_with_vocab(
101
+ t3_state_dict: dict,
102
+ device: str = "cpu",
103
+ *,
104
+ force_vocab_size: Optional[int] = None,
105
+ init_std: float = 0.02,
106
+ ) -> T3:
107
+ """
108
+ Load a T3 model with a specified vocabulary size.
109
+
110
+ - Removes a leading "t3." prefix on state_dict keys if present.
111
+ - Resizes BOTH `text_emb` and `text_head` to `force_vocab_size` (or to the checkpoint vocab if not forced).
112
+ - Pads checkpoint weights when the target vocab is larger than the checkpoint's.
113
+
114
+ Args:
115
+ t3_state_dict: state dict loaded from t3_cfg.safetensors (or similar).
116
+ device: "cpu", "cuda", or "mps".
117
+ force_vocab_size: desired vocab size (e.g., 2000 for Dhivehi-extended models).
118
+ init_std: std for random init of padded rows.
119
+
120
+ Returns:
121
+ T3: model moved to `device` and set to eval().
122
+ """
123
+ logger = logging.getLogger(__name__)
124
+
125
+ # Strip "t3." prefix if present
126
+ if any(k.startswith("t3.") for k in t3_state_dict.keys()):
127
+ t3_state_dict = {k[len("t3."):]: v for k, v in t3_state_dict.items()}
128
+
129
+ # derive checkpoint vocab if available
130
+ ckpt_vocab_size = None
131
+ if "text_emb.weight" in t3_state_dict and t3_state_dict["text_emb.weight"].ndim == 2:
132
+ ckpt_vocab_size = int(t3_state_dict["text_emb.weight"].shape[0])
133
+ elif "text_head.weight" in t3_state_dict and t3_state_dict["text_head.weight"].ndim == 2:
134
+ ckpt_vocab_size = int(t3_state_dict["text_head.weight"].shape[0])
135
+
136
+ target_vocab = int(force_vocab_size) if force_vocab_size is not None else ckpt_vocab_size
137
+ if target_vocab is None:
138
+ raise RuntimeError("Could not determine vocab size. Provide force_vocab_size.")
139
+
140
+ logger.info(f"Loading T3 with vocab={target_vocab} (ckpt_vocab={ckpt_vocab_size})")
141
+
142
+ # Build a base model and resize layers to accept the incoming state dict
143
+ t3 = T3()
144
+ _resize_model_vocab_layers(t3, target_vocab)
145
+
146
+ # Patch the checkpoint tensors to the target vocab
147
+ patched_sd = _prepare_resized_state_dict(t3_state_dict, target_vocab, init_std)
148
+
149
+ # Load (strict=False to tolerate benign extra/missing keys)
150
+ t3.load_state_dict(patched_sd, strict=False)
151
+ return t3.to(device).eval()
152
+
153
+
154
+ def from_dhivehi(
155
+ cls,
156
+ *,
157
+ ckpt_dir: Union[str, Path],
158
+ device: str = "cpu",
159
+ force_vocab_size: int = 1199,
160
+ ):
161
+ """
162
+ Construct a Dhivehi-extended ChatterboxTTS from a checkpoint directory.
163
+
164
+ Expected files in `ckpt_dir`:
165
+ - ve.safetensors
166
+ - t3_cfg.safetensors
167
+ - s3gen.safetensors
168
+ - tokenizer.json
169
+ - conds.pt (optional)
170
+ """
171
+ ckpt_dir = Path(ckpt_dir)
172
+
173
+ # Voice encoder
174
+ ve = VoiceEncoder()
175
+ ve.load_state_dict(load_file(ckpt_dir / "ve.safetensors"))
176
+ ve.to(device).eval()
177
+
178
+ # T3 with Dhivehi vocab extension
179
+ t3_state = load_file(ckpt_dir / "t3_cfg.safetensors")
180
+ t3 = load_t3_with_vocab(t3_state, device=device, force_vocab_size=force_vocab_size)
181
+
182
+ # S3Gen
183
+ s3gen = S3Gen()
184
+ s3gen.load_state_dict(load_file(ckpt_dir / "s3gen.safetensors"), strict=False)
185
+ s3gen.to(device).eval()
186
+
187
+ # Tokenizer
188
+ tokenizer = EnTokenizer(str(ckpt_dir / "tokenizer.json"))
189
+
190
+ # Optional conditionals
191
+ conds = None
192
+ conds_path = ckpt_dir / "conds.pt"
193
+ if conds_path.exists():
194
+ # Always safe-load to CPU first; .to(device) later
195
+ conds = Conditionals.load(conds_path, map_location="cpu").to(device)
196
+
197
+ return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
198
+
199
+
200
+ def extend_dhivehi():
201
+ """
202
+ Attach Dhivehi-specific helpers to ChatterboxTTS (idempotent).
203
+ - ChatterboxTTS.load_t3_with_vocab (staticmethod)
204
+ - ChatterboxTTS.from_dhivehi (classmethod)
205
+ """
206
+ if getattr(ChatterboxTTS, "_dhivehi_extended", False):
207
+ return
208
+ ChatterboxTTS.load_t3_with_vocab = staticmethod(load_t3_with_vocab)
209
+ ChatterboxTTS.from_dhivehi = classmethod(from_dhivehi)
210
+ ChatterboxTTS._dhivehi_extended = True
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ chatterbox-tts==0.1.4