vibe-voice-custom-voices / nodes /multi_speaker_node.py
vibingvoice's picture
Upload 34 files
cd0b70a verified
# Created by Fabio Sarracino
import logging
import os
import re
import tempfile
import torch
import numpy as np
from typing import List, Optional
from .base_vibevoice import BaseVibeVoiceNode
# Setup logging
logger = logging.getLogger("VibeVoice")
class VibeVoiceMultipleSpeakersNode(BaseVibeVoiceNode):
def __init__(self):
super().__init__()
# Register this instance for memory management
try:
from .free_memory_node import VibeVoiceFreeMemoryNode
VibeVoiceFreeMemoryNode.register_multi_speaker(self)
except:
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": ("STRING", {
"multiline": True,
"default": "[1]: Hello, this is the first speaker.\n[2]: Hi there, I'm the second speaker.\n[1]: Nice to meet you!\n[2]: Nice to meet you too!",
"tooltip": "Text with speaker labels. Use '[N]:' format where N is 1-4. Gets disabled when connected to another node.",
"forceInput": False,
"dynamicPrompts": True
}),
"model": (["VibeVoice-1.5B", "VibeVoice-Large", "VibeVoice-Large-Quant-4Bit"], {
"default": "VibeVoice-Large", # Large recommended for multi-speaker
"tooltip": "Model to use. Large is recommended for multi-speaker generation, Quant-4Bit uses less VRAM (CUDA only)"
}),
"attention_type": (["auto", "eager", "sdpa", "flash_attention_2", "sage"], {
"default": "auto",
"tooltip": "Attention implementation. Auto selects the best available, eager is standard, sdpa is optimized PyTorch, flash_attention_2 requires compatible GPU, sage uses quantized attention for speedup (CUDA only)"
}),
"free_memory_after_generate": ("BOOLEAN", {"default": True, "tooltip": "Free model from memory after generation to save VRAM/RAM. Disable to keep model loaded for faster subsequent generations"}),
"diffusion_steps": ("INT", {"default": 20, "min": 5, "max": 100, "step": 1, "tooltip": "Number of denoising steps. More steps = better quality but slower. Default: 20"}),
"seed": ("INT", {"default": 42, "min": 0, "max": 2**32-1, "tooltip": "Random seed for generation. Default 42 is used in official examples"}),
"cfg_scale": ("FLOAT", {"default": 1.3, "min": 0.5, "max": 3.5, "step": 0.05, "tooltip": "Classifier-free guidance scale (official default: 1.3)"}),
"use_sampling": ("BOOLEAN", {"default": False, "tooltip": "Enable sampling mode. When False (default), uses deterministic generation like official examples"}),
},
"optional": {
"speaker1_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 1. If not provided, synthetic voice will be used."}),
"speaker2_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 2. If not provided, synthetic voice will be used."}),
"speaker3_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 3. If not provided, synthetic voice will be used."}),
"speaker4_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 4. If not provided, synthetic voice will be used."}),
"temperature": ("FLOAT", {"default": 0.95, "min": 0.1, "max": 2.0, "step": 0.05, "tooltip": "Only used when sampling is enabled"}),
"top_p": ("FLOAT", {"default": 0.95, "min": 0.1, "max": 1.0, "step": 0.05, "tooltip": "Only used when sampling is enabled"}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "generate_speech"
CATEGORY = "VibeVoiceWrapper"
DESCRIPTION = "Generate multi-speaker conversations with up to 4 distinct voices using Microsoft VibeVoice"
def _prepare_voice_sample(self, voice_audio, speaker_idx: int) -> Optional[np.ndarray]:
"""Prepare a single voice sample from input audio"""
return self._prepare_audio_from_comfyui(voice_audio)
def generate_speech(self, text: str = "", model: str = "VibeVoice-7B-Preview",
attention_type: str = "auto", free_memory_after_generate: bool = True,
diffusion_steps: int = 20, seed: int = 42, cfg_scale: float = 1.3,
use_sampling: bool = False, speaker1_voice=None, speaker2_voice=None,
speaker3_voice=None, speaker4_voice=None,
temperature: float = 0.95, top_p: float = 0.95):
"""Generate multi-speaker speech from text using VibeVoice"""
try:
# Check text input
if not text or not text.strip():
raise Exception("No text provided. Please enter text with speaker labels (e.g., '[1]: Hello' or '[2]: Hi')")
# First detect how many speakers are in the text
bracket_pattern = r'\[(\d+)\]\s*:'
speakers_numbers = sorted(list(set([int(m) for m in re.findall(bracket_pattern, text)])))
# Limit to 1-4 speakers
if not speakers_numbers:
num_speakers = 1 # Default to 1 if no speaker format found
else:
num_speakers = min(max(speakers_numbers), 4) # Max speaker number, capped at 4
if max(speakers_numbers) > 4:
print(f"[VibeVoice] Warning: Found {max(speakers_numbers)} speakers, limiting to 4")
# Direct conversion from [N]: to Speaker (N-1): for VibeVoice processor
# This avoids multiple conversion steps
converted_text = text
# Find all [N]: patterns in the text
speakers_in_text = sorted(list(set([int(m) for m in re.findall(bracket_pattern, text)])))
if not speakers_in_text:
# No [N]: format found, try Speaker N: format
speaker_pattern = r'Speaker\s+(\d+)\s*:'
speakers_in_text = sorted(list(set([int(m) for m in re.findall(speaker_pattern, text)])))
if speakers_in_text:
# Text already in Speaker N format, convert to 0-based
for speaker_num in sorted(speakers_in_text, reverse=True):
pattern = f'Speaker\\s+{speaker_num}\\s*:'
replacement = f'Speaker {speaker_num - 1}:'
converted_text = re.sub(pattern, replacement, converted_text)
else:
# No speaker format found
speakers_in_text = [1]
# Parse pause keywords even for single speaker
pause_segments = self._parse_pause_keywords(text)
# Store speaker segments for pause processing
speaker_segments_with_pauses = []
segments = []
for seg_type, seg_content in pause_segments:
if seg_type == 'pause':
speaker_segments_with_pauses.append(('pause', seg_content, None))
else:
# Clean up newlines
text_clean = seg_content.replace('\n', ' ').replace('\r', ' ')
text_clean = ' '.join(text_clean.split())
if text_clean:
speaker_segments_with_pauses.append(('text', text_clean, 1))
segments.append(f"Speaker 0: {text_clean}")
# Join all segments for fallback
converted_text = '\n'.join(segments) if segments else f"Speaker 0: {text}"
else:
# Convert [N]: directly to Speaker (N-1): and handle multi-line text
# Split text to preserve speaker segments while cleaning up newlines within each segment
segments = []
# Find all speaker markers with their positions
speaker_matches = list(re.finditer(f'\\[({"|".join(map(str, speakers_in_text))})\\]\\s*:', converted_text))
# Store speaker segments for pause processing
speaker_segments_with_pauses = []
for i, match in enumerate(speaker_matches):
speaker_num = int(match.group(1))
start = match.end()
# Find where this speaker's text ends (at next speaker or end of text)
if i + 1 < len(speaker_matches):
end = speaker_matches[i + 1].start()
else:
end = len(converted_text)
# Extract the speaker's text (keep pause keywords for now)
speaker_text = converted_text[start:end].strip()
# Parse pause keywords within this speaker's text
pause_segments = self._parse_pause_keywords(speaker_text)
# Process each segment (text or pause) for this speaker
for seg_type, seg_content in pause_segments:
if seg_type == 'pause':
# Add pause segment
speaker_segments_with_pauses.append(('pause', seg_content, None))
else:
# Clean up the text segment
text_clean = seg_content.replace('\n', ' ').replace('\r', ' ')
text_clean = ' '.join(text_clean.split())
if text_clean: # Only add non-empty text
# Add text segment with speaker info
speaker_segments_with_pauses.append(('text', text_clean, speaker_num))
# Also build the traditional segments for fallback
segments.append(f'Speaker {speaker_num - 1}: {text_clean}')
# Join all segments with newlines (required for multi-speaker format) - for fallback
converted_text = '\n'.join(segments) if segments else ""
# Build speaker names list - these are just for logging, not used by processor
# The processor uses the speaker labels in the text itself
speakers = [f"Speaker {i}" for i in range(len(speakers_in_text))]
# Get model mapping and load model with attention type
model_mapping = self._get_model_mapping()
model_path = model_mapping.get(model, model)
self.load_model(model, model_path, attention_type)
voice_inputs = [speaker1_voice, speaker2_voice, speaker3_voice, speaker4_voice]
# Prepare voice samples in order of appearance
voice_samples = []
for i, speaker_num in enumerate(speakers_in_text):
idx = speaker_num - 1 # Convert to 0-based for voice array
# Try to use provided voice sample
if idx < len(voice_inputs) and voice_inputs[idx] is not None:
voice_sample = self._prepare_voice_sample(voice_inputs[idx], idx)
if voice_sample is None:
# Use the actual speaker index for consistent synthetic voice
voice_sample = self._create_synthetic_voice_sample(idx)
else:
# Use the actual speaker index for consistent synthetic voice
voice_sample = self._create_synthetic_voice_sample(idx)
voice_samples.append(voice_sample)
# Ensure voice_samples count matches detected speakers
if len(voice_samples) != len(speakers_in_text):
logger.error(f"Mismatch: {len(speakers_in_text)} speakers but {len(voice_samples)} voice samples!")
raise Exception(f"Voice sample count mismatch: expected {len(speakers_in_text)}, got {len(voice_samples)}")
# Check if we have pause segments to process
if 'speaker_segments_with_pauses' in locals() and speaker_segments_with_pauses:
# Process segments with pauses
all_audio_segments = []
sample_rate = 24000 # VibeVoice uses 24kHz
# Group consecutive text segments from same speaker for efficiency
grouped_segments = []
current_group = []
current_speaker = None
for seg_type, seg_content, speaker_num in speaker_segments_with_pauses:
if seg_type == 'pause':
# Save current group if any
if current_group:
grouped_segments.append(('text_group', current_group, current_speaker))
current_group = []
current_speaker = None
# Add pause
grouped_segments.append(('pause', seg_content, None))
else:
# Text segment
if speaker_num == current_speaker:
# Same speaker, add to current group
current_group.append(seg_content)
else:
# Different speaker, save current group and start new one
if current_group:
grouped_segments.append(('text_group', current_group, current_speaker))
current_group = [seg_content]
current_speaker = speaker_num
# Save last group if any
if current_group:
grouped_segments.append(('text_group', current_group, current_speaker))
# Process grouped segments
for seg_type, seg_content, speaker_num in grouped_segments:
if seg_type == 'pause':
# Generate silence
duration_ms = seg_content
logger.info(f"Adding {duration_ms}ms pause")
silence_audio = self._generate_silence(duration_ms, sample_rate)
all_audio_segments.append(silence_audio)
else:
# Process text group for a speaker
combined_text = ' '.join(seg_content)
formatted_text = f"Speaker {speaker_num - 1}: {combined_text}"
# Get voice sample for this speaker
speaker_idx = speakers_in_text.index(speaker_num)
speaker_voice_samples = [voice_samples[speaker_idx]]
logger.info(f"Generating audio for Speaker {speaker_num}: {len(combined_text.split())} words")
# Generate audio for this speaker's text
segment_audio = self._generate_with_vibevoice(
formatted_text, speaker_voice_samples, cfg_scale, seed,
diffusion_steps, use_sampling, temperature, top_p
)
all_audio_segments.append(segment_audio)
# Concatenate all audio segments
if all_audio_segments:
logger.info(f"Concatenating {len(all_audio_segments)} audio segments (including pauses)...")
# Extract waveforms
waveforms = []
for audio_segment in all_audio_segments:
if isinstance(audio_segment, dict) and "waveform" in audio_segment:
waveforms.append(audio_segment["waveform"])
if waveforms:
# Filter out None values if any
valid_waveforms = [w for w in waveforms if w is not None]
if valid_waveforms:
# Concatenate along time dimension
combined_waveform = torch.cat(valid_waveforms, dim=-1)
audio_dict = {
"waveform": combined_waveform,
"sample_rate": sample_rate
}
logger.info(f"Successfully generated multi-speaker audio with pauses")
else:
raise Exception("No valid audio waveforms generated")
else:
raise Exception("Failed to extract waveforms from audio segments")
else:
raise Exception("No audio segments generated")
else:
# Fallback to original method without pause support
logger.info("Processing without pause support (no pause keywords found)")
audio_dict = self._generate_with_vibevoice(
converted_text, voice_samples, cfg_scale, seed, diffusion_steps,
use_sampling, temperature, top_p
)
# Free memory if requested
if free_memory_after_generate:
self.free_memory()
return (audio_dict,)
except Exception as e:
# Check if this is an interruption by the user
import comfy.model_management as mm
if isinstance(e, mm.InterruptProcessingException):
# User interrupted - just log it and re-raise to stop the workflow
logger.info("Generation interrupted by user")
raise # Propagate the interruption to stop the workflow
else:
# Real error - show it
logger.error(f"Multi-speaker speech generation failed: {str(e)}")
raise Exception(f"Error generating multi-speaker speech: {str(e)}")
@classmethod
def IS_CHANGED(cls, text="", model="VibeVoice-7B-Preview",
speaker1_voice=None, speaker2_voice=None,
speaker3_voice=None, speaker4_voice=None, **kwargs):
"""Cache key for ComfyUI"""
voices_hash = hash(str([speaker1_voice, speaker2_voice, speaker3_voice, speaker4_voice]))
return f"{hash(text)}_{model}_{voices_hash}_{kwargs.get('cfg_scale', 1.3)}_{kwargs.get('seed', 0)}"