Spaces:
Runtime error
Runtime error
Commit
·
32701af
1
Parent(s):
568c3f1
Fix classifier for non-panns models
Browse files- remfx/models.py +95 -29
remfx/models.py
CHANGED
|
@@ -12,6 +12,7 @@ from remfx.utils import spectrogram
|
|
| 12 |
from remfx.tcn import TCN
|
| 13 |
from remfx.utils import causal_crop
|
| 14 |
from remfx import effects
|
|
|
|
| 15 |
import asteroid
|
| 16 |
import random
|
| 17 |
|
|
@@ -438,19 +439,54 @@ class FXClassifier(pl.LightningModule):
|
|
| 438 |
self.mixup = mixup
|
| 439 |
self.label_smoothing = label_smoothing
|
| 440 |
|
| 441 |
-
self.
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
self.metrics
|
| 445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
)
|
| 447 |
-
|
| 448 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
)
|
| 450 |
-
self.
|
| 451 |
-
|
| 452 |
)
|
| 453 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
def forward(self, x: torch.Tensor, train: bool = False):
|
| 455 |
return self.network(x, train=train)
|
| 456 |
|
|
@@ -467,8 +503,13 @@ class FXClassifier(pl.LightningModule):
|
|
| 467 |
else:
|
| 468 |
outputs = self(x, train)
|
| 469 |
loss = 0
|
| 470 |
-
|
| 471 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
|
| 473 |
self.log(
|
| 474 |
f"{mode}_loss",
|
|
@@ -480,32 +521,57 @@ class FXClassifier(pl.LightningModule):
|
|
| 480 |
sync_dist=True,
|
| 481 |
)
|
| 482 |
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
self.log(
|
| 489 |
-
f"{mode}
|
| 490 |
-
|
| 491 |
on_step=True,
|
| 492 |
on_epoch=True,
|
| 493 |
prog_bar=True,
|
| 494 |
logger=True,
|
| 495 |
sync_dist=True,
|
| 496 |
)
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
return loss
|
| 510 |
|
| 511 |
def training_step(self, batch, batch_idx):
|
|
|
|
| 12 |
from remfx.tcn import TCN
|
| 13 |
from remfx.utils import causal_crop
|
| 14 |
from remfx import effects
|
| 15 |
+
from remfx.classifier import Cnn14
|
| 16 |
import asteroid
|
| 17 |
import random
|
| 18 |
|
|
|
|
| 439 |
self.mixup = mixup
|
| 440 |
self.label_smoothing = label_smoothing
|
| 441 |
|
| 442 |
+
if isinstance(self.network, Cnn14):
|
| 443 |
+
self.loss_fn = torch.nn.BCELoss()
|
| 444 |
+
|
| 445 |
+
self.metrics = torch.nn.ModuleDict()
|
| 446 |
+
for effect in self.effects:
|
| 447 |
+
self.metrics[
|
| 448 |
+
f"train_{effect}_acc"
|
| 449 |
+
] = torchmetrics.classification.Accuracy(task="binary")
|
| 450 |
+
self.metrics[
|
| 451 |
+
f"valid_{effect}_acc"
|
| 452 |
+
] = torchmetrics.classification.Accuracy(task="binary")
|
| 453 |
+
self.metrics[
|
| 454 |
+
f"test_{effect}_acc"
|
| 455 |
+
] = torchmetrics.classification.Accuracy(task="binary")
|
| 456 |
+
else:
|
| 457 |
+
self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
| 458 |
+
self.train_f1 = torchmetrics.classification.MultilabelF1Score(
|
| 459 |
+
5, average="none", multidim_average="global"
|
| 460 |
+
)
|
| 461 |
+
self.val_f1 = torchmetrics.classification.MultilabelF1Score(
|
| 462 |
+
5, average="none", multidim_average="global"
|
| 463 |
+
)
|
| 464 |
+
self.test_f1 = torchmetrics.classification.MultilabelF1Score(
|
| 465 |
+
5, average="none", multidim_average="global"
|
| 466 |
)
|
| 467 |
+
|
| 468 |
+
self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
| 469 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
| 470 |
+
)
|
| 471 |
+
self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
| 472 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
| 473 |
)
|
| 474 |
+
self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
| 475 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
| 476 |
)
|
| 477 |
|
| 478 |
+
self.metrics = {
|
| 479 |
+
"train": self.train_f1,
|
| 480 |
+
"valid": self.val_f1,
|
| 481 |
+
"test": self.test_f1,
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
self.avg_metrics = {
|
| 485 |
+
"train": self.train_f1_avg,
|
| 486 |
+
"valid": self.val_f1_avg,
|
| 487 |
+
"test": self.test_f1_avg,
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
def forward(self, x: torch.Tensor, train: bool = False):
|
| 491 |
return self.network(x, train=train)
|
| 492 |
|
|
|
|
| 503 |
else:
|
| 504 |
outputs = self(x, train)
|
| 505 |
loss = 0
|
| 506 |
+
# Multi-head binary loss
|
| 507 |
+
if isinstance(self.network, Cnn14):
|
| 508 |
+
for idx, output in enumerate(outputs):
|
| 509 |
+
loss += self.loss_fn(output.squeeze(-1), wet_label[..., idx])
|
| 510 |
+
else:
|
| 511 |
+
# Output is a 2d tensor
|
| 512 |
+
loss = self.loss_fn(outputs, wet_label)
|
| 513 |
|
| 514 |
self.log(
|
| 515 |
f"{mode}_loss",
|
|
|
|
| 521 |
sync_dist=True,
|
| 522 |
)
|
| 523 |
|
| 524 |
+
if isinstance(self.network, Cnn14):
|
| 525 |
+
acc_metrics = []
|
| 526 |
+
for idx, effect_name in enumerate(self.effects):
|
| 527 |
+
acc_metric = self.metrics[f"{mode}_{effect_name}_acc"](
|
| 528 |
+
outputs[idx].squeeze(-1), wet_label[..., idx]
|
| 529 |
+
)
|
| 530 |
+
self.log(
|
| 531 |
+
f"{mode}_{effect_name}_acc",
|
| 532 |
+
acc_metric,
|
| 533 |
+
on_step=True,
|
| 534 |
+
on_epoch=True,
|
| 535 |
+
prog_bar=True,
|
| 536 |
+
logger=True,
|
| 537 |
+
sync_dist=True,
|
| 538 |
+
)
|
| 539 |
+
acc_metrics.append(acc_metric)
|
| 540 |
+
|
| 541 |
self.log(
|
| 542 |
+
f"{mode}_avg_acc",
|
| 543 |
+
torch.mean(torch.stack(acc_metrics)),
|
| 544 |
on_step=True,
|
| 545 |
on_epoch=True,
|
| 546 |
prog_bar=True,
|
| 547 |
logger=True,
|
| 548 |
sync_dist=True,
|
| 549 |
)
|
| 550 |
+
else:
|
| 551 |
+
metrics = self.metrics[mode](torch.sigmoid(outputs), wet_label.long())
|
| 552 |
+
for idx, effect_name in enumerate(self.effects):
|
| 553 |
+
self.log(
|
| 554 |
+
f"{mode}_f1_{effect_name}",
|
| 555 |
+
metrics[idx],
|
| 556 |
+
on_step=True,
|
| 557 |
+
on_epoch=True,
|
| 558 |
+
prog_bar=True,
|
| 559 |
+
logger=True,
|
| 560 |
+
sync_dist=True,
|
| 561 |
+
)
|
| 562 |
+
avg_metrics = self.avg_metrics[mode](
|
| 563 |
+
torch.sigmoid(outputs), wet_label.long()
|
| 564 |
+
)
|
| 565 |
|
| 566 |
+
self.log(
|
| 567 |
+
f"{mode}_avg_acc",
|
| 568 |
+
avg_metrics,
|
| 569 |
+
on_step=True,
|
| 570 |
+
on_epoch=True,
|
| 571 |
+
prog_bar=True,
|
| 572 |
+
logger=True,
|
| 573 |
+
sync_dist=True,
|
| 574 |
+
)
|
| 575 |
return loss
|
| 576 |
|
| 577 |
def training_step(self, batch, batch_idx):
|