Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from ..misc import MetricLogger, SmoothedValue, reduce_dict | |
| def train_one_epoch( | |
| model: nn.Module, criterion: nn.Module, dataloader, optimizer, ema, epoch, device | |
| ): | |
| """ """ | |
| model.train() | |
| metric_logger = MetricLogger(delimiter=" ") | |
| metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) | |
| print_freq = 100 | |
| header = "Epoch: [{}]".format(epoch) | |
| for imgs, labels in metric_logger.log_every(dataloader, print_freq, header): | |
| imgs = imgs.to(device) | |
| labels = labels.to(device) | |
| preds = model(imgs) | |
| loss: torch.Tensor = criterion(preds, labels, epoch) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if ema is not None: | |
| ema.update(model) | |
| loss_reduced_values = {k: v.item() for k, v in reduce_dict({"loss": loss}).items()} | |
| metric_logger.update(**loss_reduced_values) | |
| metric_logger.update(lr=optimizer.param_groups[0]["lr"]) | |
| metric_logger.synchronize_between_processes() | |
| print("Averaged stats:", metric_logger) | |
| stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} | |
| return stats | |
| def evaluate(model, criterion, dataloader, device): | |
| model.eval() | |
| metric_logger = MetricLogger(delimiter=" ") | |
| # metric_logger.add_meter('acc', SmoothedValue(window_size=1, fmt='{global_avg:.4f}')) | |
| # metric_logger.add_meter('loss', SmoothedValue(window_size=1, fmt='{value:.2f}')) | |
| metric_logger.add_meter("acc", SmoothedValue(window_size=1)) | |
| metric_logger.add_meter("loss", SmoothedValue(window_size=1)) | |
| header = "Test:" | |
| for imgs, labels in metric_logger.log_every(dataloader, 10, header): | |
| imgs, labels = imgs.to(device), labels.to(device) | |
| preds = model(imgs) | |
| acc = (preds.argmax(dim=-1) == labels).sum() / preds.shape[0] | |
| loss = criterion(preds, labels) | |
| dict_reduced = reduce_dict({"acc": acc, "loss": loss}) | |
| reduced_values = {k: v.item() for k, v in dict_reduced.items()} | |
| metric_logger.update(**reduced_values) | |
| metric_logger.synchronize_between_processes() | |
| print("Averaged stats:", metric_logger) | |
| stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} | |
| return stats | |