From cd709e8285680ecefb1f2fb86ac2051781493ced Mon Sep 17 00:00:00 2001
From: JamesTrewern <trewern.james@gmail.com>
Date: Mon, 15 May 2023 15:06:17 +0100
Subject: [PATCH] train seg classifer get model update

---
 main.py                 |  2 +-
 seg_classifier_model.py | 12 +++++++++++-
 2 files changed, 12 insertions(+), 2 deletions(-)

diff --git a/main.py b/main.py
index eb0d8d24..9ef7120b 100644
--- a/main.py
+++ b/main.py
@@ -33,7 +33,7 @@ def seg_main_classifier():
 
 def eval_main(seg, model_run="densenet121_run20"):
     if seg:
-        model = seg_classifier_model.get_model("Models/SegClassifier/classifier.pt", "Models/SegClassifier/seg.pt")
+        model = seg_classifier_model.get_model_trained("Models/SegClassifier/classifier.pt", "Models/SegClassifier/seg.pt")
         model.to("cuda")
     else:
         model = get_model("densenet121", 7, False)
diff --git a/seg_classifier_model.py b/seg_classifier_model.py
index bb39ad2a..3027d843 100644
--- a/seg_classifier_model.py
+++ b/seg_classifier_model.py
@@ -32,7 +32,7 @@ class SegClassifier(nn.Module):
         self.seg_model.load_state_dict(torch.load(seg_path))
         self.classifier.load_state_dict(torch.load(classifier_path))
 
-def get_model(classifier_path: str, seg_path: str):
+def get_model_trained(classifier_path: str, seg_path: str):
     classifier = models.densenet121(pretrained=True)
     in_features = classifier.classifier.in_features
     classifier.classifier = nn.Linear(in_features, 7)
@@ -40,3 +40,13 @@ def get_model(classifier_path: str, seg_path: str):
     model = SegClassifier(seg_model, classifier)
     model.load(classifier_path,seg_path)
     return model
+
+def get_model(classifier_path: str, seg_path: str):
+    classifier = models.densenet121(pretrained=True)
+    in_features = classifier.classifier.in_features
+    classifier.classifier = nn.Linear(in_features, 7)
+    classifier.load_state_dict(torch.load(classifier_path))
+    seg_model = SegmentationModel([3,16,32,64,1],[False, True, True])
+    seg_model.load_state_dict(torch.load(seg_path))
+    model = SegClassifier(seg_model, classifier)
+    return model
\ No newline at end of file
-- 
GitLab