Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
C
COM3015 Group Project
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Trewern, James R (PG/R - Comp Sci & Elec Eng)
COM3015 Group Project
Commits
be7e96b7
Commit
be7e96b7
authored
2 years ago
by
JamesTrewern
Browse files
Options
Downloads
Patches
Plain Diff
sef classifier WIP
parent
bb70496d
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
main.py
+9
-1
9 additions, 1 deletion
main.py
seg_classifier_model.py
+5
-5
5 additions, 5 deletions
seg_classifier_model.py
train_seg.py
+3
-3
3 additions, 3 deletions
train_seg.py
with
17 additions
and
9 deletions
main.py
+
9
−
1
View file @
be7e96b7
...
...
@@ -6,9 +6,10 @@ from torchsummary import summary
import
datahandler
import
train_seg
import
seg_model
from
constants
import
EPOCHS
,
MODEL_NAME
from
constants
import
EPOCHS
,
MODEL_NAME
,
INPUT_DIM
from
train
import
train
from
model
import
get_model
import
seg_classifier_model
def
main
():
model
=
get_model
(
MODEL_NAME
,
7
,
False
)
...
...
@@ -22,6 +23,11 @@ def seg_main():
model
.
load_state_dict
(
torch
.
load
(
f
"
Models/segmentation
{
INPUT_DIM
}
skip.pt
"
))
train_seg
.
train_epochs
(
model
,
train_loader
,
val_loader
,
EPOCHS
)
def
seg_main_classifier
():
model
=
seg_classifier_model
.
get_model
(
"
Models/densenet121_run20.pt
"
,
f
"
Models/segmentation
{
INPUT_DIM
}
skip.pt
"
)
train_loader
,
val_loader
=
datahandler
.
getHamDataLoaders
()
train_seg
.
train_epochs
(
model
,
train_loader
,
val_loader
,
EPOCHS
)
# When run from command line it can take an additional argument:
# if you add the additional argument with parameter 'seg' then it'll run the segmentation training loop
# otherwise it'll just run the main model training loop
...
...
@@ -30,6 +36,8 @@ if __name__ == '__main__':
if
sys
.
argv
[
1
]
==
"
seg
"
:
print
(
"
SEGMENTATION
"
)
seg_main
()
if
sys
.
argv
[
1
]
==
"
segc
"
:
seg_main_classifier
()
else
:
print
(
"
NORMAL
"
)
main
()
...
...
This diff is collapsed.
Click to expand it.
seg_classifier_model.py
+
5
−
5
View file @
be7e96b7
...
...
@@ -7,15 +7,15 @@ from seg_model import SegmentationModel
class
SegClassifier
(
nn
.
Module
):
def
__init__
(
self
,
seg_model
,
classifier
,
train_seg
=
False
):
super
(
SegClassifier
,
self
).
__init__
()
self
.
seg_model
=
seg_model
self
.
classifier
=
classifier
self
.
train_seg
=
train_seg
#modifiy classifier first layer not add 4th input channel
new_conv
=
nn
.
Conv2d
(
4
,
64
,
kernel_size
=
(
7
,
7
),
stride
=
(
2
,
2
),
padding
=
(
3
,
3
),
bias
=
False
)
model
.
features
.
conv0
=
new_conv
self
.
classifier
.
features
.
conv0
=
new_conv
def
forward
(
x
):
def
forward
(
self
,
x
):
if
self
.
train_seg
:
seg_mask
=
self
.
seg_model
(
x
)
else
:
...
...
@@ -30,7 +30,7 @@ def get_model(classifier_path: str, seg_path: str):
classifier
.
classifier
=
nn
.
Linear
(
in_features
,
7
)
classifier
.
load_state_dict
(
torch
.
load
(
classifier_path
))
model
=
seg_model
.
SegmentationModel
([
3
,
16
,
32
,
64
,
1
],[
False
,
True
,
True
])
seg_model
=
SegmentationModel
([
3
,
16
,
32
,
64
,
1
],[
False
,
True
,
True
])
#summary(model, (3,128,128), batch_size=16,device="cpu")
model
.
load_state_dict
(
torch
.
load
(
seg_path
))
seg_
model
.
load_state_dict
(
torch
.
load
(
seg_path
))
return
SegClassifier
(
seg_model
,
classifier
)
This diff is collapsed.
Click to expand it.
train_seg.py
+
3
−
3
View file @
be7e96b7
...
...
@@ -32,7 +32,7 @@ def train(model: nn.Module, criterion: nn.Module, optimizer: optim.Optimizer, lo
for
x
,
y
in
loop
:
x
,
y
=
x
.
to
(
device
),
y
.
to
(
device
)
# Move data to GPU if available
optimizer
.
zero_grad
()
y_pred
=
model
(
x
)
y_pred
=
torch
.
argmax
(
model
(
x
)
)
loss
=
criterion
(
y_pred
,
y
)
loss
.
backward
()
optimizer
.
step
()
...
...
@@ -51,7 +51,7 @@ def eval(model: nn.Module, criterion: nn.Module, loader: DataLoader, device: str
loop
=
tqdm
(
loader
,
desc
=
"
\t
Validation
"
,
ncols
=
100
,
mininterval
=
0.1
)
for
x
,
y
in
loop
:
x
,
y
=
x
.
to
(
device
),
y
.
to
(
device
)
# Move data to GPU if available
y_pred
=
model
(
x
)
y_pred
=
torch
.
argmax
(
model
(
x
)
)
loss
=
criterion
(
y_pred
,
y
)
acc
=
jaccard_score
(
y_pred
,
y
)
epoch_loss
+=
loss
.
item
()
...
...
@@ -104,7 +104,7 @@ def train_epochs_classifier(model: nn.Module, train_loader: DataLoader, val_load
weight_decay
=
0.0001
nesterov
=
True
optimizer
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
lr
,
momentum
=
momentum
,
weight_decay
=
weight_decay
,
nesterov
=
nesterov
)
criterion
=
nn
.
BCEWithLogitsLoss
(
)
nn
.
CrossEntropyLoss
().
to
(
device
)
criterion
.
to
(
device
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
mode
=
'
min
'
,
verbose
=
True
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment