import gradio as gr import spaces import torch from pydub import AudioSegment import numpy as np import io from scipy.io import wavfile from colpali_engine.models import ColQwen2_5Omni, ColQwen2_5OmniProcessor from transformers.utils.import_utils import is_flash_attn_2_available import base64 from scipy.io.wavfile import write import os # Global model variables model = None processor = None def load_model(): """Load model and processor once""" global model, processor if model is None: model = ColQwen2_5Omni.from_pretrained( "vidore/colqwen-omni-v0.1", torch_dtype=torch.bfloat16, device_map="cpu", # Start on CPU for ZeroGPU attn_implementation="eager" # ZeroGPU compatible ).eval() processor = ColQwen2_5OmniProcessor.from_pretrained("manu/colqwen-omni-v0.1") return model, processor def chunk_audio(audio_file, chunk_length=30): """Split audio into chunks""" audio = AudioSegment.from_file(audio_file.name) audios = [] target_rate = 16000 chunk_length_ms = chunk_length * 1000 for i in range(0, len(audio), chunk_length_ms): chunk = audio[i:i + chunk_length_ms] chunk = chunk.set_channels(1).set_frame_rate(target_rate) buf = io.BytesIO() chunk.export(buf, format="wav") buf.seek(0) rate, data = wavfile.read(buf) audios.append(data) return audios @spaces.GPU(duration=120) def embed_audio_chunks(audios): """Embed audio chunks using GPU""" model, processor = load_model() model = model.to('cuda') # Process in batches from torch.utils.data import DataLoader dataloader = DataLoader( dataset=audios, batch_size=4, shuffle=False, collate_fn=lambda x: processor.process_audios(x) ) embeddings = [] for batch_doc in dataloader: with torch.no_grad(): batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()} embeddings_doc = model(**batch_doc) embeddings.extend(list(torch.unbind(embeddings_doc.to("cpu")))) # Move model back to CPU to free GPU memory model = model.to('cpu') torch.cuda.empty_cache() return embeddings @spaces.GPU(duration=60) def search_audio(query, embeddings, audios, top_k=5): """Search for relevant audio chunks""" model, processor = load_model() model = model.to('cuda') # Process query batch_queries = processor.process_queries([query]).to(model.device) with torch.no_grad(): query_embeddings = model(**batch_queries) # Score against all embeddings scores = processor.score_multi_vector(query_embeddings, embeddings) top_indices = scores[0].topk(top_k).indices.tolist() # Move model back to CPU model = model.to('cpu') torch.cuda.empty_cache() return top_indices def audio_to_base64(data, rate=16000): """Convert audio data to base64""" buf = io.BytesIO() write(buf, rate, data) buf.seek(0) encoded_string = base64.b64encode(buf.read()).decode("utf-8") return encoded_string def process_audio_rag(audio_file, query, chunk_length=30, use_openai=False, openai_key=None): """Main processing function""" if not audio_file: return "Please upload an audio file", None, None # Chunk audio audios = chunk_audio(audio_file, chunk_length) # Embed chunks embeddings = embed_audio_chunks(audios) # Search for relevant chunks top_indices = search_audio(query, embeddings, audios) # Prepare results result_text = f"Found {len(top_indices)} relevant audio chunks:\n" result_text += f"Chunk indices: {top_indices}\n\n" # Save first result as audio file first_chunk_path = "result_chunk.wav" wavfile.write(first_chunk_path, 16000, audios[top_indices[0]]) # Optional: Use OpenAI for answer generation if use_openai and openai_key: from openai import OpenAI client = OpenAI(api_key=openai_key) content = [{"type": "text", "text": f"Answer the query using the audio files. Query: {query}"}] for idx in top_indices[:3]: # Use top 3 chunks content.extend([ {"type": "text", "text": f"Audio chunk #{idx}:"}, { "type": "input_audio", "input_audio": { "data": audio_to_base64(audios[idx]), "format": "wav" } } ]) try: completion = client.chat.completions.create( model="gpt-4o-audio-preview", messages=[{"role": "user", "content": content}] ) result_text += f"\nOpenAI Answer: {completion.choices[0].message.content}" except Exception as e: result_text += f"\nOpenAI Error: {str(e)}" # Create audio visualization import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=(10, 4)) ax.plot(audios[top_indices[0]]) ax.set_title(f"Waveform of top matching chunk (#{top_indices[0]})") ax.set_xlabel("Samples") ax.set_ylabel("Amplitude") plt.tight_layout() return result_text, first_chunk_path, fig # Create Gradio interface with gr.Blocks(title="AudioRAG Demo") as demo: gr.Markdown("# AudioRAG Demo - Semantic Audio Search") gr.Markdown("Upload an audio file and search through it using natural language queries!") with gr.Row(): with gr.Column(): audio_input = gr.Audio(label="Upload Audio File", type="filepath") query_input = gr.Textbox(label="Search Query", placeholder="What are you looking for in the audio?") chunk_length = gr.Slider(minimum=10, maximum=60, value=30, step=5, label="Chunk Length (seconds)") with gr.Accordion("OpenAI Integration (Optional)", open=False): use_openai = gr.Checkbox(label="Use OpenAI for answer generation") openai_key = gr.Textbox(label="OpenAI API Key", type="password") search_btn = gr.Button("Search Audio", variant="primary") with gr.Column(): output_text = gr.Textbox(label="Results", lines=10) output_audio = gr.Audio(label="Top Matching Audio Chunk", type="filepath") output_plot = gr.Plot(label="Audio Waveform") search_btn.click( fn=process_audio_rag, inputs=[audio_input, query_input, chunk_length, use_openai, openai_key], outputs=[output_text, output_audio, output_plot] ) gr.Examples( examples=[ ["example_audio.wav", "Was Hannibal well liked by his men?", 30], ["podcast.mp3", "What did they say about climate change?", 20], ], inputs=[audio_input, query_input, chunk_length] ) if __name__ == "__main__": # Load model on startup load_model() demo.launch()