diff --git a/Models/segmentation128skip.pt b/Models/segmentation128skip.pt
index f392d2f351a0a6fbb9b7bc48f25530d039458fdd..a1904e91bd4716043c668e73c27bb1cd07c01a53 100644
Binary files a/Models/segmentation128skip.pt and b/Models/segmentation128skip.pt differ
diff --git a/Models/segmentation224skip.pt b/Models/segmentation224skip.pt
new file mode 100644
index 0000000000000000000000000000000000000000..b2cf17e05777de509d3fbdda6e87f1e603fad6d2
Binary files /dev/null and b/Models/segmentation224skip.pt differ
diff --git a/constants.py b/constants.py
index 95535283fb47cde587cba22fe0688bd5ecbd8265..0023f34b62c60fc5f177d80c4d6ad2e06b2e8b0d 100644
--- a/constants.py
+++ b/constants.py
@@ -1,5 +1,5 @@
 EPOCHS = 50
-BATCH_SIZE = 64
-INPUT_DIM = 128
+BATCH_SIZE = 32
+INPUT_DIM = 224
 VAL_SPLIT = 2000
 TEST_SPLIT = 500
\ No newline at end of file
diff --git a/main.py b/main.py
index 5ef1384d603afd4220a0d1d711e7f52ef4b8b994..b0877bfbf36120a74a87eda4b9f20fc4a82e5c48 100644
--- a/main.py
+++ b/main.py
@@ -1,11 +1,11 @@
 import matplotlib.pyplot as plt
 import numpy as np
-
+import torch
 from torchsummary import summary
 import datahandler
 import train_seg
 import seg_model
-from constants import EPOCHS
+from constants import EPOCHS, INPUT_DIM
 
 
 def main():
@@ -15,6 +15,7 @@ def seg_main():
     train_loader, val_loader = datahandler.getHamDataLoadersSeg()
     model = seg_model.SegmentationModel([3,16,32,64,1],[False, True, True])
     #summary(model, (3,128,128), batch_size=16,device="cpu")
+    model.load_state_dict(torch.load(f"Models/segmentation{INPUT_DIM}skip.pt"))
     train_seg.train_epochs(model, train_loader, val_loader, EPOCHS)
 
 if __name__ == '__main__':
diff --git a/seg_classifier_model.py b/seg_classifier_model.py
index 2a4d92c9dc7483f4be009859c88f857478066115..f212f72e3fa1e4f1178931c1e38a25764dac9103 100644
--- a/seg_classifier_model.py
+++ b/seg_classifier_model.py
@@ -1,19 +1,36 @@
 import torch
 import torch.nn as nn
+from torchvision import models
 
 from constants import INPUT_DIM
+from seg_model import SegmentationModel
 
 class SegClassifier(nn.Module):
     def __init__(self, seg_model, classifier, train_seg = False):
         self.seg_model = seg_model
         self.classifier = classifier
+        self.train_seg = train_seg
         #modifiy classifier first layer not add 4th input channel
+        new_conv = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
+        model.features.conv0 = new_conv
+
 
     def forward(x):
         if self.train_seg:
+            seg_mask = self.seg_model(x)
+        else:
             with torch.no_grad():
                 seg_mask = self.seg_model(x)
-        else:
-            seg_mask = self.seg_model(x)
         x = torch.cat([x,seg_mask],1)
         return self.classifier(x)
+
+def get_model(classifier_path: str, seg_path: str):
+    classifier = models.densenet121(pretrained=True)
+    in_features = classifier.classifier.in_features
+    classifier.classifier = nn.Linear(in_features, 7)
+    classifier.load_state_dict(torch.load(classifier_path))
+
+    model = seg_model.SegmentationModel([3,16,32,64,1],[False, True, True])
+    #summary(model, (3,128,128), batch_size=16,device="cpu")
+    model.load_state_dict(torch.load(seg_path))
+    return SegClassifier(seg_model, classifier)
diff --git a/seg_model.py b/seg_model.py
index 4e814e2cf7951d18b31521f29953a2153f4ba6f4..8d0db719d35fbe72b08165c5709ba4bbabe55fd4 100644
--- a/seg_model.py
+++ b/seg_model.py
@@ -4,10 +4,10 @@ import torch.nn as nn
 from constants import INPUT_DIM
 
 class ConvModule(nn.Module):
-    def __init__(self, in_channels, out_channels, pool=False):
+    def __init__(self, in_channels, out_channels, pool=False, kernel_size=3):
         super().__init__()
         # Define your network architecture here
-        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same')
+        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding='same')
         self.relu = nn.ReLU(inplace=True)
         self.bn = nn.BatchNorm2d(out_channels)
         self.dropout = nn.Dropout2d(p=0.1)
@@ -22,11 +22,11 @@ class ConvModule(nn.Module):
         return x
 
 class DeconvModule(nn.Module):
-    def __init__(self, in_channels, out_channels, depool=False):
+    def __init__(self, in_channels, out_channels, depool=False, kernel_size=3):
         super().__init__()
         stride = 2 if depool else 1
         output_padding = 1 if depool else 0
-        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, output_padding=output_padding)
+        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1, output_padding=output_padding)
         self.relu = nn.ReLU(inplace=True)
         self.bn1 = nn.BatchNorm2d(out_channels)
         self.dropout = nn.Dropout2d(p=0.1)
@@ -38,7 +38,6 @@ class DeconvModule(nn.Module):
         x = self.dropout(x)
         return x
 
-
 class SegmentationModel(nn.Module):
     def __init__(self, channels: [int], pool: [bool]):
         super(SegmentationModel, self).__init__()
@@ -50,19 +49,23 @@ class SegmentationModel(nn.Module):
         self.en5 =  ConvModule(64,128) #skip connection
         self.en6 =  ConvModule(128,128, True)
         self.en7 = ConvModule(128, 256) #skip connection
-        self.en8 = ConvModule(256, 256)
+        self.en8 = ConvModule(256, 256, True)
+        self.en9 = ConvModule(256, 512) #skip connection
+        self.en10 = ConvModule(512, 512)
 
         #Decoder
-        self.de1 = DeconvModule(256, 256)
-        self.de2 = DeconvModule(256*2, 128) #skip connection
-        self.de3 = DeconvModule(128, 128, True)
-        self.de4 = DeconvModule(128*2, 64)#skip connection
-        self.de5 = DeconvModule(64, 64, True)
-        self.de6 = DeconvModule(64*2, 32) #skip connection
-        self.de7 = DeconvModule(32, 16, True)
-        self.de8 = DeconvModule(16*2, 1)# skip connection
+        self.de1 = DeconvModule(512, 512)
+        self.de2 = DeconvModule(512*2, 256)#skip connection
+        self.de3 = DeconvModule(256, 256, True)
+        self.de4 = DeconvModule(256*2, 128) #skip connection
+        self.de5 = DeconvModule(128, 128, True)
+        self.de6 = DeconvModule(128*2, 64)#skip connection
+        self.de7 = DeconvModule(64, 64, True)
+        self.de8 = DeconvModule(64*2, 32) #skip connection
+        self.de9 = DeconvModule(32, 16, True)
+        self.de10 = DeconvModule(16*2, 1)# skip connection
 
-        self.sigmoid = nn.Sigmoid()
+        #self.sigmoid = nn.Sigmoid()
 
     def forward(self, x):
         # Perform forward pass of the network
@@ -78,17 +81,22 @@ class SegmentationModel(nn.Module):
         x = self.en7(x)
         x7 = x
         x = self.en8(x)
+        x = self.en9(x)
+        x9 = x
+        x = self.en10(x)
 
 
         x = self.de1(x)
-        x = self.de2(torch.concat((x,x7),1))
+        x = self.de2(torch.concat((x,x9),1))
         x = self.de3(x)
-        x = self.de4(torch.concat((x,x5),1))
+        x = self.de4(torch.concat((x,x7),1))
         x = self.de5(x)
-        x = self.de6(torch.concat((x,x3),1))
+        x = self.de6(torch.concat((x,x5),1))
         x = self.de7(x)
-        x = self.de8(torch.concat((x,x1),1))
+        x = self.de8(torch.concat((x,x3),1))
+        x = self.de9(x)
+        x = self.de10(torch.concat((x,x1),1))
 
 
-        x = self.sigmoid(x)
+        #x = self.sigmoid(x)
         return x
diff --git a/train_seg.py b/train_seg.py
index 3954c6acb4dc37f101cc2a9083b84ce5017f1b26..5c18a47b5049bb7df37fc4759d4cb26030227caa 100644
--- a/train_seg.py
+++ b/train_seg.py
@@ -2,15 +2,33 @@ import torch.optim as optim
 import torch.nn as nn
 from torch.utils.data import DataLoader
 import torch
+import numpy as np
 from tqdm import tqdm
 
 from constants import INPUT_DIM
 
+def jaccard_score(y_pred: torch.Tensor, y_true: torch.Tensor, threshold: float = 0.5) -> float:
+    # Convert logits to probabilities using the sigmoid function
+    y_pred = torch.sigmoid(y_pred.cpu())
+    y_true = y_true.cpu().numpy()
+    # Convert probabilities to binary mask using the threshold
+    y_pred = (y_pred >= threshold).numpy()
+    
+    # Calculate the intersection and union of the two masks
+    intersection = np.logical_and(y_pred, y_true)
+    union = np.logical_or(y_pred, y_true)
+    
+    # Calculate the Jaccard score
+    jaccard = np.sum(intersection) / np.sum(union)
+    
+    return jaccard
+
 def train(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, loader: DataLoader, device: str):
     epoch_loss = 0
+    epoch_acc = 0
     model.train()
 
-    loop = tqdm(loader, desc="\tTraining", ncols=80, mininterval=0.1)
+    loop = tqdm(loader, desc="\tTraining", ncols=100, mininterval=0.1)
     for x, y in loop:
         x, y = x.to(device), y.to(device)  # Move data to GPU if available
         optimizer.zero_grad()
@@ -18,50 +36,118 @@ def train(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, lo
         loss = criterion(y_pred, y)
         loss.backward()
         optimizer.step()
+        acc = jaccard_score(y_pred, y)
         epoch_loss += loss.item()
-        loop.set_postfix(loss = loss.item())
+        epoch_acc += acc
+        loop.set_postfix(loss = loss.item(), accuracy=acc)
 
-    return epoch_loss / len(loader)
+    return epoch_loss / len(loader) , epoch_acc / len(loader)
 
 def eval(model: nn.Module, criterion: nn.Module, loader: DataLoader, device: str):
     epoch_loss = 0
+    epoch_acc = 0
     model.eval()
 
-    loop = tqdm(loader, desc="\tValidation", ncols=80, mininterval=0.1)
+    loop = tqdm(loader, desc="\tValidation", ncols=100, mininterval=0.1)
     for x, y in loop:
         x, y = x.to(device), y.to(device)  # Move data to GPU if available
         y_pred = model(x)
         loss = criterion(y_pred, y)
+        acc = jaccard_score(y_pred, y)
         epoch_loss += loss.item()
-        loop.set_postfix(loss = loss.item())
+        epoch_acc += acc
+        loop.set_postfix(loss = loss.item(), accuracy=acc)
 
-    return epoch_loss / len(loader)
+    return epoch_loss / len(loader) , epoch_acc / len(loader)
 
 def train_epochs(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, epochs: int):
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     model.to(device)
-    optimizer = optim.Adam(model.parameters(), lr=0.02)
+    lr = 0.0001
+    momentum = 0.9
+    weight_decay = 0.0001
+    nesterov = True
+    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
     criterion = nn.BCEWithLogitsLoss()
     criterion.to(device)
 
     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True)
 
     best_val_loss = float('inf')
-    train_metrics = {'train_loss': [], 'val_loss': []}
+    train_metrics = {'train_loss': [], 'train_acc':[],  'val_loss': [], 'val_acc':[]}
     for epoch in range(1, epochs+1):
         print(f"Epoch: {epoch}")
         print("-"*20)
-        train_loss = train(model, criterion, optimizer, train_loader, device)
-        print(f'\tTrain Loss: {train_loss:.3f}')
 
-        val_loss = eval(model, criterion, val_loader, device)
-        print(f'\tValidation Loss: {val_loss:.3f}')
+        train_loss, train_acc = train(model, criterion, optimizer, train_loader, device)
+        val_loss, val_acc = eval(model, criterion, val_loader, device)
+
+        print(f'\t     Train Loss: {train_loss:.3f},    Train Accuracy: {train_acc:.3f}')
+        print(f'\tValidation Loss: {val_loss:.3f}, Validation Accuracy: {val_acc:.3f}')
 
         train_metrics['train_loss'].append(train_loss)
+        train_metrics['train_acc'].append(train_acc)
         train_metrics['val_loss'].append(val_loss)
+        train_metrics['val_acc'].append(val_acc)
 
         if val_loss < best_val_loss:
             best_val_loss = val_loss
             torch.save(model.state_dict(), f"Models/segmentation{INPUT_DIM}skip.pt")
 
+        scheduler.step(val_loss)
+
+def train_epochs_classifier(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, epochs: int):
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model.to(device)
+    lr = 0.0001
+    momentum = 0.9
+    weight_decay = 0.0001
+    nesterov = True
+    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
+    criterion = nn.BCEWithLogitsLoss()
+    criterion.to(device)
+
+    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True)
+
+    best_val_loss = float('inf')
+    train_metrics = {'train_loss': [], 'train_acc':[],  'val_loss': [], 'val_acc':[]}
+    for epoch in range(1, epochs+1):
+        print(f"Epoch: {epoch}")
+        print("-"*20)
+
+        train_loss, train_acc = train(model, criterion, optimizer, train_loader, device)
+        val_loss, val_acc = eval(model, criterion, val_loader, device)
+
+        print(f'\t     Train Loss: {train_loss:.3f},    Train Accuracy: {train_acc:.3f}')
+        print(f'\tValidation Loss: {val_loss:.3f}, Validation Accuracy: {val_acc:.3f}')
+
+        train_metrics['train_loss'].append(train_loss)
+        train_metrics['train_acc'].append(train_acc)
+        train_metrics['val_loss'].append(val_loss)
+        train_metrics['val_acc'].append(val_acc)
+
+        scheduler.step(val_loss)
+
+    model.train_seg = True
+
+    best_val_loss = float('inf')
+    for epoch in range(1, epochs+1):
+        print(f"Epoch: {epoch}")
+        print("-"*20)
+
+        train_loss, train_acc = train(model, criterion, optimizer, train_loader, device)
+        val_loss, val_acc = eval(model, criterion, val_loader, device)
+
+        print(f'\t     Train Loss: {train_loss:.3f},    Train Accuracy: {train_acc:.3f}')
+        print(f'\tValidation Loss: {val_loss:.3f}, Validation Accuracy: {val_acc:.3f}')
+
+        train_metrics['train_loss'].append(train_loss)
+        train_metrics['train_acc'].append(train_acc)
+        train_metrics['val_loss'].append(val_loss)
+        train_metrics['val_acc'].append(val_acc)
+
+        if val_loss < best_val_loss:
+            best_val_loss = val_loss
+            torch.save(model.state_dict(), f"Models/segmentation_classifier{INPUT_DIM}skip.pt")
+
         scheduler.step(val_loss)
\ No newline at end of file