diff --git a/.gitignore b/.gitignore
index 9b464dbacc4a1530bd63d383f52a1ca719a97cf9..4c763f2015fda2da13d32204d82ca54e3aeca567 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,3 +2,4 @@ __pycache__/constants.cpython-39.pyc
 __pycache__/datahandler.cpython-39.pyc
 __pycache__/seg_model.cpython-39.pyc
 __pycache__/train_seg.cpython-39.pyc
+__pycache__/*
diff --git a/Models/densenet121_run20.pt b/Models/densenet121_run20.pt
new file mode 100644
index 0000000000000000000000000000000000000000..4d269e93b377dfc128fb694988681890632b07fe
Binary files /dev/null and b/Models/densenet121_run20.pt differ
diff --git a/Models/densenet121_run4.pt b/Models/densenet121_run4.pt
new file mode 100644
index 0000000000000000000000000000000000000000..5aaa5af75e05cea30e783090edc02b22a4d84184
Binary files /dev/null and b/Models/densenet121_run4.pt differ
diff --git a/Models/inception_run4.pt b/Models/inception_run4.pt
new file mode 100644
index 0000000000000000000000000000000000000000..7b575182cbeda44d0d1684235d8383dd239430c9
Binary files /dev/null and b/Models/inception_run4.pt differ
diff --git a/Models/resnet50_run4.pt b/Models/resnet50_run4.pt
new file mode 100644
index 0000000000000000000000000000000000000000..aadd0d480285444c12c23069a5530eb5186d5406
Binary files /dev/null and b/Models/resnet50_run4.pt differ
diff --git a/Models/vgg11_run1.pt b/Models/vgg11_run1.pt
new file mode 100644
index 0000000000000000000000000000000000000000..17d0dae81b603cf8cd7e8f25718c0e4f6a067744
Binary files /dev/null and b/Models/vgg11_run1.pt differ
diff --git a/__pycache__/constants.cpython-310.pyc b/__pycache__/constants.cpython-310.pyc
index e3a94dd655e78ec3adcaf3f9ef47ba1d619aacbf..fc700c06791fde09be854713971f34fdacc850a8 100644
Binary files a/__pycache__/constants.cpython-310.pyc and b/__pycache__/constants.cpython-310.pyc differ
diff --git a/__pycache__/datahandler.cpython-310.pyc b/__pycache__/datahandler.cpython-310.pyc
index 3cace1b90a1b9fc1439e90c026c440b33b8ff60a..2a8d3044a04993561a415da14eba078ba54718ae 100644
Binary files a/__pycache__/datahandler.cpython-310.pyc and b/__pycache__/datahandler.cpython-310.pyc differ
diff --git a/constants.py b/constants.py
index 95535283fb47cde587cba22fe0688bd5ecbd8265..4e1b745b4e333c76d1b415ae4a9e94577f1b0e7c 100644
--- a/constants.py
+++ b/constants.py
@@ -1,5 +1,12 @@
-EPOCHS = 50
-BATCH_SIZE = 64
-INPUT_DIM = 128
+EPOCHS = 20
+BATCH_SIZE = 32
 VAL_SPLIT = 2000
-TEST_SPLIT = 500
\ No newline at end of file
+TEST_SPLIT = 500
+
+# Model name can be one of the following:
+# resnet50
+# vgg11
+# densenet121
+# inception
+MODEL_NAME = "densenet121"
+INPUT_DIM = 224 # leave as 224 for all models except inception
\ No newline at end of file
diff --git a/datahandler.py b/datahandler.py
index 19901357f812e63c5924bf293eb2b8abfcae21a4..12df4603eda4209e78587ff2dd9e158612d7a20f 100644
--- a/datahandler.py
+++ b/datahandler.py
@@ -23,7 +23,6 @@ lesion_type_dict = {
 
 # returns trainLoader, valLoader
 def getHamDataLoaders():
-
     # Firstly returns two dataframes, one with all images and another with just the lesion_ids containing a single image (for valuating models)
     df_ham, df_ham_single_image = setupHamData()
 
diff --git a/main.py b/main.py
index 5ef1384d603afd4220a0d1d711e7f52ef4b8b994..baf6b731ab23b3a4383c8f7f6db53545aed7ab1d 100644
--- a/main.py
+++ b/main.py
@@ -1,3 +1,4 @@
+import sys
 import matplotlib.pyplot as plt
 import numpy as np
 
@@ -5,11 +6,14 @@ from torchsummary import summary
 import datahandler
 import train_seg
 import seg_model
-from constants import EPOCHS
-
+from constants import EPOCHS, MODEL_NAME
+from train import train
+from model import get_model
 
 def main():
-    pass
+    model = get_model(MODEL_NAME, 7, False)
+    train_loader, val_loader = datahandler.getHamDataLoaders()
+    train(train_loader, val_loader, model, EPOCHS)
 
 def seg_main():
     train_loader, val_loader = datahandler.getHamDataLoadersSeg()
@@ -17,7 +21,19 @@ def seg_main():
     #summary(model, (3,128,128), batch_size=16,device="cpu")
     train_seg.train_epochs(model, train_loader, val_loader, EPOCHS)
 
+# When run from command line it can take an additional argument:
+# if you add the additional argument with parameter 'seg' then it'll run the segmentation training loop
+# otherwise it'll just run the main model training loop
 if __name__ == '__main__':
-    seg_main()
+    if len(sys.argv) > 1:
+        if sys.argv[1] == "seg":
+            print("SEGMENTATION")
+            seg_main()
+        else:
+            print("NORMAL")
+            main()
+    else:
+        print("NORMAL")
+        main()
 #FYI training samples for all but the 'melanocytic nevi' have been equalised (samples duplicated) so they're an even distribution
 # Before and after prints of the tables are uncommented to show (lines 142 and 154 in datahandler)
\ No newline at end of file
diff --git a/model.py b/model.py
index b27247ec616d11276b6f9f7739bd44ec0bf95c3b..b892d910803fa19eeab6d42d9defcee791050ae0 100644
--- a/model.py
+++ b/model.py
@@ -1,3 +1,35 @@
-def getModel():
-    model = None
+from torch import nn
+from torchvision import models
+
+def set_finetuning(model, is_finetuning):
+    for param in model.parameters():
+        param.required_grad = not is_finetuning
+
+def get_model(name, num_classes, is_finetuning):
+    if name == "resnet50":
+        model = models.resnet50(pretrained=True)
+        set_finetuning(model, is_finetuning)
+        in_features = model.fc.in_features
+        model.fc = nn.Linear(in_features, num_classes)
+    elif name == "vgg11":
+        model = models.vgg11_bn(pretrained=True)
+        set_finetuning(model, is_finetuning)
+        in_features = model.classifier[6].in_features
+        model.classifier[6] = nn.Linear(in_features, num_classes)
+    elif name == "densenet121":
+        model = models.densenet121(pretrained=True)
+        set_finetuning(model, is_finetuning)
+        in_features = model.classifier.in_features
+        model.classifier = nn.Linear(in_features, num_classes)
+    elif name == "inception":
+        model = models.inception_v3(pretrained=True)
+        set_finetuning(model, is_finetuning)
+
+        in_features = model.AuxLogits.fc.in_features
+        model.AuxLogits.fc = nn.Linear(in_features, num_classes)
+        in_features = model.fc.in_features
+        model.fc = nn.Linear(in_features, num_classes)
+    else:
+        model = None
+        print("No model with that name, things are about to go wrong!")
     return model
\ No newline at end of file
diff --git a/train.py b/train.py
index 84e08f4c718d988abfb50733e059a5e4fe01f1f9..4684edc674f4df522b759490383b16b6e3bb39d7 100644
--- a/train.py
+++ b/train.py
@@ -1,3 +1,103 @@
-def train(trainLoader, valLoader, model):
-    trainStats = []
-    return trainStats
+import torch
+from torch import nn, optim
+from tqdm import tqdm
+
+from torch.autograd import Variable
+from constants import MODEL_NAME
+
+# An array of arrays containing the loss and accuracy of each epoch
+train_stats = []
+val_stats = []
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+criterion = nn.CrossEntropyLoss().to(device)
+
+class AverageMeter(object):
+    def __init__(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+
+def train(train_loader, val_loader, model, num_epochs):
+    optimizer = optim.Adam(model.parameters(), lr=1e-3)
+
+    best_validation_loss = 999999 # arbitrarily large number
+
+    for epoch in range(1, num_epochs+1):
+        print(f"EPOCH {epoch}")
+        print(f"-" * 40)
+        loss_train, acc_train = train_epoch(train_loader, model, optimizer, epoch)
+        print(f"Training   | Loss {loss_train} | Accuracy {acc_train:.3%}")
+        loss_val, acc_val = validate(val_loader, model, epoch)
+        print(f"Validation | Loss {loss_val} | Accuracy {acc_val:.3%}")
+        
+        if loss_val < best_validation_loss:
+            best_validation_loss = loss_val
+            print(f"\tNEW BEST | LOSS {loss_val} | ACCURACY {acc_val:.3%}")
+            torch.save(model.state_dict(), f"Models/{MODEL_NAME}_run{num_epochs}.pt")
+        
+        print(f"-" * 40 + "\n")
+
+    return train_stats, val_stats
+
+def train_epoch(train_loader, model, optimizer, epoch_num):
+    model.train()
+    model.to(device)
+
+    train_loss = AverageMeter()
+    train_accuracy = AverageMeter()
+
+    loop = tqdm(train_loader, desc="\tTraining", ncols=80, mininterval=0.1)
+    for images, labels in loop:
+        images, labels = Variable(images).to(device), Variable(labels).to(device)
+        N = images.size(0)
+
+        optimizer.zero_grad()
+        outputs = model(images)
+
+        loss = criterion(outputs, labels)
+        loss.backward()
+        optimizer.step()
+        prediction = outputs.max(1, keepdim=True)[1]
+
+        train_accuracy.update(prediction.eq(labels.view_as(prediction)).sum().item() / N)
+        train_loss.update(loss.item())
+
+        loop.set_postfix(loss=loss.item(), accuracy=train_accuracy.avg)
+
+    train_stats.append([train_loss.avg, train_accuracy.avg])
+    return train_loss.avg, train_accuracy.avg
+
+def validate(val_loader, model, epoch_num):
+    model.eval()
+    val_loss = AverageMeter()
+    val_acc = AverageMeter()
+
+    with torch.no_grad():
+
+        loop = tqdm(val_loader, desc="\tValidation", ncols=80, mininterval=0.1)
+        for images, labels in loop:
+            images, labels = Variable(images).to(device), Variable(labels).to(device)
+            N = images.size(0)
+
+            outputs = model(images)
+            prediction= outputs.max(1, keepdim=True)[1]
+
+            loss = criterion(outputs, labels)
+
+            val_acc.update(prediction.eq(labels.view_as(prediction)).sum().item()/N)
+            val_loss.update(loss.item())
+            
+            loop.set_postfix(loss=loss.item(), accuracy=val_acc.avg)
+
+    val_stats.append([val_loss.avg, val_acc.avg])
+    return val_loss.avg, val_acc.avg