diff --git a/main.py b/main.py
index 07c70bfc0ffc6740f7349e4bcdb4c93f7828c62d..37943864d4af16b61bedb0388778e1eaac91add1 100644
--- a/main.py
+++ b/main.py
@@ -6,9 +6,10 @@ from torchsummary import summary
 import datahandler
 import train_seg
 import seg_model
-from constants import EPOCHS, MODEL_NAME
+from constants import EPOCHS, MODEL_NAME, INPUT_DIM
 from train import train
 from model import get_model
+import seg_classifier_model
 
 def main():
     model = get_model(MODEL_NAME, 7, False)
@@ -22,6 +23,11 @@ def seg_main():
     model.load_state_dict(torch.load(f"Models/segmentation{INPUT_DIM}skip.pt"))
     train_seg.train_epochs(model, train_loader, val_loader, EPOCHS)
 
+def seg_main_classifier():
+    model = seg_classifier_model.get_model("Models/densenet121_run20.pt", f"Models/segmentation{INPUT_DIM}skip.pt")
+    train_loader, val_loader = datahandler.getHamDataLoaders()
+    train_seg.train_epochs(model, train_loader, val_loader, EPOCHS)
+
 # When run from command line it can take an additional argument:
 # if you add the additional argument with parameter 'seg' then it'll run the segmentation training loop
 # otherwise it'll just run the main model training loop
@@ -30,6 +36,8 @@ if __name__ == '__main__':
         if sys.argv[1] == "seg":
             print("SEGMENTATION")
             seg_main()
+        if sys.argv[1] == "segc":
+            seg_main_classifier()
         else:
             print("NORMAL")
             main()
diff --git a/seg_classifier_model.py b/seg_classifier_model.py
index f212f72e3fa1e4f1178931c1e38a25764dac9103..09aecd54e5c6adc28c5ee4cd6da9c1f7023bfaca 100644
--- a/seg_classifier_model.py
+++ b/seg_classifier_model.py
@@ -7,15 +7,15 @@ from seg_model import SegmentationModel
 
 class SegClassifier(nn.Module):
     def __init__(self, seg_model, classifier, train_seg = False):
+        super(SegClassifier, self).__init__()
         self.seg_model = seg_model
         self.classifier = classifier
         self.train_seg = train_seg
         #modifiy classifier first layer not add 4th input channel
         new_conv = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
-        model.features.conv0 = new_conv
+        self.classifier.features.conv0 = new_conv
 
-
-    def forward(x):
+    def forward(self, x):
         if self.train_seg:
             seg_mask = self.seg_model(x)
         else:
@@ -30,7 +30,7 @@ def get_model(classifier_path: str, seg_path: str):
     classifier.classifier = nn.Linear(in_features, 7)
     classifier.load_state_dict(torch.load(classifier_path))
 
-    model = seg_model.SegmentationModel([3,16,32,64,1],[False, True, True])
+    seg_model = SegmentationModel([3,16,32,64,1],[False, True, True])
     #summary(model, (3,128,128), batch_size=16,device="cpu")
-    model.load_state_dict(torch.load(seg_path))
+    seg_model.load_state_dict(torch.load(seg_path))
     return SegClassifier(seg_model, classifier)
diff --git a/train_seg.py b/train_seg.py
index 5c18a47b5049bb7df37fc4759d4cb26030227caa..3dc18c3f0254d0a4994a7d16c5d37503169f3525 100644
--- a/train_seg.py
+++ b/train_seg.py
@@ -32,7 +32,7 @@ def train(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, lo
     for x, y in loop:
         x, y = x.to(device), y.to(device)  # Move data to GPU if available
         optimizer.zero_grad()
-        y_pred = model(x)
+        y_pred = torch.argmax(model(x))
         loss = criterion(y_pred, y)
         loss.backward()
         optimizer.step()
@@ -51,7 +51,7 @@ def eval(model: nn.Module, criterion: nn.Module, loader: DataLoader, device: str
     loop = tqdm(loader, desc="\tValidation", ncols=100, mininterval=0.1)
     for x, y in loop:
         x, y = x.to(device), y.to(device)  # Move data to GPU if available
-        y_pred = model(x)
+        y_pred = torch.argmax(model(x))
         loss = criterion(y_pred, y)
         acc = jaccard_score(y_pred, y)
         epoch_loss += loss.item()
@@ -104,7 +104,7 @@ def train_epochs_classifier(model: nn.Module, train_loader: DataLoader, val_load
     weight_decay = 0.0001
     nesterov = True
     optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
-    criterion = nn.BCEWithLogitsLoss()
+    nn.CrossEntropyLoss().to(device)
     criterion.to(device)
 
     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True)