Spaces:
Runtime error
Runtime error
Hugo Flores Garcia
commited on
Commit
·
0321745
1
Parent(s):
2ddbddc
taking prompt c2f tokens into account
Browse files- app.py +3 -2
- vampnet/interface.py +40 -20
app.py
CHANGED
|
@@ -114,7 +114,7 @@ def _vamp(data, return_mask=False):
|
|
| 114 |
)
|
| 115 |
|
| 116 |
if use_coarse2fine:
|
| 117 |
-
zv = interface.coarse_to_fine(zv, temperature=data[temp])
|
| 118 |
|
| 119 |
sig = interface.to_signal(zv).cpu()
|
| 120 |
print("done")
|
|
@@ -410,7 +410,8 @@ with gr.Blocks() as demo:
|
|
| 410 |
|
| 411 |
use_coarse2fine = gr.Checkbox(
|
| 412 |
label="use coarse2fine",
|
| 413 |
-
value=True
|
|
|
|
| 414 |
)
|
| 415 |
|
| 416 |
num_steps = gr.Slider(
|
|
|
|
| 114 |
)
|
| 115 |
|
| 116 |
if use_coarse2fine:
|
| 117 |
+
zv = interface.coarse_to_fine(zv, temperature=data[temp], mask=mask)
|
| 118 |
|
| 119 |
sig = interface.to_signal(zv).cpu()
|
| 120 |
print("done")
|
|
|
|
| 410 |
|
| 411 |
use_coarse2fine = gr.Checkbox(
|
| 412 |
label="use coarse2fine",
|
| 413 |
+
value=True,
|
| 414 |
+
visible=False
|
| 415 |
)
|
| 416 |
|
| 417 |
num_steps = gr.Slider(
|
vampnet/interface.py
CHANGED
|
@@ -22,6 +22,7 @@ def signal_concat(
|
|
| 22 |
|
| 23 |
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
|
| 24 |
|
|
|
|
| 25 |
def _load_model(
|
| 26 |
ckpt: str,
|
| 27 |
lora_ckpt: str = None,
|
|
@@ -275,36 +276,47 @@ class Interface(torch.nn.Module):
|
|
| 275 |
|
| 276 |
def coarse_to_fine(
|
| 277 |
self,
|
| 278 |
-
|
|
|
|
| 279 |
**kwargs
|
| 280 |
):
|
| 281 |
assert self.c2f is not None, "No coarse2fine model loaded"
|
| 282 |
-
length =
|
| 283 |
chunk_len = self.s2t(self.c2f.chunk_size_s)
|
| 284 |
-
n_chunks = math.ceil(
|
| 285 |
|
| 286 |
# zero pad to chunk_len
|
| 287 |
if length % chunk_len != 0:
|
| 288 |
pad_len = chunk_len - (length % chunk_len)
|
| 289 |
-
|
|
|
|
| 290 |
|
| 291 |
-
n_codebooks_to_append = self.c2f.n_codebooks -
|
| 292 |
if n_codebooks_to_append > 0:
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
torch.zeros(
|
| 296 |
], dim=1)
|
| 297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
fine_z = []
|
| 299 |
for i in range(n_chunks):
|
| 300 |
-
chunk =
|
|
|
|
|
|
|
| 301 |
chunk = self.c2f.generate(
|
| 302 |
codec=self.codec,
|
| 303 |
time_steps=chunk_len,
|
| 304 |
start_tokens=chunk,
|
| 305 |
return_signal=False,
|
|
|
|
| 306 |
**kwargs
|
| 307 |
)
|
|
|
|
| 308 |
fine_z.append(chunk)
|
| 309 |
|
| 310 |
fine_z = torch.cat(fine_z, dim=-1)
|
|
@@ -337,6 +349,12 @@ class Interface(torch.nn.Module):
|
|
| 337 |
**kwargs
|
| 338 |
)
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
if return_mask:
|
| 341 |
return c_vamp, cz_masked
|
| 342 |
|
|
@@ -352,17 +370,18 @@ if __name__ == "__main__":
|
|
| 352 |
at.util.seed(42)
|
| 353 |
|
| 354 |
interface = Interface(
|
| 355 |
-
coarse_ckpt="./models/
|
| 356 |
-
coarse2fine_ckpt="./models/
|
| 357 |
-
codec_ckpt="./models/
|
| 358 |
device="cuda",
|
| 359 |
wavebeat_ckpt="./models/wavebeat.pth"
|
| 360 |
)
|
| 361 |
|
| 362 |
|
| 363 |
-
sig = at.AudioSignal.
|
| 364 |
|
| 365 |
z = interface.encode(sig)
|
|
|
|
| 366 |
|
| 367 |
# mask = linear_random(z, 1.0)
|
| 368 |
# mask = mask_and(
|
|
@@ -374,13 +393,14 @@ if __name__ == "__main__":
|
|
| 374 |
# )
|
| 375 |
# )
|
| 376 |
|
| 377 |
-
mask = interface.make_beat_mask(
|
| 378 |
-
|
| 379 |
-
)
|
| 380 |
# mask = dropout(mask, 0.0)
|
| 381 |
# mask = codebook_unmask(mask, 0)
|
|
|
|
|
|
|
| 382 |
|
| 383 |
-
breakpoint()
|
| 384 |
zv, mask_z = interface.coarse_vamp(
|
| 385 |
z,
|
| 386 |
mask=mask,
|
|
@@ -389,16 +409,16 @@ if __name__ == "__main__":
|
|
| 389 |
return_mask=True,
|
| 390 |
gen_fn=interface.coarse.generate
|
| 391 |
)
|
|
|
|
| 392 |
|
| 393 |
use_coarse2fine = True
|
| 394 |
if use_coarse2fine:
|
| 395 |
-
zv = interface.coarse_to_fine(zv, temperature=0.8)
|
|
|
|
| 396 |
|
| 397 |
mask = interface.to_signal(mask_z).cpu()
|
| 398 |
|
| 399 |
sig = interface.to_signal(zv).cpu()
|
| 400 |
print("done")
|
| 401 |
|
| 402 |
-
sig.write("output3.wav")
|
| 403 |
-
mask.write("mask.wav")
|
| 404 |
|
|
|
|
| 22 |
|
| 23 |
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
|
| 24 |
|
| 25 |
+
|
| 26 |
def _load_model(
|
| 27 |
ckpt: str,
|
| 28 |
lora_ckpt: str = None,
|
|
|
|
| 276 |
|
| 277 |
def coarse_to_fine(
|
| 278 |
self,
|
| 279 |
+
z: torch.Tensor,
|
| 280 |
+
mask: torch.Tensor = None,
|
| 281 |
**kwargs
|
| 282 |
):
|
| 283 |
assert self.c2f is not None, "No coarse2fine model loaded"
|
| 284 |
+
length = z.shape[-1]
|
| 285 |
chunk_len = self.s2t(self.c2f.chunk_size_s)
|
| 286 |
+
n_chunks = math.ceil(z.shape[-1] / chunk_len)
|
| 287 |
|
| 288 |
# zero pad to chunk_len
|
| 289 |
if length % chunk_len != 0:
|
| 290 |
pad_len = chunk_len - (length % chunk_len)
|
| 291 |
+
z = torch.nn.functional.pad(z, (0, pad_len))
|
| 292 |
+
mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None
|
| 293 |
|
| 294 |
+
n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
|
| 295 |
if n_codebooks_to_append > 0:
|
| 296 |
+
z = torch.cat([
|
| 297 |
+
z,
|
| 298 |
+
torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
|
| 299 |
], dim=1)
|
| 300 |
|
| 301 |
+
# set the mask to 0 for all conditioning codebooks
|
| 302 |
+
if mask is not None:
|
| 303 |
+
mask = mask.clone()
|
| 304 |
+
mask[:, :self.c2f.n_conditioning_codebooks, :] = 0
|
| 305 |
+
|
| 306 |
fine_z = []
|
| 307 |
for i in range(n_chunks):
|
| 308 |
+
chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
|
| 309 |
+
mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None
|
| 310 |
+
|
| 311 |
chunk = self.c2f.generate(
|
| 312 |
codec=self.codec,
|
| 313 |
time_steps=chunk_len,
|
| 314 |
start_tokens=chunk,
|
| 315 |
return_signal=False,
|
| 316 |
+
mask=mask_chunk,
|
| 317 |
**kwargs
|
| 318 |
)
|
| 319 |
+
breakpoint()
|
| 320 |
fine_z.append(chunk)
|
| 321 |
|
| 322 |
fine_z = torch.cat(fine_z, dim=-1)
|
|
|
|
| 349 |
**kwargs
|
| 350 |
)
|
| 351 |
|
| 352 |
+
# add the fine codes back in
|
| 353 |
+
c_vamp = torch.cat(
|
| 354 |
+
[c_vamp, z[:, self.coarse.n_codebooks :, :]],
|
| 355 |
+
dim=1
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
if return_mask:
|
| 359 |
return c_vamp, cz_masked
|
| 360 |
|
|
|
|
| 370 |
at.util.seed(42)
|
| 371 |
|
| 372 |
interface = Interface(
|
| 373 |
+
coarse_ckpt="./models/vampnet/coarse.pth",
|
| 374 |
+
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
| 375 |
+
codec_ckpt="./models/vampnet/codec.pth",
|
| 376 |
device="cuda",
|
| 377 |
wavebeat_ckpt="./models/wavebeat.pth"
|
| 378 |
)
|
| 379 |
|
| 380 |
|
| 381 |
+
sig = at.AudioSignal('assets/example.wav')
|
| 382 |
|
| 383 |
z = interface.encode(sig)
|
| 384 |
+
breakpoint()
|
| 385 |
|
| 386 |
# mask = linear_random(z, 1.0)
|
| 387 |
# mask = mask_and(
|
|
|
|
| 393 |
# )
|
| 394 |
# )
|
| 395 |
|
| 396 |
+
# mask = interface.make_beat_mask(
|
| 397 |
+
# sig, 0.0, 0.075
|
| 398 |
+
# )
|
| 399 |
# mask = dropout(mask, 0.0)
|
| 400 |
# mask = codebook_unmask(mask, 0)
|
| 401 |
+
|
| 402 |
+
mask = inpaint(z, n_prefix=100, n_suffix=100)
|
| 403 |
|
|
|
|
| 404 |
zv, mask_z = interface.coarse_vamp(
|
| 405 |
z,
|
| 406 |
mask=mask,
|
|
|
|
| 409 |
return_mask=True,
|
| 410 |
gen_fn=interface.coarse.generate
|
| 411 |
)
|
| 412 |
+
|
| 413 |
|
| 414 |
use_coarse2fine = True
|
| 415 |
if use_coarse2fine:
|
| 416 |
+
zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask)
|
| 417 |
+
breakpoint()
|
| 418 |
|
| 419 |
mask = interface.to_signal(mask_z).cpu()
|
| 420 |
|
| 421 |
sig = interface.to_signal(zv).cpu()
|
| 422 |
print("done")
|
| 423 |
|
|
|
|
|
|
|
| 424 |
|