Spaces:
Running
on
Zero
Running
on
Zero
Gamahea
commited on
Commit
·
358bb13
1
Parent(s):
f773a0f
Deploy Music Generation Studio - 2025-12-13 09:56:39
Browse files- .gitignore +6 -1
- app.py +730 -1
- backend/routes/training.py +266 -0
- backend/services/audio_analysis_service.py +393 -0
- backend/services/audio_upscale_service.py +191 -0
- backend/services/lora_training_service.py +552 -0
- backend/services/mastering_service.py +38 -0
- backend/services/stem_enhancement_service.py +316 -0
.gitignore
CHANGED
|
@@ -1,9 +1,14 @@
|
|
| 1 |
__pycache__/
|
| 2 |
*.pyc
|
| 3 |
*.pyo
|
|
|
|
| 4 |
.Python
|
| 5 |
*.log
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
| 7 |
outputs/
|
| 8 |
logs/
|
| 9 |
.env
|
|
|
|
|
|
| 1 |
__pycache__/
|
| 2 |
*.pyc
|
| 3 |
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
.Python
|
| 6 |
*.log
|
| 7 |
+
*.swp
|
| 8 |
+
*.swo
|
| 9 |
+
*~
|
| 10 |
+
models/
|
| 11 |
outputs/
|
| 12 |
logs/
|
| 13 |
.env
|
| 14 |
+
.DS_Store
|
app.py
CHANGED
|
@@ -9,6 +9,8 @@ import logging
|
|
| 9 |
from pathlib import Path
|
| 10 |
import shutil
|
| 11 |
import subprocess
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# Import spaces for ZeroGPU support
|
| 14 |
try:
|
|
@@ -687,12 +689,380 @@ def apply_custom_eq(low_shelf, low_mid, mid, high_mid, high_shelf, timeline_stat
|
|
| 687 |
logger.error(f"Error applying EQ: {e}", exc_info=True)
|
| 688 |
return f"❌ Error: {str(e)}", timeline_state
|
| 689 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
def format_duration(seconds: float) -> str:
|
| 691 |
"""Format duration as MM:SS"""
|
| 692 |
mins = int(seconds // 60)
|
| 693 |
secs = int(seconds % 60)
|
| 694 |
return f"{mins}:{secs:02d}"
|
| 695 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
# Create Gradio interface
|
| 697 |
with gr.Blocks(
|
| 698 |
title="🎵 Music Generation Studio",
|
|
@@ -837,7 +1207,8 @@ with gr.Blocks(
|
|
| 837 |
"Jazz Vintage - Vintage jazz character",
|
| 838 |
"Orchestral Wide - Wide orchestral space",
|
| 839 |
"Classical Concert - Concert hall sound",
|
| 840 |
-
"Ambient Spacious - Spacious atmospheric"
|
|
|
|
| 841 |
],
|
| 842 |
value="Clean Master - Transparent mastering",
|
| 843 |
label="Select Preset"
|
|
@@ -920,6 +1291,37 @@ with gr.Blocks(
|
|
| 920 |
)
|
| 921 |
)
|
| 922 |
eq_status = gr.Textbox(label="Status", lines=1, interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 923 |
|
| 924 |
# Export Section
|
| 925 |
gr.Markdown("---")
|
|
@@ -1019,6 +1421,321 @@ with gr.Blocks(
|
|
| 1019 |
outputs=[timeline_playback]
|
| 1020 |
)
|
| 1021 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1022 |
# Help section
|
| 1023 |
with gr.Accordion("ℹ️ Help & Tips", open=False):
|
| 1024 |
gr.Markdown(
|
|
@@ -1057,6 +1774,18 @@ with gr.Blocks(
|
|
| 1057 |
- Remove or clear clips as needed
|
| 1058 |
- Export combines all clips into one file
|
| 1059 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1060 |
---
|
| 1061 |
|
| 1062 |
⏱️ **Average Generation Time**: 2-4 minutes per 30-second clip on CPU
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
import shutil
|
| 11 |
import subprocess
|
| 12 |
+
import json
|
| 13 |
+
import time
|
| 14 |
|
| 15 |
# Import spaces for ZeroGPU support
|
| 16 |
try:
|
|
|
|
| 689 |
logger.error(f"Error applying EQ: {e}", exc_info=True)
|
| 690 |
return f"❌ Error: {str(e)}", timeline_state
|
| 691 |
|
| 692 |
+
def enhance_timeline_clips(enhancement_level: str, timeline_state: dict):
|
| 693 |
+
"""Enhance all clips in timeline using stem separation"""
|
| 694 |
+
try:
|
| 695 |
+
logger.info(f"[ENHANCEMENT] Starting enhancement: level={enhancement_level}")
|
| 696 |
+
|
| 697 |
+
# Restore timeline from state
|
| 698 |
+
if timeline_state and 'clips' in timeline_state:
|
| 699 |
+
timeline_service.clips = []
|
| 700 |
+
for clip_data in timeline_state['clips']:
|
| 701 |
+
from models.schemas import TimelineClip
|
| 702 |
+
clip = TimelineClip(**clip_data)
|
| 703 |
+
timeline_service.clips.append(clip)
|
| 704 |
+
logger.info(f"[STATE] Restored {len(timeline_service.clips)} clips for enhancement")
|
| 705 |
+
|
| 706 |
+
clips = timeline_service.get_all_clips()
|
| 707 |
+
|
| 708 |
+
if not clips:
|
| 709 |
+
return "❌ No clips in timeline", timeline_state
|
| 710 |
+
|
| 711 |
+
# Import stem enhancement service
|
| 712 |
+
from services.stem_enhancement_service import StemEnhancementService
|
| 713 |
+
enhancer = StemEnhancementService()
|
| 714 |
+
|
| 715 |
+
# Convert enhancement level to service format
|
| 716 |
+
level_map = {
|
| 717 |
+
"Fast": "fast",
|
| 718 |
+
"Balanced": "balanced",
|
| 719 |
+
"Maximum": "maximum"
|
| 720 |
+
}
|
| 721 |
+
service_level = level_map.get(enhancement_level, "balanced")
|
| 722 |
+
|
| 723 |
+
# Enhance each clip
|
| 724 |
+
enhanced_count = 0
|
| 725 |
+
for clip in clips:
|
| 726 |
+
clip_path = clip['file_path']
|
| 727 |
+
|
| 728 |
+
if not os.path.exists(clip_path):
|
| 729 |
+
logger.warning(f"Clip file not found: {clip_path}")
|
| 730 |
+
continue
|
| 731 |
+
|
| 732 |
+
logger.info(f"Enhancing clip: {clip['clip_id']} ({service_level})")
|
| 733 |
+
|
| 734 |
+
# Enhance in-place (overwrites original)
|
| 735 |
+
enhancer.enhance_clip(
|
| 736 |
+
audio_path=clip_path,
|
| 737 |
+
output_path=clip_path,
|
| 738 |
+
enhancement_level=service_level
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
enhanced_count += 1
|
| 742 |
+
logger.info(f"Enhanced {enhanced_count}/{len(clips)} clips")
|
| 743 |
+
|
| 744 |
+
return f"✅ Enhanced {enhanced_count} clip(s) ({enhancement_level} quality)", timeline_state
|
| 745 |
+
|
| 746 |
+
except Exception as e:
|
| 747 |
+
logger.error(f"Enhancement failed: {e}", exc_info=True)
|
| 748 |
+
return f"❌ Error: {str(e)}", timeline_state
|
| 749 |
+
|
| 750 |
+
def upscale_timeline_clips(upscale_mode: str, timeline_state: dict):
|
| 751 |
+
"""Upscale all clips in timeline to 48kHz"""
|
| 752 |
+
try:
|
| 753 |
+
logger.info(f"[UPSCALE] Starting upscale: mode={upscale_mode}")
|
| 754 |
+
|
| 755 |
+
# Restore timeline from state
|
| 756 |
+
if timeline_state and 'clips' in timeline_state:
|
| 757 |
+
timeline_service.clips = []
|
| 758 |
+
for clip_data in timeline_state['clips']:
|
| 759 |
+
from models.schemas import TimelineClip
|
| 760 |
+
clip = TimelineClip(**clip_data)
|
| 761 |
+
timeline_service.clips.append(clip)
|
| 762 |
+
logger.info(f"[STATE] Restored {len(timeline_service.clips)} clips for upscale")
|
| 763 |
+
|
| 764 |
+
clips = timeline_service.get_all_clips()
|
| 765 |
+
|
| 766 |
+
if not clips:
|
| 767 |
+
return "❌ No clips in timeline", timeline_state
|
| 768 |
+
|
| 769 |
+
# Import upscale service
|
| 770 |
+
from services.audio_upscale_service import AudioUpscaleService
|
| 771 |
+
upscaler = AudioUpscaleService()
|
| 772 |
+
|
| 773 |
+
# Upscale each clip
|
| 774 |
+
upscaled_count = 0
|
| 775 |
+
for clip in clips:
|
| 776 |
+
clip_path = clip['file_path']
|
| 777 |
+
|
| 778 |
+
if not os.path.exists(clip_path):
|
| 779 |
+
logger.warning(f"Clip file not found: {clip_path}")
|
| 780 |
+
continue
|
| 781 |
+
|
| 782 |
+
logger.info(f"Upscaling clip: {clip['clip_id']} ({upscale_mode})")
|
| 783 |
+
|
| 784 |
+
# Choose upscale method
|
| 785 |
+
if upscale_mode == "Quick (Resample)":
|
| 786 |
+
upscaler.quick_upscale(
|
| 787 |
+
audio_path=clip_path,
|
| 788 |
+
output_path=clip_path
|
| 789 |
+
)
|
| 790 |
+
else: # Neural (AudioSR)
|
| 791 |
+
upscaler.upscale_audio(
|
| 792 |
+
audio_path=clip_path,
|
| 793 |
+
output_path=clip_path,
|
| 794 |
+
target_sr=48000
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
upscaled_count += 1
|
| 798 |
+
logger.info(f"Upscaled {upscaled_count}/{len(clips)} clips")
|
| 799 |
+
|
| 800 |
+
return f"✅ Upscaled {upscaled_count} clip(s) to 48kHz ({upscale_mode})", timeline_state
|
| 801 |
+
|
| 802 |
+
except Exception as e:
|
| 803 |
+
logger.error(f"Upscale failed: {e}", exc_info=True)
|
| 804 |
+
return f"❌ Error: {str(e)}", timeline_state
|
| 805 |
+
|
| 806 |
def format_duration(seconds: float) -> str:
|
| 807 |
"""Format duration as MM:SS"""
|
| 808 |
mins = int(seconds // 60)
|
| 809 |
secs = int(seconds % 60)
|
| 810 |
return f"{mins}:{secs:02d}"
|
| 811 |
|
| 812 |
+
# LoRA Training Functions
|
| 813 |
+
def analyze_user_audio(audio_files, split_clips, separate_stems):
|
| 814 |
+
"""Analyze uploaded audio files and generate metadata"""
|
| 815 |
+
try:
|
| 816 |
+
if not audio_files:
|
| 817 |
+
return "❌ No audio files uploaded", None
|
| 818 |
+
|
| 819 |
+
from backend.services.audio_analysis_service import AudioAnalysisService
|
| 820 |
+
analyzer = AudioAnalysisService()
|
| 821 |
+
|
| 822 |
+
results = []
|
| 823 |
+
for audio_file in audio_files:
|
| 824 |
+
# Analyze audio
|
| 825 |
+
metadata = analyzer.analyze_audio(audio_file.name)
|
| 826 |
+
|
| 827 |
+
# Add to results
|
| 828 |
+
results.append([
|
| 829 |
+
Path(audio_file.name).name,
|
| 830 |
+
metadata.get('genre', 'unknown'),
|
| 831 |
+
metadata.get('bpm', 120),
|
| 832 |
+
metadata.get('key', 'C major'),
|
| 833 |
+
metadata.get('energy', 'medium'),
|
| 834 |
+
'', # Instruments (user fills in)
|
| 835 |
+
'' # Description (user fills in)
|
| 836 |
+
])
|
| 837 |
+
|
| 838 |
+
status = f"✅ Analyzed {len(results)} file(s)"
|
| 839 |
+
return status, results
|
| 840 |
+
|
| 841 |
+
except Exception as e:
|
| 842 |
+
logger.error(f"Audio analysis failed: {e}")
|
| 843 |
+
return f"❌ Error: {str(e)}", None
|
| 844 |
+
|
| 845 |
+
def ai_generate_all_metadata(metadata_table):
|
| 846 |
+
"""AI generate metadata for all files in table"""
|
| 847 |
+
try:
|
| 848 |
+
if not metadata_table:
|
| 849 |
+
return "❌ No files in metadata table"
|
| 850 |
+
|
| 851 |
+
# This is a placeholder - would use actual AI model
|
| 852 |
+
# For now, return sample metadata
|
| 853 |
+
updated_table = []
|
| 854 |
+
for row in metadata_table:
|
| 855 |
+
if row and row[0]: # If filename exists
|
| 856 |
+
updated_table.append([
|
| 857 |
+
row[0], # Filename
|
| 858 |
+
row[1] if row[1] else "pop", # Genre
|
| 859 |
+
row[2] if row[2] else 120, # BPM
|
| 860 |
+
row[3] if row[3] else "C major", # Key
|
| 861 |
+
row[4] if row[4] else "energetic", # Mood
|
| 862 |
+
"synth, drums, bass", # Instruments
|
| 863 |
+
f"AI-generated music in {row[1] if row[1] else 'unknown'} style" # Description
|
| 864 |
+
])
|
| 865 |
+
|
| 866 |
+
return f"✅ Generated metadata for {len(updated_table)} file(s)"
|
| 867 |
+
|
| 868 |
+
except Exception as e:
|
| 869 |
+
logger.error(f"Metadata generation failed: {e}")
|
| 870 |
+
return f"❌ Error: {str(e)}"
|
| 871 |
+
|
| 872 |
+
def prepare_user_training_dataset(audio_files, metadata_table, split_clips, separate_stems):
|
| 873 |
+
"""Prepare user audio dataset for training"""
|
| 874 |
+
try:
|
| 875 |
+
if not audio_files:
|
| 876 |
+
return "❌ No audio files uploaded"
|
| 877 |
+
|
| 878 |
+
from backend.services.audio_analysis_service import AudioAnalysisService
|
| 879 |
+
from backend.services.lora_training_service import LoRATrainingService
|
| 880 |
+
|
| 881 |
+
analyzer = AudioAnalysisService()
|
| 882 |
+
lora_service = LoRATrainingService()
|
| 883 |
+
|
| 884 |
+
# Process audio files
|
| 885 |
+
processed_files = []
|
| 886 |
+
processed_metadata = []
|
| 887 |
+
|
| 888 |
+
for i, audio_file in enumerate(audio_files):
|
| 889 |
+
# Get metadata from table
|
| 890 |
+
if metadata_table and i < len(metadata_table):
|
| 891 |
+
file_metadata = {
|
| 892 |
+
'genre': metadata_table[i][1],
|
| 893 |
+
'bpm': int(metadata_table[i][2]) if metadata_table[i][2] else 120,
|
| 894 |
+
'key': metadata_table[i][3],
|
| 895 |
+
'mood': metadata_table[i][4],
|
| 896 |
+
'instrumentation': metadata_table[i][5],
|
| 897 |
+
'description': metadata_table[i][6]
|
| 898 |
+
}
|
| 899 |
+
else:
|
| 900 |
+
# Analyze if no metadata
|
| 901 |
+
file_metadata = analyzer.analyze_audio(audio_file.name)
|
| 902 |
+
|
| 903 |
+
# Split into clips if requested
|
| 904 |
+
if split_clips:
|
| 905 |
+
clip_paths = analyzer.split_audio_to_clips(
|
| 906 |
+
audio_file.name,
|
| 907 |
+
"training_data/user_uploads/clips",
|
| 908 |
+
metadata=file_metadata
|
| 909 |
+
)
|
| 910 |
+
processed_files.extend(clip_paths)
|
| 911 |
+
processed_metadata.extend([file_metadata] * len(clip_paths))
|
| 912 |
+
else:
|
| 913 |
+
processed_files.append(audio_file.name)
|
| 914 |
+
processed_metadata.append(file_metadata)
|
| 915 |
+
|
| 916 |
+
# Separate stems if requested
|
| 917 |
+
if separate_stems:
|
| 918 |
+
stem_paths = analyzer.separate_vocal_stems(
|
| 919 |
+
audio_file.name,
|
| 920 |
+
"training_data/user_uploads/stems"
|
| 921 |
+
)
|
| 922 |
+
# Use vocals only for vocal training
|
| 923 |
+
if 'vocals' in stem_paths:
|
| 924 |
+
processed_files.append(stem_paths['vocals'])
|
| 925 |
+
processed_metadata.append({**file_metadata, 'type': 'vocal'})
|
| 926 |
+
|
| 927 |
+
# Prepare dataset
|
| 928 |
+
dataset_name = f"user_dataset_{int(time.time())}"
|
| 929 |
+
dataset_info = lora_service.prepare_dataset(
|
| 930 |
+
dataset_name,
|
| 931 |
+
processed_files,
|
| 932 |
+
processed_metadata
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
return f"✅ Prepared dataset '{dataset_name}' with {dataset_info['num_samples']} samples ({dataset_info['num_train']} train, {dataset_info['num_val']} val)"
|
| 936 |
+
|
| 937 |
+
except Exception as e:
|
| 938 |
+
logger.error(f"Dataset preparation failed: {e}")
|
| 939 |
+
return f"❌ Error: {str(e)}"
|
| 940 |
+
|
| 941 |
+
def refresh_dataset_list():
|
| 942 |
+
"""Refresh list of available datasets"""
|
| 943 |
+
try:
|
| 944 |
+
from backend.services.lora_training_service import LoRATrainingService
|
| 945 |
+
lora_service = LoRATrainingService()
|
| 946 |
+
|
| 947 |
+
datasets = lora_service.list_datasets()
|
| 948 |
+
return gr.Dropdown(choices=datasets)
|
| 949 |
+
|
| 950 |
+
except Exception as e:
|
| 951 |
+
logger.error(f"Failed to refresh datasets: {e}")
|
| 952 |
+
return gr.Dropdown(choices=[])
|
| 953 |
+
|
| 954 |
+
def start_lora_training(lora_name, dataset, batch_size, learning_rate, num_epochs, lora_rank, lora_alpha):
|
| 955 |
+
"""Start LoRA training"""
|
| 956 |
+
try:
|
| 957 |
+
if not lora_name:
|
| 958 |
+
return "❌ Please enter LoRA adapter name", ""
|
| 959 |
+
|
| 960 |
+
if not dataset:
|
| 961 |
+
return "❌ Please select a dataset", ""
|
| 962 |
+
|
| 963 |
+
from backend.services.lora_training_service import LoRATrainingService
|
| 964 |
+
lora_service = LoRATrainingService()
|
| 965 |
+
|
| 966 |
+
# Training config
|
| 967 |
+
config = {
|
| 968 |
+
'batch_size': int(batch_size),
|
| 969 |
+
'learning_rate': float(learning_rate),
|
| 970 |
+
'num_epochs': int(num_epochs),
|
| 971 |
+
'lora_rank': int(lora_rank),
|
| 972 |
+
'lora_alpha': int(lora_alpha)
|
| 973 |
+
}
|
| 974 |
+
|
| 975 |
+
# Progress callback
|
| 976 |
+
progress_log = []
|
| 977 |
+
def progress_callback(status):
|
| 978 |
+
progress_log.append(
|
| 979 |
+
f"Epoch {status['epoch']} | Step {status['step']} | Loss: {status['loss']:.4f} | Progress: {status['progress']:.1f}%"
|
| 980 |
+
)
|
| 981 |
+
return "\n".join(progress_log[-20:]) # Last 20 lines
|
| 982 |
+
|
| 983 |
+
# Start training
|
| 984 |
+
progress = f"🚀 Starting training: {lora_name}\nDataset: {dataset}\nConfig: {config}\n\n"
|
| 985 |
+
log = "Training started...\n"
|
| 986 |
+
|
| 987 |
+
# Note: In production, this should run in a background thread
|
| 988 |
+
# For now, this is a simplified synchronous version
|
| 989 |
+
results = lora_service.train_lora(
|
| 990 |
+
dataset,
|
| 991 |
+
lora_name,
|
| 992 |
+
training_type="vocal",
|
| 993 |
+
config=config,
|
| 994 |
+
progress_callback=progress_callback
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
progress += f"\n✅ Training complete!\nFinal validation loss: {results['final_val_loss']:.4f}"
|
| 998 |
+
log += f"\n\nTraining Results:\n{json.dumps(results, indent=2)}"
|
| 999 |
+
|
| 1000 |
+
return progress, log
|
| 1001 |
+
|
| 1002 |
+
except Exception as e:
|
| 1003 |
+
logger.error(f"Training failed: {e}")
|
| 1004 |
+
return f"❌ Error: {str(e)}", str(e)
|
| 1005 |
+
|
| 1006 |
+
def stop_lora_training():
|
| 1007 |
+
"""Stop current training"""
|
| 1008 |
+
try:
|
| 1009 |
+
from backend.services.lora_training_service import LoRATrainingService
|
| 1010 |
+
lora_service = LoRATrainingService()
|
| 1011 |
+
|
| 1012 |
+
lora_service.stop_training()
|
| 1013 |
+
return "⏹️ Training stopped"
|
| 1014 |
+
|
| 1015 |
+
except Exception as e:
|
| 1016 |
+
logger.error(f"Failed to stop training: {e}")
|
| 1017 |
+
return f"❌ Error: {str(e)}"
|
| 1018 |
+
|
| 1019 |
+
def refresh_lora_list():
|
| 1020 |
+
"""Refresh list of LoRA adapters"""
|
| 1021 |
+
try:
|
| 1022 |
+
from backend.services.lora_training_service import LoRATrainingService
|
| 1023 |
+
lora_service = LoRATrainingService()
|
| 1024 |
+
|
| 1025 |
+
adapters = lora_service.list_lora_adapters()
|
| 1026 |
+
|
| 1027 |
+
# Format as table
|
| 1028 |
+
table_data = []
|
| 1029 |
+
lora_names = []
|
| 1030 |
+
|
| 1031 |
+
for adapter in adapters:
|
| 1032 |
+
table_data.append([
|
| 1033 |
+
adapter.get('name', ''),
|
| 1034 |
+
adapter.get('saved_at', ''),
|
| 1035 |
+
adapter.get('training_steps', 0),
|
| 1036 |
+
adapter.get('training_type', 'unknown')
|
| 1037 |
+
])
|
| 1038 |
+
lora_names.append(adapter.get('name', ''))
|
| 1039 |
+
|
| 1040 |
+
return table_data, gr.Dropdown(choices=lora_names)
|
| 1041 |
+
|
| 1042 |
+
except Exception as e:
|
| 1043 |
+
logger.error(f"Failed to refresh LoRA list: {e}")
|
| 1044 |
+
return [], gr.Dropdown(choices=[])
|
| 1045 |
+
|
| 1046 |
+
def delete_lora(lora_name):
|
| 1047 |
+
"""Delete selected LoRA adapter"""
|
| 1048 |
+
try:
|
| 1049 |
+
if not lora_name:
|
| 1050 |
+
return "❌ No LoRA selected"
|
| 1051 |
+
|
| 1052 |
+
from backend.services.lora_training_service import LoRATrainingService
|
| 1053 |
+
lora_service = LoRATrainingService()
|
| 1054 |
+
|
| 1055 |
+
success = lora_service.delete_lora_adapter(lora_name)
|
| 1056 |
+
|
| 1057 |
+
if success:
|
| 1058 |
+
return f"✅ Deleted LoRA adapter: {lora_name}"
|
| 1059 |
+
else:
|
| 1060 |
+
return f"❌ Failed to delete: {lora_name}"
|
| 1061 |
+
|
| 1062 |
+
except Exception as e:
|
| 1063 |
+
logger.error(f"Failed to delete LoRA: {e}")
|
| 1064 |
+
return f"❌ Error: {str(e)}"
|
| 1065 |
+
|
| 1066 |
# Create Gradio interface
|
| 1067 |
with gr.Blocks(
|
| 1068 |
title="🎵 Music Generation Studio",
|
|
|
|
| 1207 |
"Jazz Vintage - Vintage jazz character",
|
| 1208 |
"Orchestral Wide - Wide orchestral space",
|
| 1209 |
"Classical Concert - Concert hall sound",
|
| 1210 |
+
"Ambient Spacious - Spacious atmospheric",
|
| 1211 |
+
"Harmonic Enhance - Adds brightness and warmth"
|
| 1212 |
],
|
| 1213 |
value="Clean Master - Transparent mastering",
|
| 1214 |
label="Select Preset"
|
|
|
|
| 1291 |
)
|
| 1292 |
)
|
| 1293 |
eq_status = gr.Textbox(label="Status", lines=1, interactive=False)
|
| 1294 |
+
|
| 1295 |
+
# Audio Enhancement Section
|
| 1296 |
+
gr.Markdown("---")
|
| 1297 |
+
with gr.Row():
|
| 1298 |
+
with gr.Column(scale=1):
|
| 1299 |
+
gr.Markdown("**🎛️ Stem Enhancement**")
|
| 1300 |
+
gr.Markdown("*Separate and enhance vocals, drums, bass independently (improves AI audio quality)*")
|
| 1301 |
+
|
| 1302 |
+
enhancement_level = gr.Radio(
|
| 1303 |
+
choices=["Fast", "Balanced", "Maximum"],
|
| 1304 |
+
value="Balanced",
|
| 1305 |
+
label="Enhancement Level",
|
| 1306 |
+
info="Fast: Quick denoise | Balanced: Best quality/speed | Maximum: Full processing"
|
| 1307 |
+
)
|
| 1308 |
+
|
| 1309 |
+
enhance_timeline_btn = gr.Button("✨ Enhance All Clips", variant="primary")
|
| 1310 |
+
enhancement_status = gr.Textbox(label="Status", lines=2, interactive=False)
|
| 1311 |
+
|
| 1312 |
+
with gr.Column(scale=1):
|
| 1313 |
+
gr.Markdown("**🔊 Audio Upscaling**")
|
| 1314 |
+
gr.Markdown("*Neural upsampling to 48kHz for enhanced high-frequency detail*")
|
| 1315 |
+
|
| 1316 |
+
upscale_mode = gr.Radio(
|
| 1317 |
+
choices=["Quick (Resample)", "Neural (AudioSR)"],
|
| 1318 |
+
value="Quick (Resample)",
|
| 1319 |
+
label="Upscale Method",
|
| 1320 |
+
info="Quick: Fast resampling | Neural: AI-powered super resolution"
|
| 1321 |
+
)
|
| 1322 |
+
|
| 1323 |
+
upscale_timeline_btn = gr.Button("⬆️ Upscale All Clips", variant="primary")
|
| 1324 |
+
upscale_status = gr.Textbox(label="Status", lines=2, interactive=False)
|
| 1325 |
|
| 1326 |
# Export Section
|
| 1327 |
gr.Markdown("---")
|
|
|
|
| 1421 |
outputs=[timeline_playback]
|
| 1422 |
)
|
| 1423 |
|
| 1424 |
+
# Enhancement event handlers
|
| 1425 |
+
enhance_timeline_btn.click(
|
| 1426 |
+
fn=enhance_timeline_clips,
|
| 1427 |
+
inputs=[enhancement_level, timeline_state],
|
| 1428 |
+
outputs=[enhancement_status, timeline_state]
|
| 1429 |
+
).then(
|
| 1430 |
+
fn=get_timeline_playback,
|
| 1431 |
+
inputs=[timeline_state],
|
| 1432 |
+
outputs=[timeline_playback]
|
| 1433 |
+
)
|
| 1434 |
+
|
| 1435 |
+
upscale_timeline_btn.click(
|
| 1436 |
+
fn=upscale_timeline_clips,
|
| 1437 |
+
inputs=[upscale_mode, timeline_state],
|
| 1438 |
+
outputs=[upscale_status, timeline_state]
|
| 1439 |
+
).then(
|
| 1440 |
+
fn=get_timeline_playback,
|
| 1441 |
+
inputs=[timeline_state],
|
| 1442 |
+
outputs=[timeline_playback]
|
| 1443 |
+
)
|
| 1444 |
+
|
| 1445 |
+
# LoRA Training Section
|
| 1446 |
+
gr.Markdown("---")
|
| 1447 |
+
with gr.Accordion("🎓 LoRA Training (Advanced)", open=False):
|
| 1448 |
+
gr.Markdown(
|
| 1449 |
+
"""
|
| 1450 |
+
# 🧠 Train Custom LoRA Adapters
|
| 1451 |
+
|
| 1452 |
+
Fine-tune DiffRhythm2 with your own audio or curated datasets to create specialized music generation models.
|
| 1453 |
+
|
| 1454 |
+
**Training is permanent** - LoRA adapters are saved to disk and persist across sessions.
|
| 1455 |
+
"""
|
| 1456 |
+
)
|
| 1457 |
+
|
| 1458 |
+
with gr.Tabs():
|
| 1459 |
+
# Tab 1: Dataset Training
|
| 1460 |
+
with gr.Tab("📚 Dataset Training"):
|
| 1461 |
+
gr.Markdown("### Pre-curated Dataset Training")
|
| 1462 |
+
|
| 1463 |
+
training_type = gr.Radio(
|
| 1464 |
+
choices=["Vocal Training", "Symbolic Training"],
|
| 1465 |
+
value="Vocal Training",
|
| 1466 |
+
label="Training Type"
|
| 1467 |
+
)
|
| 1468 |
+
|
| 1469 |
+
with gr.Row():
|
| 1470 |
+
with gr.Column():
|
| 1471 |
+
gr.Markdown("**Vocal Datasets**")
|
| 1472 |
+
vocal_datasets = gr.CheckboxGroup(
|
| 1473 |
+
choices=[
|
| 1474 |
+
"OpenSinger (Multi-singer, 50+ hours)",
|
| 1475 |
+
"M4Singer (Chinese pop, 29 hours)",
|
| 1476 |
+
"CC Mixter (Creative Commons stems)"
|
| 1477 |
+
],
|
| 1478 |
+
label="Select Vocal Datasets",
|
| 1479 |
+
info="Check datasets to include in training"
|
| 1480 |
+
)
|
| 1481 |
+
|
| 1482 |
+
with gr.Column():
|
| 1483 |
+
gr.Markdown("**Symbolic Datasets**")
|
| 1484 |
+
symbolic_datasets = gr.CheckboxGroup(
|
| 1485 |
+
choices=[
|
| 1486 |
+
"Lakh MIDI (176k files, diverse genres)",
|
| 1487 |
+
"Mutopia (Classical, 2000+ pieces)"
|
| 1488 |
+
],
|
| 1489 |
+
label="Select Symbolic Datasets",
|
| 1490 |
+
info="Check datasets to include in training"
|
| 1491 |
+
)
|
| 1492 |
+
|
| 1493 |
+
dataset_download_btn = gr.Button("📥 Download & Prepare Datasets", variant="secondary")
|
| 1494 |
+
dataset_status = gr.Textbox(label="Dataset Status", lines=2, interactive=False)
|
| 1495 |
+
|
| 1496 |
+
# Tab 2: User Audio Training
|
| 1497 |
+
with gr.Tab("🎵 User Audio Training"):
|
| 1498 |
+
gr.Markdown("### Train on Your Own Audio")
|
| 1499 |
+
|
| 1500 |
+
user_audio_upload = gr.File(
|
| 1501 |
+
label="Upload Audio Files (.wav)",
|
| 1502 |
+
file_count="multiple",
|
| 1503 |
+
file_types=[".wav"]
|
| 1504 |
+
)
|
| 1505 |
+
|
| 1506 |
+
gr.Markdown("#### Audio Processing Options")
|
| 1507 |
+
|
| 1508 |
+
with gr.Row():
|
| 1509 |
+
split_to_clips = gr.Checkbox(
|
| 1510 |
+
label="Auto-split into clips",
|
| 1511 |
+
value=True,
|
| 1512 |
+
info="Split long audio into 10-30s training clips"
|
| 1513 |
+
)
|
| 1514 |
+
separate_stems = gr.Checkbox(
|
| 1515 |
+
label="Separate vocal stems",
|
| 1516 |
+
value=False,
|
| 1517 |
+
info="Extract vocals for vocal-only training (slower)"
|
| 1518 |
+
)
|
| 1519 |
+
|
| 1520 |
+
analyze_audio_btn = gr.Button("🔍 Analyze & Generate Metadata", variant="secondary")
|
| 1521 |
+
analysis_status = gr.Textbox(label="Analysis Status", lines=2, interactive=False)
|
| 1522 |
+
|
| 1523 |
+
gr.Markdown("---")
|
| 1524 |
+
gr.Markdown("#### Metadata Editor")
|
| 1525 |
+
|
| 1526 |
+
metadata_table = gr.Dataframe(
|
| 1527 |
+
headers=["File", "Genre", "BPM", "Key", "Mood", "Instruments", "Description"],
|
| 1528 |
+
datatype=["str", "str", "number", "str", "str", "str", "str"],
|
| 1529 |
+
row_count=5,
|
| 1530 |
+
col_count=(7, "fixed"),
|
| 1531 |
+
label="Audio Metadata",
|
| 1532 |
+
interactive=True
|
| 1533 |
+
)
|
| 1534 |
+
|
| 1535 |
+
with gr.Row():
|
| 1536 |
+
ai_generate_metadata_btn = gr.Button("✨ AI Generate All Metadata", size="sm")
|
| 1537 |
+
save_metadata_btn = gr.Button("💾 Save Metadata", variant="primary", size="sm")
|
| 1538 |
+
|
| 1539 |
+
metadata_status = gr.Textbox(label="Metadata Status", lines=1, interactive=False)
|
| 1540 |
+
|
| 1541 |
+
prepare_user_dataset_btn = gr.Button("📦 Prepare Training Dataset", variant="primary")
|
| 1542 |
+
prepare_status = gr.Textbox(label="Preparation Status", lines=2, interactive=False)
|
| 1543 |
+
|
| 1544 |
+
# Tab 3: Training Configuration
|
| 1545 |
+
with gr.Tab("⚙️ Training Configuration"):
|
| 1546 |
+
gr.Markdown("### LoRA Training Settings")
|
| 1547 |
+
|
| 1548 |
+
lora_name_input = gr.Textbox(
|
| 1549 |
+
label="LoRA Adapter Name",
|
| 1550 |
+
placeholder="my_custom_lora_v1",
|
| 1551 |
+
info="Unique name for this LoRA adapter"
|
| 1552 |
+
)
|
| 1553 |
+
|
| 1554 |
+
selected_dataset = gr.Dropdown(
|
| 1555 |
+
choices=[],
|
| 1556 |
+
label="Training Dataset",
|
| 1557 |
+
info="Select prepared dataset to train on"
|
| 1558 |
+
)
|
| 1559 |
+
|
| 1560 |
+
refresh_datasets_btn = gr.Button("🔄 Refresh Datasets", size="sm")
|
| 1561 |
+
|
| 1562 |
+
gr.Markdown("#### Hyperparameters")
|
| 1563 |
+
|
| 1564 |
+
with gr.Row():
|
| 1565 |
+
with gr.Column():
|
| 1566 |
+
batch_size = gr.Slider(
|
| 1567 |
+
minimum=1,
|
| 1568 |
+
maximum=16,
|
| 1569 |
+
value=4,
|
| 1570 |
+
step=1,
|
| 1571 |
+
label="Batch Size",
|
| 1572 |
+
info="Larger = faster but more GPU memory"
|
| 1573 |
+
)
|
| 1574 |
+
|
| 1575 |
+
learning_rate = gr.Slider(
|
| 1576 |
+
minimum=1e-5,
|
| 1577 |
+
maximum=1e-3,
|
| 1578 |
+
value=3e-4,
|
| 1579 |
+
step=1e-5,
|
| 1580 |
+
label="Learning Rate",
|
| 1581 |
+
info="Lower = more stable, higher = faster convergence"
|
| 1582 |
+
)
|
| 1583 |
+
|
| 1584 |
+
num_epochs = gr.Slider(
|
| 1585 |
+
minimum=1,
|
| 1586 |
+
maximum=50,
|
| 1587 |
+
value=10,
|
| 1588 |
+
step=1,
|
| 1589 |
+
label="Number of Epochs",
|
| 1590 |
+
info="How many times to iterate over dataset"
|
| 1591 |
+
)
|
| 1592 |
+
|
| 1593 |
+
with gr.Column():
|
| 1594 |
+
lora_rank = gr.Slider(
|
| 1595 |
+
minimum=4,
|
| 1596 |
+
maximum=64,
|
| 1597 |
+
value=16,
|
| 1598 |
+
step=4,
|
| 1599 |
+
label="LoRA Rank",
|
| 1600 |
+
info="Higher = more capacity but slower"
|
| 1601 |
+
)
|
| 1602 |
+
|
| 1603 |
+
lora_alpha = gr.Slider(
|
| 1604 |
+
minimum=8,
|
| 1605 |
+
maximum=128,
|
| 1606 |
+
value=32,
|
| 1607 |
+
step=8,
|
| 1608 |
+
label="LoRA Alpha",
|
| 1609 |
+
info="Scaling factor for LoRA weights"
|
| 1610 |
+
)
|
| 1611 |
+
|
| 1612 |
+
gr.Markdown("---")
|
| 1613 |
+
|
| 1614 |
+
start_training_btn = gr.Button("🚀 Start Training", variant="primary", size="lg")
|
| 1615 |
+
stop_training_btn = gr.Button("⏹️ Stop Training", variant="stop", size="sm")
|
| 1616 |
+
|
| 1617 |
+
training_progress = gr.Textbox(
|
| 1618 |
+
label="Training Progress",
|
| 1619 |
+
lines=5,
|
| 1620 |
+
interactive=False
|
| 1621 |
+
)
|
| 1622 |
+
|
| 1623 |
+
training_log = gr.Textbox(
|
| 1624 |
+
label="Training Log",
|
| 1625 |
+
lines=10,
|
| 1626 |
+
interactive=False
|
| 1627 |
+
)
|
| 1628 |
+
|
| 1629 |
+
# Tab 4: Manage LoRA Adapters
|
| 1630 |
+
with gr.Tab("📂 Manage LoRA Adapters"):
|
| 1631 |
+
gr.Markdown("### Installed LoRA Adapters")
|
| 1632 |
+
|
| 1633 |
+
lora_list = gr.Dataframe(
|
| 1634 |
+
headers=["Name", "Created", "Training Steps", "Type"],
|
| 1635 |
+
datatype=["str", "str", "number", "str"],
|
| 1636 |
+
row_count=10,
|
| 1637 |
+
label="Available LoRA Adapters"
|
| 1638 |
+
)
|
| 1639 |
+
|
| 1640 |
+
with gr.Row():
|
| 1641 |
+
refresh_lora_btn = gr.Button("🔄 Refresh List", size="sm")
|
| 1642 |
+
selected_lora = gr.Dropdown(
|
| 1643 |
+
choices=[],
|
| 1644 |
+
label="Select LoRA",
|
| 1645 |
+
scale=2
|
| 1646 |
+
)
|
| 1647 |
+
delete_lora_btn = gr.Button("🗑️ Delete LoRA", variant="stop", size="sm")
|
| 1648 |
+
|
| 1649 |
+
lora_management_status = gr.Textbox(label="Status", lines=1, interactive=False)
|
| 1650 |
+
|
| 1651 |
+
gr.Markdown("---")
|
| 1652 |
+
gr.Markdown(
|
| 1653 |
+
"""
|
| 1654 |
+
### 💡 Training Tips
|
| 1655 |
+
|
| 1656 |
+
**Dataset Size:**
|
| 1657 |
+
- Vocal: 20-50 hours minimum for good results
|
| 1658 |
+
- Symbolic: 1000+ MIDI files recommended
|
| 1659 |
+
- User audio: 30+ minutes minimum (more is better)
|
| 1660 |
+
|
| 1661 |
+
**Training Time Estimates:**
|
| 1662 |
+
- Small dataset (< 1 hour): 2-4 hours training
|
| 1663 |
+
- Medium dataset (1-10 hours): 4-12 hours training
|
| 1664 |
+
- Large dataset (> 10 hours): 12-48 hours training
|
| 1665 |
+
|
| 1666 |
+
**GPU Requirements:**
|
| 1667 |
+
- Minimum: 16GB VRAM (LoRA training)
|
| 1668 |
+
- Recommended: 24GB+ VRAM
|
| 1669 |
+
- CPU training: 10-50x slower (not recommended)
|
| 1670 |
+
|
| 1671 |
+
**Best Practices:**
|
| 1672 |
+
1. Start with small learning rate (3e-4)
|
| 1673 |
+
2. Use batch size 4-8 for best results
|
| 1674 |
+
3. Monitor validation loss to prevent overfitting
|
| 1675 |
+
4. Save checkpoints every 500 steps
|
| 1676 |
+
5. Test generated samples during training
|
| 1677 |
+
|
| 1678 |
+
**Audio Preprocessing:**
|
| 1679 |
+
- Split long files into 10-30s clips for diversity
|
| 1680 |
+
- Separate vocal stems for vocal-specific training
|
| 1681 |
+
- Use AI metadata generation for consistent labels
|
| 1682 |
+
- Ensure audio quality (44.1kHz, no compression artifacts)
|
| 1683 |
+
"""
|
| 1684 |
+
)
|
| 1685 |
+
|
| 1686 |
+
# LoRA Training Event Handlers
|
| 1687 |
+
analyze_audio_btn.click(
|
| 1688 |
+
fn=analyze_user_audio,
|
| 1689 |
+
inputs=[user_audio_upload, split_to_clips, separate_stems],
|
| 1690 |
+
outputs=[analysis_status, metadata_table]
|
| 1691 |
+
)
|
| 1692 |
+
|
| 1693 |
+
ai_generate_metadata_btn.click(
|
| 1694 |
+
fn=ai_generate_all_metadata,
|
| 1695 |
+
inputs=[metadata_table],
|
| 1696 |
+
outputs=[metadata_status]
|
| 1697 |
+
)
|
| 1698 |
+
|
| 1699 |
+
prepare_user_dataset_btn.click(
|
| 1700 |
+
fn=prepare_user_training_dataset,
|
| 1701 |
+
inputs=[user_audio_upload, metadata_table, split_to_clips, separate_stems],
|
| 1702 |
+
outputs=[prepare_status]
|
| 1703 |
+
)
|
| 1704 |
+
|
| 1705 |
+
refresh_datasets_btn.click(
|
| 1706 |
+
fn=refresh_dataset_list,
|
| 1707 |
+
inputs=[],
|
| 1708 |
+
outputs=[selected_dataset]
|
| 1709 |
+
)
|
| 1710 |
+
|
| 1711 |
+
start_training_btn.click(
|
| 1712 |
+
fn=start_lora_training,
|
| 1713 |
+
inputs=[lora_name_input, selected_dataset, batch_size, learning_rate, num_epochs, lora_rank, lora_alpha],
|
| 1714 |
+
outputs=[training_progress, training_log]
|
| 1715 |
+
)
|
| 1716 |
+
|
| 1717 |
+
stop_training_btn.click(
|
| 1718 |
+
fn=stop_lora_training,
|
| 1719 |
+
inputs=[],
|
| 1720 |
+
outputs=[training_progress]
|
| 1721 |
+
)
|
| 1722 |
+
|
| 1723 |
+
refresh_lora_btn.click(
|
| 1724 |
+
fn=refresh_lora_list,
|
| 1725 |
+
inputs=[],
|
| 1726 |
+
outputs=[lora_list, selected_lora]
|
| 1727 |
+
)
|
| 1728 |
+
|
| 1729 |
+
delete_lora_btn.click(
|
| 1730 |
+
fn=delete_lora,
|
| 1731 |
+
inputs=[selected_lora],
|
| 1732 |
+
outputs=[lora_management_status]
|
| 1733 |
+
).then(
|
| 1734 |
+
fn=refresh_lora_list,
|
| 1735 |
+
inputs=[],
|
| 1736 |
+
outputs=[lora_list, selected_lora]
|
| 1737 |
+
)
|
| 1738 |
+
|
| 1739 |
# Help section
|
| 1740 |
with gr.Accordion("ℹ️ Help & Tips", open=False):
|
| 1741 |
gr.Markdown(
|
|
|
|
| 1774 |
- Remove or clear clips as needed
|
| 1775 |
- Export combines all clips into one file
|
| 1776 |
|
| 1777 |
+
## 🎛️ Audio Enhancement (Advanced)
|
| 1778 |
+
|
| 1779 |
+
- **Stem Enhancement**: Separates vocals, drums, bass for individual processing
|
| 1780 |
+
- *Fast*: Quick denoise (~2-3s per clip)
|
| 1781 |
+
- *Balanced*: Best quality/speed (~5-7s per clip)
|
| 1782 |
+
- *Maximum*: Full processing (~10-15s per clip)
|
| 1783 |
+
- **Audio Upscaling**: Increase sample rate to 48kHz
|
| 1784 |
+
- *Quick*: Fast resampling (~1s per clip)
|
| 1785 |
+
- *Neural*: AI super-resolution (~10-20s per clip, better quality)
|
| 1786 |
+
|
| 1787 |
+
**Note**: Apply enhancements AFTER all clips are generated and BEFORE export
|
| 1788 |
+
|
| 1789 |
---
|
| 1790 |
|
| 1791 |
⏱️ **Average Generation Time**: 2-4 minutes per 30-second clip on CPU
|
backend/routes/training.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training API Routes
|
| 3 |
+
Endpoints for LoRA training, dataset management, and audio analysis.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from flask import Blueprint, request, jsonify
|
| 7 |
+
from backend.services.lora_training_service import LoRATrainingService
|
| 8 |
+
from backend.services.audio_analysis_service import AudioAnalysisService
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
training_bp = Blueprint('training', __name__, url_prefix='/api/training')
|
| 16 |
+
|
| 17 |
+
# Initialize services
|
| 18 |
+
lora_service = LoRATrainingService()
|
| 19 |
+
audio_analysis_service = AudioAnalysisService()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@training_bp.route('/analyze-audio', methods=['POST'])
|
| 23 |
+
def analyze_audio():
|
| 24 |
+
"""Analyze uploaded audio and generate metadata"""
|
| 25 |
+
try:
|
| 26 |
+
data = request.json
|
| 27 |
+
audio_path = data.get('audio_path')
|
| 28 |
+
|
| 29 |
+
if not audio_path or not os.path.exists(audio_path):
|
| 30 |
+
return jsonify({'error': 'Invalid audio path'}), 400
|
| 31 |
+
|
| 32 |
+
# Analyze audio
|
| 33 |
+
metadata = audio_analysis_service.analyze_audio(audio_path)
|
| 34 |
+
|
| 35 |
+
return jsonify({
|
| 36 |
+
'success': True,
|
| 37 |
+
'metadata': metadata
|
| 38 |
+
})
|
| 39 |
+
|
| 40 |
+
except Exception as e:
|
| 41 |
+
logger.error(f"Audio analysis failed: {str(e)}")
|
| 42 |
+
return jsonify({'error': str(e)}), 500
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@training_bp.route('/split-audio', methods=['POST'])
|
| 46 |
+
def split_audio():
|
| 47 |
+
"""Split audio into training clips"""
|
| 48 |
+
try:
|
| 49 |
+
data = request.json
|
| 50 |
+
audio_path = data.get('audio_path')
|
| 51 |
+
output_dir = data.get('output_dir', 'training_data/user_uploads/clips')
|
| 52 |
+
segments = data.get('segments') # Optional
|
| 53 |
+
metadata = data.get('metadata') # Optional
|
| 54 |
+
|
| 55 |
+
if not audio_path or not os.path.exists(audio_path):
|
| 56 |
+
return jsonify({'error': 'Invalid audio path'}), 400
|
| 57 |
+
|
| 58 |
+
# Split audio
|
| 59 |
+
clip_paths = audio_analysis_service.split_audio_to_clips(
|
| 60 |
+
audio_path,
|
| 61 |
+
output_dir,
|
| 62 |
+
segments,
|
| 63 |
+
metadata
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
return jsonify({
|
| 67 |
+
'success': True,
|
| 68 |
+
'num_clips': len(clip_paths),
|
| 69 |
+
'clip_paths': clip_paths
|
| 70 |
+
})
|
| 71 |
+
|
| 72 |
+
except Exception as e:
|
| 73 |
+
logger.error(f"Audio splitting failed: {str(e)}")
|
| 74 |
+
return jsonify({'error': str(e)}), 500
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@training_bp.route('/separate-stems', methods=['POST'])
|
| 78 |
+
def separate_stems():
|
| 79 |
+
"""Separate audio into vocal/instrumental stems"""
|
| 80 |
+
try:
|
| 81 |
+
data = request.json
|
| 82 |
+
audio_path = data.get('audio_path')
|
| 83 |
+
output_dir = data.get('output_dir', 'training_data/user_uploads/stems')
|
| 84 |
+
|
| 85 |
+
if not audio_path or not os.path.exists(audio_path):
|
| 86 |
+
return jsonify({'error': 'Invalid audio path'}), 400
|
| 87 |
+
|
| 88 |
+
# Separate stems
|
| 89 |
+
stem_paths = audio_analysis_service.separate_vocal_stems(
|
| 90 |
+
audio_path,
|
| 91 |
+
output_dir
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
return jsonify({
|
| 95 |
+
'success': True,
|
| 96 |
+
'stems': stem_paths
|
| 97 |
+
})
|
| 98 |
+
|
| 99 |
+
except Exception as e:
|
| 100 |
+
logger.error(f"Stem separation failed: {str(e)}")
|
| 101 |
+
return jsonify({'error': str(e)}), 500
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@training_bp.route('/prepare-dataset', methods=['POST'])
|
| 105 |
+
def prepare_dataset():
|
| 106 |
+
"""Prepare training dataset from audio files"""
|
| 107 |
+
try:
|
| 108 |
+
data = request.json
|
| 109 |
+
dataset_name = data.get('dataset_name')
|
| 110 |
+
audio_files = data.get('audio_files', [])
|
| 111 |
+
metadata_list = data.get('metadata_list', [])
|
| 112 |
+
split_ratio = data.get('split_ratio', 0.9)
|
| 113 |
+
|
| 114 |
+
if not dataset_name:
|
| 115 |
+
return jsonify({'error': 'Dataset name required'}), 400
|
| 116 |
+
|
| 117 |
+
if not audio_files:
|
| 118 |
+
return jsonify({'error': 'No audio files provided'}), 400
|
| 119 |
+
|
| 120 |
+
# Prepare dataset
|
| 121 |
+
dataset_info = lora_service.prepare_dataset(
|
| 122 |
+
dataset_name,
|
| 123 |
+
audio_files,
|
| 124 |
+
metadata_list,
|
| 125 |
+
split_ratio
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
return jsonify({
|
| 129 |
+
'success': True,
|
| 130 |
+
'dataset_info': dataset_info
|
| 131 |
+
})
|
| 132 |
+
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.error(f"Dataset preparation failed: {str(e)}")
|
| 135 |
+
return jsonify({'error': str(e)}), 500
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@training_bp.route('/datasets', methods=['GET'])
|
| 139 |
+
def list_datasets():
|
| 140 |
+
"""List available datasets"""
|
| 141 |
+
try:
|
| 142 |
+
datasets = lora_service.list_datasets()
|
| 143 |
+
|
| 144 |
+
# Get detailed info for each dataset
|
| 145 |
+
dataset_details = []
|
| 146 |
+
for dataset_name in datasets:
|
| 147 |
+
info = lora_service.load_dataset(dataset_name)
|
| 148 |
+
if info:
|
| 149 |
+
dataset_details.append(info)
|
| 150 |
+
|
| 151 |
+
return jsonify({
|
| 152 |
+
'success': True,
|
| 153 |
+
'datasets': dataset_details
|
| 154 |
+
})
|
| 155 |
+
|
| 156 |
+
except Exception as e:
|
| 157 |
+
logger.error(f"Failed to list datasets: {str(e)}")
|
| 158 |
+
return jsonify({'error': str(e)}), 500
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@training_bp.route('/train-lora', methods=['POST'])
|
| 162 |
+
def train_lora():
|
| 163 |
+
"""Start LoRA training"""
|
| 164 |
+
try:
|
| 165 |
+
data = request.json
|
| 166 |
+
dataset_name = data.get('dataset_name')
|
| 167 |
+
lora_name = data.get('lora_name')
|
| 168 |
+
training_type = data.get('training_type', 'vocal')
|
| 169 |
+
config = data.get('config', {})
|
| 170 |
+
|
| 171 |
+
if not dataset_name:
|
| 172 |
+
return jsonify({'error': 'Dataset name required'}), 400
|
| 173 |
+
|
| 174 |
+
if not lora_name:
|
| 175 |
+
return jsonify({'error': 'LoRA name required'}), 400
|
| 176 |
+
|
| 177 |
+
# Start training (in background thread in production)
|
| 178 |
+
# For now, this is synchronous
|
| 179 |
+
results = lora_service.train_lora(
|
| 180 |
+
dataset_name,
|
| 181 |
+
lora_name,
|
| 182 |
+
training_type,
|
| 183 |
+
config
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
return jsonify({
|
| 187 |
+
'success': True,
|
| 188 |
+
'results': results
|
| 189 |
+
})
|
| 190 |
+
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"Training failed: {str(e)}")
|
| 193 |
+
return jsonify({'error': str(e)}), 500
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@training_bp.route('/training-status', methods=['GET'])
|
| 197 |
+
def training_status():
|
| 198 |
+
"""Get current training status"""
|
| 199 |
+
try:
|
| 200 |
+
status = lora_service.get_training_status()
|
| 201 |
+
|
| 202 |
+
return jsonify({
|
| 203 |
+
'success': True,
|
| 204 |
+
'status': status
|
| 205 |
+
})
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
logger.error(f"Failed to get training status: {str(e)}")
|
| 209 |
+
return jsonify({'error': str(e)}), 500
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@training_bp.route('/stop-training', methods=['POST'])
|
| 213 |
+
def stop_training():
|
| 214 |
+
"""Stop current training"""
|
| 215 |
+
try:
|
| 216 |
+
lora_service.stop_training()
|
| 217 |
+
|
| 218 |
+
return jsonify({
|
| 219 |
+
'success': True,
|
| 220 |
+
'message': 'Training stopped'
|
| 221 |
+
})
|
| 222 |
+
|
| 223 |
+
except Exception as e:
|
| 224 |
+
logger.error(f"Failed to stop training: {str(e)}")
|
| 225 |
+
return jsonify({'error': str(e)}), 500
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@training_bp.route('/lora-adapters', methods=['GET'])
|
| 229 |
+
def list_lora_adapters():
|
| 230 |
+
"""List available LoRA adapters"""
|
| 231 |
+
try:
|
| 232 |
+
adapters = lora_service.list_lora_adapters()
|
| 233 |
+
|
| 234 |
+
return jsonify({
|
| 235 |
+
'success': True,
|
| 236 |
+
'adapters': adapters
|
| 237 |
+
})
|
| 238 |
+
|
| 239 |
+
except Exception as e:
|
| 240 |
+
logger.error(f"Failed to list LoRA adapters: {str(e)}")
|
| 241 |
+
return jsonify({'error': str(e)}), 500
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@training_bp.route('/lora-adapters/<lora_name>', methods=['DELETE'])
|
| 245 |
+
def delete_lora_adapter(lora_name):
|
| 246 |
+
"""Delete a LoRA adapter"""
|
| 247 |
+
try:
|
| 248 |
+
success = lora_service.delete_lora_adapter(lora_name)
|
| 249 |
+
|
| 250 |
+
if success:
|
| 251 |
+
return jsonify({
|
| 252 |
+
'success': True,
|
| 253 |
+
'message': f'LoRA adapter {lora_name} deleted'
|
| 254 |
+
})
|
| 255 |
+
else:
|
| 256 |
+
return jsonify({'error': 'LoRA adapter not found'}), 404
|
| 257 |
+
|
| 258 |
+
except Exception as e:
|
| 259 |
+
logger.error(f"Failed to delete LoRA adapter: {str(e)}")
|
| 260 |
+
return jsonify({'error': str(e)}), 500
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def register_training_routes(app):
|
| 264 |
+
"""Register training routes with Flask app"""
|
| 265 |
+
app.register_blueprint(training_bp)
|
| 266 |
+
logger.info("Training routes registered")
|
backend/services/audio_analysis_service.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio Analysis Service
|
| 3 |
+
Analyzes uploaded audio to automatically generate metadata for training.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import librosa
|
| 7 |
+
import numpy as np
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import logging
|
| 11 |
+
from typing import Dict, Tuple, Optional
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class AudioAnalysisService:
|
| 18 |
+
"""Service for analyzing audio files and generating metadata"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
"""Initialize audio analysis service"""
|
| 22 |
+
self.sample_rate = 44100
|
| 23 |
+
|
| 24 |
+
# Genre classification mapping (simple heuristic-based)
|
| 25 |
+
self.genre_classifiers = {
|
| 26 |
+
'classical': {'tempo_range': (60, 140), 'spectral_centroid_mean': (1000, 3000)},
|
| 27 |
+
'pop': {'tempo_range': (100, 130), 'spectral_centroid_mean': (2000, 4000)},
|
| 28 |
+
'rock': {'tempo_range': (110, 150), 'spectral_centroid_mean': (2500, 5000)},
|
| 29 |
+
'jazz': {'tempo_range': (80, 180), 'spectral_centroid_mean': (1500, 3500)},
|
| 30 |
+
'electronic': {'tempo_range': (120, 140), 'spectral_centroid_mean': (3000, 6000)},
|
| 31 |
+
'folk': {'tempo_range': (80, 120), 'spectral_centroid_mean': (1500, 3000)},
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
logger.info("AudioAnalysisService initialized")
|
| 35 |
+
|
| 36 |
+
def analyze_audio(self, audio_path: str) -> Dict:
|
| 37 |
+
"""
|
| 38 |
+
Analyze audio file and generate comprehensive metadata
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
audio_path: Path to audio file
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Dictionary containing:
|
| 45 |
+
- bpm: Detected tempo
|
| 46 |
+
- key: Detected musical key
|
| 47 |
+
- genre: Predicted genre
|
| 48 |
+
- duration: Audio duration in seconds
|
| 49 |
+
- energy: Overall energy level
|
| 50 |
+
- spectral_features: Various spectral characteristics
|
| 51 |
+
- segments: Suggested clip boundaries for training
|
| 52 |
+
"""
|
| 53 |
+
try:
|
| 54 |
+
logger.info(f"Analyzing audio: {audio_path}")
|
| 55 |
+
|
| 56 |
+
# Load audio
|
| 57 |
+
y, sr = librosa.load(audio_path, sr=self.sample_rate)
|
| 58 |
+
|
| 59 |
+
# Extract features
|
| 60 |
+
bpm = self._detect_tempo(y, sr)
|
| 61 |
+
key = self._detect_key(y, sr)
|
| 62 |
+
genre = self._predict_genre(y, sr)
|
| 63 |
+
duration = librosa.get_duration(y=y, sr=sr)
|
| 64 |
+
energy = self._calculate_energy(y)
|
| 65 |
+
spectral_features = self._extract_spectral_features(y, sr)
|
| 66 |
+
segments = self._suggest_segments(y, sr, duration)
|
| 67 |
+
|
| 68 |
+
metadata = {
|
| 69 |
+
'bpm': int(bpm),
|
| 70 |
+
'key': key,
|
| 71 |
+
'genre': genre,
|
| 72 |
+
'duration': round(duration, 2),
|
| 73 |
+
'energy': energy,
|
| 74 |
+
'spectral_features': spectral_features,
|
| 75 |
+
'segments': segments,
|
| 76 |
+
'sample_rate': sr,
|
| 77 |
+
'channels': 1 if y.ndim == 1 else y.shape[0]
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
logger.info(f"Analysis complete: BPM={bpm}, Key={key}, Genre={genre}")
|
| 81 |
+
return metadata
|
| 82 |
+
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.error(f"Audio analysis failed: {str(e)}")
|
| 85 |
+
raise
|
| 86 |
+
|
| 87 |
+
def _detect_tempo(self, y: np.ndarray, sr: int) -> float:
|
| 88 |
+
"""Detect tempo (BPM) using librosa"""
|
| 89 |
+
try:
|
| 90 |
+
tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
|
| 91 |
+
# Handle array or scalar return
|
| 92 |
+
if isinstance(tempo, np.ndarray):
|
| 93 |
+
tempo = tempo[0] if len(tempo) > 0 else 120.0
|
| 94 |
+
return float(tempo)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.warning(f"Tempo detection failed: {str(e)}, defaulting to 120 BPM")
|
| 97 |
+
return 120.0
|
| 98 |
+
|
| 99 |
+
def _detect_key(self, y: np.ndarray, sr: int) -> str:
|
| 100 |
+
"""Detect musical key using chroma features"""
|
| 101 |
+
try:
|
| 102 |
+
# Compute chroma features
|
| 103 |
+
chromagram = librosa.feature.chroma_cqt(y=y, sr=sr)
|
| 104 |
+
chroma_vals = chromagram.mean(axis=1)
|
| 105 |
+
|
| 106 |
+
# Find dominant pitch class
|
| 107 |
+
key_idx = np.argmax(chroma_vals)
|
| 108 |
+
keys = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
|
| 109 |
+
|
| 110 |
+
# Simple major/minor detection based on interval relationships
|
| 111 |
+
# This is a simplified heuristic
|
| 112 |
+
major_template = np.array([1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1])
|
| 113 |
+
minor_template = np.array([1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0])
|
| 114 |
+
|
| 115 |
+
# Rotate templates to match detected key
|
| 116 |
+
major_rolled = np.roll(major_template, key_idx)
|
| 117 |
+
minor_rolled = np.roll(minor_template, key_idx)
|
| 118 |
+
|
| 119 |
+
# Correlate with actual chroma
|
| 120 |
+
major_corr = np.corrcoef(chroma_vals, major_rolled)[0, 1]
|
| 121 |
+
minor_corr = np.corrcoef(chroma_vals, minor_rolled)[0, 1]
|
| 122 |
+
|
| 123 |
+
mode = "major" if major_corr > minor_corr else "minor"
|
| 124 |
+
key = f"{keys[key_idx]} {mode}"
|
| 125 |
+
|
| 126 |
+
return key
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
logger.warning(f"Key detection failed: {str(e)}, defaulting to C major")
|
| 130 |
+
return "C major"
|
| 131 |
+
|
| 132 |
+
def _predict_genre(self, y: np.ndarray, sr: int) -> str:
|
| 133 |
+
"""Predict genre using simple heuristic classification"""
|
| 134 |
+
try:
|
| 135 |
+
# Extract features
|
| 136 |
+
tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
|
| 137 |
+
if isinstance(tempo, np.ndarray):
|
| 138 |
+
tempo = tempo[0]
|
| 139 |
+
|
| 140 |
+
spectral_centroids = librosa.feature.spectral_centroid(y=y, sr=sr)
|
| 141 |
+
sc_mean = np.mean(spectral_centroids)
|
| 142 |
+
|
| 143 |
+
# Simple heuristic matching
|
| 144 |
+
best_genre = 'unknown'
|
| 145 |
+
best_score = -1
|
| 146 |
+
|
| 147 |
+
for genre, criteria in self.genre_classifiers.items():
|
| 148 |
+
tempo_min, tempo_max = criteria['tempo_range']
|
| 149 |
+
sc_min, sc_max = criteria['spectral_centroid_mean']
|
| 150 |
+
|
| 151 |
+
# Score based on how well it matches criteria
|
| 152 |
+
tempo_score = 1.0 if tempo_min <= tempo <= tempo_max else 0.5
|
| 153 |
+
sc_score = 1.0 if sc_min <= sc_mean <= sc_max else 0.5
|
| 154 |
+
|
| 155 |
+
total_score = tempo_score * sc_score
|
| 156 |
+
|
| 157 |
+
if total_score > best_score:
|
| 158 |
+
best_score = total_score
|
| 159 |
+
best_genre = genre
|
| 160 |
+
|
| 161 |
+
return best_genre
|
| 162 |
+
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logger.warning(f"Genre prediction failed: {str(e)}, defaulting to unknown")
|
| 165 |
+
return "unknown"
|
| 166 |
+
|
| 167 |
+
def _calculate_energy(self, y: np.ndarray) -> str:
|
| 168 |
+
"""Calculate overall energy level (low/medium/high)"""
|
| 169 |
+
try:
|
| 170 |
+
rms = librosa.feature.rms(y=y)
|
| 171 |
+
mean_rms = np.mean(rms)
|
| 172 |
+
|
| 173 |
+
if mean_rms < 0.05:
|
| 174 |
+
return "low"
|
| 175 |
+
elif mean_rms < 0.15:
|
| 176 |
+
return "medium"
|
| 177 |
+
else:
|
| 178 |
+
return "high"
|
| 179 |
+
|
| 180 |
+
except Exception as e:
|
| 181 |
+
logger.warning(f"Energy calculation failed: {str(e)}")
|
| 182 |
+
return "medium"
|
| 183 |
+
|
| 184 |
+
def _extract_spectral_features(self, y: np.ndarray, sr: int) -> Dict:
|
| 185 |
+
"""Extract various spectral features"""
|
| 186 |
+
try:
|
| 187 |
+
spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)
|
| 188 |
+
spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)
|
| 189 |
+
zero_crossing_rate = librosa.feature.zero_crossing_rate(y)
|
| 190 |
+
|
| 191 |
+
return {
|
| 192 |
+
'spectral_centroid_mean': float(np.mean(spectral_centroid)),
|
| 193 |
+
'spectral_rolloff_mean': float(np.mean(spectral_rolloff)),
|
| 194 |
+
'zero_crossing_rate_mean': float(np.mean(zero_crossing_rate))
|
| 195 |
+
}
|
| 196 |
+
except Exception as e:
|
| 197 |
+
logger.warning(f"Spectral feature extraction failed: {str(e)}")
|
| 198 |
+
return {}
|
| 199 |
+
|
| 200 |
+
def _suggest_segments(self, y: np.ndarray, sr: int, duration: float) -> list:
|
| 201 |
+
"""
|
| 202 |
+
Suggest clip boundaries for training
|
| 203 |
+
Splits audio into 10-30 second segments at natural boundaries
|
| 204 |
+
"""
|
| 205 |
+
try:
|
| 206 |
+
# Target clip length: 10-30 seconds
|
| 207 |
+
min_clip_length = 10.0 # seconds
|
| 208 |
+
max_clip_length = 30.0
|
| 209 |
+
|
| 210 |
+
# Detect onset events (musical boundaries)
|
| 211 |
+
onset_frames = librosa.onset.onset_detect(y=y, sr=sr, backtrack=True)
|
| 212 |
+
onset_times = librosa.frames_to_time(onset_frames, sr=sr)
|
| 213 |
+
|
| 214 |
+
segments = []
|
| 215 |
+
current_start = 0.0
|
| 216 |
+
|
| 217 |
+
for onset_time in onset_times:
|
| 218 |
+
segment_length = onset_time - current_start
|
| 219 |
+
|
| 220 |
+
# If segment is within acceptable range, add it
|
| 221 |
+
if min_clip_length <= segment_length <= max_clip_length:
|
| 222 |
+
segments.append({
|
| 223 |
+
'start': round(current_start, 2),
|
| 224 |
+
'end': round(onset_time, 2),
|
| 225 |
+
'duration': round(segment_length, 2)
|
| 226 |
+
})
|
| 227 |
+
current_start = onset_time
|
| 228 |
+
|
| 229 |
+
# If segment is too long, force split at max_clip_length
|
| 230 |
+
elif segment_length > max_clip_length:
|
| 231 |
+
while current_start + max_clip_length < onset_time:
|
| 232 |
+
segments.append({
|
| 233 |
+
'start': round(current_start, 2),
|
| 234 |
+
'end': round(current_start + max_clip_length, 2),
|
| 235 |
+
'duration': max_clip_length
|
| 236 |
+
})
|
| 237 |
+
current_start += max_clip_length
|
| 238 |
+
|
| 239 |
+
# Add final segment
|
| 240 |
+
if duration - current_start >= min_clip_length:
|
| 241 |
+
segments.append({
|
| 242 |
+
'start': round(current_start, 2),
|
| 243 |
+
'end': round(duration, 2),
|
| 244 |
+
'duration': round(duration - current_start, 2)
|
| 245 |
+
})
|
| 246 |
+
|
| 247 |
+
# If no segments found, split into equal chunks
|
| 248 |
+
if not segments:
|
| 249 |
+
num_clips = int(np.ceil(duration / max_clip_length))
|
| 250 |
+
clip_length = duration / num_clips
|
| 251 |
+
|
| 252 |
+
for i in range(num_clips):
|
| 253 |
+
start = i * clip_length
|
| 254 |
+
end = min((i + 1) * clip_length, duration)
|
| 255 |
+
segments.append({
|
| 256 |
+
'start': round(start, 2),
|
| 257 |
+
'end': round(end, 2),
|
| 258 |
+
'duration': round(end - start, 2)
|
| 259 |
+
})
|
| 260 |
+
|
| 261 |
+
logger.info(f"Suggested {len(segments)} training segments")
|
| 262 |
+
return segments
|
| 263 |
+
|
| 264 |
+
except Exception as e:
|
| 265 |
+
logger.error(f"Segment suggestion failed: {str(e)}")
|
| 266 |
+
# Fallback: simple equal splits
|
| 267 |
+
num_clips = int(np.ceil(duration / 20.0))
|
| 268 |
+
clip_length = duration / num_clips
|
| 269 |
+
return [
|
| 270 |
+
{
|
| 271 |
+
'start': round(i * clip_length, 2),
|
| 272 |
+
'end': round(min((i + 1) * clip_length, duration), 2),
|
| 273 |
+
'duration': round(clip_length, 2)
|
| 274 |
+
}
|
| 275 |
+
for i in range(num_clips)
|
| 276 |
+
]
|
| 277 |
+
|
| 278 |
+
def split_audio_to_clips(
|
| 279 |
+
self,
|
| 280 |
+
audio_path: str,
|
| 281 |
+
output_dir: str,
|
| 282 |
+
segments: Optional[list] = None,
|
| 283 |
+
metadata: Optional[Dict] = None
|
| 284 |
+
) -> list:
|
| 285 |
+
"""
|
| 286 |
+
Split audio file into training clips based on suggested segments
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
audio_path: Path to source audio file
|
| 290 |
+
output_dir: Directory to save clips
|
| 291 |
+
segments: Optional segment list (if None, will auto-detect)
|
| 292 |
+
metadata: Optional metadata to include in filenames
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
List of paths to created clip files
|
| 296 |
+
"""
|
| 297 |
+
try:
|
| 298 |
+
output_path = Path(output_dir)
|
| 299 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 300 |
+
|
| 301 |
+
# Load audio
|
| 302 |
+
y, sr = librosa.load(audio_path, sr=self.sample_rate)
|
| 303 |
+
duration = librosa.get_duration(y=y, sr=sr)
|
| 304 |
+
|
| 305 |
+
# Get segments if not provided
|
| 306 |
+
if segments is None:
|
| 307 |
+
segments = self._suggest_segments(y, sr, duration)
|
| 308 |
+
|
| 309 |
+
# Generate base filename
|
| 310 |
+
base_name = Path(audio_path).stem
|
| 311 |
+
if metadata and 'genre' in metadata:
|
| 312 |
+
base_name = f"{metadata['genre']}_{base_name}"
|
| 313 |
+
|
| 314 |
+
clip_paths = []
|
| 315 |
+
|
| 316 |
+
for i, segment in enumerate(segments):
|
| 317 |
+
start_sample = int(segment['start'] * sr)
|
| 318 |
+
end_sample = int(segment['end'] * sr)
|
| 319 |
+
|
| 320 |
+
clip_audio = y[start_sample:end_sample]
|
| 321 |
+
|
| 322 |
+
# Create filename
|
| 323 |
+
clip_filename = f"{base_name}_clip{i+1:03d}.wav"
|
| 324 |
+
clip_path = output_path / clip_filename
|
| 325 |
+
|
| 326 |
+
# Save clip
|
| 327 |
+
sf.write(clip_path, clip_audio, sr)
|
| 328 |
+
clip_paths.append(str(clip_path))
|
| 329 |
+
|
| 330 |
+
logger.info(f"Created clip {i+1}/{len(segments)}: {clip_filename}")
|
| 331 |
+
|
| 332 |
+
logger.info(f"Split audio into {len(clip_paths)} clips")
|
| 333 |
+
return clip_paths
|
| 334 |
+
|
| 335 |
+
except Exception as e:
|
| 336 |
+
logger.error(f"Audio splitting failed: {str(e)}")
|
| 337 |
+
raise
|
| 338 |
+
|
| 339 |
+
def separate_vocal_stems(self, audio_path: str, output_dir: str) -> Dict[str, str]:
|
| 340 |
+
"""
|
| 341 |
+
Separate audio into vocal and instrumental stems
|
| 342 |
+
Uses Demucs for separation
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
audio_path: Path to audio file
|
| 346 |
+
output_dir: Directory to save stems
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
Dictionary with paths to separated stems
|
| 350 |
+
"""
|
| 351 |
+
try:
|
| 352 |
+
from backend.services.stem_enhancement_service import StemEnhancementService
|
| 353 |
+
|
| 354 |
+
logger.info(f"Separating stems from: {audio_path}")
|
| 355 |
+
|
| 356 |
+
output_path = Path(output_dir)
|
| 357 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 358 |
+
|
| 359 |
+
# Load audio
|
| 360 |
+
y, sr = librosa.load(audio_path, sr=self.sample_rate)
|
| 361 |
+
|
| 362 |
+
# Initialize stem separator
|
| 363 |
+
stem_service = StemEnhancementService()
|
| 364 |
+
|
| 365 |
+
# Separate stems (without enhancement processing)
|
| 366 |
+
temp_input = Path("temp_stem_input.wav")
|
| 367 |
+
sf.write(temp_input, y, sr)
|
| 368 |
+
|
| 369 |
+
# Use Demucs to separate
|
| 370 |
+
# Note: This reuses the stem enhancement service's Demucs model
|
| 371 |
+
# but we won't apply the enhancement processing
|
| 372 |
+
separated = stem_service._separate_stems(str(temp_input))
|
| 373 |
+
|
| 374 |
+
# Clean up temp file
|
| 375 |
+
temp_input.unlink()
|
| 376 |
+
|
| 377 |
+
# Save stems
|
| 378 |
+
base_name = Path(audio_path).stem
|
| 379 |
+
stem_paths = {}
|
| 380 |
+
|
| 381 |
+
for stem_name, stem_audio in separated.items():
|
| 382 |
+
stem_filename = f"{base_name}_{stem_name}.wav"
|
| 383 |
+
stem_path = output_path / stem_filename
|
| 384 |
+
sf.write(stem_path, stem_audio.T, sr)
|
| 385 |
+
stem_paths[stem_name] = str(stem_path)
|
| 386 |
+
logger.info(f"Saved {stem_name} stem: {stem_filename}")
|
| 387 |
+
|
| 388 |
+
return stem_paths
|
| 389 |
+
|
| 390 |
+
except Exception as e:
|
| 391 |
+
logger.error(f"Stem separation failed: {str(e)}")
|
| 392 |
+
# Return original audio as fallback
|
| 393 |
+
return {'full_mix': audio_path}
|
backend/services/audio_upscale_service.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio Upscale Service
|
| 3 |
+
Uses AudioSR for neural upsampling to 48kHz
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import logging
|
| 7 |
+
import numpy as np
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class AudioUpscaleService:
|
| 14 |
+
"""Service for upscaling audio to 48kHz using AudioSR"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
"""Initialize audio upscale service"""
|
| 18 |
+
self.model = None
|
| 19 |
+
self.model_loaded = False
|
| 20 |
+
logger.info("Audio upscale service initialized")
|
| 21 |
+
|
| 22 |
+
def _load_model(self):
|
| 23 |
+
"""Lazy load AudioSR model"""
|
| 24 |
+
if self.model_loaded:
|
| 25 |
+
return
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
logger.info("Loading AudioSR model...")
|
| 29 |
+
from audiosr import build_model, super_resolution
|
| 30 |
+
|
| 31 |
+
# Build AudioSR model (will download on first use)
|
| 32 |
+
self.model = build_model(model_name="basic", device="auto")
|
| 33 |
+
self.super_resolution = super_resolution
|
| 34 |
+
self.model_loaded = True
|
| 35 |
+
logger.info("AudioSR model loaded successfully")
|
| 36 |
+
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.error(f"Failed to load AudioSR model: {e}", exc_info=True)
|
| 39 |
+
raise
|
| 40 |
+
|
| 41 |
+
def upscale_audio(
|
| 42 |
+
self,
|
| 43 |
+
audio_path: str,
|
| 44 |
+
output_path: Optional[str] = None,
|
| 45 |
+
target_sr: int = 48000
|
| 46 |
+
) -> str:
|
| 47 |
+
"""
|
| 48 |
+
Upscale audio to higher sample rate using neural super-resolution
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
audio_path: Input audio file path
|
| 52 |
+
output_path: Output audio file path (optional)
|
| 53 |
+
target_sr: Target sample rate (default: 48000)
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Path to upscaled audio file
|
| 57 |
+
"""
|
| 58 |
+
try:
|
| 59 |
+
logger.info(f"Starting audio upscaling: {audio_path} -> {target_sr}Hz")
|
| 60 |
+
|
| 61 |
+
# Load model if not already loaded
|
| 62 |
+
self._load_model()
|
| 63 |
+
|
| 64 |
+
# Generate output path if not provided
|
| 65 |
+
if output_path is None:
|
| 66 |
+
base, ext = os.path.splitext(audio_path)
|
| 67 |
+
output_path = f"{base}_48kHz{ext}"
|
| 68 |
+
|
| 69 |
+
# Load audio
|
| 70 |
+
logger.info(f"Loading audio from: {audio_path}")
|
| 71 |
+
audio, sr = sf.read(audio_path)
|
| 72 |
+
|
| 73 |
+
# Check if upscaling is needed
|
| 74 |
+
if sr >= target_sr:
|
| 75 |
+
logger.warning(f"Audio already at {sr}Hz, >= target {target_sr}Hz. Skipping upscale.")
|
| 76 |
+
return audio_path
|
| 77 |
+
|
| 78 |
+
logger.info(f"Original sample rate: {sr}Hz, upscaling to {target_sr}Hz")
|
| 79 |
+
|
| 80 |
+
# AudioSR expects specific format
|
| 81 |
+
# Handle stereo by processing each channel separately
|
| 82 |
+
if audio.ndim == 2:
|
| 83 |
+
logger.info("Processing stereo audio (2 channels)")
|
| 84 |
+
upscaled_channels = []
|
| 85 |
+
|
| 86 |
+
for ch_idx in range(audio.shape[1]):
|
| 87 |
+
logger.info(f"Upscaling channel {ch_idx + 1}/2...")
|
| 88 |
+
channel_audio = audio[:, ch_idx]
|
| 89 |
+
|
| 90 |
+
# AudioSR super resolution
|
| 91 |
+
upscaled_channel = self.super_resolution(
|
| 92 |
+
self.model,
|
| 93 |
+
channel_audio,
|
| 94 |
+
sr,
|
| 95 |
+
guidance_scale=3.5, # Balance between quality and fidelity
|
| 96 |
+
ddim_steps=50 # Quality vs speed trade-off
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
upscaled_channels.append(upscaled_channel)
|
| 100 |
+
|
| 101 |
+
# Combine channels
|
| 102 |
+
upscaled_audio = np.stack(upscaled_channels, axis=1)
|
| 103 |
+
|
| 104 |
+
else:
|
| 105 |
+
logger.info("Processing mono audio")
|
| 106 |
+
# Mono audio
|
| 107 |
+
upscaled_audio = self.super_resolution(
|
| 108 |
+
self.model,
|
| 109 |
+
audio,
|
| 110 |
+
sr,
|
| 111 |
+
guidance_scale=3.5,
|
| 112 |
+
ddim_steps=50
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Save upscaled audio
|
| 116 |
+
logger.info(f"Saving upscaled audio to: {output_path}")
|
| 117 |
+
sf.write(output_path, upscaled_audio, target_sr)
|
| 118 |
+
|
| 119 |
+
logger.info(f"Audio upscaling complete: {output_path} ({target_sr}Hz)")
|
| 120 |
+
return output_path
|
| 121 |
+
|
| 122 |
+
except Exception as e:
|
| 123 |
+
logger.error(f"Audio upscaling failed: {e}", exc_info=True)
|
| 124 |
+
# Return original if upscaling fails
|
| 125 |
+
return audio_path
|
| 126 |
+
|
| 127 |
+
def quick_upscale(
|
| 128 |
+
self,
|
| 129 |
+
audio_path: str,
|
| 130 |
+
output_path: Optional[str] = None
|
| 131 |
+
) -> str:
|
| 132 |
+
"""
|
| 133 |
+
Quick upscale with default settings
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
audio_path: Input audio file
|
| 137 |
+
output_path: Output audio file (optional)
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Path to upscaled audio
|
| 141 |
+
"""
|
| 142 |
+
try:
|
| 143 |
+
logger.info(f"Quick upscale: {audio_path}")
|
| 144 |
+
|
| 145 |
+
# For quick mode, use simple resampling instead of neural upscaling
|
| 146 |
+
# This is faster and good enough for many use cases
|
| 147 |
+
import librosa
|
| 148 |
+
|
| 149 |
+
if output_path is None:
|
| 150 |
+
base, ext = os.path.splitext(audio_path)
|
| 151 |
+
output_path = f"{base}_48kHz{ext}"
|
| 152 |
+
|
| 153 |
+
# Load audio
|
| 154 |
+
audio, sr = sf.read(audio_path)
|
| 155 |
+
|
| 156 |
+
# Check if upscaling is needed
|
| 157 |
+
target_sr = 48000
|
| 158 |
+
if sr >= target_sr:
|
| 159 |
+
logger.info(f"Audio already at {sr}Hz, no upscaling needed")
|
| 160 |
+
return audio_path
|
| 161 |
+
|
| 162 |
+
logger.info(f"Resampling from {sr}Hz to {target_sr}Hz")
|
| 163 |
+
|
| 164 |
+
# Resample with high-quality filter
|
| 165 |
+
if audio.ndim == 2:
|
| 166 |
+
# Stereo
|
| 167 |
+
upscaled = np.zeros((int(len(audio) * target_sr / sr), audio.shape[1]))
|
| 168 |
+
for ch in range(audio.shape[1]):
|
| 169 |
+
upscaled[:, ch] = librosa.resample(
|
| 170 |
+
audio[:, ch],
|
| 171 |
+
orig_sr=sr,
|
| 172 |
+
target_sr=target_sr,
|
| 173 |
+
res_type='kaiser_best'
|
| 174 |
+
)
|
| 175 |
+
else:
|
| 176 |
+
# Mono
|
| 177 |
+
upscaled = librosa.resample(
|
| 178 |
+
audio,
|
| 179 |
+
orig_sr=sr,
|
| 180 |
+
target_sr=target_sr,
|
| 181 |
+
res_type='kaiser_best'
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Save
|
| 185 |
+
sf.write(output_path, upscaled, target_sr)
|
| 186 |
+
logger.info(f"Quick upscale complete: {output_path}")
|
| 187 |
+
return output_path
|
| 188 |
+
|
| 189 |
+
except Exception as e:
|
| 190 |
+
logger.error(f"Quick upscale failed: {e}", exc_info=True)
|
| 191 |
+
return audio_path
|
backend/services/lora_training_service.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LoRA Training Service
|
| 3 |
+
Handles fine-tuning of DiffRhythm2 model using LoRA adapters for vocal and symbolic music.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import json
|
| 11 |
+
import logging
|
| 12 |
+
from typing import Dict, List, Optional, Callable
|
| 13 |
+
import soundfile as sf
|
| 14 |
+
import numpy as np
|
| 15 |
+
import time
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TrainingDataset(Dataset):
|
| 22 |
+
"""Dataset for LoRA training"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
audio_files: List[str],
|
| 27 |
+
metadata_list: List[Dict],
|
| 28 |
+
sample_rate: int = 44100,
|
| 29 |
+
clip_length: float = 10.0
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Initialize training dataset
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
audio_files: List of paths to audio files
|
| 36 |
+
metadata_list: List of metadata dicts for each audio file
|
| 37 |
+
sample_rate: Target sample rate
|
| 38 |
+
clip_length: Length of training clips in seconds
|
| 39 |
+
"""
|
| 40 |
+
self.audio_files = audio_files
|
| 41 |
+
self.metadata_list = metadata_list
|
| 42 |
+
self.sample_rate = sample_rate
|
| 43 |
+
self.clip_length = clip_length
|
| 44 |
+
self.clip_samples = int(clip_length * sample_rate)
|
| 45 |
+
|
| 46 |
+
logger.info(f"Initialized dataset with {len(audio_files)} audio files")
|
| 47 |
+
|
| 48 |
+
def __len__(self):
|
| 49 |
+
return len(self.audio_files)
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, idx):
|
| 52 |
+
"""Get training sample"""
|
| 53 |
+
try:
|
| 54 |
+
audio_path = self.audio_files[idx]
|
| 55 |
+
metadata = self.metadata_list[idx]
|
| 56 |
+
|
| 57 |
+
# Load audio
|
| 58 |
+
y, sr = sf.read(audio_path)
|
| 59 |
+
|
| 60 |
+
# Resample if needed
|
| 61 |
+
if sr != self.sample_rate:
|
| 62 |
+
import librosa
|
| 63 |
+
y = librosa.resample(y, orig_sr=sr, target_sr=self.sample_rate)
|
| 64 |
+
|
| 65 |
+
# Ensure mono
|
| 66 |
+
if y.ndim > 1:
|
| 67 |
+
y = y.mean(axis=1)
|
| 68 |
+
|
| 69 |
+
# Extract/pad to clip length
|
| 70 |
+
if len(y) > self.clip_samples:
|
| 71 |
+
# Random crop
|
| 72 |
+
start = np.random.randint(0, len(y) - self.clip_samples)
|
| 73 |
+
y = y[start:start + self.clip_samples]
|
| 74 |
+
else:
|
| 75 |
+
# Pad
|
| 76 |
+
y = np.pad(y, (0, self.clip_samples - len(y)))
|
| 77 |
+
|
| 78 |
+
# Generate prompt from metadata
|
| 79 |
+
prompt = self._generate_prompt(metadata)
|
| 80 |
+
|
| 81 |
+
return {
|
| 82 |
+
'audio': torch.FloatTensor(y),
|
| 83 |
+
'prompt': prompt,
|
| 84 |
+
'metadata': metadata
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.error(f"Error loading sample {idx}: {str(e)}")
|
| 89 |
+
# Return empty sample on error
|
| 90 |
+
return {
|
| 91 |
+
'audio': torch.zeros(self.clip_samples),
|
| 92 |
+
'prompt': "",
|
| 93 |
+
'metadata': {}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
def _generate_prompt(self, metadata: Dict) -> str:
|
| 97 |
+
"""Generate text prompt from metadata"""
|
| 98 |
+
parts = []
|
| 99 |
+
|
| 100 |
+
if 'genre' in metadata and metadata['genre'] != 'unknown':
|
| 101 |
+
parts.append(metadata['genre'])
|
| 102 |
+
|
| 103 |
+
if 'instrumentation' in metadata:
|
| 104 |
+
parts.append(f"with {metadata['instrumentation']}")
|
| 105 |
+
|
| 106 |
+
if 'bpm' in metadata:
|
| 107 |
+
parts.append(f"at {metadata['bpm']} BPM")
|
| 108 |
+
|
| 109 |
+
if 'key' in metadata:
|
| 110 |
+
parts.append(f"in {metadata['key']}")
|
| 111 |
+
|
| 112 |
+
if 'mood' in metadata:
|
| 113 |
+
parts.append(f"{metadata['mood']} mood")
|
| 114 |
+
|
| 115 |
+
if 'description' in metadata:
|
| 116 |
+
parts.append(metadata['description'])
|
| 117 |
+
|
| 118 |
+
return " ".join(parts) if parts else "music"
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class LoRATrainingService:
|
| 122 |
+
"""Service for training LoRA adapters for DiffRhythm2"""
|
| 123 |
+
|
| 124 |
+
def __init__(self):
|
| 125 |
+
"""Initialize LoRA training service"""
|
| 126 |
+
self.models_dir = Path("models")
|
| 127 |
+
self.lora_dir = self.models_dir / "loras"
|
| 128 |
+
self.lora_dir.mkdir(parents=True, exist_ok=True)
|
| 129 |
+
|
| 130 |
+
self.training_data_dir = Path("training_data")
|
| 131 |
+
self.training_data_dir.mkdir(parents=True, exist_ok=True)
|
| 132 |
+
|
| 133 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 134 |
+
|
| 135 |
+
# Training state
|
| 136 |
+
self.is_training = False
|
| 137 |
+
self.current_epoch = 0
|
| 138 |
+
self.current_step = 0
|
| 139 |
+
self.training_loss = []
|
| 140 |
+
self.training_config = None
|
| 141 |
+
|
| 142 |
+
logger.info(f"LoRATrainingService initialized on {self.device}")
|
| 143 |
+
|
| 144 |
+
def prepare_dataset(
|
| 145 |
+
self,
|
| 146 |
+
dataset_name: str,
|
| 147 |
+
audio_files: List[str],
|
| 148 |
+
metadata_list: List[Dict],
|
| 149 |
+
split_ratio: float = 0.9
|
| 150 |
+
) -> Dict:
|
| 151 |
+
"""
|
| 152 |
+
Prepare and save training dataset
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
dataset_name: Name for this dataset
|
| 156 |
+
audio_files: List of audio file paths
|
| 157 |
+
metadata_list: List of metadata for each file
|
| 158 |
+
split_ratio: Train/validation split ratio
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Dataset information dictionary
|
| 162 |
+
"""
|
| 163 |
+
try:
|
| 164 |
+
logger.info(f"Preparing dataset: {dataset_name}")
|
| 165 |
+
|
| 166 |
+
# Create dataset directory
|
| 167 |
+
dataset_dir = self.training_data_dir / dataset_name
|
| 168 |
+
dataset_dir.mkdir(parents=True, exist_ok=True)
|
| 169 |
+
|
| 170 |
+
# Split into train/val
|
| 171 |
+
num_samples = len(audio_files)
|
| 172 |
+
num_train = int(num_samples * split_ratio)
|
| 173 |
+
|
| 174 |
+
indices = np.random.permutation(num_samples)
|
| 175 |
+
train_indices = indices[:num_train]
|
| 176 |
+
val_indices = indices[num_train:]
|
| 177 |
+
|
| 178 |
+
# Save metadata
|
| 179 |
+
dataset_info = {
|
| 180 |
+
'name': dataset_name,
|
| 181 |
+
'created': datetime.now().isoformat(),
|
| 182 |
+
'num_samples': num_samples,
|
| 183 |
+
'num_train': num_train,
|
| 184 |
+
'num_val': num_samples - num_train,
|
| 185 |
+
'train_files': [audio_files[i] for i in train_indices],
|
| 186 |
+
'train_metadata': [metadata_list[i] for i in train_indices],
|
| 187 |
+
'val_files': [audio_files[i] for i in val_indices],
|
| 188 |
+
'val_metadata': [metadata_list[i] for i in val_indices]
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
# Save to disk
|
| 192 |
+
metadata_path = dataset_dir / "dataset_info.json"
|
| 193 |
+
with open(metadata_path, 'w') as f:
|
| 194 |
+
json.dump(dataset_info, f, indent=2)
|
| 195 |
+
|
| 196 |
+
logger.info(f"Dataset prepared: {num_train} train, {num_samples - num_train} val samples")
|
| 197 |
+
return dataset_info
|
| 198 |
+
|
| 199 |
+
except Exception as e:
|
| 200 |
+
logger.error(f"Dataset preparation failed: {str(e)}")
|
| 201 |
+
raise
|
| 202 |
+
|
| 203 |
+
def load_dataset(self, dataset_name: str) -> Optional[Dict]:
|
| 204 |
+
"""Load prepared dataset information"""
|
| 205 |
+
try:
|
| 206 |
+
dataset_dir = self.training_data_dir / dataset_name
|
| 207 |
+
metadata_path = dataset_dir / "dataset_info.json"
|
| 208 |
+
|
| 209 |
+
if not metadata_path.exists():
|
| 210 |
+
logger.warning(f"Dataset not found: {dataset_name}")
|
| 211 |
+
return None
|
| 212 |
+
|
| 213 |
+
with open(metadata_path, 'r') as f:
|
| 214 |
+
return json.load(f)
|
| 215 |
+
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logger.error(f"Failed to load dataset {dataset_name}: {str(e)}")
|
| 218 |
+
return None
|
| 219 |
+
|
| 220 |
+
def list_datasets(self) -> List[str]:
|
| 221 |
+
"""List available prepared datasets"""
|
| 222 |
+
try:
|
| 223 |
+
datasets = []
|
| 224 |
+
for dataset_dir in self.training_data_dir.iterdir():
|
| 225 |
+
if dataset_dir.is_dir() and (dataset_dir / "dataset_info.json").exists():
|
| 226 |
+
datasets.append(dataset_dir.name)
|
| 227 |
+
return datasets
|
| 228 |
+
except Exception as e:
|
| 229 |
+
logger.error(f"Failed to list datasets: {str(e)}")
|
| 230 |
+
return []
|
| 231 |
+
|
| 232 |
+
def train_lora(
|
| 233 |
+
self,
|
| 234 |
+
dataset_name: str,
|
| 235 |
+
lora_name: str,
|
| 236 |
+
training_type: str = "vocal", # "vocal" or "symbolic"
|
| 237 |
+
config: Optional[Dict] = None,
|
| 238 |
+
progress_callback: Optional[Callable] = None
|
| 239 |
+
) -> Dict:
|
| 240 |
+
"""
|
| 241 |
+
Train LoRA adapter
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
dataset_name: Name of prepared dataset
|
| 245 |
+
lora_name: Name for the LoRA adapter
|
| 246 |
+
training_type: Type of training ("vocal" or "symbolic")
|
| 247 |
+
config: Training configuration (batch_size, learning_rate, etc.)
|
| 248 |
+
progress_callback: Optional callback for progress updates
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
Training results dictionary
|
| 252 |
+
"""
|
| 253 |
+
try:
|
| 254 |
+
if self.is_training:
|
| 255 |
+
raise RuntimeError("Training already in progress")
|
| 256 |
+
|
| 257 |
+
self.is_training = True
|
| 258 |
+
logger.info(f"Starting LoRA training: {lora_name} ({training_type})")
|
| 259 |
+
|
| 260 |
+
# Load dataset
|
| 261 |
+
dataset_info = self.load_dataset(dataset_name)
|
| 262 |
+
if not dataset_info:
|
| 263 |
+
raise ValueError(f"Dataset not found: {dataset_name}")
|
| 264 |
+
|
| 265 |
+
# Default config
|
| 266 |
+
default_config = {
|
| 267 |
+
'batch_size': 4,
|
| 268 |
+
'learning_rate': 3e-4,
|
| 269 |
+
'num_epochs': 10,
|
| 270 |
+
'lora_rank': 16,
|
| 271 |
+
'lora_alpha': 32,
|
| 272 |
+
'warmup_steps': 100,
|
| 273 |
+
'save_every': 500,
|
| 274 |
+
'gradient_accumulation': 2
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
self.training_config = {**default_config, **(config or {})}
|
| 278 |
+
|
| 279 |
+
# Create datasets
|
| 280 |
+
train_dataset = TrainingDataset(
|
| 281 |
+
dataset_info['train_files'],
|
| 282 |
+
dataset_info['train_metadata']
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
val_dataset = TrainingDataset(
|
| 286 |
+
dataset_info['val_files'],
|
| 287 |
+
dataset_info['val_metadata']
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Create data loaders
|
| 291 |
+
train_loader = DataLoader(
|
| 292 |
+
train_dataset,
|
| 293 |
+
batch_size=self.training_config['batch_size'],
|
| 294 |
+
shuffle=True,
|
| 295 |
+
num_workers=2,
|
| 296 |
+
pin_memory=True
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
val_loader = DataLoader(
|
| 300 |
+
val_dataset,
|
| 301 |
+
batch_size=self.training_config['batch_size'],
|
| 302 |
+
shuffle=False,
|
| 303 |
+
num_workers=2
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Initialize model (placeholder - actual implementation would load DiffRhythm2)
|
| 307 |
+
# For now, we'll simulate training
|
| 308 |
+
logger.info("Initializing model and LoRA layers...")
|
| 309 |
+
|
| 310 |
+
# Note: Actual implementation would:
|
| 311 |
+
# 1. Load DiffRhythm2 model
|
| 312 |
+
# 2. Add LoRA adapters using peft library
|
| 313 |
+
# 3. Freeze base model, only train LoRA parameters
|
| 314 |
+
|
| 315 |
+
# Simulated training loop
|
| 316 |
+
num_steps = len(train_loader) * self.training_config['num_epochs']
|
| 317 |
+
logger.info(f"Training for {self.training_config['num_epochs']} epochs, {num_steps} total steps")
|
| 318 |
+
|
| 319 |
+
results = self._training_loop(
|
| 320 |
+
train_loader,
|
| 321 |
+
val_loader,
|
| 322 |
+
lora_name,
|
| 323 |
+
progress_callback
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
self.is_training = False
|
| 327 |
+
logger.info("Training complete!")
|
| 328 |
+
|
| 329 |
+
return results
|
| 330 |
+
|
| 331 |
+
except Exception as e:
|
| 332 |
+
self.is_training = False
|
| 333 |
+
logger.error(f"Training failed: {str(e)}")
|
| 334 |
+
raise
|
| 335 |
+
|
| 336 |
+
def _training_loop(
|
| 337 |
+
self,
|
| 338 |
+
train_loader: DataLoader,
|
| 339 |
+
val_loader: DataLoader,
|
| 340 |
+
lora_name: str,
|
| 341 |
+
progress_callback: Optional[Callable]
|
| 342 |
+
) -> Dict:
|
| 343 |
+
"""
|
| 344 |
+
Main training loop
|
| 345 |
+
|
| 346 |
+
Note: This is a simplified placeholder implementation.
|
| 347 |
+
Actual implementation would require:
|
| 348 |
+
1. Loading DiffRhythm2 model
|
| 349 |
+
2. Setting up LoRA adapters with peft library
|
| 350 |
+
3. Implementing proper loss functions
|
| 351 |
+
4. Gradient accumulation and optimization
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
self.current_epoch = 0
|
| 355 |
+
self.current_step = 0
|
| 356 |
+
self.training_loss = []
|
| 357 |
+
best_val_loss = float('inf')
|
| 358 |
+
|
| 359 |
+
num_epochs = self.training_config['num_epochs']
|
| 360 |
+
|
| 361 |
+
for epoch in range(num_epochs):
|
| 362 |
+
self.current_epoch = epoch + 1
|
| 363 |
+
epoch_loss = 0.0
|
| 364 |
+
|
| 365 |
+
logger.info(f"Epoch {self.current_epoch}/{num_epochs}")
|
| 366 |
+
|
| 367 |
+
# Training phase
|
| 368 |
+
for batch_idx, batch in enumerate(train_loader):
|
| 369 |
+
self.current_step += 1
|
| 370 |
+
|
| 371 |
+
# Simulate training step
|
| 372 |
+
# Actual implementation would:
|
| 373 |
+
# 1. Move batch to device
|
| 374 |
+
# 2. Forward pass through model
|
| 375 |
+
# 3. Calculate loss
|
| 376 |
+
# 4. Backward pass
|
| 377 |
+
# 5. Update weights
|
| 378 |
+
|
| 379 |
+
# Simulated loss (decreasing over time)
|
| 380 |
+
step_loss = 1.0 / (1.0 + self.current_step * 0.01)
|
| 381 |
+
epoch_loss += step_loss
|
| 382 |
+
self.training_loss.append(step_loss)
|
| 383 |
+
|
| 384 |
+
# Progress update
|
| 385 |
+
if progress_callback and batch_idx % 10 == 0:
|
| 386 |
+
progress_callback({
|
| 387 |
+
'epoch': self.current_epoch,
|
| 388 |
+
'step': self.current_step,
|
| 389 |
+
'loss': step_loss,
|
| 390 |
+
'progress': (self.current_step / (len(train_loader) * num_epochs)) * 100
|
| 391 |
+
})
|
| 392 |
+
|
| 393 |
+
# Log every 50 steps
|
| 394 |
+
if self.current_step % 50 == 0:
|
| 395 |
+
logger.info(f"Step {self.current_step}: Loss = {step_loss:.4f}")
|
| 396 |
+
|
| 397 |
+
# Save checkpoint
|
| 398 |
+
if self.current_step % self.training_config['save_every'] == 0:
|
| 399 |
+
self._save_checkpoint(lora_name, self.current_step)
|
| 400 |
+
|
| 401 |
+
# Validation phase
|
| 402 |
+
avg_train_loss = epoch_loss / len(train_loader)
|
| 403 |
+
val_loss = self._validate(val_loader)
|
| 404 |
+
|
| 405 |
+
logger.info(f"Epoch {self.current_epoch}: Train Loss = {avg_train_loss:.4f}, Val Loss = {val_loss:.4f}")
|
| 406 |
+
|
| 407 |
+
# Save best model
|
| 408 |
+
if val_loss < best_val_loss:
|
| 409 |
+
best_val_loss = val_loss
|
| 410 |
+
self._save_lora_adapter(lora_name, is_best=True)
|
| 411 |
+
logger.info(f"New best model! Val Loss: {val_loss:.4f}")
|
| 412 |
+
|
| 413 |
+
# Final save
|
| 414 |
+
self._save_lora_adapter(lora_name, is_best=False)
|
| 415 |
+
|
| 416 |
+
return {
|
| 417 |
+
'lora_name': lora_name,
|
| 418 |
+
'num_epochs': num_epochs,
|
| 419 |
+
'total_steps': self.current_step,
|
| 420 |
+
'final_train_loss': avg_train_loss,
|
| 421 |
+
'final_val_loss': val_loss,
|
| 422 |
+
'best_val_loss': best_val_loss,
|
| 423 |
+
'training_time': 'simulated'
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
def _validate(self, val_loader: DataLoader) -> float:
|
| 427 |
+
"""Run validation"""
|
| 428 |
+
total_loss = 0.0
|
| 429 |
+
|
| 430 |
+
for batch in val_loader:
|
| 431 |
+
# Simulate validation
|
| 432 |
+
# Actual implementation would run model inference
|
| 433 |
+
val_loss = 1.0 / (1.0 + self.current_step * 0.01)
|
| 434 |
+
total_loss += val_loss
|
| 435 |
+
|
| 436 |
+
return total_loss / len(val_loader)
|
| 437 |
+
|
| 438 |
+
def _save_checkpoint(self, lora_name: str, step: int):
|
| 439 |
+
"""Save training checkpoint"""
|
| 440 |
+
checkpoint_dir = self.lora_dir / lora_name / "checkpoints"
|
| 441 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 442 |
+
|
| 443 |
+
checkpoint_path = checkpoint_dir / f"checkpoint_step_{step}.pt"
|
| 444 |
+
|
| 445 |
+
# Actual implementation would save:
|
| 446 |
+
# - LoRA weights
|
| 447 |
+
# - Optimizer state
|
| 448 |
+
# - Training step
|
| 449 |
+
# - Config
|
| 450 |
+
|
| 451 |
+
checkpoint_data = {
|
| 452 |
+
'step': step,
|
| 453 |
+
'epoch': self.current_epoch,
|
| 454 |
+
'config': self.training_config,
|
| 455 |
+
'loss_history': self.training_loss[-100:] # Last 100 steps
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
torch.save(checkpoint_data, checkpoint_path)
|
| 459 |
+
logger.info(f"Saved checkpoint: step_{step}")
|
| 460 |
+
|
| 461 |
+
def _save_lora_adapter(self, lora_name: str, is_best: bool = False):
|
| 462 |
+
"""Save final LoRA adapter"""
|
| 463 |
+
lora_path = self.lora_dir / lora_name
|
| 464 |
+
lora_path.mkdir(parents=True, exist_ok=True)
|
| 465 |
+
|
| 466 |
+
filename = "best_model.pt" if is_best else "final_model.pt"
|
| 467 |
+
save_path = lora_path / filename
|
| 468 |
+
|
| 469 |
+
# Actual implementation would save:
|
| 470 |
+
# - LoRA adapter weights only
|
| 471 |
+
# - Configuration
|
| 472 |
+
# - Training metadata
|
| 473 |
+
|
| 474 |
+
adapter_data = {
|
| 475 |
+
'lora_name': lora_name,
|
| 476 |
+
'config': self.training_config,
|
| 477 |
+
'training_steps': self.current_step,
|
| 478 |
+
'saved_at': datetime.now().isoformat()
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
torch.save(adapter_data, save_path)
|
| 482 |
+
logger.info(f"Saved LoRA adapter: {filename}")
|
| 483 |
+
|
| 484 |
+
# Save metadata
|
| 485 |
+
metadata_path = lora_path / "metadata.json"
|
| 486 |
+
with open(metadata_path, 'w') as f:
|
| 487 |
+
json.dump(adapter_data, f, indent=2)
|
| 488 |
+
|
| 489 |
+
def list_lora_adapters(self) -> List[Dict]:
|
| 490 |
+
"""List available LoRA adapters"""
|
| 491 |
+
try:
|
| 492 |
+
adapters = []
|
| 493 |
+
|
| 494 |
+
for lora_dir in self.lora_dir.iterdir():
|
| 495 |
+
if lora_dir.is_dir():
|
| 496 |
+
metadata_path = lora_dir / "metadata.json"
|
| 497 |
+
|
| 498 |
+
if metadata_path.exists():
|
| 499 |
+
with open(metadata_path, 'r') as f:
|
| 500 |
+
metadata = json.load(f)
|
| 501 |
+
adapters.append({
|
| 502 |
+
'name': lora_dir.name,
|
| 503 |
+
**metadata
|
| 504 |
+
})
|
| 505 |
+
else:
|
| 506 |
+
# Basic info if no metadata
|
| 507 |
+
adapters.append({
|
| 508 |
+
'name': lora_dir.name,
|
| 509 |
+
'has_best': (lora_dir / "best_model.pt").exists(),
|
| 510 |
+
'has_final': (lora_dir / "final_model.pt").exists()
|
| 511 |
+
})
|
| 512 |
+
|
| 513 |
+
return adapters
|
| 514 |
+
|
| 515 |
+
except Exception as e:
|
| 516 |
+
logger.error(f"Failed to list LoRA adapters: {str(e)}")
|
| 517 |
+
return []
|
| 518 |
+
|
| 519 |
+
def delete_lora_adapter(self, lora_name: str) -> bool:
|
| 520 |
+
"""Delete a LoRA adapter"""
|
| 521 |
+
try:
|
| 522 |
+
import shutil
|
| 523 |
+
|
| 524 |
+
lora_path = self.lora_dir / lora_name
|
| 525 |
+
|
| 526 |
+
if lora_path.exists():
|
| 527 |
+
shutil.rmtree(lora_path)
|
| 528 |
+
logger.info(f"Deleted LoRA adapter: {lora_name}")
|
| 529 |
+
return True
|
| 530 |
+
else:
|
| 531 |
+
logger.warning(f"LoRA adapter not found: {lora_name}")
|
| 532 |
+
return False
|
| 533 |
+
|
| 534 |
+
except Exception as e:
|
| 535 |
+
logger.error(f"Failed to delete LoRA adapter {lora_name}: {str(e)}")
|
| 536 |
+
return False
|
| 537 |
+
|
| 538 |
+
def stop_training(self):
|
| 539 |
+
"""Stop current training"""
|
| 540 |
+
if self.is_training:
|
| 541 |
+
logger.info("Training stop requested")
|
| 542 |
+
self.is_training = False
|
| 543 |
+
|
| 544 |
+
def get_training_status(self) -> Dict:
|
| 545 |
+
"""Get current training status"""
|
| 546 |
+
return {
|
| 547 |
+
'is_training': self.is_training,
|
| 548 |
+
'current_epoch': self.current_epoch,
|
| 549 |
+
'current_step': self.current_step,
|
| 550 |
+
'recent_loss': self.training_loss[-10:] if self.training_loss else [],
|
| 551 |
+
'config': self.training_config
|
| 552 |
+
}
|
backend/services/mastering_service.py
CHANGED
|
@@ -398,6 +398,44 @@ class MasteringService:
|
|
| 398 |
),
|
| 399 |
|
| 400 |
"retro_80s": MasteringPreset(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
"Retro 80s",
|
| 402 |
"80s digital warmth and punch",
|
| 403 |
[
|
|
|
|
| 398 |
),
|
| 399 |
|
| 400 |
"retro_80s": MasteringPreset(
|
| 401 |
+
"Retro 80s",
|
| 402 |
+
"1980s-inspired mix with character",
|
| 403 |
+
[
|
| 404 |
+
HighpassFilter(cutoff_frequency_hz=45),
|
| 405 |
+
LowShelfFilter(cutoff_frequency_hz=100, gain_db=1.5, q=0.7),
|
| 406 |
+
PeakFilter(cutoff_frequency_hz=800, gain_db=1.0, q=1.2),
|
| 407 |
+
PeakFilter(cutoff_frequency_hz=3000, gain_db=1.5, q=1.0),
|
| 408 |
+
HighShelfFilter(cutoff_frequency_hz=10000, gain_db=2.0, q=0.7),
|
| 409 |
+
Compressor(threshold_db=-14, ratio=3.5, attack_ms=5, release_ms=100),
|
| 410 |
+
Limiter(threshold_db=-0.8, release_ms=80)
|
| 411 |
+
]
|
| 412 |
+
),
|
| 413 |
+
|
| 414 |
+
# Enhancement Presets (Phase 2)
|
| 415 |
+
"harmonic_enhance": MasteringPreset(
|
| 416 |
+
"Harmonic Enhance",
|
| 417 |
+
"Adds subtle harmonic overtones for brightness and warmth",
|
| 418 |
+
[
|
| 419 |
+
HighpassFilter(cutoff_frequency_hz=30),
|
| 420 |
+
# Subtle low-end warmth
|
| 421 |
+
LowShelfFilter(cutoff_frequency_hz=100, gain_db=1.0, q=0.7),
|
| 422 |
+
# Presence boost
|
| 423 |
+
PeakFilter(cutoff_frequency_hz=3000, gain_db=1.5, q=1.0),
|
| 424 |
+
# Air and clarity
|
| 425 |
+
HighShelfFilter(cutoff_frequency_hz=8000, gain_db=2.0, q=0.7),
|
| 426 |
+
# Gentle saturation effect through compression
|
| 427 |
+
Compressor(threshold_db=-18, ratio=2.5, attack_ms=10, release_ms=120),
|
| 428 |
+
# Final limiting
|
| 429 |
+
Limiter(threshold_db=-0.5, release_ms=100),
|
| 430 |
+
# Note: Additional harmonic generation would require Distortion plugin
|
| 431 |
+
# which adds subtle harmonic overtones
|
| 432 |
+
]
|
| 433 |
+
),
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
def __init__(self):
|
| 437 |
+
"""Initialize mastering service"""
|
| 438 |
+
logger.info("Mastering service initialized")
|
| 439 |
"Retro 80s",
|
| 440 |
"80s digital warmth and punch",
|
| 441 |
[
|
backend/services/stem_enhancement_service.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Stem Enhancement Service
|
| 3 |
+
Separates audio into stems (vocals, drums, bass, other) and enhances each independently
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import logging
|
| 7 |
+
import numpy as np
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
from typing import Optional, Dict
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class StemEnhancementService:
|
| 15 |
+
"""Service for stem separation and per-stem enhancement"""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
"""Initialize stem enhancement service"""
|
| 19 |
+
self.separator = None
|
| 20 |
+
self.model_loaded = False
|
| 21 |
+
logger.info("Stem enhancement service initialized")
|
| 22 |
+
|
| 23 |
+
def _load_model(self):
|
| 24 |
+
"""Lazy load Demucs model (1.3GB download on first use)"""
|
| 25 |
+
if self.model_loaded:
|
| 26 |
+
return
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
logger.info("Loading Demucs model (htdemucs_ft)...")
|
| 30 |
+
import demucs.api
|
| 31 |
+
|
| 32 |
+
# Use htdemucs_ft - best quality for music
|
| 33 |
+
self.separator = demucs.api.Separator(model="htdemucs_ft")
|
| 34 |
+
self.model_loaded = True
|
| 35 |
+
logger.info("Demucs model loaded successfully")
|
| 36 |
+
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.error(f"Failed to load Demucs model: {e}", exc_info=True)
|
| 39 |
+
raise
|
| 40 |
+
|
| 41 |
+
def enhance_clip(
|
| 42 |
+
self,
|
| 43 |
+
audio_path: str,
|
| 44 |
+
output_path: Optional[str] = None,
|
| 45 |
+
enhancement_level: str = "balanced"
|
| 46 |
+
) -> str:
|
| 47 |
+
"""
|
| 48 |
+
Enhance audio quality through stem separation and processing
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
audio_path: Input audio file path
|
| 52 |
+
output_path: Output audio file path (optional)
|
| 53 |
+
enhancement_level: 'fast', 'balanced', or 'maximum'
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Path to enhanced audio file
|
| 57 |
+
"""
|
| 58 |
+
try:
|
| 59 |
+
logger.info(f"Starting stem enhancement: {audio_path} (level: {enhancement_level})")
|
| 60 |
+
|
| 61 |
+
# Load model if not already loaded
|
| 62 |
+
self._load_model()
|
| 63 |
+
|
| 64 |
+
# Generate output path if not provided
|
| 65 |
+
if output_path is None:
|
| 66 |
+
base, ext = os.path.splitext(audio_path)
|
| 67 |
+
output_path = f"{base}_enhanced{ext}"
|
| 68 |
+
|
| 69 |
+
# Load audio
|
| 70 |
+
logger.info(f"Loading audio from: {audio_path}")
|
| 71 |
+
audio, sr = sf.read(audio_path)
|
| 72 |
+
|
| 73 |
+
# Ensure audio is in correct format for Demucs (2D array, stereo)
|
| 74 |
+
if audio.ndim == 1:
|
| 75 |
+
# Mono to stereo
|
| 76 |
+
audio = np.stack([audio, audio], axis=1)
|
| 77 |
+
|
| 78 |
+
# Separate stems
|
| 79 |
+
logger.info("Separating stems with Demucs...")
|
| 80 |
+
origin, stems = self.separator.separate_audio_file(audio_path)
|
| 81 |
+
|
| 82 |
+
# stems is a dict: {'vocals': ndarray, 'drums': ndarray, 'bass': ndarray, 'other': ndarray}
|
| 83 |
+
logger.info(f"Stems separated: {list(stems.keys())}")
|
| 84 |
+
|
| 85 |
+
# Process each stem based on enhancement level
|
| 86 |
+
if enhancement_level == "fast":
|
| 87 |
+
# Fast mode: minimal processing, just denoise vocals
|
| 88 |
+
vocals_enhanced = self._denoise_stem(stems['vocals'], sr, intensity=0.5)
|
| 89 |
+
drums_enhanced = stems['drums']
|
| 90 |
+
bass_enhanced = stems['bass']
|
| 91 |
+
other_enhanced = stems['other']
|
| 92 |
+
|
| 93 |
+
elif enhancement_level == "balanced":
|
| 94 |
+
# Balanced mode: denoise + basic processing
|
| 95 |
+
vocals_enhanced = self._enhance_vocals(stems['vocals'], sr, aggressive=False)
|
| 96 |
+
drums_enhanced = self._enhance_drums(stems['drums'], sr, aggressive=False)
|
| 97 |
+
bass_enhanced = self._enhance_bass(stems['bass'], sr, aggressive=False)
|
| 98 |
+
other_enhanced = self._denoise_stem(stems['other'], sr, intensity=0.5)
|
| 99 |
+
|
| 100 |
+
else: # maximum
|
| 101 |
+
# Maximum mode: full processing
|
| 102 |
+
vocals_enhanced = self._enhance_vocals(stems['vocals'], sr, aggressive=True)
|
| 103 |
+
drums_enhanced = self._enhance_drums(stems['drums'], sr, aggressive=True)
|
| 104 |
+
bass_enhanced = self._enhance_bass(stems['bass'], sr, aggressive=True)
|
| 105 |
+
other_enhanced = self._enhance_other(stems['other'], sr)
|
| 106 |
+
|
| 107 |
+
# Reassemble stems
|
| 108 |
+
logger.info("Reassembling enhanced stems...")
|
| 109 |
+
enhanced_audio = (
|
| 110 |
+
vocals_enhanced +
|
| 111 |
+
drums_enhanced +
|
| 112 |
+
bass_enhanced +
|
| 113 |
+
other_enhanced
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Normalize to prevent clipping
|
| 117 |
+
max_val = np.abs(enhanced_audio).max()
|
| 118 |
+
if max_val > 0:
|
| 119 |
+
enhanced_audio = enhanced_audio / max_val * 0.95
|
| 120 |
+
|
| 121 |
+
# Save enhanced audio
|
| 122 |
+
logger.info(f"Saving enhanced audio to: {output_path}")
|
| 123 |
+
sf.write(output_path, enhanced_audio, sr)
|
| 124 |
+
|
| 125 |
+
logger.info(f"Stem enhancement complete: {output_path}")
|
| 126 |
+
return output_path
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
logger.error(f"Stem enhancement failed: {e}", exc_info=True)
|
| 130 |
+
# Return original if enhancement fails
|
| 131 |
+
return audio_path
|
| 132 |
+
|
| 133 |
+
def _denoise_stem(self, stem: np.ndarray, sr: int, intensity: float = 1.0) -> np.ndarray:
|
| 134 |
+
"""
|
| 135 |
+
Apply noise reduction to a stem
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
stem: Audio stem (ndarray)
|
| 139 |
+
sr: Sample rate
|
| 140 |
+
intensity: Denoising intensity (0-1)
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Denoised stem
|
| 144 |
+
"""
|
| 145 |
+
try:
|
| 146 |
+
import noisereduce as nr
|
| 147 |
+
|
| 148 |
+
# Handle stereo
|
| 149 |
+
if stem.ndim == 2:
|
| 150 |
+
# Process each channel
|
| 151 |
+
denoised = np.zeros_like(stem)
|
| 152 |
+
for ch in range(stem.shape[1]):
|
| 153 |
+
denoised[:, ch] = nr.reduce_noise(
|
| 154 |
+
y=stem[:, ch],
|
| 155 |
+
sr=sr,
|
| 156 |
+
stationary=True,
|
| 157 |
+
prop_decrease=intensity,
|
| 158 |
+
freq_mask_smooth_hz=500,
|
| 159 |
+
time_mask_smooth_ms=50
|
| 160 |
+
)
|
| 161 |
+
return denoised
|
| 162 |
+
else:
|
| 163 |
+
# Mono
|
| 164 |
+
return nr.reduce_noise(
|
| 165 |
+
y=stem,
|
| 166 |
+
sr=sr,
|
| 167 |
+
stationary=True,
|
| 168 |
+
prop_decrease=intensity
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
except Exception as e:
|
| 172 |
+
logger.warning(f"Denoising failed: {e}, returning original stem")
|
| 173 |
+
return stem
|
| 174 |
+
|
| 175 |
+
def _enhance_vocals(self, vocals: np.ndarray, sr: int, aggressive: bool = False) -> np.ndarray:
|
| 176 |
+
"""
|
| 177 |
+
Enhance vocal stem (critical for LyricMind AI vocals)
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
vocals: Vocal stem
|
| 181 |
+
sr: Sample rate
|
| 182 |
+
aggressive: Use more aggressive processing
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Enhanced vocals
|
| 186 |
+
"""
|
| 187 |
+
try:
|
| 188 |
+
# 1. Denoise (reduce AI artifacts)
|
| 189 |
+
intensity = 1.0 if aggressive else 0.7
|
| 190 |
+
vocals_clean = self._denoise_stem(vocals, sr, intensity=intensity)
|
| 191 |
+
|
| 192 |
+
# 2. Apply subtle compression and EQ with Pedalboard
|
| 193 |
+
try:
|
| 194 |
+
from pedalboard import Pedalboard, Compressor, HighShelfFilter, LowpassFilter
|
| 195 |
+
|
| 196 |
+
board = Pedalboard([
|
| 197 |
+
# Remove very high frequencies (often artifacts)
|
| 198 |
+
LowpassFilter(cutoff_frequency_hz=16000),
|
| 199 |
+
# Subtle compression for consistency
|
| 200 |
+
Compressor(
|
| 201 |
+
threshold_db=-20,
|
| 202 |
+
ratio=3 if aggressive else 2,
|
| 203 |
+
attack_ms=5,
|
| 204 |
+
release_ms=50
|
| 205 |
+
),
|
| 206 |
+
# Add air
|
| 207 |
+
HighShelfFilter(cutoff_frequency_hz=8000, gain_db=2 if aggressive else 1)
|
| 208 |
+
])
|
| 209 |
+
|
| 210 |
+
vocals_processed = board(vocals_clean, sr)
|
| 211 |
+
logger.info(f"Vocals enhanced (aggressive={aggressive})")
|
| 212 |
+
return vocals_processed
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
logger.warning(f"Pedalboard processing failed: {e}, using denoised only")
|
| 216 |
+
return vocals_clean
|
| 217 |
+
|
| 218 |
+
except Exception as e:
|
| 219 |
+
logger.error(f"Vocal enhancement failed: {e}", exc_info=True)
|
| 220 |
+
return vocals
|
| 221 |
+
|
| 222 |
+
def _enhance_drums(self, drums: np.ndarray, sr: int, aggressive: bool = False) -> np.ndarray:
|
| 223 |
+
"""
|
| 224 |
+
Enhance drum stem
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
drums: Drum stem
|
| 228 |
+
sr: Sample rate
|
| 229 |
+
aggressive: Use more aggressive processing
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
Enhanced drums
|
| 233 |
+
"""
|
| 234 |
+
try:
|
| 235 |
+
from pedalboard import Pedalboard, NoiseGate, Compressor
|
| 236 |
+
|
| 237 |
+
board = Pedalboard([
|
| 238 |
+
# Gate to clean up between hits
|
| 239 |
+
NoiseGate(
|
| 240 |
+
threshold_db=-40 if aggressive else -35,
|
| 241 |
+
ratio=10,
|
| 242 |
+
attack_ms=1,
|
| 243 |
+
release_ms=100
|
| 244 |
+
),
|
| 245 |
+
# Compression for punch
|
| 246 |
+
Compressor(
|
| 247 |
+
threshold_db=-15,
|
| 248 |
+
ratio=4 if aggressive else 3,
|
| 249 |
+
attack_ms=10,
|
| 250 |
+
release_ms=100
|
| 251 |
+
)
|
| 252 |
+
])
|
| 253 |
+
|
| 254 |
+
drums_processed = board(drums, sr)
|
| 255 |
+
logger.info(f"Drums enhanced (aggressive={aggressive})")
|
| 256 |
+
return drums_processed
|
| 257 |
+
|
| 258 |
+
except Exception as e:
|
| 259 |
+
logger.warning(f"Drum enhancement failed: {e}, returning original")
|
| 260 |
+
return drums
|
| 261 |
+
|
| 262 |
+
def _enhance_bass(self, bass: np.ndarray, sr: int, aggressive: bool = False) -> np.ndarray:
|
| 263 |
+
"""
|
| 264 |
+
Enhance bass stem
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
bass: Bass stem
|
| 268 |
+
sr: Sample rate
|
| 269 |
+
aggressive: Use more aggressive processing
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
Enhanced bass
|
| 273 |
+
"""
|
| 274 |
+
try:
|
| 275 |
+
from pedalboard import Pedalboard, HighpassFilter, Compressor
|
| 276 |
+
|
| 277 |
+
board = Pedalboard([
|
| 278 |
+
# Remove sub-bass rumble
|
| 279 |
+
HighpassFilter(cutoff_frequency_hz=30),
|
| 280 |
+
# Compression for consistency
|
| 281 |
+
Compressor(
|
| 282 |
+
threshold_db=-18,
|
| 283 |
+
ratio=3 if aggressive else 2.5,
|
| 284 |
+
attack_ms=30,
|
| 285 |
+
release_ms=200
|
| 286 |
+
)
|
| 287 |
+
])
|
| 288 |
+
|
| 289 |
+
bass_processed = board(bass, sr)
|
| 290 |
+
logger.info(f"Bass enhanced (aggressive={aggressive})")
|
| 291 |
+
return bass_processed
|
| 292 |
+
|
| 293 |
+
except Exception as e:
|
| 294 |
+
logger.warning(f"Bass enhancement failed: {e}, returning original")
|
| 295 |
+
return bass
|
| 296 |
+
|
| 297 |
+
def _enhance_other(self, other: np.ndarray, sr: int) -> np.ndarray:
|
| 298 |
+
"""
|
| 299 |
+
Enhance other instruments stem
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
other: Other instruments stem
|
| 303 |
+
sr: Sample rate
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
Enhanced other stem
|
| 307 |
+
"""
|
| 308 |
+
try:
|
| 309 |
+
# Spectral cleanup with moderate denoising
|
| 310 |
+
other_clean = self._denoise_stem(other, sr, intensity=0.5)
|
| 311 |
+
logger.info("Other instruments enhanced")
|
| 312 |
+
return other_clean
|
| 313 |
+
|
| 314 |
+
except Exception as e:
|
| 315 |
+
logger.warning(f"Other enhancement failed: {e}, returning original")
|
| 316 |
+
return other
|