Skip to content
Snippets Groups Projects
Commit 9e378340 authored by Graves, Conor S (UG - Comp Sci & Elec Eng)'s avatar Graves, Conor S (UG - Comp Sci & Elec Eng)
Browse files

Merge branch 'Evaluation' into 'main'

Evaluation merge

See merge request !9
parents be7e96b7 daecd9f6
No related branches found
No related tags found
1 merge request!9Evaluation merge
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import torch.optim as optim
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def setupEval(model_run, model):
path = 'Models/' + model_run + '.pt'
model.load_state_dict(torch.load(path, map_location=torch.device(device)))
model.eval()
def metrics(model_run, model, val_loader):
setupEval(model_run, model)
# Set loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
val_loss = 0
val_correct = 0
val_total = 0
val_preds = []
val_targets = []
with torch.no_grad():
for i, data in enumerate(val_loader):
#if i == 5:
#break
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
val_total += labels.size(0)
val_correct += (predicted == labels).sum().item()
val_preds += predicted.tolist()
val_targets += labels.tolist()
# Print validation metrics
val_accuracy = val_correct / val_total
val_loss /= len(val_loader)
print(f"Validation Accuracy: {val_accuracy:.4f}")
print(f"Validation Loss: {val_loss:.4f}")
print(classification_report(val_targets, val_preds))
# Plot metrics
plt.figure(figsize=(8, 6))
plt.title("Validation Metrics")
plt.plot(val_loss, label="Loss")
plt.plot(val_accuracy, label="Accuracy")
plt.xlabel("Epoch")
plt.legend()
plt.show()
def confMtrx(model_run, model, train_loader, val_loader):
setupEval(model_run, model)
# initialize empty lists for true labels and predicted labels
train_true_labels = []
train_pred_labels = []
val_true_labels = []
val_pred_labels = []
# get the true and predicted labels for the train_loader
with torch.no_grad():
for data in train_loader:
#for i, data in enumerate(train_loader):
#if i == 20:
#break
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
train_true_labels += labels.tolist()
train_pred_labels += predicted.tolist()
# get the true and predicted labels for the val_loader
with torch.no_grad():
#for data in val_loader:
for i, data in enumerate(val_loader):
if i == 20:
break
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
val_true_labels += labels.tolist()
val_pred_labels += predicted.tolist()
# define the labels for the confusion matrix
labels = ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
# create the confusion matrices for train and val sets
train_conf_matrix = confusion_matrix(train_true_labels, train_pred_labels, labels=labels)
val_conf_matrix = confusion_matrix(val_true_labels, val_pred_labels, labels=labels)
print("Train Confusion Matrix:")
print(train_conf_matrix)
print("Val Confusion Matrix:")
print(val_conf_matrix)
\ No newline at end of file
...@@ -9,12 +9,21 @@ import seg_model ...@@ -9,12 +9,21 @@ import seg_model
from constants import EPOCHS, MODEL_NAME, INPUT_DIM from constants import EPOCHS, MODEL_NAME, INPUT_DIM
from train import train from train import train
from model import get_model from model import get_model
from evaluation import metrics, confMtrx
import seg_classifier_model import seg_classifier_model
def main(): def main():
model = get_model(MODEL_NAME, 7, False) model = get_model(MODEL_NAME, 7, False)
train_loader, val_loader = datahandler.getHamDataLoaders() train_loader, val_loader = datahandler.getHamDataLoaders()
train(train_loader, val_loader, model, EPOCHS) #train(train_loader, val_loader, model, EPOCHS)
#Evaluation: Uncomment below
model_run = "densenet121_run20"
#metrics(model_run, model, val_loader)
#confMtrx(model_run, model, train_loader, val_loader)
print("\n\n -------------Done----------------")
def seg_main(): def seg_main():
train_loader, val_loader = datahandler.getHamDataLoadersSeg() train_loader, val_loader = datahandler.getHamDataLoadersSeg()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment