Skip to content
Snippets Groups Projects
Commit db37ed33 authored by Fish, Edward (PG/R - Music & Media)'s avatar Fish, Edward (PG/R - Music & Media)
Browse files

Upload New File

parent 2ca89f7a
No related branches found
No related tags found
No related merge requests found
custom.py 0 → 100644
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 22 13:33:55 2019
@author: ed
"""
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import torchvision.models as models
import scipy
from torch.utils.tensorboard import SummaryWriter
writer=SummaryWriter()
#Hyper parameters
batch_size= 16
transform = transforms.Compose([
transforms.RandomResizedCrop(200),
#transforms.RandomAffine(45),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_dataset = torchvision.datasets.ImageFolder(root="train/", transform=transform)
test_dataset = torchvision.datasets.ImageFolder(root="test/", transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
#Show images
classes= ('35mm', '50mm')
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(train_loader)
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images))
class Unit(nn.Module):
def __init__(self,in_channels,out_channels):
super(Unit,self).__init__()
self.conv = nn.Conv2d(in_channels=in_channels,kernel_size=3,out_channels=out_channels,stride=1,padding=1)
self.bn = nn.BatchNorm2d(num_features=out_channels)
self.relu = nn.ReLU()
def forward(self,input):
output = self.conv(input)
output = self.bn(output)
output = self.relu(output)
return output
class SimpleNet(nn.Module):
def __init__(self,num_classes=2):
super(SimpleNet,self).__init__()
#Create 14 layers of the unit with max pooling in between
self.unit1 = Unit(in_channels=3,out_channels=32)
self.unit2 = Unit(in_channels=32, out_channels=32)
self.unit3 = Unit(in_channels=32, out_channels=32)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.unit4 = Unit(in_channels=32, out_channels=64)
self.unit5 = Unit(in_channels=64, out_channels=128)
self.unit6 = Unit(in_channels=128, out_channels=128)
self.unit7 = Unit(in_channels=128, out_channels=192)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.unit8 = Unit(in_channels=192,out_channels=192)
self.unit9 = Unit(in_channels=192, out_channels=192)
self.unit10 = Unit(in_channels=192,out_channels=256)
self.unit11 = Unit(in_channels=256, out_channels=256)
self.pool3 = nn.MaxPool2d(kernel_size=2)
self.unit12 = Unit(in_channels=256, out_channels=128)
self.unit13 = Unit(in_channels=128, out_channels=128)
self.unit14 = Unit(in_channels=128, out_channels=128)
self.avgpool = nn.AvgPool2d(kernel_size=25)
#Add all the units into the Sequential layer in exact order
self.net = nn.Sequential(self.unit1, self.unit2, self.unit3, self.pool1, self.unit4, self.unit5, self.unit6
,self.unit7, self.pool2, self.unit8, self.unit9, self.unit10, self.unit11, self.pool3,
self.unit12, self.unit13, self.unit14, self.avgpool)
#self.fc = nn.Dropout()
self.f2 = nn.Linear(in_features=128, out_features=num_classes)
def forward(self, input):
output = self.net(input)
output = output.view(-1,128)
#output = self.fc(output)
output = self.f2(output)
return output
def adjust_learning_rate(epoch):
lr = 0.01
if epoch > 5:
lr = lr / 10
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def save_models(epoch):
torch.save(model.state_dict(), "cifar10model_{}.model".format(epoch))
print("Chekcpoint saved")
def train(num_epochs):
best_acc = 0.0
for epoch in range(num_epochs):
model.train()
train_acc = 0.0
train_loss = 0.0
for i, (images, labels) in enumerate(train_loader):
# Move images and labels to gpu if available
images = images.to(device)
labels = labels.to(device)
# Clear all accumulated gradients
optimizer.zero_grad()
# Predict classes using images from the test set
outputs = model(images)
# Compute the loss based on the predictions and actual labels
loss = loss_fn(outputs, labels)
# Backpropagate the loss
loss.backward()
# Adjust parameters according to the computed gradients
optimizer.step()
train_loss += loss.item()*images.size(0)
_, prediction = torch.max(outputs.data, 1)
train_acc += torch.sum(prediction == labels.data)
# Call the learning rate adjustment function
adjust_learning_rate(epoch)
# Compute the average acc and loss over all 50000 training images
train_acc = train_acc / 8400
train_loss = train_loss / 8400
writer.add_scalar('loss-train', train_loss, epoch)
writer.add_scalar('acc-train', train_acc, epoch)
# Evaluate on the test set
test_acc = test(epoch)
# Save the model if the test acc is greater than our current best
if test_acc > best_acc:
save_models(epoch)
best_acc = test_acc
# Print the metrics
print("Epoch {}, Train Accuracy: {} , TrainLoss: {} , Test Accuracy: {}".format(epoch, train_acc, train_loss,test_acc))
def test(e):
model.eval()
test_acc = 0.0
for i, (images, labels) in enumerate(val_loader):
images = images.to(device)
labels = labels.to(device)
#Predict classes using images from the test set
outputs = model(images)
_,prediction = torch.max(outputs.data, 1)
#prediction = prediction.cpu().numpy()
test_acc += torch.sum(prediction == labels.data)
#Compute the average acc and loss over all 10000 test images
test_acc = test_acc / 3600
writer.add_scalar('test_acc', test_acc, e)
return test_acc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNet(num_classes=2)
writer.add_graph(model, images)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
loss_fn = nn.CrossEntropyLoss()
if __name__ == "__main__":
train(200)
writer.close()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment