diff --git a/alex_trainL.py b/alex_trainL.py index 2d9380ab1d58528adb593c979e15f6c678a86dfb..747852c685c58dcb49528eedd5f811eeabc5230f 100644 --- a/alex_trainL.py +++ b/alex_trainL.py @@ -86,8 +86,8 @@ def train_model(args, model_id, dataset_id): # Naive 80/20 (change this I guess) #print(len(seq_datasets[0])) - train_data = ConcatDataset(seq_datasets[:14]) # 14 - val_data = ConcatDataset(seq_datasets[14:]) + train_data = ConcatDataset(seq_datasets[:1]) # 14 + val_data = ConcatDataset(seq_datasets[:1]) print(train_data)