vegeta786alpha's picture
Update app.py
248cf31 verified
#!/usr/bin/env python3
import os
import sys
import subprocess
import tempfile
import argparse
import time
import re
import gradio as gr
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
# Fix NumPy compatibility before importing torch/torchaudio
import numpy as np
if hasattr(np, '__version__') and np.__version__.startswith('2.'):
# Downgrade numpy if version 2.x is detected
subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy<2.0.0", "--force-reinstall"])
# Restart the process
os.execv(sys.executable, [sys.executable] + sys.argv)
import torch
import torchaudio
import openai
# Try to load .env file if python-dotenv is available
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass # python-dotenv not installed, continue without it
# Import whisper_at only when needed (for tagging)
whisper = None
# Define known vocal and instrumental tags for classification
VOCAL_TAGS = {
"Singing", "Speech", "Choir", "Female singing", "Male singing",
"Chant", "Yodeling", "Shout", "Bellow", "Rapping", "Narration",
"Child singing", "Vocal music", "Opera", "A capella", "Voice",
"Male speech, man speaking", "Female speech, woman speaking",
"Child speech, kid speaking", "Conversation", "Narration, monologue",
"Babbling", "Speech synthesizer", "Whoop", "Yell", "Battle cry",
"Children shouting", "Screaming", "Whispering", "Mantra",
"Synthetic singing", "Humming", "Whistling", "Beatboxing",
"Gospel music", "Lullaby", "Groan", "Grunt"
}
# Definitive speech tags that guarantee vocal classification
DEFINITIVE_SPEECH_TAGS = {
"Male speech, man speaking", "Female speech, woman speaking",
"Child speech, kid speaking", "Conversation", "Narration, monologue"
}
INSTRUMENTAL_TAGS = {
"Piano", "Electric piano", "Keyboard (musical)", "Synthesizer", "Organ",
"Electronic organ", "Harpsichord", "Guitar", "Bass guitar", "Drums", "Violin",
"Trumpet", "Flute", "Saxophone", "Plucked string instrument", "Electric guitar",
"Acoustic guitar", "Steel guitar, slide guitar", "Banjo", "Sitar", "Mandolin",
"Ukulele", "Hammond organ", "Percussion", "Drum kit", "Drum machine", "Drum",
"Snare drum", "Bass drum", "Timpani", "Tabla", "Cymbal", "Hi-hat", "Tambourine",
"Marimba, xylophone", "Vibraphone", "Brass instrument", "French horn", "Trombone",
"Bowed string instrument", "String section", "Violin, fiddle", "Cello", "Double bass",
"Wind instrument, woodwind instrument", "Clarinet", "Harp", "Harmonica", "Accordion"
}
# Genre tags for fancy music classification
GENRE_TAGS = {
# Main genres
"Pop music", "Rock music", "Jazz", "Classical music", "Electronic music",
"Blues", "Country", "Folk music", "Reggae", "Funk", "Soul music",
"Rhythm and blues", "Gospel music", "Opera", "Hip hop music",
# Electronic subgenres
"House music", "Techno", "Dubstep", "Drum and bass", "Electronica",
"Electronic dance music", "Ambient music", "Trance music",
# Rock subgenres
"Heavy metal", "Punk rock", "Grunge", "Progressive rock", "Rock and roll",
"Psychedelic rock",
# World music
"Music of Latin America", "Salsa music", "Flamenco", "Music of Africa",
"Afrobeat", "Music of Asia", "Carnatic music", "Music of Bollywood",
"Middle Eastern music", "Traditional music",
# Other genres
"Swing music", "Bluegrass", "Ska", "Disco", "New-age music",
"Independent music", "Christian music", "Soundtrack music",
"Theme music", "Video game music", "Dance music", "Wedding music",
"Christmas music", "Music for children",
# Mood/style tags
"Happy music", "Funny music", "Sad music", "Tender music",
"Exciting music", "Angry music", "Scary music",
# Vocal styles
"A capella", "Vocal music", "Choir", "Chant", "Mantra", "Lullaby",
"Beatboxing", "Rapping", "Yodeling"
}
class AudioCaptionGenerator:
def __init__(self, api_key: str):
"""
Initialize the caption generator with OpenAI API key
Args:
api_key (str): Your OpenAI API key
"""
self.client = openai.OpenAI(api_key=api_key)
def create_caption_prompt(self, classification: str, genres: List[str]) -> str:
"""
Create a prompt for OpenAI to generate an audio caption
Args:
classification (str): Type of audio (Speech/Vocal, Song, Instrumental)
genres (List[str]): List of genre tags
Returns:
str: Formatted prompt for OpenAI
"""
genre_list = ", ".join(genres) if genres else "Various"
prompt = f"""Create a descriptive and engaging caption for an audio track with the following characteristics:
Classification: {classification}
Genres: {genre_list}
Please write a caption that:
1. Describes the audio in an engaging way
2. Incorporates the key genres naturally
3. Matches the classification type (vocal/instrumental/speech)
4. Is suitable for social media or music platforms
5. Is concise but descriptive (1-2 sentences)
Caption:"""
return prompt
def generate_caption(self, classification: str, genres: List[str],
model: str = "gpt-3.5-turbo", temperature: float = 0.7) -> Dict:
"""
Generate a caption for the audio based on classification and genres
Args:
classification (str): Audio classification
genres (List[str]): List of detected genres
model (str): OpenAI model to use
temperature (float): Creativity level (0.0 to 1.0)
Returns:
Dict: Contains original data, prompt, and generated caption
"""
try:
# Create prompt
prompt = self.create_caption_prompt(classification, genres)
# Generate caption using OpenAI
response = self.client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a creative music and audio content writer who creates engaging captions for audio tracks."},
{"role": "user", "content": prompt}
],
temperature=temperature,
max_tokens=150
)
caption = response.choices[0].message.content.strip()
return {
"success": True,
"classification": classification,
"genres": genres,
"prompt": prompt,
"caption": caption,
"model_used": model
}
except Exception as e:
return {
"success": False,
"error": str(e),
"classification": classification,
"genres": genres,
"caption": None
}
def load_vad_model():
"""Load Silero VAD model"""
try:
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
trust_repo=True)
get_speech_timestamps = utils[0]
return model, get_speech_timestamps
except Exception as e:
print(f"Error loading VAD model: {e}")
raise
def convert_audio_with_ffmpeg(input_path, output_path):
"""Convert audio file to WAV format using ffmpeg"""
try:
cmd = [
'ffmpeg', '-i', input_path,
'-ar', '16000', # Set sample rate to 16kHz
'-ac', '1', # Convert to mono
'-y', # Overwrite output file
output_path
]
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
return True
except (subprocess.CalledProcessError, FileNotFoundError) as e:
print(f"FFmpeg conversion failed: {e}")
return False
def detect_vocals_vad(file_path, model, get_speech_timestamps):
"""Detect vocals using VAD and return detection results"""
if not os.path.exists(file_path):
return None, f"Error: Audio file '{file_path}' not found."
waveform = None
sample_rate = None
temp_file = None
try:
# Try to load the audio file directly with torchaudio
try:
waveform, sample_rate = torchaudio.load(file_path)
except Exception as e1:
print(f"Direct loading failed: {e1}")
# Try converting with ffmpeg
temp_file = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
temp_file.close()
if convert_audio_with_ffmpeg(file_path, temp_file.name):
try:
waveform, sample_rate = torchaudio.load(temp_file.name)
except Exception as e2:
return None, f"Error loading converted audio: {str(e2)}"
else:
return None, f"Failed to convert audio file: {str(e1)}"
# Convert to mono if stereo
if len(waveform.shape) > 1 and waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Ensure sample rate is 16000 Hz as required by Silero VAD
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
sample_rate = 16000
# Flatten to 1D array if needed
if len(waveform.shape) > 1:
waveform = waveform.squeeze()
# Get speech timestamps using Silero VAD
speech_timestamps = get_speech_timestamps(waveform, model, threshold=0.5, sampling_rate=16000)
# Calculate speech statistics
total_duration_seconds = len(waveform) / sample_rate
speech_duration = sum([t['end'] - t['start'] for t in speech_timestamps]) / sample_rate
speech_percentage = (speech_duration / total_duration_seconds) * 100 if total_duration_seconds > 0 else 0
# Determine if vocal is detected
vocal_detected = len(speech_timestamps) > 0 and speech_percentage > 1.0
return {
'vocal_detected': vocal_detected,
'speech_percentage': round(speech_percentage, 2),
'total_duration': round(total_duration_seconds, 2),
'speech_duration': round(speech_duration, 2)
}, None
except Exception as e:
return None, f"Error processing audio: {str(e)}"
finally:
# Clean up temporary file if created
if temp_file and os.path.exists(temp_file.name):
try:
os.unlink(temp_file.name)
except:
pass
def extract_genre_tags(top_tags):
"""Extract genre tags from detected top tags"""
detected_genres = []
for tag in top_tags:
if tag in GENRE_TAGS:
detected_genres.append(tag)
return detected_genres
def extract_instrumental_tags(top_tags):
"""Extract instrumental tags from detected top tags"""
detected_instruments = []
for tag in top_tags:
if tag in INSTRUMENTAL_TAGS:
detected_instruments.append(tag)
return detected_instruments
def classify_audio_tags(top_tags):
"""Classify audio based on detected tags"""
# Check for definitive speech tags first - if any are present, it's definitely vocal
has_definitive_speech = any(tag in DEFINITIVE_SPEECH_TAGS for tag in top_tags)
if has_definitive_speech:
return "Vocal"
# Regular classification logic as fallback
has_vocal = any(tag in VOCAL_TAGS for tag in top_tags)
has_instrumental = any(tag in INSTRUMENTAL_TAGS for tag in top_tags)
if has_vocal and not has_instrumental:
return "Vocal"
elif has_instrumental and not has_vocal:
return "Instrumental"
elif has_vocal and has_instrumental:
return "Song"
else:
return "Unknown"
def classify_with_tagging(audio_path, model_size="small"):
"""Classify audio using Whisper-AT tagging"""
global whisper
# Import whisper_at when needed
if whisper is None:
try:
import whisper_at as whisper
except ImportError:
return "Error: whisper-at not installed. Please install it first.", [], [], []
audio_tagging_time_resolution = 4.8
try:
model = whisper.load_model(model_size)
result = model.transcribe(audio_path, at_time_res=audio_tagging_time_resolution)
audio_tag_result = whisper.parse_at_label(
result,
language='en',
top_k=15,
p_threshold=-5
)
all_tags_set = set()
tag_freq = defaultdict(int)
for segment in audio_tag_result:
# Update tag set and frequency
for tag, score in segment['audio tags']:
all_tags_set.add(tag)
tag_freq[tag] += 1
# Find top tags (those that appear more than once)
top_tags = [tag for tag, freq in tag_freq.items() if freq > 1]
# Extract genre and instrumental tags
genre_tags = extract_genre_tags(top_tags)
instrumental_tags = extract_instrumental_tags(top_tags)
classification = classify_audio_tags(top_tags)
# Return all detected tags along with classification
return classification, genre_tags, instrumental_tags, top_tags
except Exception as e:
return f"Error in tagging: {str(e)}", [], [], []
def is_vocal_classification(classification):
"""Check if the classification indicates vocal/speech content"""
vocal_keywords = ["vocal", "speech", "song"]
return any(keyword in classification.lower() for keyword in vocal_keywords)
# Global variables for models
vad_model = None
get_speech_timestamps = None
def initialize_models():
"""Initialize VAD model once"""
global vad_model, get_speech_timestamps
if vad_model is None:
try:
print("Loading VAD model...")
vad_model, get_speech_timestamps = load_vad_model()
print("VAD model loaded successfully!")
except Exception as e:
raise Exception(f"Error loading VAD model: {e}")
def process_audio_file(audio_file, model_size, vad_only, openai_key, generate_caption, caption_model):
"""Process audio file and return results"""
try:
# Initialize models if needed
initialize_models()
if audio_file is None:
return "❌ Please upload an audio file.", "", "", "", ""
audio_path = audio_file
# Step 1: VAD Detection
vad_result, vad_error = detect_vocals_vad(audio_path, vad_model, get_speech_timestamps)
if vad_error:
return f"❌ {vad_error}", "", "", "", ""
vad_info = f"""πŸ“Š **VAD Analysis Results:**
β€’ Vocal Detected: {'βœ… Yes' if vad_result['vocal_detected'] else '❌ No'}
β€’ Speech Percentage: {vad_result['speech_percentage']}%
β€’ Total Duration: {vad_result['total_duration']} seconds
β€’ Speech Duration: {vad_result['speech_duration']} seconds"""
# Step 2: Classification Logic
if vad_only:
# VAD-only mode
if vad_result['vocal_detected']:
final_classification = "🎀 Vocal (VAD detected)"
reason = "VAD-only mode - vocals detected"
else:
final_classification = "🎡 Instrumental (No vocals detected by VAD)"
reason = "VAD-only mode - no vocals detected"
detected_genres = []
detected_instruments = []
all_detected_tags = []
else:
# ALWAYS run tagging for detailed classification (whether vocals detected or not)
tag_classification, detected_genres, detected_instruments, all_detected_tags = classify_with_tagging(audio_path, model_size)
# Use VAD as the definitive decision maker for vocal vs instrumental
if vad_result['vocal_detected']:
# VAD detected vocals - use tagging for detailed classification
final_classification = f"🎀 {tag_classification}"
reason = "Vocals detected by VAD, classified using audio tagging"
else:
# VAD detected no vocals - it's instrumental, but we still have tagging data
final_classification = "🎡 Instrumental"
reason = "No vocals detected by VAD (definitive decision), with audio tagging analysis"
# Format classification info
classification_info = f"""🎯 **Final Classification:** {final_classification}
πŸ“ **Reason:** {reason}"""
# Format genre information
if detected_genres:
genre_info = "🎭 **Detected Genres/Styles:**\n"
for i, genre in enumerate(detected_genres, 1):
genre_info += f"β€’ {genre}\n"
else:
genre_info = "🎭 **Detected Genres/Styles:** None detected"
# Format instrumental information - SHOW FOR ALL NON-VOCAL CONTENT
if not is_vocal_classification(final_classification) or "Instrumental" in final_classification:
# Show instrumental tags for instrumental content
if detected_instruments:
instrument_info = "🎹 **Detected Instruments:**\n"
for i, instrument in enumerate(detected_instruments, 1):
instrument_info += f"β€’ {instrument}\n"
else:
instrument_info = "🎹 **Detected Instruments:** None detected"
else:
# For vocal content, still show if there are instruments detected
if detected_instruments:
instrument_info = "🎹 **Detected Instruments (accompanying):**\n"
for i, instrument in enumerate(detected_instruments, 1):
instrument_info += f"β€’ {instrument}\n"
else:
instrument_info = ""
# Generate caption if requested
caption_info = ""
if generate_caption and openai_key:
try:
caption_generator = AudioCaptionGenerator(openai_key)
caption_result = caption_generator.generate_caption(
final_classification.replace('🎀 ', '').replace('🎡 ', ''),
detected_genres,
model=caption_model
)
if caption_result["success"]:
caption_info = f"""✨ **Generated Caption:**
"{caption_result["caption"]}"
"""
else:
caption_info = f"❌ **Caption generation failed:** {caption_result['error']}"
except Exception as e:
caption_info = f"❌ **Error generating caption:** {str(e)}"
elif generate_caption and not openai_key:
caption_info = "⚠️ **Caption generation skipped:** No OpenAI API key provided"
return vad_info, classification_info, genre_info, instrument_info, caption_info
except Exception as e:
return f"❌ **Error processing audio:** {str(e)}", "", "", "", ""
def create_gradio_interface():
"""Create and configure the Gradio interface"""
# Custom CSS for better styling
css = """
.gradio-container {
max-width: 1200px !important;
margin: auto !important;
}
.result-box {
border: 1px solid #333333;
border-radius: 8px;
padding: 15px;
margin: 10px 0;
background-color: #000000;
color: #ffffff;
}
"""
with gr.Blocks(css=css, title="🎡 Audio Classification & Caption Generator") as demo:
gr.Markdown("""
# 🎡 Audio Classification & Caption Generator
Upload an audio file to analyze its content and generate captions. This tool uses:
- **Silero VAD** for voice activity detection
- **Whisper-AT** for detailed audio tagging and classification
- **OpenAI GPT** for intelligent caption generation
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## πŸ“ Upload & Settings")
audio_input = gr.Audio(
label="Upload Audio File (WAV, MP3, FLAC, M4A, etc.)",
type="filepath"
)
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
model_size = gr.Dropdown(
choices=['tiny', 'base', 'small', 'medium', 'large-v1'],
value='small',
label="Whisper Model Size (larger = more accurate but slower)"
)
vad_only = gr.Checkbox(
label="VAD Only Mode (skip detailed tagging)",
value=False
)
gr.Markdown("## πŸ€– Caption Generation")
generate_caption = gr.Checkbox(
label="Generate AI Caption (requires OpenAI API key)",
value=False # Default to False for public deployment
)
openai_key = gr.Textbox(
label="OpenAI API Key (sk-...)",
type="password",
placeholder="Enter your OpenAI API key for caption generation"
)
caption_model = gr.Dropdown(
choices=['gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo'],
value='gpt-3.5-turbo',
label="Caption Model",
info="OpenAI model for caption generation"
)
process_btn = gr.Button("πŸš€ Analyze Audio", variant="primary", size="lg")
with gr.Column(scale=2):
gr.Markdown("## πŸ“Š Analysis Results")
vad_output = gr.Markdown(
label="VAD Analysis",
elem_classes=["result-box"]
)
classification_output = gr.Markdown(
label="Classification",
elem_classes=["result-box"]
)
genre_output = gr.Markdown(
label="Genres",
elem_classes=["result-box"]
)
instrument_output = gr.Markdown(
label="Detected Instruments",
elem_classes=["result-box"]
)
caption_output = gr.Markdown(
label="Generated Caption",
elem_classes=["result-box"]
)
# Event handlers
process_btn.click(
fn=process_audio_file,
inputs=[
audio_input,
model_size,
vad_only,
openai_key,
generate_caption,
caption_model
],
outputs=[
vad_output,
classification_output,
genre_output,
instrument_output,
caption_output
]
)
# Example files section
gr.Markdown("""
## πŸ’‘ Tips
- **Supported formats:** WAV, MP3, FLAC, M4A, OGG, and more
- **Best results:** Use clear, high-quality audio files
- **Processing time:** Depends on file length and model size
- **OpenAI API:** Required for caption generation (get yours at [OpenAI](https://platform.openai.com/))
""")
return demo
def install_system_dependencies():
"""Install system dependencies if needed"""
try:
# Check if ffmpeg is available
subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True)
except (subprocess.CalledProcessError, FileNotFoundError):
print("Installing ffmpeg...")
try:
subprocess.run(['apt-get', 'update'], check=True)
subprocess.run(['apt-get', 'install', '-y', 'ffmpeg'], check=True)
except:
print("Warning: Could not install ffmpeg. Audio conversion may fail.")
def main():
# Install system dependencies
install_system_dependencies()
# For Hugging Face deployment, we'll run on 0.0.0.0:7860
host = "0.0.0.0"
port = 7860
# Create and launch the interface
demo = create_gradio_interface()
print("πŸš€ Starting Audio Classification Server...")
print(f"πŸ“‘ Server will run on http://{host}:{port}")
demo.launch(
server_name=host,
server_port=port,
share=False, # Don't create external share link on HF
show_error=True
)
if __name__ == '__main__':
main()