Skip to content
Snippets Groups Projects
Commit 974a21bf authored by Graves, Conor S (UG - Comp Sci & Elec Eng)'s avatar Graves, Conor S (UG - Comp Sci & Elec Eng)
Browse files

Added confusion matrix evaluation

parent ca3dcb8e
Branches master
No related tags found
1 merge request!9Evaluation merge
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
import torch.nn as nn import torch.nn as nn
from sklearn.metrics import accuracy_score, classification_report from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import torch.optim as optim import torch.optim as optim
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def setupEval(model_run, model): def setupEval(model_run, model):
...@@ -10,7 +11,7 @@ def setupEval(model_run, model): ...@@ -10,7 +11,7 @@ def setupEval(model_run, model):
model.load_state_dict(torch.load(path, map_location=torch.device(device))) model.load_state_dict(torch.load(path, map_location=torch.device(device)))
model.eval() model.eval()
def printMetrics(model_run, model, val_loader): def metrics(model_run, model, val_loader):
setupEval(model_run, model) setupEval(model_run, model)
# Set loss function and optimizer # Set loss function and optimizer
...@@ -53,4 +54,51 @@ def printMetrics(model_run, model, val_loader): ...@@ -53,4 +54,51 @@ def printMetrics(model_run, model, val_loader):
plt.plot(val_accuracy, label="Accuracy") plt.plot(val_accuracy, label="Accuracy")
plt.xlabel("Epoch") plt.xlabel("Epoch")
plt.legend() plt.legend()
plt.show() plt.show()
\ No newline at end of file
def confMtrx(model_run, model, train_loader, val_loader):
setupEval(model_run, model)
# initialize empty lists for true labels and predicted labels
train_true_labels = []
train_pred_labels = []
val_true_labels = []
val_pred_labels = []
# get the true and predicted labels for the train_loader
with torch.no_grad():
for data in train_loader:
#for i, data in enumerate(train_loader):
#if i == 20:
#break
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
train_true_labels += labels.tolist()
train_pred_labels += predicted.tolist()
# get the true and predicted labels for the val_loader
with torch.no_grad():
#for data in val_loader:
for i, data in enumerate(val_loader):
if i == 20:
break
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
val_true_labels += labels.tolist()
val_pred_labels += predicted.tolist()
# define the labels for the confusion matrix
labels = ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
# create the confusion matrices for train and val sets
train_conf_matrix = confusion_matrix(train_true_labels, train_pred_labels, labels=labels)
val_conf_matrix = confusion_matrix(val_true_labels, val_pred_labels, labels=labels)
print("Train Confusion Matrix:")
print(train_conf_matrix)
print("Val Confusion Matrix:")
print(val_conf_matrix)
\ No newline at end of file
...@@ -9,15 +9,18 @@ import seg_model ...@@ -9,15 +9,18 @@ import seg_model
from constants import EPOCHS, MODEL_NAME from constants import EPOCHS, MODEL_NAME
from train import train from train import train
from model import get_model from model import get_model
from evaluation import printMetrics from evaluation import metrics, confMtrx
def main(): def main():
model = get_model(MODEL_NAME, 7, False) model = get_model(MODEL_NAME, 7, False)
train_loader, val_loader = datahandler.getHamDataLoaders() train_loader, val_loader = datahandler.getHamDataLoaders()
#train(train_loader, val_loader, model, EPOCHS) #train(train_loader, val_loader, model, EPOCHS)
model_run = "densenet121_run4" #Evaluation: Uncomment below
printMetrics(model_run, model, val_loader) model_run = "densenet121_run20"
#metrics(model_run, model, val_loader)
#confMtrx(model_run, model, train_loader, val_loader)
print("\n\n -------------Done----------------") print("\n\n -------------Done----------------")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment