diff --git a/main.py b/main.py index 07c70bfc0ffc6740f7349e4bcdb4c93f7828c62d..37943864d4af16b61bedb0388778e1eaac91add1 100644 --- a/main.py +++ b/main.py @@ -6,9 +6,10 @@ from torchsummary import summary import datahandler import train_seg import seg_model -from constants import EPOCHS, MODEL_NAME +from constants import EPOCHS, MODEL_NAME, INPUT_DIM from train import train from model import get_model +import seg_classifier_model def main(): model = get_model(MODEL_NAME, 7, False) @@ -22,6 +23,11 @@ def seg_main(): model.load_state_dict(torch.load(f"Models/segmentation{INPUT_DIM}skip.pt")) train_seg.train_epochs(model, train_loader, val_loader, EPOCHS) +def seg_main_classifier(): + model = seg_classifier_model.get_model("Models/densenet121_run20.pt", f"Models/segmentation{INPUT_DIM}skip.pt") + train_loader, val_loader = datahandler.getHamDataLoaders() + train_seg.train_epochs(model, train_loader, val_loader, EPOCHS) + # When run from command line it can take an additional argument: # if you add the additional argument with parameter 'seg' then it'll run the segmentation training loop # otherwise it'll just run the main model training loop @@ -30,6 +36,8 @@ if __name__ == '__main__': if sys.argv[1] == "seg": print("SEGMENTATION") seg_main() + if sys.argv[1] == "segc": + seg_main_classifier() else: print("NORMAL") main() diff --git a/seg_classifier_model.py b/seg_classifier_model.py index f212f72e3fa1e4f1178931c1e38a25764dac9103..09aecd54e5c6adc28c5ee4cd6da9c1f7023bfaca 100644 --- a/seg_classifier_model.py +++ b/seg_classifier_model.py @@ -7,15 +7,15 @@ from seg_model import SegmentationModel class SegClassifier(nn.Module): def __init__(self, seg_model, classifier, train_seg = False): + super(SegClassifier, self).__init__() self.seg_model = seg_model self.classifier = classifier self.train_seg = train_seg #modifiy classifier first layer not add 4th input channel new_conv = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) - model.features.conv0 = new_conv + self.classifier.features.conv0 = new_conv - - def forward(x): + def forward(self, x): if self.train_seg: seg_mask = self.seg_model(x) else: @@ -30,7 +30,7 @@ def get_model(classifier_path: str, seg_path: str): classifier.classifier = nn.Linear(in_features, 7) classifier.load_state_dict(torch.load(classifier_path)) - model = seg_model.SegmentationModel([3,16,32,64,1],[False, True, True]) + seg_model = SegmentationModel([3,16,32,64,1],[False, True, True]) #summary(model, (3,128,128), batch_size=16,device="cpu") - model.load_state_dict(torch.load(seg_path)) + seg_model.load_state_dict(torch.load(seg_path)) return SegClassifier(seg_model, classifier) diff --git a/train_seg.py b/train_seg.py index 5c18a47b5049bb7df37fc4759d4cb26030227caa..3dc18c3f0254d0a4994a7d16c5d37503169f3525 100644 --- a/train_seg.py +++ b/train_seg.py @@ -32,7 +32,7 @@ def train(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, lo for x, y in loop: x, y = x.to(device), y.to(device) # Move data to GPU if available optimizer.zero_grad() - y_pred = model(x) + y_pred = torch.argmax(model(x)) loss = criterion(y_pred, y) loss.backward() optimizer.step() @@ -51,7 +51,7 @@ def eval(model: nn.Module, criterion: nn.Module, loader: DataLoader, device: str loop = tqdm(loader, desc="\tValidation", ncols=100, mininterval=0.1) for x, y in loop: x, y = x.to(device), y.to(device) # Move data to GPU if available - y_pred = model(x) + y_pred = torch.argmax(model(x)) loss = criterion(y_pred, y) acc = jaccard_score(y_pred, y) epoch_loss += loss.item() @@ -104,7 +104,7 @@ def train_epochs_classifier(model: nn.Module, train_loader: DataLoader, val_load weight_decay = 0.0001 nesterov = True optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov) - criterion = nn.BCEWithLogitsLoss() + nn.CrossEntropyLoss().to(device) criterion.to(device) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True)