diff --git a/main.py b/main.py index 37943864d4af16b61bedb0388778e1eaac91add1..3e2e19c424f0f41c4f672976f0f41f20952fd55f 100644 --- a/main.py +++ b/main.py @@ -26,7 +26,7 @@ def seg_main(): def seg_main_classifier(): model = seg_classifier_model.get_model("Models/densenet121_run20.pt", f"Models/segmentation{INPUT_DIM}skip.pt") train_loader, val_loader = datahandler.getHamDataLoaders() - train_seg.train_epochs(model, train_loader, val_loader, EPOCHS) + train_seg.train_epochs_classifier(model, train_loader, val_loader, EPOCHS) # When run from command line it can take an additional argument: # if you add the additional argument with parameter 'seg' then it'll run the segmentation training loop diff --git a/seg_classifier_model.py b/seg_classifier_model.py index 09aecd54e5c6adc28c5ee4cd6da9c1f7023bfaca..aaffc5ccb988a6cbb087a2434a763a86fe9e88ee 100644 --- a/seg_classifier_model.py +++ b/seg_classifier_model.py @@ -10,7 +10,7 @@ class SegClassifier(nn.Module): super(SegClassifier, self).__init__() self.seg_model = seg_model self.classifier = classifier - self.train_seg = train_seg + self.train_seg = True, train_seg #modifiy classifier first layer not add 4th input channel new_conv = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) self.classifier.features.conv0 = new_conv diff --git a/train_seg.py b/train_seg.py index 3dc18c3f0254d0a4994a7d16c5d37503169f3525..8ff9b5d651c547f284e740d89302fce3ec225621 100644 --- a/train_seg.py +++ b/train_seg.py @@ -1,3 +1,4 @@ +from typing import Callable import torch.optim as optim import torch.nn as nn from torch.utils.data import DataLoader @@ -23,7 +24,13 @@ def jaccard_score(y_pred: torch.Tensor, y_true: torch.Tensor, threshold: float = return jaccard -def train(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, loader: DataLoader, device: str): +def accuracy(y_pred, y_true): + correct = 0 + for pred, y in zip(y_pred.argmax(1).cpu().numpy(),y_true.cpu().numpy()): + if pred == y: correct +=1 + return correct/len(y_true) + +def train(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, loader: DataLoader, device: str, acccuracy_fn: Callable): epoch_loss = 0 epoch_acc = 0 model.train() @@ -32,18 +39,18 @@ def train(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, lo for x, y in loop: x, y = x.to(device), y.to(device) # Move data to GPU if available optimizer.zero_grad() - y_pred = torch.argmax(model(x)) + y_pred = model(x).float() loss = criterion(y_pred, y) loss.backward() optimizer.step() - acc = jaccard_score(y_pred, y) + acc = acccuracy_fn(y_pred, y) epoch_loss += loss.item() epoch_acc += acc loop.set_postfix(loss = loss.item(), accuracy=acc) return epoch_loss / len(loader) , epoch_acc / len(loader) -def eval(model: nn.Module, criterion: nn.Module, loader: DataLoader, device: str): +def eval(model: nn.Module, criterion: nn.Module, loader: DataLoader, device: str, acccuracy_fn: Callable): epoch_loss = 0 epoch_acc = 0 model.eval() @@ -53,7 +60,7 @@ def eval(model: nn.Module, criterion: nn.Module, loader: DataLoader, device: str x, y = x.to(device), y.to(device) # Move data to GPU if available y_pred = torch.argmax(model(x)) loss = criterion(y_pred, y) - acc = jaccard_score(y_pred, y) + acc = acccuracy_fn(y_pred, y) epoch_loss += loss.item() epoch_acc += acc loop.set_postfix(loss = loss.item(), accuracy=acc) @@ -79,8 +86,8 @@ def train_epochs(model: nn.Module, train_loader: DataLoader, val_loader: DataLoa print(f"Epoch: {epoch}") print("-"*20) - train_loss, train_acc = train(model, criterion, optimizer, train_loader, device) - val_loss, val_acc = eval(model, criterion, val_loader, device) + train_loss, train_acc = train(model, criterion, optimizer, train_loader, device, jaccard_score) + val_loss, val_acc = eval(model, criterion, val_loader, device, jaccard_score) print(f'\t Train Loss: {train_loss:.3f}, Train Accuracy: {train_acc:.3f}') print(f'\tValidation Loss: {val_loss:.3f}, Validation Accuracy: {val_acc:.3f}') @@ -104,19 +111,18 @@ def train_epochs_classifier(model: nn.Module, train_loader: DataLoader, val_load weight_decay = 0.0001 nesterov = True optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov) - nn.CrossEntropyLoss().to(device) + criterion = nn.CrossEntropyLoss() criterion.to(device) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True) - best_val_loss = float('inf') train_metrics = {'train_loss': [], 'train_acc':[], 'val_loss': [], 'val_acc':[]} for epoch in range(1, epochs+1): print(f"Epoch: {epoch}") print("-"*20) - train_loss, train_acc = train(model, criterion, optimizer, train_loader, device) - val_loss, val_acc = eval(model, criterion, val_loader, device) + train_loss, train_acc = train(model, criterion, optimizer, train_loader, device, accuracy) + val_loss, val_acc = eval(model, criterion, val_loader, device, accuracy) print(f'\t Train Loss: {train_loss:.3f}, Train Accuracy: {train_acc:.3f}') print(f'\tValidation Loss: {val_loss:.3f}, Validation Accuracy: {val_acc:.3f}') @@ -148,6 +154,6 @@ def train_epochs_classifier(model: nn.Module, train_loader: DataLoader, val_load if val_loss < best_val_loss: best_val_loss = val_loss - torch.save(model.state_dict(), f"Models/segmentation_classifier{INPUT_DIM}skip.pt") + torch.save(model.state_dict(), f"Models/segmentation_classifier{INPUT_DIM}skip{epoch}.pt") scheduler.step(val_loss) \ No newline at end of file