Spaces:
Runtime error
Runtime error
| from pytorch_lightning.callbacks import Callback | |
| import pytorch_lightning as pl | |
| from einops import rearrange | |
| import torch | |
| import wandb | |
| from torch import Tensor | |
| from remfx import effects | |
| ALL_EFFECTS = effects.Pedalboard_Effects | |
| class AudioCallback(Callback): | |
| def __init__(self, sample_rate, log_audio, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.log_audio = log_audio | |
| self.log_train_audio = True | |
| self.sample_rate = sample_rate | |
| if not self.log_audio: | |
| self.log_train_audio = False | |
| def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): | |
| # Log initial audio | |
| if self.log_train_audio: | |
| x, y, _, _ = batch | |
| # Concat samples together for easier viewing in dashboard | |
| input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0) | |
| target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0) | |
| log_wandb_audio_batch( | |
| logger=trainer.logger, | |
| id="input_effected_audio", | |
| samples=input_samples.cpu(), | |
| sampling_rate=self.sample_rate, | |
| caption="Training Data", | |
| ) | |
| log_wandb_audio_batch( | |
| logger=trainer.logger, | |
| id="target_audio", | |
| samples=target_samples.cpu(), | |
| sampling_rate=self.sample_rate, | |
| caption="Target Data", | |
| ) | |
| self.log_train_audio = False | |
| def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): | |
| x, target, _, rem_fx_labels = batch | |
| # Only run on first batch | |
| if batch_idx == 0 and self.log_audio: | |
| with torch.no_grad(): | |
| # Avoids circular import | |
| from remfx.models import RemFXChainInference | |
| if isinstance(pl_module, RemFXChainInference): | |
| y = pl_module.sample(batch) | |
| effects_present_name = [ | |
| [ | |
| ALL_EFFECTS[i].__name__.replace("RandomPedalboard", "") | |
| for i, effect in enumerate(effect_label) | |
| if effect == 1.0 | |
| ] | |
| for effect_label in rem_fx_labels | |
| ] | |
| for i, label in enumerate(effects_present_name): | |
| self.log(f"{'_'.join(label)}", 0.0) | |
| else: | |
| y = pl_module.model.sample(x) | |
| # Concat samples together for easier viewing in dashboard | |
| # 2 seconds of silence between each sample | |
| silence = torch.zeros_like(x) | |
| silence = silence[:, : self.sample_rate * 2] | |
| concat_samples = torch.cat([y, silence, x, silence, target], dim=-1) | |
| log_wandb_audio_batch( | |
| logger=trainer.logger, | |
| id="prediction_input_target", | |
| samples=concat_samples.cpu(), | |
| sampling_rate=self.sample_rate, | |
| caption=f"Epoch {trainer.current_epoch}", | |
| ) | |
| def on_test_batch_start(self, *args): | |
| self.on_validation_batch_start(*args) | |
| def log_wandb_audio_batch( | |
| logger: pl.loggers.WandbLogger, | |
| id: str, | |
| samples: Tensor, | |
| sampling_rate: int, | |
| caption: str = "", | |
| max_items: int = 10, | |
| ): | |
| if type(logger) != pl.loggers.WandbLogger: | |
| return | |
| num_items = samples.shape[0] | |
| samples = rearrange(samples, "b c t -> b t c") | |
| for idx in range(num_items): | |
| if idx >= max_items: | |
| break | |
| logger.experiment.log( | |
| { | |
| f"{id}_{idx}": wandb.Audio( | |
| samples[idx].cpu().numpy(), | |
| caption=caption, | |
| sample_rate=sampling_rate, | |
| ) | |
| } | |
| ) | |