diff --git a/alex_trainL.py b/alex_trainL.py index c9316d3be03857e8c9d1366dcf65d47430e07c7d..b0aa4251266f222f087fb78155ca4def9818df12 100644 --- a/alex_trainL.py +++ b/alex_trainL.py @@ -108,8 +108,8 @@ def train_model(args, model_id, dataset_id): print(f"train: {seq_datasets[:14]}") print(f"val: {seq_datasets[14:]}") - train_loader = DataLoader(train_data, batch_size=args.batch_size, num_workers=5, shuffle= False, pin_memory = True) - val_loader = DataLoader(val_data, batch_size=args.batch_size, num_workers=5, shuffle = False, pin_memory = True) + train_loader = DataLoader(train_data, batch_size=args.batch_size, num_workers=1, shuffle= False, pin_memory = True) + val_loader = DataLoader(val_data, batch_size=args.batch_size, num_workers=1, shuffle = False, pin_memory = True) #print(f"train sequences: {seq_names[:14]}") #print(f"validation sequences: {seq_names[14:]}") diff --git a/alex_train_onesample.py b/alex_train_onesample.py index 993678d50358a675f0af30361e2991af9d08e231..453ef8494ecf56ed0a85949aef2ab7b18bb9018e 100644 --- a/alex_train_onesample.py +++ b/alex_train_onesample.py @@ -76,12 +76,12 @@ def train_model(args, model_id, dataset_id): for sample in seq_names: print(sample) # Take a smaller spatial crop to get more granularity in the temporal dimension. (same crop as in E-RAFT) - seq_datasets.append(SequenceEDENN(Path(root_folder),"zurich_city_05_a",num_bins=args.n_bins, crop_window=(288, 384), mode="val")) + seq_datasets.append(SequenceEDENN(Path(root_folder),sample,num_bins=args.n_bins, crop_window=(288, 384), mode="val")) # Naive 80/20 (change this I guess) print(len(seq_datasets[0])) - train_data = ConcatDataset(seq_datasets[:1]) - val_data = ConcatDataset(seq_datasets[:1]) + train_data = ConcatDataset(seq_datasets[:14]) + val_data = ConcatDataset(seq_datasets[14:]) print(f"train sequences: {seq_names[:14]}") print(f"validation sequences: {seq_names[14:]}") diff --git a/edenn/datasets/dsec.py b/edenn/datasets/dsec.py index 523690a73510e0ec1a0cc034cd32bf382d8d74d5..570fd5e6ca7e2712706f12272844c50d143621a8 100644 --- a/edenn/datasets/dsec.py +++ b/edenn/datasets/dsec.py @@ -197,9 +197,6 @@ class SequenceEDENN(Dataset): def get_data_sample(self, index, crop_window=None, flip=None): # Get full event stream in between flow GT readings. - print(f"index: {index}") - print(self.flow_names[index]) - print(self.flow_root) ts_start = self.flow_timestamps[index][0] ts_end = self.flow_timestamps[index][1]