diff --git a/main.py b/main.py index eb0d8d24f46fdb3c51dea3266f457671c031803c..9ef7120b946c3314d65bbb104a5ac0861a144dc2 100644 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ def seg_main_classifier(): def eval_main(seg, model_run="densenet121_run20"): if seg: - model = seg_classifier_model.get_model("Models/SegClassifier/classifier.pt", "Models/SegClassifier/seg.pt") + model = seg_classifier_model.get_model_trained("Models/SegClassifier/classifier.pt", "Models/SegClassifier/seg.pt") model.to("cuda") else: model = get_model("densenet121", 7, False) diff --git a/seg_classifier_model.py b/seg_classifier_model.py index bb39ad2af00bf0ba416a676c143bd1adc19d1487..3027d8435bbf25e419537c8fca23085f891274b6 100644 --- a/seg_classifier_model.py +++ b/seg_classifier_model.py @@ -32,7 +32,7 @@ class SegClassifier(nn.Module): self.seg_model.load_state_dict(torch.load(seg_path)) self.classifier.load_state_dict(torch.load(classifier_path)) -def get_model(classifier_path: str, seg_path: str): +def get_model_trained(classifier_path: str, seg_path: str): classifier = models.densenet121(pretrained=True) in_features = classifier.classifier.in_features classifier.classifier = nn.Linear(in_features, 7) @@ -40,3 +40,13 @@ def get_model(classifier_path: str, seg_path: str): model = SegClassifier(seg_model, classifier) model.load(classifier_path,seg_path) return model + +def get_model(classifier_path: str, seg_path: str): + classifier = models.densenet121(pretrained=True) + in_features = classifier.classifier.in_features + classifier.classifier = nn.Linear(in_features, 7) + classifier.load_state_dict(torch.load(classifier_path)) + seg_model = SegmentationModel([3,16,32,64,1],[False, True, True]) + seg_model.load_state_dict(torch.load(seg_path)) + model = SegClassifier(seg_model, classifier) + return model \ No newline at end of file