From cd709e8285680ecefb1f2fb86ac2051781493ced Mon Sep 17 00:00:00 2001 From: JamesTrewern <trewern.james@gmail.com> Date: Mon, 15 May 2023 15:06:17 +0100 Subject: [PATCH] train seg classifer get model update --- main.py | 2 +- seg_classifier_model.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index eb0d8d24..9ef7120b 100644 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ def seg_main_classifier(): def eval_main(seg, model_run="densenet121_run20"): if seg: - model = seg_classifier_model.get_model("Models/SegClassifier/classifier.pt", "Models/SegClassifier/seg.pt") + model = seg_classifier_model.get_model_trained("Models/SegClassifier/classifier.pt", "Models/SegClassifier/seg.pt") model.to("cuda") else: model = get_model("densenet121", 7, False) diff --git a/seg_classifier_model.py b/seg_classifier_model.py index bb39ad2a..3027d843 100644 --- a/seg_classifier_model.py +++ b/seg_classifier_model.py @@ -32,7 +32,7 @@ class SegClassifier(nn.Module): self.seg_model.load_state_dict(torch.load(seg_path)) self.classifier.load_state_dict(torch.load(classifier_path)) -def get_model(classifier_path: str, seg_path: str): +def get_model_trained(classifier_path: str, seg_path: str): classifier = models.densenet121(pretrained=True) in_features = classifier.classifier.in_features classifier.classifier = nn.Linear(in_features, 7) @@ -40,3 +40,13 @@ def get_model(classifier_path: str, seg_path: str): model = SegClassifier(seg_model, classifier) model.load(classifier_path,seg_path) return model + +def get_model(classifier_path: str, seg_path: str): + classifier = models.densenet121(pretrained=True) + in_features = classifier.classifier.in_features + classifier.classifier = nn.Linear(in_features, 7) + classifier.load_state_dict(torch.load(classifier_path)) + seg_model = SegmentationModel([3,16,32,64,1],[False, True, True]) + seg_model.load_state_dict(torch.load(seg_path)) + model = SegClassifier(seg_model, classifier) + return model \ No newline at end of file -- GitLab