Skip to content
Snippets Groups Projects
Commit be7e96b7 authored by JamesTrewern's avatar JamesTrewern
Browse files

sef classifier WIP

parent bb70496d
No related branches found
No related tags found
No related merge requests found
......@@ -6,9 +6,10 @@ from torchsummary import summary
import datahandler
import train_seg
import seg_model
from constants import EPOCHS, MODEL_NAME
from constants import EPOCHS, MODEL_NAME, INPUT_DIM
from train import train
from model import get_model
import seg_classifier_model
def main():
model = get_model(MODEL_NAME, 7, False)
......@@ -22,6 +23,11 @@ def seg_main():
model.load_state_dict(torch.load(f"Models/segmentation{INPUT_DIM}skip.pt"))
train_seg.train_epochs(model, train_loader, val_loader, EPOCHS)
def seg_main_classifier():
model = seg_classifier_model.get_model("Models/densenet121_run20.pt", f"Models/segmentation{INPUT_DIM}skip.pt")
train_loader, val_loader = datahandler.getHamDataLoaders()
train_seg.train_epochs(model, train_loader, val_loader, EPOCHS)
# When run from command line it can take an additional argument:
# if you add the additional argument with parameter 'seg' then it'll run the segmentation training loop
# otherwise it'll just run the main model training loop
......@@ -30,6 +36,8 @@ if __name__ == '__main__':
if sys.argv[1] == "seg":
print("SEGMENTATION")
seg_main()
if sys.argv[1] == "segc":
seg_main_classifier()
else:
print("NORMAL")
main()
......
......@@ -7,15 +7,15 @@ from seg_model import SegmentationModel
class SegClassifier(nn.Module):
def __init__(self, seg_model, classifier, train_seg = False):
super(SegClassifier, self).__init__()
self.seg_model = seg_model
self.classifier = classifier
self.train_seg = train_seg
#modifiy classifier first layer not add 4th input channel
new_conv = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.features.conv0 = new_conv
self.classifier.features.conv0 = new_conv
def forward(x):
def forward(self, x):
if self.train_seg:
seg_mask = self.seg_model(x)
else:
......@@ -30,7 +30,7 @@ def get_model(classifier_path: str, seg_path: str):
classifier.classifier = nn.Linear(in_features, 7)
classifier.load_state_dict(torch.load(classifier_path))
model = seg_model.SegmentationModel([3,16,32,64,1],[False, True, True])
seg_model = SegmentationModel([3,16,32,64,1],[False, True, True])
#summary(model, (3,128,128), batch_size=16,device="cpu")
model.load_state_dict(torch.load(seg_path))
seg_model.load_state_dict(torch.load(seg_path))
return SegClassifier(seg_model, classifier)
......@@ -32,7 +32,7 @@ def train(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, lo
for x, y in loop:
x, y = x.to(device), y.to(device) # Move data to GPU if available
optimizer.zero_grad()
y_pred = model(x)
y_pred = torch.argmax(model(x))
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()
......@@ -51,7 +51,7 @@ def eval(model: nn.Module, criterion: nn.Module, loader: DataLoader, device: str
loop = tqdm(loader, desc="\tValidation", ncols=100, mininterval=0.1)
for x, y in loop:
x, y = x.to(device), y.to(device) # Move data to GPU if available
y_pred = model(x)
y_pred = torch.argmax(model(x))
loss = criterion(y_pred, y)
acc = jaccard_score(y_pred, y)
epoch_loss += loss.item()
......@@ -104,7 +104,7 @@ def train_epochs_classifier(model: nn.Module, train_loader: DataLoader, val_load
weight_decay = 0.0001
nesterov = True
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
criterion = nn.BCEWithLogitsLoss()
nn.CrossEntropyLoss().to(device)
criterion.to(device)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment