From 69995a5fa4a43028215de7d9f43b0198c34a208c Mon Sep 17 00:00:00 2001
From: JamesTrewern <trewern.james@gmail.com>
Date: Thu, 13 Apr 2023 13:50:01 +0100
Subject: [PATCH] WIP

---
 constants.py   |  4 +++-
 datahandler.py | 25 ++++---------------------
 2 files changed, 7 insertions(+), 22 deletions(-)

diff --git a/constants.py b/constants.py
index 7173b6f5..7060ece8 100644
--- a/constants.py
+++ b/constants.py
@@ -1,3 +1,5 @@
 EPOCHS = 16
 BATCH_SIZE = 32
-INPUT_DIM = 32
\ No newline at end of file
+INPUT_DIM = 32
+VAL_SPLIT = 2000
+TEST_SPLIT = 500
\ No newline at end of file
diff --git a/datahandler.py b/datahandler.py
index 48ebf37d..a8d7e856 100644
--- a/datahandler.py
+++ b/datahandler.py
@@ -29,26 +29,21 @@ def getHamDataLoaders():
     df_ham, df_ham_single_image = setupHamData()
 
     # Then returns two dataframes, df_ham_val is a 20% split of all single images, df_ham_train is the rest
-    df_ham_train, df_ham_val = createHamTrainTest(df_ham, df_ham_single_image)
+    df_ham_train, df_ham_val = createHamTrainVal(df_ham, df_ham_single_image)
 
     # Define the training set using the table train_df and using our defined transitions (train_transform)
-    train_transform, _ = getTrainValTransform(INPUT_DIM, 0.48215827, 0.24348505)
+    train_transform, val_transform = getTrainValTransform(INPUT_DIM, 0.48215827, 0.24348505)
     training_set = HAM10000(df_ham_train, transform=train_transform)
     trainLoader = DataLoader(training_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
 
     # Same for the validation set:
-    validation_set = HAM10000(df_ham_val, transform=train_transform)
+    validation_set = HAM10000(df_ham_val, transform=val_transform)
     valLoader = DataLoader(validation_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
 
 
     return trainLoader, valLoader
 
 
-def getHamTestLoader():
-    testLoader = None
-    return testLoader
-
-
 # Excludes rows that are in the val set
 # Identifies if an image is part of the train or val
 def get_val_rows(x):
@@ -160,7 +155,7 @@ def equalSampling(df_ham_train):
 
     return concat_df_ham_train
 
-def createHamTrainTest(df_ham, df_ham_single_image):
+def createHamTrainVal(df_ham, df_ham_single_image):
     # create a val set knowing none of the images have multi in the train set
     y = df_ham_single_image['cell_type_idx']
 
@@ -178,29 +173,17 @@ def createHamTrainTest(df_ham, df_ham_single_image):
 
     count = 0
     for x in df_ham['train_or_val']:
-        #print(x)
         if x in val_list:
             df_ham.iloc[count, df_ham.columns.get_loc('train_or_val')] = 'val'
-            #print((df_ham.iloc[count, df_ham.columns.get_loc('train_or_val')]), "\t val")
             count += 1
         else:
             df_ham.iloc[count, df_ham.columns.get_loc('train_or_val')] = 'train'
-            #print((df_ham.iloc[count, df_ham.columns.get_loc('train_or_val')]), "\t train")
             count += 1
-    #print("\n df_ham After: \n", df_ham['train_or_val'], "\n\n")
-    #print("\n df_ham After: \n", df_ham, "\n\n")
-
 
     # filter out the train rows
     df_ham_train = df_ham[df_ham['train_or_val'] == 'train']
-    #print("Length of df_ham_train set: ", len(df_ham_train))
-    #print("Length of df_ham_val set: ", len(df_ham_val))
 
     # df_ham_val is a 20% split of all single images, df_ham_train is the rest
-    #print("\nTraining samples: ", df_ham_train[['cell_type', 'cell_type_idx']].value_counts(), "\n")
-    #print("\nTest Samples (No multi): ", df_ham_val[['cell_type', 'cell_type_idx']].value_counts(), "\n")
-    #print("\nTotal Samples: ", df_ham[['cell_type', 'cell_type_idx']].value_counts(), "\n")
-
     # Equilising training samples
     concat_df_ham_train = equalSampling(df_ham_train)
 
-- 
GitLab