diff --git a/alex_trainL.py b/alex_trainL.py index 084d953265b3bdd8b62ac7e5960d2dbb979facf4..e5f802d703dc7a57efd67f657e95c5f9955ab1f4 100644 --- a/alex_trainL.py +++ b/alex_trainL.py @@ -102,8 +102,8 @@ def train_model(args, model_id, dataset_id): # Naive 80/20 (change this I guess) #print(len(seq_datasets[0])) - train_data = ConcatDataset(seq_datasets[:14]) # 14 - val_data = ConcatDataset(seq_datasets[14:]) + train_data = ConcatDataset(seq_datasets[:1]) # 14 + val_data = ConcatDataset(seq_datasets[:1]) print(f"train: {seq_names[:14]}") print(f"val: {seq_names[14:]}")