Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import time | |
| from typing import List, Tuple | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| # import spaces | |
| from dataclasses import dataclass | |
| from generator import Segment, load_csm_1b | |
| from huggingface_hub import login | |
| # Disable torch compile feature to avoid triton error | |
| torch._dynamo.config.suppress_errors = True | |
| # Check if GPU is available and configure the device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Login to Hugging Face Hub if token is available | |
| def login_huggingface(): | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| print("Logging in to Hugging Face Hub...") | |
| login(token=hf_token) | |
| print("Login successful!") | |
| else: | |
| print("HF_TOKEN not found in environment variables. Some models may not be accessible.") | |
| # Login at startup | |
| login_huggingface() | |
| # Global variables to track model state | |
| generator = None | |
| model_loaded = False | |
| # Function to load model in ZeroGPU | |
| # @spaces.GPU(duration=30) | |
| def initialize_model(): | |
| global generator, model_loaded | |
| if not model_loaded: | |
| print("Loading CSM-1B model in GPU...") | |
| generator = load_csm_1b(device="cuda") | |
| model_loaded = True | |
| print("Model loaded successfully!") | |
| return generator | |
| # Function to get the loaded model | |
| # @spaces.GPU(duration=30) | |
| def get_model(): | |
| global generator, model_loaded | |
| if not model_loaded: | |
| return initialize_model() | |
| return generator | |
| # Preload model if environment variable is set | |
| def preload_model_if_needed(): | |
| if os.environ.get("PRELOAD_MODEL", "").lower() in ("true", "1", "yes"): | |
| print("PRELOAD_MODEL is set. Attempting to preload model...") | |
| try: | |
| # We can't directly call initialize_model() here because it's decorated with @spaces.GPU | |
| # Instead, we'll set a flag that will be checked when the first request comes in | |
| global model_loaded | |
| model_loaded = False | |
| print("Model will be loaded on first request.") | |
| except Exception as e: | |
| print(f"Error during model preloading setup: {e}") | |
| else: | |
| print("PRELOAD_MODEL is not set. Model will be loaded on demand.") | |
| # Call preload function at startup | |
| preload_model_if_needed() | |
| # Function to convert audio to tensor | |
| def audio_to_tensor(audio_path: str) -> Tuple[torch.Tensor, int]: | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| waveform = waveform.mean(dim=0) # Convert stereo to mono if needed | |
| return waveform, sample_rate | |
| # Function to save audio tensor to file | |
| def save_audio(audio_tensor: torch.Tensor, sample_rate: int) -> str: | |
| # Lưu file vào thư mục hiện tại hoặc thư mục files mà Gradio mặc định sử dụng | |
| output_path = f"csm1b_output_{int(time.time())}.wav" | |
| torchaudio.save(output_path, audio_tensor.unsqueeze(0), sample_rate) | |
| return output_path | |
| # Function to generate speech from text using ZeroGPU | |
| # @spaces.GPU(duration=30) | |
| def generate_speech( | |
| text: str, | |
| speaker_id: int, | |
| context_audio_path1: str = None, | |
| context_text1: str = None, | |
| context_speaker1: int = 0, | |
| context_audio_path2: str = None, | |
| context_text2: str = None, | |
| context_speaker2: int = 1, | |
| max_duration_ms: float = 30000, | |
| temperature: float = 0.9, | |
| top_k: int = 50, | |
| progress=gr.Progress() | |
| ) -> str: | |
| try: | |
| # Get the loaded model | |
| generator = get_model() | |
| # Prepare context | |
| context = [] | |
| progress(0.1, "Processing context...") | |
| # Process context 1 | |
| if context_audio_path1 and context_text1: | |
| waveform, sample_rate = audio_to_tensor(context_audio_path1) | |
| # Resample if needed | |
| if sample_rate != generator.sample_rate: | |
| waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=generator.sample_rate) | |
| context.append(Segment(speaker=context_speaker1, text=context_text1, audio=waveform)) | |
| # Process context 2 | |
| if context_audio_path2 and context_text2: | |
| waveform, sample_rate = audio_to_tensor(context_audio_path2) | |
| # Resample if needed | |
| if sample_rate != generator.sample_rate: | |
| waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=generator.sample_rate) | |
| context.append(Segment(speaker=context_speaker2, text=context_text2, audio=waveform)) | |
| progress(0.3, "Generating audio...") | |
| # Generate audio from text | |
| audio = generator.generate( | |
| text=text, | |
| speaker=speaker_id, | |
| context=context, | |
| max_audio_length_ms=max_duration_ms, | |
| # temperature=temperature, | |
| # topk=top_k | |
| ) | |
| progress(0.8, "Saving audio...") | |
| # Save audio to file | |
| # output_path = save_audio(audio, generator.sample_rate) | |
| output_path = f"csm1b_output_{int(time.time())}.wav" | |
| progress(1.0, "Completed!") | |
| return output_path | |
| except Exception as e: | |
| # Handle ZeroGPU quota exceeded error | |
| error_message = str(e) | |
| if "GPU quota exceeded" in error_message: | |
| # Extract wait time from error message | |
| import re | |
| wait_time_match = re.search(r"Try again in (\d+:\d+:\d+)", error_message) | |
| wait_time = wait_time_match.group(1) if wait_time_match else "some time" | |
| return f"GPU quota exceeded. Please try again in {wait_time}." | |
| return f"GPU error: {error_message}" | |
| except Exception as e: | |
| return f"Error generating speech: {str(e)}" | |
| # Function to generate simple speech without context | |
| # @spaces.GPU(duration=30) | |
| def generate_speech_simple( | |
| text: str, | |
| speaker_id: int, | |
| max_duration_ms: float = 30000, | |
| temperature: float = 0.9, | |
| top_k: int = 50, | |
| progress=gr.Progress() | |
| ) -> str: | |
| try: | |
| # Get the loaded model | |
| generator = get_model() | |
| progress(0.3, "Generating audio...") | |
| # Generate audio from text | |
| audio = generator.generate( | |
| text=text, | |
| speaker=speaker_id, | |
| context=[], # No context | |
| max_audio_length_ms=max_duration_ms, | |
| # temperature=temperature, | |
| # topk=top_k | |
| ) | |
| progress(0.8, "Saving audio...") | |
| # Save audio to file | |
| # output_path = save_audio(audio, generator.sample_rate) | |
| output_path = f"csm1b_output_{int(time.time())}.wav" | |
| torchaudio.save(output_path, audio.unsqueeze(0).cpu(), generator.sample_rate) | |
| print(f"Audio saved to {output_path}") | |
| progress(1.0, "Completed!") | |
| return output_path | |
| except Exception as e: | |
| # Handle ZeroGPU quota exceeded error | |
| error_message = str(e) | |
| if "GPU quota exceeded" in error_message: | |
| # Extract wait time from error message | |
| import re | |
| wait_time_match = re.search(r"Try again in (\d+:\d+:\d+)", error_message) | |
| wait_time = wait_time_match.group(1) if wait_time_match else "some time" | |
| return f"GPU quota exceeded. Please try again in {wait_time}." | |
| return f"GPU error: {error_message}" | |
| except Exception as e: | |
| return f"Error generating speech: {str(e)}" | |
| # Create Gradio interface | |
| def create_demo(): | |
| with gr.Blocks(title="CSM-1B Text-to-Speech") as demo: | |
| gr.Markdown("# CSM-1B Text-to-Speech Demo") | |
| gr.Markdown("CSM-1B (Collaborative Speech Model) is an advanced text-to-speech model capable of generating natural-sounding speech from text.") | |
| with gr.Tab("Simple Audio Generation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Text to convert to speech", | |
| placeholder="Enter the text you want to convert to speech...", | |
| lines=5 | |
| ) | |
| speaker_id = gr.Number( | |
| label="Speaker ID", | |
| value=0, | |
| precision=0, | |
| minimum=0, | |
| maximum=10 | |
| ) | |
| with gr.Row(): | |
| max_duration = gr.Slider( | |
| label="Maximum Duration (ms)", | |
| minimum=1000, | |
| maximum=90000, | |
| value=30000, | |
| step=1000 | |
| ) | |
| # temperature = gr.Slider( | |
| # label="Temperature", | |
| # minimum=0.1, | |
| # maximum=1.5, | |
| # value=0.9, | |
| # step=0.1 | |
| # ) | |
| # top_k = gr.Slider( | |
| # label="Top-K", | |
| # minimum=1, | |
| # maximum=100, | |
| # value=50, | |
| # step=1 | |
| # ) | |
| generate_btn = gr.Button("Generate Audio") | |
| with gr.Column(): | |
| output_audio = gr.Audio(label="Output Audio", type="filepath", autoplay=True) | |
| with gr.Tab("Audio Generation with Context"): | |
| gr.Markdown("This feature allows you to provide audio clips and text as context to help the model generate more appropriate speech.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| context_text1 = gr.Textbox(label="Context Text 1", lines=2) | |
| context_audio1 = gr.Audio(label="Context Audio 1", type="filepath") | |
| context_speaker1 = gr.Number(label="Speaker ID 1", value=0, precision=0) | |
| context_text2 = gr.Textbox(label="Context Text 2", lines=2) | |
| context_audio2 = gr.Audio(label="Context Audio 2", type="filepath") | |
| context_speaker2 = gr.Number(label="Speaker ID 2", value=1, precision=0) | |
| text_input_context = gr.Textbox( | |
| label="Text to convert to speech", | |
| placeholder="Enter the text you want to convert to speech...", | |
| lines=3 | |
| ) | |
| speaker_id_context = gr.Number( | |
| label="Speaker ID", | |
| value=0, | |
| precision=0 | |
| ) | |
| with gr.Row(): | |
| max_duration_context = gr.Slider( | |
| label="Maximum Duration (ms)", | |
| minimum=1000, | |
| maximum=90000, | |
| value=30000, | |
| step=1000 | |
| ) | |
| # temperature_context = gr.Slider( | |
| # label="Temperature", | |
| # minimum=0.1, | |
| # maximum=1.5, | |
| # value=0.9, | |
| # step=0.1 | |
| # ) | |
| # top_k_context = gr.Slider( | |
| # label="Top-K", | |
| # minimum=1, | |
| # maximum=100, | |
| # value=50, | |
| # step=1 | |
| # ) | |
| generate_context_btn = gr.Button("Generate Audio with Context") | |
| with gr.Column(): | |
| output_audio_context = gr.Audio(label="Output Audio", type="filepath", autoplay=True) | |
| # Add Hugging Face configuration tab | |
| with gr.Tab("Configuration"): | |
| gr.Markdown("### Hugging Face Token Configuration") | |
| gr.Markdown(""" | |
| To use the CSM-1B model, you need access to the model on Hugging Face. | |
| You can configure your token by: | |
| 1. Create a token at [Hugging Face Settings](https://huggingface.co/settings/tokens) | |
| 2. Set the `HF_TOKEN` environment variable with your token value | |
| Note: In Hugging Face Spaces, you can set environment variables in the Space Settings. | |
| """) | |
| hf_token_input = gr.Textbox( | |
| label="Hugging Face Token (Only for this session)", | |
| placeholder="Enter your token...", | |
| type="password" | |
| ) | |
| def set_token(token): | |
| if token: | |
| os.environ["HF_TOKEN"] = token | |
| login(token=token) | |
| return "Token set successfully! You can now load the model." | |
| return "Invalid token. Please enter a valid token." | |
| set_token_btn = gr.Button("Set Token") | |
| token_status = gr.Textbox(label="Status", interactive=False) | |
| set_token_btn.click(fn=set_token, inputs=hf_token_input, outputs=token_status) | |
| # Add GPU information tab | |
| with gr.Tab("GPU Information"): | |
| gr.Markdown("### About ZeroGPU") | |
| gr.Markdown(""" | |
| This application uses Hugging Face Spaces' ZeroGPU to optimize GPU usage. | |
| ZeroGPU helps free up GPU memory when not in use, saving resources and improving performance. | |
| When you generate audio, the GPU will be used automatically and released after completion. | |
| Note: In the ZeroGPU environment, CUDA is not initialized in the main process, but only in functions with the @spaces.GPU decorator. | |
| """) | |
| gr.Markdown("### GPU Quota Information") | |
| gr.Markdown(""" | |
| Hugging Face Spaces has GPU quota limitations: | |
| - Each GPU operation has a default duration of 60 seconds | |
| - We've reduced this to 30 seconds for audio generation and 10 seconds for GPU checks | |
| - If you exceed your quota, you'll need to wait for it to reset (usually a few hours) | |
| - For better performance, try generating shorter audio clips | |
| If you encounter a "GPU quota exceeded" error, please wait for the specified time and try again. | |
| """) | |
| # @spaces.GPU(duration=10) | |
| def check_gpu(): | |
| if torch.cuda.is_available(): | |
| gpu_name = torch.cuda.get_device_name(0) | |
| gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| return f"GPU: {gpu_name}\nMemory: {gpu_memory:.2f} GB" | |
| else: | |
| return "No GPU found. The application will run on CPU." | |
| check_gpu_btn = gr.Button("Check GPU") | |
| gpu_info = gr.Textbox(label="GPU Information", interactive=False) | |
| check_gpu_btn.click(fn=check_gpu, inputs=None, outputs=gpu_info) | |
| # Add model loading button | |
| load_model_btn = gr.Button("Load Model") | |
| model_status = gr.Textbox(label="Model Status", interactive=False) | |
| # @spaces.GPU(duration=10) | |
| def load_model_and_report(): | |
| global model_loaded | |
| if model_loaded: | |
| return "Model has already been loaded!" | |
| else: | |
| initialize_model() | |
| return "Model loaded successfully!" | |
| load_model_btn.click(fn=load_model_and_report, inputs=None, outputs=model_status) | |
| # Connect components | |
| generate_btn.click( | |
| fn=generate_speech_simple, | |
| inputs=[ | |
| text_input, | |
| speaker_id, | |
| max_duration, | |
| # temperature, | |
| # top_k | |
| ], | |
| outputs=output_audio | |
| ) | |
| generate_context_btn.click( | |
| fn=generate_speech, | |
| inputs=[ | |
| text_input_context, | |
| speaker_id_context, | |
| context_audio1, | |
| context_text1, | |
| context_speaker1, | |
| context_audio2, | |
| context_text2, | |
| context_speaker2, | |
| max_duration_context, | |
| # temperature_context, | |
| # top_k_context | |
| ], | |
| outputs=output_audio_context | |
| ) | |
| return demo | |
| # Launch the application | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.queue().launch() |