| import torch |
| import torch.nn as nn |
| from torchvision import transforms, models |
| from torch.utils.data import DataLoader, Subset |
| from torchvision.datasets import ImageFolder |
| from ClassUtils import CrosswalkDataset |
| import numpy as np |
| import random |
| import time |
|
|
|
|
| import warnings |
| |
| warnings.filterwarnings( |
| action='ignore', |
| category=DeprecationWarning, |
| module=r'.*' |
| ) |
|
|
| |
| learning_rate = 4e-3 |
| epoch_num = 25 |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| vgg16 = models.vgg16(weights = models.VGG16_Weights) |
| |
| vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, 2) |
|
|
| |
| |
| |
| |
| |
|
|
| vgg16 = vgg16.to(device) |
| loss_function = nn.BCELoss() |
|
|
| |
| if __name__ == "__main__": |
| |
| optimiser = torch.optim.Adam(params= |
| filter(lambda p: p.requires_grad, vgg16.parameters()), |
| lr=learning_rate) |
|
|
|
|
| training_dataset = CrosswalkDataset("zebra_annotations/classification_data") |
| training_loader = DataLoader(Subset(training_dataset, random.sample(range(len(training_dataset)-1), 25000)), batch_size=128, shuffle=True) |
|
|
| for param in vgg16.features.parameters(): |
| param.requires_grad = False |
|
|
|
|
| vgg16.train() |
| print(len(training_dataset)) |
| for epoch in range(epoch_num): |
| running_loss = 0.0 |
| start_time = time.time() |
| last_time = start_time |
| for images, gt in training_loader: |
| images, gt = images.to(device), gt.to(device) |
|
|
| classifications = torch.sigmoid(vgg16(images)) |
| loss = loss_function(classifications, gt) |
| optimiser.zero_grad() |
| loss.backward() |
| optimiser.step() |
|
|
| batch_time = time.time() |
|
|
| running_loss += loss.item() |
|
|
| last_time = batch_time |
| print(",,, ---") |
|
|
| |
| print(f"\nEpoch {epoch + 1} of {epoch_num} has a per image loss of [{running_loss/len(training_loader):.4f}]") |
| print(f"{(last_time - start_time):.6f}") |
|
|
| |
| torch.save(vgg16.state_dict(), "VGG16_Full_State_Dict.pth") |
| |
| |
| torch.save(vgg16.classifier[6].state_dict(), "vgg16_binary_classifier_onlyHead.pth") |
|
|