Skip to content
Snippets Groups Projects
Commit 69995a5f authored by JamesTrewern's avatar JamesTrewern
Browse files

WIP

parent c6c00517
No related branches found
No related tags found
1 merge request!2Quick fixes
EPOCHS = 16
BATCH_SIZE = 32
INPUT_DIM = 32
\ No newline at end of file
INPUT_DIM = 32
VAL_SPLIT = 2000
TEST_SPLIT = 500
\ No newline at end of file
......@@ -29,26 +29,21 @@ def getHamDataLoaders():
df_ham, df_ham_single_image = setupHamData()
# Then returns two dataframes, df_ham_val is a 20% split of all single images, df_ham_train is the rest
df_ham_train, df_ham_val = createHamTrainTest(df_ham, df_ham_single_image)
df_ham_train, df_ham_val = createHamTrainVal(df_ham, df_ham_single_image)
# Define the training set using the table train_df and using our defined transitions (train_transform)
train_transform, _ = getTrainValTransform(INPUT_DIM, 0.48215827, 0.24348505)
train_transform, val_transform = getTrainValTransform(INPUT_DIM, 0.48215827, 0.24348505)
training_set = HAM10000(df_ham_train, transform=train_transform)
trainLoader = DataLoader(training_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
# Same for the validation set:
validation_set = HAM10000(df_ham_val, transform=train_transform)
validation_set = HAM10000(df_ham_val, transform=val_transform)
valLoader = DataLoader(validation_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
return trainLoader, valLoader
def getHamTestLoader():
testLoader = None
return testLoader
# Excludes rows that are in the val set
# Identifies if an image is part of the train or val
def get_val_rows(x):
......@@ -160,7 +155,7 @@ def equalSampling(df_ham_train):
return concat_df_ham_train
def createHamTrainTest(df_ham, df_ham_single_image):
def createHamTrainVal(df_ham, df_ham_single_image):
# create a val set knowing none of the images have multi in the train set
y = df_ham_single_image['cell_type_idx']
......@@ -178,29 +173,17 @@ def createHamTrainTest(df_ham, df_ham_single_image):
count = 0
for x in df_ham['train_or_val']:
#print(x)
if x in val_list:
df_ham.iloc[count, df_ham.columns.get_loc('train_or_val')] = 'val'
#print((df_ham.iloc[count, df_ham.columns.get_loc('train_or_val')]), "\t val")
count += 1
else:
df_ham.iloc[count, df_ham.columns.get_loc('train_or_val')] = 'train'
#print((df_ham.iloc[count, df_ham.columns.get_loc('train_or_val')]), "\t train")
count += 1
#print("\n df_ham After: \n", df_ham['train_or_val'], "\n\n")
#print("\n df_ham After: \n", df_ham, "\n\n")
# filter out the train rows
df_ham_train = df_ham[df_ham['train_or_val'] == 'train']
#print("Length of df_ham_train set: ", len(df_ham_train))
#print("Length of df_ham_val set: ", len(df_ham_val))
# df_ham_val is a 20% split of all single images, df_ham_train is the rest
#print("\nTraining samples: ", df_ham_train[['cell_type', 'cell_type_idx']].value_counts(), "\n")
#print("\nTest Samples (No multi): ", df_ham_val[['cell_type', 'cell_type_idx']].value_counts(), "\n")
#print("\nTotal Samples: ", df_ham[['cell_type', 'cell_type_idx']].value_counts(), "\n")
# Equilising training samples
concat_df_ham_train = equalSampling(df_ham_train)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment