diff --git a/.gitignore b/.gitignore index 9b464dbacc4a1530bd63d383f52a1ca719a97cf9..4c763f2015fda2da13d32204d82ca54e3aeca567 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ __pycache__/constants.cpython-39.pyc __pycache__/datahandler.cpython-39.pyc __pycache__/seg_model.cpython-39.pyc __pycache__/train_seg.cpython-39.pyc +__pycache__/* diff --git a/Models/densenet121_run20.pt b/Models/densenet121_run20.pt new file mode 100644 index 0000000000000000000000000000000000000000..4d269e93b377dfc128fb694988681890632b07fe Binary files /dev/null and b/Models/densenet121_run20.pt differ diff --git a/Models/densenet121_run4.pt b/Models/densenet121_run4.pt new file mode 100644 index 0000000000000000000000000000000000000000..5aaa5af75e05cea30e783090edc02b22a4d84184 Binary files /dev/null and b/Models/densenet121_run4.pt differ diff --git a/Models/inception_run4.pt b/Models/inception_run4.pt new file mode 100644 index 0000000000000000000000000000000000000000..7b575182cbeda44d0d1684235d8383dd239430c9 Binary files /dev/null and b/Models/inception_run4.pt differ diff --git a/Models/resnet50_run4.pt b/Models/resnet50_run4.pt new file mode 100644 index 0000000000000000000000000000000000000000..aadd0d480285444c12c23069a5530eb5186d5406 Binary files /dev/null and b/Models/resnet50_run4.pt differ diff --git a/Models/vgg11_run1.pt b/Models/vgg11_run1.pt new file mode 100644 index 0000000000000000000000000000000000000000..17d0dae81b603cf8cd7e8f25718c0e4f6a067744 Binary files /dev/null and b/Models/vgg11_run1.pt differ diff --git a/__pycache__/constants.cpython-310.pyc b/__pycache__/constants.cpython-310.pyc index e3a94dd655e78ec3adcaf3f9ef47ba1d619aacbf..fc700c06791fde09be854713971f34fdacc850a8 100644 Binary files a/__pycache__/constants.cpython-310.pyc and b/__pycache__/constants.cpython-310.pyc differ diff --git a/__pycache__/datahandler.cpython-310.pyc b/__pycache__/datahandler.cpython-310.pyc index 3cace1b90a1b9fc1439e90c026c440b33b8ff60a..2a8d3044a04993561a415da14eba078ba54718ae 100644 Binary files a/__pycache__/datahandler.cpython-310.pyc and b/__pycache__/datahandler.cpython-310.pyc differ diff --git a/constants.py b/constants.py index 95535283fb47cde587cba22fe0688bd5ecbd8265..4e1b745b4e333c76d1b415ae4a9e94577f1b0e7c 100644 --- a/constants.py +++ b/constants.py @@ -1,5 +1,12 @@ -EPOCHS = 50 -BATCH_SIZE = 64 -INPUT_DIM = 128 +EPOCHS = 20 +BATCH_SIZE = 32 VAL_SPLIT = 2000 -TEST_SPLIT = 500 \ No newline at end of file +TEST_SPLIT = 500 + +# Model name can be one of the following: +# resnet50 +# vgg11 +# densenet121 +# inception +MODEL_NAME = "densenet121" +INPUT_DIM = 224 # leave as 224 for all models except inception \ No newline at end of file diff --git a/datahandler.py b/datahandler.py index 19901357f812e63c5924bf293eb2b8abfcae21a4..12df4603eda4209e78587ff2dd9e158612d7a20f 100644 --- a/datahandler.py +++ b/datahandler.py @@ -23,7 +23,6 @@ lesion_type_dict = { # returns trainLoader, valLoader def getHamDataLoaders(): - # Firstly returns two dataframes, one with all images and another with just the lesion_ids containing a single image (for valuating models) df_ham, df_ham_single_image = setupHamData() diff --git a/main.py b/main.py index 5ef1384d603afd4220a0d1d711e7f52ef4b8b994..baf6b731ab23b3a4383c8f7f6db53545aed7ab1d 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import sys import matplotlib.pyplot as plt import numpy as np @@ -5,11 +6,14 @@ from torchsummary import summary import datahandler import train_seg import seg_model -from constants import EPOCHS - +from constants import EPOCHS, MODEL_NAME +from train import train +from model import get_model def main(): - pass + model = get_model(MODEL_NAME, 7, False) + train_loader, val_loader = datahandler.getHamDataLoaders() + train(train_loader, val_loader, model, EPOCHS) def seg_main(): train_loader, val_loader = datahandler.getHamDataLoadersSeg() @@ -17,7 +21,19 @@ def seg_main(): #summary(model, (3,128,128), batch_size=16,device="cpu") train_seg.train_epochs(model, train_loader, val_loader, EPOCHS) +# When run from command line it can take an additional argument: +# if you add the additional argument with parameter 'seg' then it'll run the segmentation training loop +# otherwise it'll just run the main model training loop if __name__ == '__main__': - seg_main() + if len(sys.argv) > 1: + if sys.argv[1] == "seg": + print("SEGMENTATION") + seg_main() + else: + print("NORMAL") + main() + else: + print("NORMAL") + main() #FYI training samples for all but the 'melanocytic nevi' have been equalised (samples duplicated) so they're an even distribution # Before and after prints of the tables are uncommented to show (lines 142 and 154 in datahandler) \ No newline at end of file diff --git a/model.py b/model.py index b27247ec616d11276b6f9f7739bd44ec0bf95c3b..b892d910803fa19eeab6d42d9defcee791050ae0 100644 --- a/model.py +++ b/model.py @@ -1,3 +1,35 @@ -def getModel(): - model = None +from torch import nn +from torchvision import models + +def set_finetuning(model, is_finetuning): + for param in model.parameters(): + param.required_grad = not is_finetuning + +def get_model(name, num_classes, is_finetuning): + if name == "resnet50": + model = models.resnet50(pretrained=True) + set_finetuning(model, is_finetuning) + in_features = model.fc.in_features + model.fc = nn.Linear(in_features, num_classes) + elif name == "vgg11": + model = models.vgg11_bn(pretrained=True) + set_finetuning(model, is_finetuning) + in_features = model.classifier[6].in_features + model.classifier[6] = nn.Linear(in_features, num_classes) + elif name == "densenet121": + model = models.densenet121(pretrained=True) + set_finetuning(model, is_finetuning) + in_features = model.classifier.in_features + model.classifier = nn.Linear(in_features, num_classes) + elif name == "inception": + model = models.inception_v3(pretrained=True) + set_finetuning(model, is_finetuning) + + in_features = model.AuxLogits.fc.in_features + model.AuxLogits.fc = nn.Linear(in_features, num_classes) + in_features = model.fc.in_features + model.fc = nn.Linear(in_features, num_classes) + else: + model = None + print("No model with that name, things are about to go wrong!") return model \ No newline at end of file diff --git a/train.py b/train.py index 84e08f4c718d988abfb50733e059a5e4fe01f1f9..4684edc674f4df522b759490383b16b6e3bb39d7 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,103 @@ -def train(trainLoader, valLoader, model): - trainStats = [] - return trainStats +import torch +from torch import nn, optim +from tqdm import tqdm + +from torch.autograd import Variable +from constants import MODEL_NAME + +# An array of arrays containing the loss and accuracy of each epoch +train_stats = [] +val_stats = [] + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +criterion = nn.CrossEntropyLoss().to(device) + +class AverageMeter(object): + def __init__(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def train(train_loader, val_loader, model, num_epochs): + optimizer = optim.Adam(model.parameters(), lr=1e-3) + + best_validation_loss = 999999 # arbitrarily large number + + for epoch in range(1, num_epochs+1): + print(f"EPOCH {epoch}") + print(f"-" * 40) + loss_train, acc_train = train_epoch(train_loader, model, optimizer, epoch) + print(f"Training | Loss {loss_train} | Accuracy {acc_train:.3%}") + loss_val, acc_val = validate(val_loader, model, epoch) + print(f"Validation | Loss {loss_val} | Accuracy {acc_val:.3%}") + + if loss_val < best_validation_loss: + best_validation_loss = loss_val + print(f"\tNEW BEST | LOSS {loss_val} | ACCURACY {acc_val:.3%}") + torch.save(model.state_dict(), f"Models/{MODEL_NAME}_run{num_epochs}.pt") + + print(f"-" * 40 + "\n") + + return train_stats, val_stats + +def train_epoch(train_loader, model, optimizer, epoch_num): + model.train() + model.to(device) + + train_loss = AverageMeter() + train_accuracy = AverageMeter() + + loop = tqdm(train_loader, desc="\tTraining", ncols=80, mininterval=0.1) + for images, labels in loop: + images, labels = Variable(images).to(device), Variable(labels).to(device) + N = images.size(0) + + optimizer.zero_grad() + outputs = model(images) + + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + prediction = outputs.max(1, keepdim=True)[1] + + train_accuracy.update(prediction.eq(labels.view_as(prediction)).sum().item() / N) + train_loss.update(loss.item()) + + loop.set_postfix(loss=loss.item(), accuracy=train_accuracy.avg) + + train_stats.append([train_loss.avg, train_accuracy.avg]) + return train_loss.avg, train_accuracy.avg + +def validate(val_loader, model, epoch_num): + model.eval() + val_loss = AverageMeter() + val_acc = AverageMeter() + + with torch.no_grad(): + + loop = tqdm(val_loader, desc="\tValidation", ncols=80, mininterval=0.1) + for images, labels in loop: + images, labels = Variable(images).to(device), Variable(labels).to(device) + N = images.size(0) + + outputs = model(images) + prediction= outputs.max(1, keepdim=True)[1] + + loss = criterion(outputs, labels) + + val_acc.update(prediction.eq(labels.view_as(prediction)).sum().item()/N) + val_loss.update(loss.item()) + + loop.set_postfix(loss=loss.item(), accuracy=val_acc.avg) + + val_stats.append([val_loss.avg, val_acc.avg]) + return val_loss.avg, val_acc.avg