""" LoRA Training Service Handles fine-tuning of DiffRhythm2 model using LoRA adapters for vocal and symbolic music. """ import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from pathlib import Path import json import logging from typing import Dict, List, Optional, Callable import soundfile as sf import numpy as np import time import shutil import zipfile from datetime import datetime logger = logging.getLogger(__name__) class TrainingDataset(Dataset): """Dataset for LoRA training""" def __init__( self, audio_files: List[str], metadata_list: List[Dict], sample_rate: int = 44100, clip_length: float = 10.0 ): """ Initialize training dataset Args: audio_files: List of paths to audio files metadata_list: List of metadata dicts for each audio file sample_rate: Target sample rate clip_length: Length of training clips in seconds """ self.audio_files = audio_files self.metadata_list = metadata_list self.sample_rate = sample_rate self.clip_length = clip_length self.clip_samples = int(clip_length * sample_rate) logger.info(f"Initialized dataset with {len(audio_files)} audio files") def __len__(self): return len(self.audio_files) def __getitem__(self, idx): """Get training sample""" try: audio_path = self.audio_files[idx] metadata = self.metadata_list[idx] # Load audio y, sr = sf.read(audio_path) # Resample if needed if sr != self.sample_rate: import librosa y = librosa.resample(y, orig_sr=sr, target_sr=self.sample_rate) # Ensure mono if y.ndim > 1: y = y.mean(axis=1) # Extract/pad to clip length if len(y) > self.clip_samples: # Random crop start = np.random.randint(0, len(y) - self.clip_samples) y = y[start:start + self.clip_samples] else: # Pad y = np.pad(y, (0, self.clip_samples - len(y))) # Generate prompt from metadata prompt = self._generate_prompt(metadata) return { 'audio': torch.FloatTensor(y), 'prompt': prompt, 'metadata': metadata } except Exception as e: logger.error(f"Error loading sample {idx}: {str(e)}") # Return empty sample on error return { 'audio': torch.zeros(self.clip_samples), 'prompt': "", 'metadata': {} } def _generate_prompt(self, metadata: Dict) -> str: """Generate text prompt from metadata""" parts = [] if 'genre' in metadata and metadata['genre'] != 'unknown': parts.append(metadata['genre']) if 'instrumentation' in metadata: parts.append(f"with {metadata['instrumentation']}") if 'bpm' in metadata: parts.append(f"at {metadata['bpm']} BPM") if 'key' in metadata: parts.append(f"in {metadata['key']}") if 'mood' in metadata: parts.append(f"{metadata['mood']} mood") if 'description' in metadata: parts.append(metadata['description']) return " ".join(parts) if parts else "music" class LoRATrainingService: """Service for training LoRA adapters for DiffRhythm2""" def __init__(self): """Initialize LoRA training service""" self.models_dir = Path("models") self.lora_dir = self.models_dir / "loras" self.lora_dir.mkdir(parents=True, exist_ok=True) self.training_data_dir = Path("training_data") self.training_data_dir.mkdir(parents=True, exist_ok=True) self.device = "cuda" if torch.cuda.is_available() else "cpu" # Training state self.is_training = False self.current_epoch = 0 self.current_step = 0 self.training_loss = [] self.training_config = None logger.info(f"LoRATrainingService initialized on {self.device}") def prepare_dataset( self, dataset_name: str, audio_files: List[str], metadata_list: List[Dict], split_ratio: float = 0.9 ) -> Dict: """ Prepare and save training dataset Args: dataset_name: Name for this dataset audio_files: List of audio file paths metadata_list: List of metadata for each file split_ratio: Train/validation split ratio Returns: Dataset information dictionary """ try: logger.info(f"Preparing dataset: {dataset_name}") # Create dataset directory dataset_dir = self.training_data_dir / dataset_name dataset_dir.mkdir(parents=True, exist_ok=True) # Split into train/val num_samples = len(audio_files) num_train = int(num_samples * split_ratio) indices = np.random.permutation(num_samples) train_indices = indices[:num_train] val_indices = indices[num_train:] # Save metadata dataset_info = { 'name': dataset_name, 'created': datetime.now().isoformat(), 'num_samples': num_samples, 'num_train': num_train, 'num_val': num_samples - num_train, 'train_files': [audio_files[i] for i in train_indices], 'train_metadata': [metadata_list[i] for i in train_indices], 'val_files': [audio_files[i] for i in val_indices], 'val_metadata': [metadata_list[i] for i in val_indices] } # Save to disk metadata_path = dataset_dir / "dataset_info.json" with open(metadata_path, 'w') as f: json.dump(dataset_info, f, indent=2) logger.info(f"Dataset prepared: {num_train} train, {num_samples - num_train} val samples") return dataset_info except Exception as e: logger.error(f"Dataset preparation failed: {str(e)}") raise def load_dataset(self, dataset_name: str) -> Optional[Dict]: """Load prepared dataset information""" try: dataset_dir = self.training_data_dir / dataset_name metadata_path = dataset_dir / "dataset_info.json" if not metadata_path.exists(): logger.warning(f"Dataset not found: {dataset_name}") return None with open(metadata_path, 'r') as f: return json.load(f) except Exception as e: logger.error(f"Failed to load dataset {dataset_name}: {str(e)}") return None def list_datasets(self) -> List[str]: """List available prepared datasets""" try: datasets = [] for dataset_dir in self.training_data_dir.iterdir(): if dataset_dir.is_dir() and (dataset_dir / "dataset_info.json").exists(): datasets.append(dataset_dir.name) return datasets except Exception as e: logger.error(f"Failed to list datasets: {str(e)}") return [] def list_loras(self) -> List[str]: """List available LoRA adapters""" try: loras = [] if not self.lora_dir.exists(): return loras for lora_path in self.lora_dir.iterdir(): if lora_path.is_dir(): # Check for adapter files if (lora_path / "adapter_config.json").exists(): loras.append(lora_path.name) # Also check for .safetensors or .bin files elif list(lora_path.glob("*.safetensors")) or list(lora_path.glob("*.bin")): loras.append(lora_path.name) return sorted(loras) except Exception as e: logger.error(f"Failed to list LoRAs: {str(e)}") return [] def train_lora( self, dataset_name: str, lora_name: str, training_type: str = "vocal", # "vocal" or "symbolic" config: Optional[Dict] = None, progress_callback: Optional[Callable] = None ) -> Dict: """ Train LoRA adapter Args: dataset_name: Name of prepared dataset lora_name: Name for the LoRA adapter training_type: Type of training ("vocal" or "symbolic") config: Training configuration (batch_size, learning_rate, etc.) progress_callback: Optional callback for progress updates Returns: Training results dictionary """ try: if self.is_training: raise RuntimeError("Training already in progress") self.is_training = True logger.info(f"Starting LoRA training: {lora_name} ({training_type})") # Load dataset dataset_info = self.load_dataset(dataset_name) if not dataset_info: raise ValueError(f"Dataset not found: {dataset_name}") # Check if dataset is from HuggingFace and needs preparation if dataset_info.get('hf_dataset') and not dataset_info.get('prepared'): raise ValueError( f"Dataset '{dataset_name}' is a HuggingFace dataset that hasn't been prepared for training yet. " f"Please use the 'User Audio Training' tab to upload and prepare your own audio files, " f"or wait for dataset preparation features to be implemented." ) # Validate dataset has required fields if 'train_files' not in dataset_info or 'val_files' not in dataset_info: raise ValueError( f"Dataset '{dataset_name}' is missing required training files. " f"Please use prepared datasets or upload your own audio in the 'User Audio Training' tab." ) # Validate datasets are not empty if not dataset_info['train_files'] or len(dataset_info['train_files']) == 0: raise ValueError( f"Dataset '{dataset_name}' has no training samples. " f"The dataset may not have been prepared correctly. " f"Please re-prepare the dataset or use a different one." ) if not dataset_info['val_files'] or len(dataset_info['val_files']) == 0: raise ValueError( f"Dataset '{dataset_name}' has no validation samples. " f"The dataset may not have been prepared correctly. " f"Please re-prepare the dataset or use a different one." ) # Default config default_config = { 'batch_size': 4, 'learning_rate': 3e-4, 'num_epochs': 10, 'lora_rank': 16, 'lora_alpha': 32, 'warmup_steps': 100, 'save_every': 500, 'gradient_accumulation': 2 } self.training_config = {**default_config, **(config or {})} # Create datasets train_dataset = TrainingDataset( dataset_info['train_files'], dataset_info['train_metadata'] ) val_dataset = TrainingDataset( dataset_info['val_files'], dataset_info['val_metadata'] ) # Create data loaders # Disable pin_memory and num_workers for compatibility with ZeroGPU and CPU # pin_memory requires persistent CUDA access which ZeroGPU doesn't provide at this stage train_loader = DataLoader( train_dataset, batch_size=self.training_config['batch_size'], shuffle=True, num_workers=0, pin_memory=False ) val_loader = DataLoader( val_dataset, batch_size=self.training_config['batch_size'], shuffle=False, num_workers=0, pin_memory=False ) # Initialize model (placeholder - actual implementation would load DiffRhythm2) # For now, we'll simulate training logger.info("Initializing model and LoRA layers...") # Note: Actual implementation would: # 1. Load DiffRhythm2 model # 2. Add LoRA adapters using peft library # 3. Freeze base model, only train LoRA parameters # Simulated training loop num_steps = len(train_loader) * self.training_config['num_epochs'] logger.info(f"Training for {self.training_config['num_epochs']} epochs, {num_steps} total steps") results = self._training_loop( train_loader, val_loader, lora_name, progress_callback ) self.is_training = False logger.info("Training complete!") return results except Exception as e: self.is_training = False logger.error(f"Training failed: {str(e)}") raise def _training_loop( self, train_loader: DataLoader, val_loader: DataLoader, lora_name: str, progress_callback: Optional[Callable] ) -> Dict: """ Main training loop Note: This is a simplified placeholder implementation. Actual implementation would require: 1. Loading DiffRhythm2 model 2. Setting up LoRA adapters with peft library 3. Implementing proper loss functions 4. Gradient accumulation and optimization """ self.current_epoch = 0 self.current_step = 0 self.training_loss = [] best_val_loss = float('inf') num_epochs = self.training_config['num_epochs'] for epoch in range(num_epochs): self.current_epoch = epoch + 1 epoch_loss = 0.0 logger.info(f"Epoch {self.current_epoch}/{num_epochs}") # Training phase for batch_idx, batch in enumerate(train_loader): self.current_step += 1 # Simulate training step # Actual implementation would: # 1. Move batch to device # 2. Forward pass through model # 3. Calculate loss # 4. Backward pass # 5. Update weights # Simulated loss (decreasing over time) step_loss = 1.0 / (1.0 + self.current_step * 0.01) epoch_loss += step_loss self.training_loss.append(step_loss) # Progress update if progress_callback and batch_idx % 10 == 0: progress_callback({ 'epoch': self.current_epoch, 'step': self.current_step, 'loss': step_loss, 'progress': (self.current_step / (len(train_loader) * num_epochs)) * 100 }) # Log every 50 steps if self.current_step % 50 == 0: logger.info(f"Step {self.current_step}: Loss = {step_loss:.4f}") # Save checkpoint if self.current_step % self.training_config['save_every'] == 0: self._save_checkpoint(lora_name, self.current_step) # Validation phase avg_train_loss = epoch_loss / len(train_loader) val_loss = self._validate(val_loader) logger.info(f"Epoch {self.current_epoch}: Train Loss = {avg_train_loss:.4f}, Val Loss = {val_loss:.4f}") # Save best model if val_loss < best_val_loss: best_val_loss = val_loss self._save_lora_adapter(lora_name, is_best=True) logger.info(f"New best model! Val Loss: {val_loss:.4f}") # Final save self._save_lora_adapter(lora_name, is_best=False) return { 'lora_name': lora_name, 'num_epochs': num_epochs, 'total_steps': self.current_step, 'final_train_loss': avg_train_loss, 'final_val_loss': val_loss, 'best_val_loss': best_val_loss, 'training_time': 'simulated' } def _validate(self, val_loader: DataLoader) -> float: """Run validation""" total_loss = 0.0 for batch in val_loader: # Simulate validation # Actual implementation would run model inference val_loss = 1.0 / (1.0 + self.current_step * 0.01) total_loss += val_loss return total_loss / len(val_loader) def _save_checkpoint(self, lora_name: str, step: int): """Save training checkpoint""" checkpoint_dir = self.lora_dir / lora_name / "checkpoints" checkpoint_dir.mkdir(parents=True, exist_ok=True) checkpoint_path = checkpoint_dir / f"checkpoint_step_{step}.pt" # Actual implementation would save: # - LoRA weights # - Optimizer state # - Training step # - Config checkpoint_data = { 'step': step, 'epoch': self.current_epoch, 'config': self.training_config, 'loss_history': self.training_loss[-100:] # Last 100 steps } torch.save(checkpoint_data, checkpoint_path) logger.info(f"Saved checkpoint: step_{step}") def _save_lora_adapter(self, lora_name: str, is_best: bool = False): """Save final LoRA adapter""" lora_path = self.lora_dir / lora_name lora_path.mkdir(parents=True, exist_ok=True) filename = "best_model.pt" if is_best else "final_model.pt" save_path = lora_path / filename # Actual implementation would save: # - LoRA adapter weights only # - Configuration # - Training metadata adapter_data = { 'lora_name': lora_name, 'config': self.training_config, 'training_steps': self.current_step, 'saved_at': datetime.now().isoformat() } torch.save(adapter_data, save_path) logger.info(f"Saved LoRA adapter: {filename}") # Save metadata metadata_path = lora_path / "metadata.json" with open(metadata_path, 'w') as f: json.dump(adapter_data, f, indent=2) def list_lora_adapters(self) -> List[Dict]: """List available LoRA adapters""" try: adapters = [] for lora_dir in self.lora_dir.iterdir(): if lora_dir.is_dir(): metadata_path = lora_dir / "metadata.json" if metadata_path.exists(): with open(metadata_path, 'r') as f: metadata = json.load(f) adapters.append({ 'name': lora_dir.name, **metadata }) else: # Basic info if no metadata adapters.append({ 'name': lora_dir.name, 'has_best': (lora_dir / "best_model.pt").exists(), 'has_final': (lora_dir / "final_model.pt").exists() }) return adapters except Exception as e: logger.error(f"Failed to list LoRA adapters: {str(e)}") return [] def delete_lora_adapter(self, lora_name: str) -> bool: """Delete a LoRA adapter""" try: import shutil lora_path = self.lora_dir / lora_name if lora_path.exists(): shutil.rmtree(lora_path) logger.info(f"Deleted LoRA adapter: {lora_name}") return True else: logger.warning(f"LoRA adapter not found: {lora_name}") return False except Exception as e: logger.error(f"Failed to delete LoRA adapter {lora_name}: {str(e)}") return False def stop_training(self): """Stop current training""" if self.is_training: logger.info("Training stop requested") self.is_training = False def get_training_status(self) -> Dict: """Get current training status""" return { 'is_training': self.is_training, 'current_epoch': self.current_epoch, 'current_step': self.current_step, 'recent_loss': self.training_loss[-10:] if self.training_loss else [], 'config': self.training_config } def export_lora_adapter(self, lora_name: str) -> Optional[str]: """ Export a LoRA adapter as a zip file for download Args: lora_name: Name of the LoRA adapter to export Returns: Path to the exported zip file, or None if failed """ try: import shutil import tempfile lora_path = self.lora_dir / lora_name if not lora_path.exists(): logger.error(f"LoRA adapter not found: {lora_name}") return None # Create exports directory if it doesn't exist exports_dir = Path("outputs/lora_exports") exports_dir.mkdir(parents=True, exist_ok=True) # Create zip file zip_path = exports_dir / f"{lora_name}.zip" # Remove existing zip if present if zip_path.exists(): zip_path.unlink() # Create zip archive shutil.make_archive( str(exports_dir / lora_name), 'zip', str(lora_path) ) logger.info(f"Exported LoRA adapter to: {zip_path}") return str(zip_path) except Exception as e: logger.error(f"Failed to export LoRA adapter {lora_name}: {str(e)}") return None def import_lora_adapter(self, zip_path: str) -> Optional[str]: """ Import a LoRA adapter from a zip file Args: zip_path: Path to the zip file containing LoRA adapter Returns: Name of the imported LoRA adapter, or None if failed """ try: import zipfile import tempfile # Extract to temporary directory first with tempfile.TemporaryDirectory() as temp_dir: with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(temp_dir) temp_path = Path(temp_dir) # Case 1: Check if metadata.json is at root level if (temp_path / "metadata.json").exists(): logger.info("Found metadata.json at root level") source_dir = temp_path # Read metadata to get LoRA name with open(temp_path / "metadata.json", 'r') as f: metadata = json.load(f) lora_name = metadata.get('lora_name', 'imported_lora') else: # Case 2: Look for a subfolder with metadata.json lora_folders = [d for d in temp_path.iterdir() if d.is_dir() and (d / "metadata.json").exists()] if not lora_folders: logger.error("No valid LoRA adapter found in zip file. Expected metadata.json at root or in a subfolder.") return None source_dir = lora_folders[0] lora_name = source_dir.name logger.info(f"Found LoRA in subfolder: {lora_name}") # Copy to loras directory dest_path = self.lora_dir / lora_name # If already exists, rename with timestamp if dest_path.exists(): timestamp = int(time.time()) original_name = lora_name lora_name = f"{lora_name}_{timestamp}" dest_path = self.lora_dir / lora_name logger.info(f"LoRA '{original_name}' already exists, importing as '{lora_name}'") # Create destination directory dest_path.mkdir(parents=True, exist_ok=True) # Copy all files from source to destination for item in source_dir.iterdir(): if item.is_file(): shutil.copy2(item, dest_path / item.name) elif item.is_dir(): shutil.copytree(item, dest_path / item.name) logger.info(f"✅ Imported LoRA adapter: {lora_name}") return lora_name except Exception as e: logger.error(f"Failed to import LoRA adapter: {str(e)}", exc_info=True) return None