Update app.py
Browse files
app.py
CHANGED
|
@@ -19,7 +19,7 @@ import spaces # Para ZeroGPU
|
|
| 19 |
# =============================================================================
|
| 20 |
|
| 21 |
class TextToMusicGenerator(nn.Module):
|
| 22 |
-
"""Arquitetura do modelo Ricco"""
|
| 23 |
|
| 24 |
def __init__(self, config):
|
| 25 |
super().__init__()
|
|
@@ -69,6 +69,29 @@ class TextToMusicGenerator(nn.Module):
|
|
| 69 |
nn.Tanh()
|
| 70 |
)
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
# Congelar BERT
|
| 73 |
for param in self.text_encoder.parameters():
|
| 74 |
param.requires_grad = False
|
|
|
|
| 19 |
# =============================================================================
|
| 20 |
|
| 21 |
class TextToMusicGenerator(nn.Module):
|
| 22 |
+
"""Arquitetura do modelo Ricco (com discriminator para compatibilidade)"""
|
| 23 |
|
| 24 |
def __init__(self, config):
|
| 25 |
super().__init__()
|
|
|
|
| 69 |
nn.Tanh()
|
| 70 |
)
|
| 71 |
|
| 72 |
+
# Discriminator (para compatibilidade com o modelo treinado)
|
| 73 |
+
self.discriminator = nn.Sequential(
|
| 74 |
+
nn.Conv2d(1, 32, 4, stride=2, padding=1),
|
| 75 |
+
nn.LeakyReLU(0.2),
|
| 76 |
+
|
| 77 |
+
nn.Conv2d(32, 64, 4, stride=2, padding=1),
|
| 78 |
+
nn.BatchNorm2d(64),
|
| 79 |
+
nn.LeakyReLU(0.2),
|
| 80 |
+
|
| 81 |
+
nn.Conv2d(64, 128, 4, stride=2, padding=1),
|
| 82 |
+
nn.BatchNorm2d(128),
|
| 83 |
+
nn.LeakyReLU(0.2),
|
| 84 |
+
|
| 85 |
+
nn.Conv2d(128, 256, 4, stride=2, padding=1),
|
| 86 |
+
nn.BatchNorm2d(256),
|
| 87 |
+
nn.LeakyReLU(0.2),
|
| 88 |
+
|
| 89 |
+
nn.AdaptiveAvgPool2d(1),
|
| 90 |
+
nn.Flatten(),
|
| 91 |
+
nn.Linear(256, 1),
|
| 92 |
+
nn.Sigmoid()
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
# Congelar BERT
|
| 96 |
for param in self.text_encoder.parameters():
|
| 97 |
param.requires_grad = False
|