Spaces:
Runtime error
Runtime error
Hugo Flores Garcia
commited on
Commit
·
93ca721
1
Parent(s):
a004369
LAC!
Browse files- .gitignore +2 -1
- scripts/exp/train.py +2 -2
- vampnet/interface.py +2 -2
.gitignore
CHANGED
|
@@ -179,4 +179,5 @@ gradio-outputs/
|
|
| 179 |
models/
|
| 180 |
samples*/
|
| 181 |
models-all/
|
| 182 |
-
models.zip
|
|
|
|
|
|
| 179 |
models/
|
| 180 |
samples*/
|
| 181 |
models-all/
|
| 182 |
+
models.zip
|
| 183 |
+
.git-old
|
scripts/exp/train.py
CHANGED
|
@@ -20,7 +20,7 @@ import vampnet
|
|
| 20 |
from vampnet.modules.transformer import VampNet
|
| 21 |
from vampnet.util import codebook_unflatten, codebook_flatten
|
| 22 |
from vampnet import mask as pmask
|
| 23 |
-
from
|
| 24 |
|
| 25 |
|
| 26 |
# Enable cudnn autotuner to speed up training
|
|
@@ -109,7 +109,7 @@ def load(
|
|
| 109 |
load_weights: bool = False,
|
| 110 |
fine_tune_checkpoint: Optional[str] = None,
|
| 111 |
):
|
| 112 |
-
codec =
|
| 113 |
codec.eval()
|
| 114 |
|
| 115 |
model, v_extra = None, {}
|
|
|
|
| 20 |
from vampnet.modules.transformer import VampNet
|
| 21 |
from vampnet.util import codebook_unflatten, codebook_flatten
|
| 22 |
from vampnet import mask as pmask
|
| 23 |
+
from dac.model.dac import DAC
|
| 24 |
|
| 25 |
|
| 26 |
# Enable cudnn autotuner to speed up training
|
|
|
|
| 109 |
load_weights: bool = False,
|
| 110 |
fine_tune_checkpoint: Optional[str] = None,
|
| 111 |
):
|
| 112 |
+
codec = DAC.load(args["codec_ckpt"], map_location="cpu")
|
| 113 |
codec.eval()
|
| 114 |
|
| 115 |
model, v_extra = None, {}
|
vampnet/interface.py
CHANGED
|
@@ -11,7 +11,7 @@ from .modules.transformer import VampNet
|
|
| 11 |
from .beats import WaveBeat
|
| 12 |
from .mask import *
|
| 13 |
|
| 14 |
-
from
|
| 15 |
|
| 16 |
|
| 17 |
def signal_concat(
|
|
@@ -63,7 +63,7 @@ class Interface(torch.nn.Module):
|
|
| 63 |
):
|
| 64 |
super().__init__()
|
| 65 |
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
| 66 |
-
self.codec =
|
| 67 |
self.codec.eval()
|
| 68 |
self.codec.to(device)
|
| 69 |
|
|
|
|
| 11 |
from .beats import WaveBeat
|
| 12 |
from .mask import *
|
| 13 |
|
| 14 |
+
from dac.model.dac import DAC
|
| 15 |
|
| 16 |
|
| 17 |
def signal_concat(
|
|
|
|
| 63 |
):
|
| 64 |
super().__init__()
|
| 65 |
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
| 66 |
+
self.codec = DAC.load(Path(codec_ckpt))
|
| 67 |
self.codec.eval()
|
| 68 |
self.codec.to(device)
|
| 69 |
|