vibingvoice commited on
Commit
d9adea5
·
verified ·
1 Parent(s): 2f52f66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -68
app.py CHANGED
@@ -1,27 +1,28 @@
1
  import os
2
  import sys
3
- import torch
4
  import spaces
 
5
  import numpy as np
6
  import soundfile as sf
7
  import librosa
8
  import logging
9
  import gradio as gr
10
  import tempfile
11
- from typing import Dict, Optional, List
 
12
 
13
  # --- 1. Setup Environment ---
14
 
15
- # Add the project root to the Python path to allow importing local modules
16
  project_root = os.path.dirname(os.path.abspath(__file__))
17
  if project_root not in sys.path:
18
  sys.path.insert(0, project_root)
19
 
20
- # Configure logging to see VibeVoice messages
21
  logging.basicConfig(level=logging.INFO, format='[%(name)s] %(message)s')
22
  logger = logging.getLogger("VibeVoiceGradio")
23
 
24
- # Mock ComfyUI's folder_paths module for model caching
25
  class MockFolderPaths:
26
  def get_folder_paths(self, folder_name):
27
  if folder_name == "checkpoints":
@@ -32,28 +33,36 @@ class MockFolderPaths:
32
 
33
  sys.modules['folder_paths'] = MockFolderPaths()
34
 
35
- # Import the node class after setting up the environment
36
- # We use MultiSpeakerNode as it can handle single-speaker text too.
37
  from nodes.multi_speaker_node import VibeVoiceMultipleSpeakersNode
38
 
39
- # --- 2. Load Model Globally ---
40
 
41
- logger.info("Initializing VibeVoice node...")
42
- # We use the multi-speaker node as it can handle single-speaker cases gracefully.
43
- # This instance will hold the model in memory for all Gradio calls.
44
- vibevoice_node = VibeVoiceMultipleSpeakersNode()
45
 
46
  try:
47
- logger.info("Loading VibeVoice-Large model. This may take a while on the first run...")
48
- # Pre-load the model into the node instance.
49
- vibevoice_node.load_model(
50
  model_name='VibeVoice-Large',
51
  model_path='aoi-ot/VibeVoice-Large',
52
  attention_type='auto'
53
  )
54
- logger.info("VibeVoice-Large model loaded successfully!")
 
 
 
 
 
 
 
 
55
  except Exception as e:
56
- logger.error(f"Failed to load the model: {e}")
57
  logger.error("Please ensure you have an internet connection for the first run and sufficient VRAM.")
58
  sys.exit(1)
59
 
@@ -61,7 +70,7 @@ except Exception as e:
61
  # --- 3. Helper Functions ---
62
 
63
  def load_audio_for_node(file_path: Optional[str]) -> Optional[Dict]:
64
- """Loads an audio file from a path and formats it for the VibeVoice node."""
65
  if file_path is None:
66
  return None
67
  try:
@@ -75,19 +84,22 @@ def load_audio_for_node(file_path: Optional[str]) -> Optional[Dict]:
75
  def save_audio_to_tempfile(audio_dict: Dict) -> Optional[str]:
76
  """Saves the node's audio output to a temporary WAV file for Gradio."""
77
  if not audio_dict or "waveform" not in audio_dict:
78
- logger.error("Invalid audio dictionary received from node.")
79
  return None
80
 
81
- waveform_tensor = audio_dict["waveform"]
82
- sample_rate = audio_dict["sample_rate"]
83
-
84
- waveform_np = waveform_tensor.squeeze().cpu().numpy()
85
-
86
- # Create a temporary file
87
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
88
- sf.write(tmpfile.name, waveform_np, sample_rate)
89
  return tmpfile.name
90
 
 
 
 
 
 
 
 
 
 
91
  # --- 4. Gradio Core Logic ---
92
 
93
  @spaces.GPU
@@ -103,52 +115,68 @@ def generate_speech_gradio(
103
  use_sampling: bool,
104
  temperature: float,
105
  top_p: float,
 
106
  progress=gr.Progress(track_tqdm=True)
107
  ):
108
- """The main function that Gradio will call to generate speech."""
109
  if not text or not text.strip():
110
  raise gr.Error("Please provide some text to generate.")
111
 
112
- progress(0, desc="Processing audio inputs...")
113
- logger.info("Processing user inputs...")
114
-
115
- # Load uploaded voices
116
- speaker_voices = [
117
- load_audio_for_node(speaker1_audio_path),
118
- load_audio_for_node(speaker2_audio_path),
119
- load_audio_for_node(speaker3_audio_path),
120
- load_audio_for_node(speaker4_audio_path),
121
- ]
122
 
123
  progress(0.2, desc="Generating speech... (this can take a moment)")
124
- logger.info("Calling VibeVoice model to generate speech...")
125
-
126
  try:
127
- # Call the generate_speech method on our globally loaded node
128
- audio_output_tuple = vibevoice_node.generate_speech(
129
- text=text,
130
- model='VibeVoice-Large',
131
- attention_type='auto',
132
- free_memory_after_generate=False, # Keep model in memory for next call
133
- diffusion_steps=int(diffusion_steps),
134
- seed=int(seed),
135
- cfg_scale=cfg_scale,
136
- use_sampling=use_sampling,
137
- speaker1_voice=speaker_voices[0],
138
- speaker2_voice=speaker_voices[1],
139
- speaker3_voice=speaker_voices[2],
140
- speaker4_voice=speaker_voices[3],
141
- temperature=temperature,
142
- top_p=top_p
143
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  except Exception as e:
145
- logger.error(f"Error during speech generation: {e}")
146
  raise gr.Error(f"An error occurred during generation: {e}")
147
 
148
  progress(0.9, desc="Saving audio file...")
149
- logger.info("Generation complete. Saving audio output.")
150
-
151
- # Save the output to a temporary file for Gradio to serve
152
  output_audio_path = save_audio_to_tempfile(audio_output_tuple[0])
153
 
154
  if output_audio_path is None:
@@ -161,7 +189,7 @@ def generate_speech_gradio(
161
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
162
  gr.Markdown(
163
  "# VibeVoice Text-to-Speech Demo\n"
164
- "Generate multi-speaker conversations with optional voice cloning using Microsoft's VibeVoice-Large model."
165
  )
166
 
167
  with gr.Row():
@@ -169,15 +197,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
169
  text_input = gr.Textbox(
170
  label="Text Input",
171
  placeholder=(
172
- "Enter text using speaker tags like [1]:, [2]:, etc.\n\n"
173
  "[1]: Hello, I'm the first speaker.\n"
174
- "[2]: Hi there, I'm the second! How are you?\n"
175
- "[1]: I'm doing great, thanks for asking!"
176
  ),
177
  lines=8,
178
  max_lines=20
179
  )
180
- with gr.Accordion("Upload Speaker Voices (Optional)", open=False):
181
  gr.Markdown("Upload a short audio clip (3-30 seconds, clear audio) for each speaker you want to clone.")
182
  with gr.Row():
183
  speaker1_audio = gr.Audio(label="Speaker 1 Voice", type="filepath")
@@ -193,6 +220,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
193
  use_sampling = gr.Checkbox(label="Use Sampling", value=False, interactive=True, info="Enable for more varied, less deterministic output.")
194
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.05, value=0.95, interactive=True, info="Only used when sampling is enabled.")
195
  top_p = gr.Slider(label="Top P", minimum=0.1, maximum=1.0, step=0.05, value=0.95, interactive=True, info="Only used when sampling is enabled.")
 
196
 
197
  with gr.Column(scale=1):
198
  generate_button = gr.Button("Generate Speech", variant="primary")
@@ -201,7 +229,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
201
  inputs = [
202
  text_input,
203
  speaker1_audio, speaker2_audio, speaker3_audio, speaker4_audio,
204
- seed, diffusion_steps, cfg_scale, use_sampling, temperature, top_p
205
  ]
206
 
207
  generate_button.click(
@@ -211,5 +239,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
211
  )
212
 
213
  if __name__ == "__main__":
214
- # Launch the Gradio app
215
- demo.launch(share=True) # Add share=True to create a public link: demo.launch(share=True)
 
1
  import os
2
  import sys
 
3
  import spaces
4
+ import torch
5
  import numpy as np
6
  import soundfile as sf
7
  import librosa
8
  import logging
9
  import gradio as gr
10
  import tempfile
11
+ import re
12
+ from typing import Dict, Optional
13
 
14
  # --- 1. Setup Environment ---
15
 
16
+ # Add the project root to the Python path
17
  project_root = os.path.dirname(os.path.abspath(__file__))
18
  if project_root not in sys.path:
19
  sys.path.insert(0, project_root)
20
 
21
+ # Configure logging
22
  logging.basicConfig(level=logging.INFO, format='[%(name)s] %(message)s')
23
  logger = logging.getLogger("VibeVoiceGradio")
24
 
25
+ # Mock ComfyUI's folder_paths module
26
  class MockFolderPaths:
27
  def get_folder_paths(self, folder_name):
28
  if folder_name == "checkpoints":
 
33
 
34
  sys.modules['folder_paths'] = MockFolderPaths()
35
 
36
+ # Import BOTH node classes
37
+ from nodes.single_speaker_node import VibeVoiceSingleSpeakerNode
38
  from nodes.multi_speaker_node import VibeVoiceMultipleSpeakersNode
39
 
40
+ # --- 2. Load Models and Share Weights ---
41
 
42
+ logger.info("Initializing VibeVoice nodes...")
43
+ # Instantiate both node types.
44
+ single_speaker_node = VibeVoiceSingleSpeakerNode()
45
+ multi_speaker_node = VibeVoiceMultipleSpeakersNode()
46
 
47
  try:
48
+ logger.info("Loading VibeVoice-Large model once. This may take a while on the first run...")
49
+ # Load the model into one node first.
50
+ multi_speaker_node.load_model(
51
  model_name='VibeVoice-Large',
52
  model_path='aoi-ot/VibeVoice-Large',
53
  attention_type='auto'
54
  )
55
+
56
+ logger.info("Sharing loaded model weights between node instances...")
57
+ single_speaker_node.model = multi_speaker_node.model
58
+ single_speaker_node.processor = multi_speaker_node.processor
59
+ single_speaker_node.current_model_path = multi_speaker_node.current_model_path
60
+ single_speaker_node.current_attention_type = multi_speaker_node.current_attention_type
61
+
62
+ logger.info("VibeVoice-Large model loaded and shared successfully!")
63
+
64
  except Exception as e:
65
+ logger.error(f"Failed to load the model: {e}", exc_info=True)
66
  logger.error("Please ensure you have an internet connection for the first run and sufficient VRAM.")
67
  sys.exit(1)
68
 
 
70
  # --- 3. Helper Functions ---
71
 
72
  def load_audio_for_node(file_path: Optional[str]) -> Optional[Dict]:
73
+ """Loads an audio file and formats it for the node."""
74
  if file_path is None:
75
  return None
76
  try:
 
84
  def save_audio_to_tempfile(audio_dict: Dict) -> Optional[str]:
85
  """Saves the node's audio output to a temporary WAV file for Gradio."""
86
  if not audio_dict or "waveform" not in audio_dict:
 
87
  return None
88
 
89
+ waveform_np = audio_dict["waveform"].squeeze().cpu().numpy()
 
 
 
 
 
90
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
91
+ sf.write(tmpfile.name, waveform_np, audio_dict["sample_rate"])
92
  return tmpfile.name
93
 
94
+ def detect_speaker_count(text: str) -> int:
95
+ """Analyzes text to count the number of unique speakers."""
96
+ speaker_tags = re.findall(r'\[(\d+)\]\s*:', text)
97
+ if not speaker_tags:
98
+ # No tags found, treat as a single speaker monologue.
99
+ return 1
100
+ unique_speakers = set(int(tag) for tag in speaker_tags)
101
+ return len(unique_speakers)
102
+
103
  # --- 4. Gradio Core Logic ---
104
 
105
  @spaces.GPU
 
115
  use_sampling: bool,
116
  temperature: float,
117
  top_p: float,
118
+ max_words_per_chunk: int,
119
  progress=gr.Progress(track_tqdm=True)
120
  ):
121
+ """The main function that Gradio will call, now with dynamic node switching."""
122
  if not text or not text.strip():
123
  raise gr.Error("Please provide some text to generate.")
124
 
125
+ progress(0, desc="Analyzing text and loading voices...")
126
+
127
+ speaker_count = detect_speaker_count(text)
128
+
129
+ # Load voices
130
+ speaker1_voice = load_audio_for_node(speaker1_audio_path)
131
+ speaker2_voice = load_audio_for_node(speaker2_audio_path)
132
+ speaker3_voice = load_audio_for_node(speaker3_audio_path)
133
+ speaker4_voice = load_audio_for_node(speaker4_audio_path)
 
134
 
135
  progress(0.2, desc="Generating speech... (this can take a moment)")
136
+
 
137
  try:
138
+ if speaker_count <= 1:
139
+ logger.info(f"Detected single speaker. Using VibeVoiceSingleSpeakerNode.")
140
+ # Prepare text for single speaker node (remove tags like [1]:)
141
+ processed_text = re.sub(r'\[1\]\s*:', '', text).strip()
142
+
143
+ audio_output_tuple = single_speaker_node.generate_speech(
144
+ text=processed_text,
145
+ model='VibeVoice-Large',
146
+ attention_type='auto',
147
+ free_memory_after_generate=False,
148
+ diffusion_steps=int(diffusion_steps),
149
+ seed=int(seed),
150
+ cfg_scale=cfg_scale,
151
+ use_sampling=use_sampling,
152
+ voice_to_clone=speaker1_voice, # Use speaker 1's voice for cloning
153
+ temperature=temperature,
154
+ top_p=top_p,
155
+ max_words_per_chunk=int(max_words_per_chunk)
156
+ )
157
+ else:
158
+ logger.info(f"Detected {speaker_count} speakers. Using VibeVoiceMultipleSpeakersNode.")
159
+ audio_output_tuple = multi_speaker_node.generate_speech(
160
+ text=text,
161
+ model='VibeVoice-Large',
162
+ attention_type='auto',
163
+ free_memory_after_generate=False,
164
+ diffusion_steps=int(diffusion_steps),
165
+ seed=int(seed),
166
+ cfg_scale=cfg_scale,
167
+ use_sampling=use_sampling,
168
+ speaker1_voice=speaker1_voice,
169
+ speaker2_voice=speaker2_voice,
170
+ speaker3_voice=speaker3_voice,
171
+ speaker4_voice=speaker4_voice,
172
+ temperature=temperature,
173
+ top_p=top_p
174
+ )
175
  except Exception as e:
176
+ logger.error(f"Error during speech generation: {e}", exc_info=True)
177
  raise gr.Error(f"An error occurred during generation: {e}")
178
 
179
  progress(0.9, desc="Saving audio file...")
 
 
 
180
  output_audio_path = save_audio_to_tempfile(audio_output_tuple[0])
181
 
182
  if output_audio_path is None:
 
189
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
190
  gr.Markdown(
191
  "# VibeVoice Text-to-Speech Demo\n"
192
+ "Generate single or multi-speaker audio. For single-speaker monologues, the system automatically uses a specialized node with text chunking."
193
  )
194
 
195
  with gr.Row():
 
197
  text_input = gr.Textbox(
198
  label="Text Input",
199
  placeholder=(
200
+ "Enter plain text for a single speaker, or use tags like [1]:, [2]: for multiple speakers.\n\n"
201
  "[1]: Hello, I'm the first speaker.\n"
202
+ "[2]: Hi there, I'm the second! How are you?"
 
203
  ),
204
  lines=8,
205
  max_lines=20
206
  )
207
+ with gr.Accordion("Upload Speaker Voices (Optional)", open=True):
208
  gr.Markdown("Upload a short audio clip (3-30 seconds, clear audio) for each speaker you want to clone.")
209
  with gr.Row():
210
  speaker1_audio = gr.Audio(label="Speaker 1 Voice", type="filepath")
 
220
  use_sampling = gr.Checkbox(label="Use Sampling", value=False, interactive=True, info="Enable for more varied, less deterministic output.")
221
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.05, value=0.95, interactive=True, info="Only used when sampling is enabled.")
222
  top_p = gr.Slider(label="Top P", minimum=0.1, maximum=1.0, step=0.05, value=0.95, interactive=True, info="Only used when sampling is enabled.")
223
+ max_words_per_chunk = gr.Slider(label="Max Words Per Chunk", minimum=100, maximum=500, step=10, value=250, interactive=True, info="For long single-speaker text. Splits text to avoid errors.")
224
 
225
  with gr.Column(scale=1):
226
  generate_button = gr.Button("Generate Speech", variant="primary")
 
229
  inputs = [
230
  text_input,
231
  speaker1_audio, speaker2_audio, speaker3_audio, speaker4_audio,
232
+ seed, diffusion_steps, cfg_scale, use_sampling, temperature, top_p, max_words_per_chunk
233
  ]
234
 
235
  generate_button.click(
 
239
  )
240
 
241
  if __name__ == "__main__":
242
+ demo.launch()