Spaces:
Runtime error
Runtime error
Hugo Flores Garcia
commited on
Commit
·
bf35d45
1
Parent(s):
8c3b3e7
lora interface
Browse files- app.py +50 -5
- vampnet/interface.py +2 -3
app.py
CHANGED
|
@@ -18,10 +18,45 @@ Interface = argbind.bind(Interface)
|
|
| 18 |
|
| 19 |
conf = argbind.parse_args()
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# dataset = at.data.datasets.AudioDataset(
|
| 27 |
# loader,
|
|
@@ -55,6 +90,8 @@ def load_example_audio():
|
|
| 55 |
|
| 56 |
|
| 57 |
def _vamp(data, return_mask=False):
|
|
|
|
|
|
|
| 58 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
| 59 |
out_dir.mkdir()
|
| 60 |
sig = at.AudioSignal(data[input_audio])
|
|
@@ -173,6 +210,7 @@ def save_vamp(data):
|
|
| 173 |
"use_coarse2fine": data[use_coarse2fine],
|
| 174 |
"stretch_factor": data[stretch_factor],
|
| 175 |
"seed": data[seed],
|
|
|
|
| 176 |
}
|
| 177 |
|
| 178 |
# save with yaml
|
|
@@ -472,6 +510,13 @@ with gr.Blocks() as demo:
|
|
| 472 |
|
| 473 |
# mask settings
|
| 474 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
vamp_button = gr.Button("generate (vamp)!!!")
|
| 476 |
output_audio = gr.Audio(
|
| 477 |
label="output audio",
|
|
@@ -514,7 +559,7 @@ with gr.Blocks() as demo:
|
|
| 514 |
beat_mask_width,
|
| 515 |
beat_mask_downbeats,
|
| 516 |
seed,
|
| 517 |
-
|
| 518 |
}
|
| 519 |
|
| 520 |
# connect widgets
|
|
|
|
| 18 |
|
| 19 |
conf = argbind.parse_args()
|
| 20 |
|
| 21 |
+
def load_interface():
|
| 22 |
+
with argbind.scope(conf):
|
| 23 |
+
interface = Interface()
|
| 24 |
+
# loader = AudioLoader()
|
| 25 |
+
print(f"interface device is {interface.device}")
|
| 26 |
+
return interface
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
LORA_NONE = "None"
|
| 30 |
+
def load_loras():
|
| 31 |
+
loras = {}
|
| 32 |
+
# find confs under conf/generated
|
| 33 |
+
for conf_file in Path("conf/generated").glob("**/interface.yml"):
|
| 34 |
+
name = conf_file.parent.name
|
| 35 |
+
with open(conf_file) as f:
|
| 36 |
+
loras[name] = yaml.safe_load(f)
|
| 37 |
+
loras[LORA_NONE] = None
|
| 38 |
+
return loras
|
| 39 |
+
|
| 40 |
+
interface = load_interface()
|
| 41 |
+
loras = load_loras()
|
| 42 |
+
cur_lora = LORA_NONE
|
| 43 |
+
|
| 44 |
+
def load_lora(name):
|
| 45 |
+
global interface
|
| 46 |
+
global cur_lora
|
| 47 |
+
if name == cur_lora:
|
| 48 |
+
return
|
| 49 |
+
if name != LORA_NONE:
|
| 50 |
+
interface.lora_load(
|
| 51 |
+
coarse_ckpt=loras[name]["Interface.coarse_lora_ckpt"],
|
| 52 |
+
c2f_ckpt=loras[name]["Interface.coarse2fine_lora_ckpt"],
|
| 53 |
+
full_ckpts=False
|
| 54 |
+
)
|
| 55 |
+
cur_lora = name
|
| 56 |
+
|
| 57 |
+
else:
|
| 58 |
+
interface = load_interface()
|
| 59 |
+
cur_lora = LORA_NONE
|
| 60 |
|
| 61 |
# dataset = at.data.datasets.AudioDataset(
|
| 62 |
# loader,
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
def _vamp(data, return_mask=False):
|
| 93 |
+
load_lora(data[lora_choice])
|
| 94 |
+
|
| 95 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
| 96 |
out_dir.mkdir()
|
| 97 |
sig = at.AudioSignal(data[input_audio])
|
|
|
|
| 210 |
"use_coarse2fine": data[use_coarse2fine],
|
| 211 |
"stretch_factor": data[stretch_factor],
|
| 212 |
"seed": data[seed],
|
| 213 |
+
"lora": data[lora_choice],
|
| 214 |
}
|
| 215 |
|
| 216 |
# save with yaml
|
|
|
|
| 510 |
|
| 511 |
# mask settings
|
| 512 |
with gr.Column():
|
| 513 |
+
|
| 514 |
+
lora_choice = gr.Dropdown(
|
| 515 |
+
label="lora choice",
|
| 516 |
+
choices=list(loras.keys()),
|
| 517 |
+
value=LORA_NONE,
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
vamp_button = gr.Button("generate (vamp)!!!")
|
| 521 |
output_audio = gr.Audio(
|
| 522 |
label="output audio",
|
|
|
|
| 559 |
beat_mask_width,
|
| 560 |
beat_mask_downbeats,
|
| 561 |
seed,
|
| 562 |
+
lora_choice,
|
| 563 |
}
|
| 564 |
|
| 565 |
# connect widgets
|
vampnet/interface.py
CHANGED
|
@@ -120,17 +120,16 @@ class Interface(torch.nn.Module):
|
|
| 120 |
if coarse_ckpt is not None:
|
| 121 |
self.coarse.to("cpu")
|
| 122 |
state_dict = torch.load(coarse_ckpt, map_location="cpu")
|
| 123 |
-
|
| 124 |
self.coarse.load_state_dict(state_dict, strict=False)
|
| 125 |
self.coarse.to(self.device)
|
| 126 |
if c2f_ckpt is not None:
|
| 127 |
self.c2f.to("cpu")
|
| 128 |
state_dict = torch.load(c2f_ckpt, map_location="cpu")
|
| 129 |
-
|
| 130 |
self.c2f.load_state_dict(state_dict, strict=False)
|
| 131 |
self.c2f.to(self.device)
|
| 132 |
|
| 133 |
-
|
| 134 |
def s2t(self, seconds: float):
|
| 135 |
"""seconds to tokens"""
|
| 136 |
if isinstance(seconds, np.ndarray):
|
|
|
|
| 120 |
if coarse_ckpt is not None:
|
| 121 |
self.coarse.to("cpu")
|
| 122 |
state_dict = torch.load(coarse_ckpt, map_location="cpu")
|
| 123 |
+
print(f"loading coarse from {coarse_ckpt}")
|
| 124 |
self.coarse.load_state_dict(state_dict, strict=False)
|
| 125 |
self.coarse.to(self.device)
|
| 126 |
if c2f_ckpt is not None:
|
| 127 |
self.c2f.to("cpu")
|
| 128 |
state_dict = torch.load(c2f_ckpt, map_location="cpu")
|
| 129 |
+
print(f"loading c2f from {c2f_ckpt}")
|
| 130 |
self.c2f.load_state_dict(state_dict, strict=False)
|
| 131 |
self.c2f.to(self.device)
|
| 132 |
|
|
|
|
| 133 |
def s2t(self, seconds: float):
|
| 134 |
"""seconds to tokens"""
|
| 135 |
if isinstance(seconds, np.ndarray):
|