|
|
import os |
|
|
import gdown |
|
|
import zipfile |
|
|
import shutil |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchvision.datasets as datasets |
|
|
import torchvision.transforms as transforms |
|
|
from torch.utils.data import DataLoader |
|
|
import time |
|
|
import modules.model as model |
|
|
|
|
|
|
|
|
if os.path.exists('celeba/') == False: |
|
|
url = 'https://drive.google.com/file/d/13vkq4tFCPE8O78KTj84HHM6kBnYkt8gP/view?usp=sharing' |
|
|
output = 'download.zip' |
|
|
gdown.download(url, output, fuzzy=True) |
|
|
|
|
|
with zipfile.ZipFile(output, 'r') as zip_ref: |
|
|
zip_ref.extractall() |
|
|
|
|
|
os.remove(output) |
|
|
shutil.rmtree('__MACOSX') |
|
|
|
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
|
device = torch.device('mps') |
|
|
device_name = 'Apple Silicon GPU' |
|
|
elif torch.cuda.is_available(): |
|
|
device = torch.device('cuda') |
|
|
device_name = 'CUDA' |
|
|
else: |
|
|
device = torch.device('cpu') |
|
|
device_name = 'CPU' |
|
|
|
|
|
torch.set_default_device(device) |
|
|
|
|
|
print(f'\nDevice: {device_name}') |
|
|
|
|
|
|
|
|
imsize = int(128/0.8) |
|
|
batch_size = 10 |
|
|
|
|
|
fivecrop_transform = transforms.Compose([ |
|
|
transforms.Resize([imsize, imsize]), |
|
|
transforms.Grayscale(1), |
|
|
transforms.FiveCrop(int(imsize*0.8)), |
|
|
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), |
|
|
transforms.Normalize(0, 1) |
|
|
]) |
|
|
|
|
|
train_dataset = datasets.CelebA( |
|
|
root='', |
|
|
split='all', |
|
|
target_type='attr', |
|
|
transform=fivecrop_transform, |
|
|
download=True, |
|
|
) |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
generator=torch.Generator(device=device) |
|
|
) |
|
|
|
|
|
|
|
|
factor = 20 |
|
|
|
|
|
|
|
|
torch.manual_seed(2687) |
|
|
resnet = model.resnetModel_128() |
|
|
criterion = nn.CrossEntropyLoss() |
|
|
optimizer = torch.optim.SGD( |
|
|
resnet.parameters(), |
|
|
lr=0.01, |
|
|
momentum=0.9, |
|
|
weight_decay=0.001 |
|
|
) |
|
|
scheduler = torch.optim.lr_scheduler.StepLR( |
|
|
optimizer=optimizer, |
|
|
step_size=1, |
|
|
gamma=0.1 |
|
|
) |
|
|
|
|
|
def mins_to_hours(mins): |
|
|
hours = int(mins/60) |
|
|
rem_mins = mins % 60 |
|
|
return hours, rem_mins |
|
|
|
|
|
epochs = 2 |
|
|
train_losses = [] |
|
|
train_accuracy = [] |
|
|
for i in range(epochs): |
|
|
epoch_time = 0 |
|
|
|
|
|
for j, (X_train, y_train) in enumerate(train_loader): |
|
|
batch_start = time.time() |
|
|
|
|
|
X_train = X_train.to(device) |
|
|
y_train = y_train[:, factor] |
|
|
|
|
|
bs, ncrops, c, h, w = X_train.size() |
|
|
y_pred_crops = resnet.forward(X_train.view(-1, c, h, w)) |
|
|
y_pred = y_pred_crops.view(bs, ncrops, -1).mean(1) |
|
|
|
|
|
loss = criterion(y_pred, y_train) |
|
|
|
|
|
predicted = torch.max(y_pred.data, 1)[1] |
|
|
train_batch_accuracy = (predicted == y_train).sum()/len(X_train) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
train_losses.append(loss.item()) |
|
|
train_accuracy.append(train_batch_accuracy.item()) |
|
|
|
|
|
batch_end = time.time() |
|
|
|
|
|
batch_time = batch_end - batch_start |
|
|
epoch_time += batch_time |
|
|
avg_batch_time = epoch_time/(j+1) |
|
|
batches_remaining = len(train_loader)-(j+1) |
|
|
epoch_mins_remaining = round(batches_remaining*avg_batch_time/60) |
|
|
epoch_time_remaining = mins_to_hours(epoch_mins_remaining) |
|
|
|
|
|
full_epoch = avg_batch_time*len(train_loader) |
|
|
epochs_remaining = epochs-(i+1) |
|
|
rem_epoch_mins_remaining = epoch_mins_remaining+round(full_epoch*epochs_remaining/60) |
|
|
rem_epoch_time_remaining = mins_to_hours(rem_epoch_mins_remaining) |
|
|
|
|
|
if (j+1) % 10 == 0: |
|
|
print(f'\nEpoch: {i+1}/{epochs} | Train Batch: {j+1}/{len(train_loader)}') |
|
|
print(f'Current epoch: {epoch_time_remaining[0]} hours {epoch_time_remaining[1]} minutes') |
|
|
print(f'Remaining epochs: {rem_epoch_time_remaining[0]} hours {rem_epoch_time_remaining[1]} minutes') |
|
|
print(f'Train Loss: {loss}') |
|
|
print(f'Train Accuracy: {train_batch_accuracy}') |
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
trained_model_name = resnet.model_name + '_epoch_' + str(i+1) + '.pt' |
|
|
torch.save( |
|
|
resnet.state_dict(), |
|
|
trained_model_name |
|
|
) |