diff --git a/alex_trainL.py b/alex_trainL.py index 885979f3f16b3aa0503d639c222fda5fa6d346a4..c9316d3be03857e8c9d1366dcf65d47430e07c7d 100644 --- a/alex_trainL.py +++ b/alex_trainL.py @@ -78,7 +78,11 @@ def train_model(args, model_id, dataset_id): # Names of the sequences that do have GT flow #seq_names = os.listdir(os.path.join(root_folder, "train_optical_flow")) # List full of individual sequence dataset objects2. Iterate through sequence names creating datasets. - seq_names = ["zurich_city_10_b", "zurich_city_05_b", + + + seq_names = ["zurich_city_10_a","zurich_city_11_a","zurich_city_02_d","zurich_city_11_b", + "zurich_city_10_b", + "zurich_city_05_b", "thun_00_a", "zurich_city_11_c", "zurich_city_08_a", @@ -105,7 +109,7 @@ def train_model(args, model_id, dataset_id): print(f"val: {seq_datasets[14:]}") train_loader = DataLoader(train_data, batch_size=args.batch_size, num_workers=5, shuffle= False, pin_memory = True) - val_loader = DataLoader(train_data, batch_size=args.batch_size, num_workers=5, shuffle = False, pin_memory = True) + val_loader = DataLoader(val_data, batch_size=args.batch_size, num_workers=5, shuffle = False, pin_memory = True) #print(f"train sequences: {seq_names[:14]}") #print(f"validation sequences: {seq_names[14:]}")