diff --git a/Models/segmentation128skip.pt b/Models/segmentation128skip.pt index f392d2f351a0a6fbb9b7bc48f25530d039458fdd..a1904e91bd4716043c668e73c27bb1cd07c01a53 100644 Binary files a/Models/segmentation128skip.pt and b/Models/segmentation128skip.pt differ diff --git a/Models/segmentation224skip.pt b/Models/segmentation224skip.pt new file mode 100644 index 0000000000000000000000000000000000000000..b2cf17e05777de509d3fbdda6e87f1e603fad6d2 Binary files /dev/null and b/Models/segmentation224skip.pt differ diff --git a/constants.py b/constants.py index 95535283fb47cde587cba22fe0688bd5ecbd8265..0023f34b62c60fc5f177d80c4d6ad2e06b2e8b0d 100644 --- a/constants.py +++ b/constants.py @@ -1,5 +1,5 @@ EPOCHS = 50 -BATCH_SIZE = 64 -INPUT_DIM = 128 +BATCH_SIZE = 32 +INPUT_DIM = 224 VAL_SPLIT = 2000 TEST_SPLIT = 500 \ No newline at end of file diff --git a/main.py b/main.py index 5ef1384d603afd4220a0d1d711e7f52ef4b8b994..b0877bfbf36120a74a87eda4b9f20fc4a82e5c48 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,11 @@ import matplotlib.pyplot as plt import numpy as np - +import torch from torchsummary import summary import datahandler import train_seg import seg_model -from constants import EPOCHS +from constants import EPOCHS, INPUT_DIM def main(): @@ -15,6 +15,7 @@ def seg_main(): train_loader, val_loader = datahandler.getHamDataLoadersSeg() model = seg_model.SegmentationModel([3,16,32,64,1],[False, True, True]) #summary(model, (3,128,128), batch_size=16,device="cpu") + model.load_state_dict(torch.load(f"Models/segmentation{INPUT_DIM}skip.pt")) train_seg.train_epochs(model, train_loader, val_loader, EPOCHS) if __name__ == '__main__': diff --git a/seg_classifier_model.py b/seg_classifier_model.py index 2a4d92c9dc7483f4be009859c88f857478066115..f212f72e3fa1e4f1178931c1e38a25764dac9103 100644 --- a/seg_classifier_model.py +++ b/seg_classifier_model.py @@ -1,19 +1,36 @@ import torch import torch.nn as nn +from torchvision import models from constants import INPUT_DIM +from seg_model import SegmentationModel class SegClassifier(nn.Module): def __init__(self, seg_model, classifier, train_seg = False): self.seg_model = seg_model self.classifier = classifier + self.train_seg = train_seg #modifiy classifier first layer not add 4th input channel + new_conv = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + model.features.conv0 = new_conv + def forward(x): if self.train_seg: + seg_mask = self.seg_model(x) + else: with torch.no_grad(): seg_mask = self.seg_model(x) - else: - seg_mask = self.seg_model(x) x = torch.cat([x,seg_mask],1) return self.classifier(x) + +def get_model(classifier_path: str, seg_path: str): + classifier = models.densenet121(pretrained=True) + in_features = classifier.classifier.in_features + classifier.classifier = nn.Linear(in_features, 7) + classifier.load_state_dict(torch.load(classifier_path)) + + model = seg_model.SegmentationModel([3,16,32,64,1],[False, True, True]) + #summary(model, (3,128,128), batch_size=16,device="cpu") + model.load_state_dict(torch.load(seg_path)) + return SegClassifier(seg_model, classifier) diff --git a/seg_model.py b/seg_model.py index 4e814e2cf7951d18b31521f29953a2153f4ba6f4..8d0db719d35fbe72b08165c5709ba4bbabe55fd4 100644 --- a/seg_model.py +++ b/seg_model.py @@ -4,10 +4,10 @@ import torch.nn as nn from constants import INPUT_DIM class ConvModule(nn.Module): - def __init__(self, in_channels, out_channels, pool=False): + def __init__(self, in_channels, out_channels, pool=False, kernel_size=3): super().__init__() # Define your network architecture here - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same') + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding='same') self.relu = nn.ReLU(inplace=True) self.bn = nn.BatchNorm2d(out_channels) self.dropout = nn.Dropout2d(p=0.1) @@ -22,11 +22,11 @@ class ConvModule(nn.Module): return x class DeconvModule(nn.Module): - def __init__(self, in_channels, out_channels, depool=False): + def __init__(self, in_channels, out_channels, depool=False, kernel_size=3): super().__init__() stride = 2 if depool else 1 output_padding = 1 if depool else 0 - self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, output_padding=output_padding) + self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1, output_padding=output_padding) self.relu = nn.ReLU(inplace=True) self.bn1 = nn.BatchNorm2d(out_channels) self.dropout = nn.Dropout2d(p=0.1) @@ -38,7 +38,6 @@ class DeconvModule(nn.Module): x = self.dropout(x) return x - class SegmentationModel(nn.Module): def __init__(self, channels: [int], pool: [bool]): super(SegmentationModel, self).__init__() @@ -50,19 +49,23 @@ class SegmentationModel(nn.Module): self.en5 = ConvModule(64,128) #skip connection self.en6 = ConvModule(128,128, True) self.en7 = ConvModule(128, 256) #skip connection - self.en8 = ConvModule(256, 256) + self.en8 = ConvModule(256, 256, True) + self.en9 = ConvModule(256, 512) #skip connection + self.en10 = ConvModule(512, 512) #Decoder - self.de1 = DeconvModule(256, 256) - self.de2 = DeconvModule(256*2, 128) #skip connection - self.de3 = DeconvModule(128, 128, True) - self.de4 = DeconvModule(128*2, 64)#skip connection - self.de5 = DeconvModule(64, 64, True) - self.de6 = DeconvModule(64*2, 32) #skip connection - self.de7 = DeconvModule(32, 16, True) - self.de8 = DeconvModule(16*2, 1)# skip connection + self.de1 = DeconvModule(512, 512) + self.de2 = DeconvModule(512*2, 256)#skip connection + self.de3 = DeconvModule(256, 256, True) + self.de4 = DeconvModule(256*2, 128) #skip connection + self.de5 = DeconvModule(128, 128, True) + self.de6 = DeconvModule(128*2, 64)#skip connection + self.de7 = DeconvModule(64, 64, True) + self.de8 = DeconvModule(64*2, 32) #skip connection + self.de9 = DeconvModule(32, 16, True) + self.de10 = DeconvModule(16*2, 1)# skip connection - self.sigmoid = nn.Sigmoid() + #self.sigmoid = nn.Sigmoid() def forward(self, x): # Perform forward pass of the network @@ -78,17 +81,22 @@ class SegmentationModel(nn.Module): x = self.en7(x) x7 = x x = self.en8(x) + x = self.en9(x) + x9 = x + x = self.en10(x) x = self.de1(x) - x = self.de2(torch.concat((x,x7),1)) + x = self.de2(torch.concat((x,x9),1)) x = self.de3(x) - x = self.de4(torch.concat((x,x5),1)) + x = self.de4(torch.concat((x,x7),1)) x = self.de5(x) - x = self.de6(torch.concat((x,x3),1)) + x = self.de6(torch.concat((x,x5),1)) x = self.de7(x) - x = self.de8(torch.concat((x,x1),1)) + x = self.de8(torch.concat((x,x3),1)) + x = self.de9(x) + x = self.de10(torch.concat((x,x1),1)) - x = self.sigmoid(x) + #x = self.sigmoid(x) return x diff --git a/train_seg.py b/train_seg.py index 3954c6acb4dc37f101cc2a9083b84ce5017f1b26..5c18a47b5049bb7df37fc4759d4cb26030227caa 100644 --- a/train_seg.py +++ b/train_seg.py @@ -2,15 +2,33 @@ import torch.optim as optim import torch.nn as nn from torch.utils.data import DataLoader import torch +import numpy as np from tqdm import tqdm from constants import INPUT_DIM +def jaccard_score(y_pred: torch.Tensor, y_true: torch.Tensor, threshold: float = 0.5) -> float: + # Convert logits to probabilities using the sigmoid function + y_pred = torch.sigmoid(y_pred.cpu()) + y_true = y_true.cpu().numpy() + # Convert probabilities to binary mask using the threshold + y_pred = (y_pred >= threshold).numpy() + + # Calculate the intersection and union of the two masks + intersection = np.logical_and(y_pred, y_true) + union = np.logical_or(y_pred, y_true) + + # Calculate the Jaccard score + jaccard = np.sum(intersection) / np.sum(union) + + return jaccard + def train(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, loader: DataLoader, device: str): epoch_loss = 0 + epoch_acc = 0 model.train() - loop = tqdm(loader, desc="\tTraining", ncols=80, mininterval=0.1) + loop = tqdm(loader, desc="\tTraining", ncols=100, mininterval=0.1) for x, y in loop: x, y = x.to(device), y.to(device) # Move data to GPU if available optimizer.zero_grad() @@ -18,50 +36,118 @@ def train(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, lo loss = criterion(y_pred, y) loss.backward() optimizer.step() + acc = jaccard_score(y_pred, y) epoch_loss += loss.item() - loop.set_postfix(loss = loss.item()) + epoch_acc += acc + loop.set_postfix(loss = loss.item(), accuracy=acc) - return epoch_loss / len(loader) + return epoch_loss / len(loader) , epoch_acc / len(loader) def eval(model: nn.Module, criterion: nn.Module, loader: DataLoader, device: str): epoch_loss = 0 + epoch_acc = 0 model.eval() - loop = tqdm(loader, desc="\tValidation", ncols=80, mininterval=0.1) + loop = tqdm(loader, desc="\tValidation", ncols=100, mininterval=0.1) for x, y in loop: x, y = x.to(device), y.to(device) # Move data to GPU if available y_pred = model(x) loss = criterion(y_pred, y) + acc = jaccard_score(y_pred, y) epoch_loss += loss.item() - loop.set_postfix(loss = loss.item()) + epoch_acc += acc + loop.set_postfix(loss = loss.item(), accuracy=acc) - return epoch_loss / len(loader) + return epoch_loss / len(loader) , epoch_acc / len(loader) def train_epochs(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, epochs: int): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) - optimizer = optim.Adam(model.parameters(), lr=0.02) + lr = 0.0001 + momentum = 0.9 + weight_decay = 0.0001 + nesterov = True + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov) criterion = nn.BCEWithLogitsLoss() criterion.to(device) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True) best_val_loss = float('inf') - train_metrics = {'train_loss': [], 'val_loss': []} + train_metrics = {'train_loss': [], 'train_acc':[], 'val_loss': [], 'val_acc':[]} for epoch in range(1, epochs+1): print(f"Epoch: {epoch}") print("-"*20) - train_loss = train(model, criterion, optimizer, train_loader, device) - print(f'\tTrain Loss: {train_loss:.3f}') - val_loss = eval(model, criterion, val_loader, device) - print(f'\tValidation Loss: {val_loss:.3f}') + train_loss, train_acc = train(model, criterion, optimizer, train_loader, device) + val_loss, val_acc = eval(model, criterion, val_loader, device) + + print(f'\t Train Loss: {train_loss:.3f}, Train Accuracy: {train_acc:.3f}') + print(f'\tValidation Loss: {val_loss:.3f}, Validation Accuracy: {val_acc:.3f}') train_metrics['train_loss'].append(train_loss) + train_metrics['train_acc'].append(train_acc) train_metrics['val_loss'].append(val_loss) + train_metrics['val_acc'].append(val_acc) if val_loss < best_val_loss: best_val_loss = val_loss torch.save(model.state_dict(), f"Models/segmentation{INPUT_DIM}skip.pt") + scheduler.step(val_loss) + +def train_epochs_classifier(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, epochs: int): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + lr = 0.0001 + momentum = 0.9 + weight_decay = 0.0001 + nesterov = True + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov) + criterion = nn.BCEWithLogitsLoss() + criterion.to(device) + + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True) + + best_val_loss = float('inf') + train_metrics = {'train_loss': [], 'train_acc':[], 'val_loss': [], 'val_acc':[]} + for epoch in range(1, epochs+1): + print(f"Epoch: {epoch}") + print("-"*20) + + train_loss, train_acc = train(model, criterion, optimizer, train_loader, device) + val_loss, val_acc = eval(model, criterion, val_loader, device) + + print(f'\t Train Loss: {train_loss:.3f}, Train Accuracy: {train_acc:.3f}') + print(f'\tValidation Loss: {val_loss:.3f}, Validation Accuracy: {val_acc:.3f}') + + train_metrics['train_loss'].append(train_loss) + train_metrics['train_acc'].append(train_acc) + train_metrics['val_loss'].append(val_loss) + train_metrics['val_acc'].append(val_acc) + + scheduler.step(val_loss) + + model.train_seg = True + + best_val_loss = float('inf') + for epoch in range(1, epochs+1): + print(f"Epoch: {epoch}") + print("-"*20) + + train_loss, train_acc = train(model, criterion, optimizer, train_loader, device) + val_loss, val_acc = eval(model, criterion, val_loader, device) + + print(f'\t Train Loss: {train_loss:.3f}, Train Accuracy: {train_acc:.3f}') + print(f'\tValidation Loss: {val_loss:.3f}, Validation Accuracy: {val_acc:.3f}') + + train_metrics['train_loss'].append(train_loss) + train_metrics['train_acc'].append(train_acc) + train_metrics['val_loss'].append(val_loss) + train_metrics['val_acc'].append(val_acc) + + if val_loss < best_val_loss: + best_val_loss = val_loss + torch.save(model.state_dict(), f"Models/segmentation_classifier{INPUT_DIM}skip.pt") + scheduler.step(val_loss) \ No newline at end of file