diff --git a/main.py b/main.py
index 37943864d4af16b61bedb0388778e1eaac91add1..3e2e19c424f0f41c4f672976f0f41f20952fd55f 100644
--- a/main.py
+++ b/main.py
@@ -26,7 +26,7 @@ def seg_main():
 def seg_main_classifier():
     model = seg_classifier_model.get_model("Models/densenet121_run20.pt", f"Models/segmentation{INPUT_DIM}skip.pt")
     train_loader, val_loader = datahandler.getHamDataLoaders()
-    train_seg.train_epochs(model, train_loader, val_loader, EPOCHS)
+    train_seg.train_epochs_classifier(model, train_loader, val_loader, EPOCHS)
 
 # When run from command line it can take an additional argument:
 # if you add the additional argument with parameter 'seg' then it'll run the segmentation training loop
diff --git a/seg_classifier_model.py b/seg_classifier_model.py
index 09aecd54e5c6adc28c5ee4cd6da9c1f7023bfaca..aaffc5ccb988a6cbb087a2434a763a86fe9e88ee 100644
--- a/seg_classifier_model.py
+++ b/seg_classifier_model.py
@@ -10,7 +10,7 @@ class SegClassifier(nn.Module):
         super(SegClassifier, self).__init__()
         self.seg_model = seg_model
         self.classifier = classifier
-        self.train_seg = train_seg
+        self.train_seg = True, train_seg
         #modifiy classifier first layer not add 4th input channel
         new_conv = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
         self.classifier.features.conv0 = new_conv
diff --git a/train_seg.py b/train_seg.py
index 3dc18c3f0254d0a4994a7d16c5d37503169f3525..8ff9b5d651c547f284e740d89302fce3ec225621 100644
--- a/train_seg.py
+++ b/train_seg.py
@@ -1,3 +1,4 @@
+from typing import Callable
 import torch.optim as optim
 import torch.nn as nn
 from torch.utils.data import DataLoader
@@ -23,7 +24,13 @@ def jaccard_score(y_pred: torch.Tensor, y_true: torch.Tensor, threshold: float =
     
     return jaccard
 
-def train(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, loader: DataLoader, device: str):
+def accuracy(y_pred, y_true):
+    correct = 0
+    for pred, y in zip(y_pred.argmax(1).cpu().numpy(),y_true.cpu().numpy()):
+        if pred == y: correct +=1
+    return correct/len(y_true)
+
+def train(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, loader: DataLoader, device: str, acccuracy_fn: Callable):
     epoch_loss = 0
     epoch_acc = 0
     model.train()
@@ -32,18 +39,18 @@ def train(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, lo
     for x, y in loop:
         x, y = x.to(device), y.to(device)  # Move data to GPU if available
         optimizer.zero_grad()
-        y_pred = torch.argmax(model(x))
+        y_pred = model(x).float()
         loss = criterion(y_pred, y)
         loss.backward()
         optimizer.step()
-        acc = jaccard_score(y_pred, y)
+        acc = acccuracy_fn(y_pred, y)
         epoch_loss += loss.item()
         epoch_acc += acc
         loop.set_postfix(loss = loss.item(), accuracy=acc)
 
     return epoch_loss / len(loader) , epoch_acc / len(loader)
 
-def eval(model: nn.Module, criterion: nn.Module, loader: DataLoader, device: str):
+def eval(model: nn.Module, criterion: nn.Module, loader: DataLoader, device: str, acccuracy_fn: Callable):
     epoch_loss = 0
     epoch_acc = 0
     model.eval()
@@ -53,7 +60,7 @@ def eval(model: nn.Module, criterion: nn.Module, loader: DataLoader, device: str
         x, y = x.to(device), y.to(device)  # Move data to GPU if available
         y_pred = torch.argmax(model(x))
         loss = criterion(y_pred, y)
-        acc = jaccard_score(y_pred, y)
+        acc = acccuracy_fn(y_pred, y)
         epoch_loss += loss.item()
         epoch_acc += acc
         loop.set_postfix(loss = loss.item(), accuracy=acc)
@@ -79,8 +86,8 @@ def train_epochs(model: nn.Module, train_loader: DataLoader, val_loader: DataLoa
         print(f"Epoch: {epoch}")
         print("-"*20)
 
-        train_loss, train_acc = train(model, criterion, optimizer, train_loader, device)
-        val_loss, val_acc = eval(model, criterion, val_loader, device)
+        train_loss, train_acc = train(model, criterion, optimizer, train_loader, device, jaccard_score)
+        val_loss, val_acc = eval(model, criterion, val_loader, device, jaccard_score)
 
         print(f'\t     Train Loss: {train_loss:.3f},    Train Accuracy: {train_acc:.3f}')
         print(f'\tValidation Loss: {val_loss:.3f}, Validation Accuracy: {val_acc:.3f}')
@@ -104,19 +111,18 @@ def train_epochs_classifier(model: nn.Module, train_loader: DataLoader, val_load
     weight_decay = 0.0001
     nesterov = True
     optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
-    nn.CrossEntropyLoss().to(device)
+    criterion = nn.CrossEntropyLoss()
     criterion.to(device)
 
     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True)
-
     best_val_loss = float('inf')
     train_metrics = {'train_loss': [], 'train_acc':[],  'val_loss': [], 'val_acc':[]}
     for epoch in range(1, epochs+1):
         print(f"Epoch: {epoch}")
         print("-"*20)
 
-        train_loss, train_acc = train(model, criterion, optimizer, train_loader, device)
-        val_loss, val_acc = eval(model, criterion, val_loader, device)
+        train_loss, train_acc = train(model, criterion, optimizer, train_loader, device, accuracy)
+        val_loss, val_acc = eval(model, criterion, val_loader, device, accuracy)
 
         print(f'\t     Train Loss: {train_loss:.3f},    Train Accuracy: {train_acc:.3f}')
         print(f'\tValidation Loss: {val_loss:.3f}, Validation Accuracy: {val_acc:.3f}')
@@ -148,6 +154,6 @@ def train_epochs_classifier(model: nn.Module, train_loader: DataLoader, val_load
 
         if val_loss < best_val_loss:
             best_val_loss = val_loss
-            torch.save(model.state_dict(), f"Models/segmentation_classifier{INPUT_DIM}skip.pt")
+            torch.save(model.state_dict(), f"Models/segmentation_classifier{INPUT_DIM}skip{epoch}.pt")
 
         scheduler.step(val_loss)
\ No newline at end of file