diff --git a/evaluation.py b/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..6c68c2c1164da556d4de604e6107ac4b20fb838c --- /dev/null +++ b/evaluation.py @@ -0,0 +1,104 @@ +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +from sklearn.metrics import accuracy_score, classification_report, confusion_matrix +import torch.optim as optim +import numpy as np + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +def setupEval(model_run, model): + path = 'Models/' + model_run + '.pt' + + model.load_state_dict(torch.load(path, map_location=torch.device(device))) + model.eval() +def metrics(model_run, model, val_loader): + setupEval(model_run, model) + + # Set loss function and optimizer + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters()) + + val_loss = 0 + val_correct = 0 + val_total = 0 + val_preds = [] + val_targets = [] + + with torch.no_grad(): + + for i, data in enumerate(val_loader): + #if i == 5: + #break + inputs, labels = data + inputs, labels = inputs.to(device), labels.to(device) + outputs = model(inputs) + loss = criterion(outputs, labels) + val_loss += loss.item() + _, predicted = torch.max(outputs.data, 1) + val_total += labels.size(0) + val_correct += (predicted == labels).sum().item() + val_preds += predicted.tolist() + val_targets += labels.tolist() + + # Print validation metrics + val_accuracy = val_correct / val_total + val_loss /= len(val_loader) + print(f"Validation Accuracy: {val_accuracy:.4f}") + print(f"Validation Loss: {val_loss:.4f}") + print(classification_report(val_targets, val_preds)) + + # Plot metrics + plt.figure(figsize=(8, 6)) + plt.title("Validation Metrics") + plt.plot(val_loss, label="Loss") + plt.plot(val_accuracy, label="Accuracy") + plt.xlabel("Epoch") + plt.legend() + plt.show() + +def confMtrx(model_run, model, train_loader, val_loader): + setupEval(model_run, model) + + # initialize empty lists for true labels and predicted labels + train_true_labels = [] + train_pred_labels = [] + val_true_labels = [] + val_pred_labels = [] + + # get the true and predicted labels for the train_loader + with torch.no_grad(): + for data in train_loader: + #for i, data in enumerate(train_loader): + #if i == 20: + #break + inputs, labels = data + inputs, labels = inputs.to(device), labels.to(device) + outputs = model(inputs) + _, predicted = torch.max(outputs.data, 1) + train_true_labels += labels.tolist() + train_pred_labels += predicted.tolist() + + # get the true and predicted labels for the val_loader + with torch.no_grad(): + #for data in val_loader: + for i, data in enumerate(val_loader): + if i == 20: + break + inputs, labels = data + inputs, labels = inputs.to(device), labels.to(device) + outputs = model(inputs) + _, predicted = torch.max(outputs.data, 1) + val_true_labels += labels.tolist() + val_pred_labels += predicted.tolist() + + # define the labels for the confusion matrix + labels = ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc'] + + # create the confusion matrices for train and val sets + train_conf_matrix = confusion_matrix(train_true_labels, train_pred_labels, labels=labels) + val_conf_matrix = confusion_matrix(val_true_labels, val_pred_labels, labels=labels) + + print("Train Confusion Matrix:") + print(train_conf_matrix) + print("Val Confusion Matrix:") + print(val_conf_matrix) \ No newline at end of file diff --git a/main.py b/main.py index 37943864d4af16b61bedb0388778e1eaac91add1..74024fa84d9200f96f7742ea79b361307ac18695 100644 --- a/main.py +++ b/main.py @@ -9,12 +9,21 @@ import seg_model from constants import EPOCHS, MODEL_NAME, INPUT_DIM from train import train from model import get_model +from evaluation import metrics, confMtrx import seg_classifier_model + def main(): model = get_model(MODEL_NAME, 7, False) train_loader, val_loader = datahandler.getHamDataLoaders() - train(train_loader, val_loader, model, EPOCHS) + #train(train_loader, val_loader, model, EPOCHS) + + #Evaluation: Uncomment below + model_run = "densenet121_run20" + #metrics(model_run, model, val_loader) + #confMtrx(model_run, model, train_loader, val_loader) + + print("\n\n -------------Done----------------") def seg_main(): train_loader, val_loader = datahandler.getHamDataLoadersSeg()