cindyangelira commited on
Commit
e3eed3f
·
verified ·
1 Parent(s): 3bc158b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -0
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import (
5
+ AutoModelForCausalLM,
6
+ AutoTokenizer,
7
+ pipeline,
8
+ AutoProcessor,
9
+ AutoModelForSpeechSeq2Seq,
10
+ BitsAndBytesConfig
11
+ )
12
+ from datasets import load_dataset
13
+ import numpy as np
14
+ from transformers import AutoModelForTextToSpeech, SpeechT5HifiGan
15
+ import torchaudio
16
+
17
+ @spaces.GPU
18
+ def dummy(): # just a dummy
19
+ pass
20
+
21
+ # Constants
22
+ # DEVICE = "cpu"
23
+ LANGUAGE_CODES = {
24
+ "English": "en",
25
+ "Chinese": "zh"
26
+ }
27
+
28
+ # Initialize components with efficient settings
29
+ def initialize_components():
30
+ # Use XVERSE-13B-Chat as the base model - good multilingual support and reasonable size
31
+ # Load in 4-bit quantization to reduce memory usage
32
+ bnb_config = BitsAndBytesConfig(
33
+ load_in_4bit=True,
34
+ bnb_4bit_quant_type="nf4",
35
+ bnb_4bit_compute_dtype=torch.float16,
36
+ )
37
+
38
+ llm = AutoModelForCausalLM.from_pretrained(
39
+ "xverse/XVERSE-13B-Chat",
40
+ quantization_config=bnb_config,
41
+ device_map="auto"
42
+ )
43
+ tokenizer = AutoTokenizer.from_pretrained("xverse/XVERSE-13B-Chat")
44
+
45
+ # Whisper model for STT (small for efficiency)
46
+ processor = AutoProcessor.from_pretrained("openai/whisper-small")
47
+ stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
48
+ "openai/whisper-small",
49
+ torch_dtype=torch.float32,
50
+ low_cpu_mem_usage=True,
51
+ )
52
+
53
+ # VITS for TTS (supports both English and Chinese)
54
+ tts_model = load_model("facebook/mms-tts-eng)
55
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
56
+
57
+ return llm, tokenizer, processor, stt_model, tts_model, vocoder
58
+
59
+ def load_model(model_name):
60
+ """Helper function to load models with optimized settings"""
61
+ return AutoModelForTextToSpeech.from_pretrained(
62
+ model_name,
63
+ torch_dtype=torch.float32,
64
+ low_cpu_mem_usage=True,
65
+ )
66
+
67
+ class ConversationManager:
68
+ def __init__(self):
69
+ self.history = []
70
+
71
+ def add_message(self, role, content, audio_path=None):
72
+ self.history.append({
73
+ "role": role,
74
+ "content": content,
75
+ "audio_path": audio_path
76
+ })
77
+
78
+ def get_formatted_history(self):
79
+ return "\n".join([
80
+ f"{msg['role']}: {msg['content']}" for msg in self.history
81
+ ])
82
+
83
+ def speech_to_text(audio, processor, model, target_language):
84
+ """Convert speech to text using Whisper"""
85
+ input_features = processor(
86
+ audio,
87
+ sampling_rate=16000,
88
+ return_tensors="pt"
89
+ ).input_features
90
+
91
+ predicted_ids = model.generate(
92
+ input_features,
93
+ language=LANGUAGE_CODES[target_language]
94
+ )
95
+
96
+ transcription = processor.batch_decode(
97
+ predicted_ids,
98
+ skip_special_tokens=True
99
+ )[0]
100
+ return transcription
101
+
102
+ def generate_response(prompt, llm, tokenizer):
103
+ """Generate LLM response with optimized settings"""
104
+ inputs = tokenizer(prompt, return_tensors="pt")
105
+ outputs = llm.generate(
106
+ **inputs,
107
+ max_length=512,
108
+ num_return_sequences=1,
109
+ temperature=0.7,
110
+ do_sample=True,
111
+ pad_token_id=tokenizer.eos_token_id
112
+ )
113
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
114
+ return response
115
+
116
+ def text_to_speech(text, model, vocoder, language):
117
+ """Convert text to speech using MMS-TTS"""
118
+ inputs = processor(text, return_tensors="pt")
119
+ speech = model.generate_speech(inputs["input_ids"], vocoder)
120
+ return speech
121
+
122
+ def create_gradio_interface():
123
+ # Initialize components
124
+ llm, tokenizer, processor, stt_model, tts_model, vocoder = initialize_components()
125
+ conversation_manager = ConversationManager()
126
+
127
+ with gr.Blocks() as interface:
128
+ with gr.Row():
129
+ language_selector = gr.Dropdown(
130
+ choices=list(LANGUAGE_CODES.keys()),
131
+ value="English",
132
+ label="Select Language"
133
+ )
134
+
135
+ with gr.Row():
136
+ # Audio input
137
+ audio_input = gr.Audio(
138
+ source="microphone",
139
+ type="numpy",
140
+ label="Speak"
141
+ )
142
+
143
+ with gr.Row():
144
+ # Chat history display
145
+ chat_display = gr.Textbox(
146
+ value="",
147
+ label="Conversation History",
148
+ lines=10,
149
+ readonly=True
150
+ )
151
+
152
+ with gr.Row():
153
+ # Assistant's audio response
154
+ audio_output = gr.Audio(
155
+ label="Assistant's Response",
156
+ type="numpy"
157
+ )
158
+
159
+ def process_conversation(audio, language):
160
+ # Speech to text
161
+ user_text = speech_to_text(
162
+ audio,
163
+ processor,
164
+ stt_model,
165
+ language
166
+ )
167
+ conversation_manager.add_message("User", user_text)
168
+
169
+ # Generate LLM response
170
+ context = conversation_manager.get_formatted_history()
171
+ response = generate_response(context, llm, tokenizer)
172
+ conversation_manager.add_message("Assistant", response)
173
+
174
+ # Text to speech
175
+ speech_output = text_to_speech(
176
+ response,
177
+ tts_model,
178
+ vocoder,
179
+ language
180
+ )
181
+
182
+ return (
183
+ conversation_manager.get_formatted_history(),
184
+ (16000, speech_output.numpy())
185
+ )
186
+
187
+ audio_input.change(
188
+ process_conversation,
189
+ inputs=[audio_input, language_selector],
190
+ outputs=[chat_display, audio_output]
191
+ )
192
+
193
+ return interface
194
+
195
+ # Launch the application
196
+ if __name__ == "__main__":
197
+ interface = create_gradio_interface()
198
+ interface.launch()