Spaces:
Runtime error
Runtime error
| from audiotools import AudioSignal | |
| import torch | |
| from pathlib import Path | |
| import argbind | |
| from tqdm import tqdm | |
| import random | |
| from typing import List | |
| from collections import defaultdict | |
| def coarse2fine_infer( | |
| signal, | |
| model, | |
| vqvae, | |
| device, | |
| ): | |
| output = {} | |
| w = signal | |
| w = w.to(device) | |
| z = vqvae.encode(w.audio_data, w.sample_rate)["codes"] | |
| model.to(device) | |
| output["reconstructed"] = model.to_signal(z, vqvae).cpu() | |
| # make a full mask | |
| mask = torch.ones_like(z) | |
| mask[:, :model.n_conditioning_codebooks, :] = 0 | |
| output["sampled"] = model.sample( | |
| codec=vqvae, | |
| time_steps=z.shape[-1], | |
| sampling_steps=12, | |
| start_tokens=z, | |
| mask=mask, | |
| temperature=0.85, | |
| top_k=None, | |
| sample="gumbel", | |
| typical_filtering=True, | |
| return_signal=True | |
| ).cpu() | |
| output["argmax"] = model.sample( | |
| codec=vqvae, | |
| time_steps=z.shape[-1], | |
| sampling_steps=1, | |
| start_tokens=z, | |
| mask=mask, | |
| temperature=1.0, | |
| top_k=None, | |
| sample="argmax", | |
| typical_filtering=True, | |
| return_signal=True | |
| ).cpu() | |
| return output | |
| def main( | |
| sources=[ | |
| "/data/spotdl/audio/val", "/data/spotdl/audio/test" | |
| ], | |
| exp_name="noise_mode", | |
| model_paths=[ | |
| "runs/c2f-exp-03.22.23/ckpt/mask/epoch=400/vampnet/weights.pth", | |
| "runs/c2f-exp-03.22.23/ckpt/random/epoch=400/vampnet/weights.pth", | |
| ], | |
| model_keys=[ | |
| "mask", | |
| "random", | |
| ], | |
| vqvae_path: str = "runs/codec-ckpt/codec.pth", | |
| device: str = "cuda", | |
| output_dir: str = ".", | |
| max_excerpts: int = 5000, | |
| duration: float = 3.0, | |
| ): | |
| from vampnet.modules.transformer import VampNet | |
| from lac.model.lac import LAC | |
| models = { | |
| k: VampNet.load(p) for k, p in zip(model_keys, model_paths) | |
| } | |
| for model in models.values(): | |
| model.eval() | |
| print(f"Loaded {len(models)} models.") | |
| vqvae = LAC.load(vqvae_path) | |
| vqvae.to(device) | |
| vqvae.eval() | |
| print("Loaded VQVAE.") | |
| output_dir = Path(output_dir) / f"{exp_name}-samples" | |
| from audiotools.data.datasets import AudioLoader, AudioDataset | |
| loader = AudioLoader(sources=sources) | |
| dataset = AudioDataset(loader, | |
| sample_rate=vqvae.sample_rate, | |
| duration=duration, | |
| n_examples=max_excerpts, | |
| without_replacement=True, | |
| ) | |
| for i in tqdm(range(max_excerpts)): | |
| sig = dataset[i]["signal"] | |
| sig.resample(vqvae.sample_rate).normalize(-24).ensure_max_of_audio(1.0) | |
| for model_key, model in models.items(): | |
| out = coarse2fine_infer(sig, model, vqvae, device) | |
| out_dir = output_dir / model_key / Path(sig.path_to_file).stem | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| for k, s in out.items(): | |
| s.write(out_dir / f"{k}.wav") | |
| if __name__ == "__main__": | |
| args = argbind.parse_args() | |
| with argbind.scope(args): | |
| main() | |