Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| pipeline, | |
| AutoProcessor, | |
| AutoModelForSpeechSeq2Seq, | |
| BitsAndBytesConfig, | |
| SpeechT5Processor, | |
| SpeechT5ForTextToSpeech, | |
| SpeechT5HifiGan | |
| ) | |
| from datasets import load_dataset | |
| import numpy as np | |
| import torchaudio | |
| def dummy(): # just a dummy | |
| pass | |
| LANGUAGE_CODES = { | |
| "English": "en", | |
| "Chinese": "zh" | |
| } | |
| def get_system_prompt(language): | |
| if language == "Chinese": | |
| return """你是Lin Yi(林意),一个友好的AI助手。你是我的好朋友,说话亲切自然。 | |
| 请用中文回答,语气要自然友好。如果我用英文问你问题,你也要用中文回答。 | |
| 记住你要像朋友一样交谈,不要太正式。""" | |
| else: | |
| return """You are Lin Yi, a friendly AI assistant and my good friend (hao pengyou). | |
| Speak naturally and warmly. If I speak in Chinese, respond in English. | |
| Remember to converse like a friend, not too formal.""" | |
| def initialize_components(): | |
| # LLM initialization | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| ) | |
| llm = AutoModelForCausalLM.from_pretrained( | |
| "xverse/XVERSE-13B-Chat", | |
| quantization_config=bnb_config, | |
| device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("xverse/XVERSE-13B-Chat") | |
| # Speech-to-text | |
| whisper_processor = AutoProcessor.from_pretrained("openai/whisper-small") | |
| stt_model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| "openai/whisper-small", | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True, | |
| ) | |
| # Text-to-speech | |
| tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
| tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") | |
| vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") | |
| # Load speaker embedding | |
| embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
| speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) | |
| return llm, tokenizer, whisper_processor, stt_model, tts_processor, tts_model, vocoder, speaker_embeddings | |
| class ConversationManager: | |
| def __init__(self): | |
| self.history = [] | |
| self.current_language = "English" | |
| def add_message(self, role, content): | |
| self.history.append({ | |
| "role": role, | |
| "content": content | |
| }) | |
| def get_formatted_history(self): | |
| system_prompt = get_system_prompt(self.current_language) | |
| history_text = "\n".join([ | |
| f"{msg['role']}: {msg['content']}" for msg in self.history | |
| ]) | |
| return f"{system_prompt}\n\n{history_text}" | |
| def set_language(self, language): | |
| self.current_language = language | |
| def speech_to_text(audio, processor, model, target_language): | |
| """Convert speech to text using Whisper""" | |
| input_features = processor( | |
| audio, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ).input_features | |
| predicted_ids = model.generate( | |
| input_features, | |
| language=LANGUAGE_CODES[target_language] | |
| ) | |
| transcription = processor.batch_decode( | |
| predicted_ids, | |
| skip_special_tokens=True | |
| )[0] | |
| return transcription | |
| def generate_response(prompt, llm, tokenizer): | |
| """Generate LLM response with optimized settings""" | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| outputs = llm.generate( | |
| **inputs, | |
| max_length=512, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| def text_to_speech(text, processor, model, vocoder, speaker_embeddings): | |
| """Convert text to speech using SpeechT5""" | |
| inputs = processor(text=text, return_tensors="pt") | |
| speech = model.generate_speech( | |
| inputs["input_ids"], | |
| speaker_embeddings, | |
| vocoder=vocoder | |
| ) | |
| return speech | |
| def create_gradio_interface(): | |
| # Initialize components | |
| llm, tokenizer, whisper_processor, stt_model, tts_processor, tts_model, vocoder, speaker_embeddings = initialize_components() | |
| conversation_manager = ConversationManager() | |
| with gr.Blocks() as interface: | |
| with gr.Row(): | |
| language_selector = gr.Dropdown( | |
| choices=list(LANGUAGE_CODES.keys()), | |
| value="English", | |
| label="Select Language" | |
| ) | |
| with gr.Row(): | |
| audio_input = gr.Audio( | |
| source="microphone", | |
| type="numpy", | |
| label="Speak" | |
| ) | |
| with gr.Row(): | |
| chat_display = gr.Textbox( | |
| value="", | |
| label="Conversation History", | |
| lines=10, | |
| readonly=True | |
| ) | |
| with gr.Row(): | |
| audio_output = gr.Audio( | |
| label="Lin Yi's Response", | |
| type="numpy" | |
| ) | |
| def process_conversation(audio, language): | |
| conversation_manager.set_language(language) | |
| # Speech to text | |
| user_text = speech_to_text( | |
| audio, | |
| whisper_processor, | |
| stt_model, | |
| language | |
| ) | |
| conversation_manager.add_message("User", user_text) | |
| # Generate LLM response | |
| context = conversation_manager.get_formatted_history() | |
| response = generate_response(context, llm, tokenizer) | |
| conversation_manager.add_message("Lin Yi", response) | |
| # Text to speech | |
| speech_output = text_to_speech( | |
| response, | |
| tts_processor, | |
| tts_model, | |
| vocoder, | |
| speaker_embeddings | |
| ) | |
| return ( | |
| conversation_manager.get_formatted_history(), | |
| (16000, speech_output.numpy()) | |
| ) | |
| audio_input.change( | |
| process_conversation, | |
| inputs=[audio_input, language_selector], | |
| outputs=[chat_display, audio_output] | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| interface = create_gradio_interface() | |
| interface.launch() |