From 26feb36fec4afd765f78d4b770f75546bb451ce6 Mon Sep 17 00:00:00 2001
From: ah02299 <ah02299@surrey.ac.uk>
Date: Wed, 14 Feb 2024 13:43:51 +0000
Subject: [PATCH] noooo

---
 alex_trainL.py          | 4 ++--
 alex_train_onesample.py | 6 +++---
 edenn/datasets/dsec.py  | 3 ---
 3 files changed, 5 insertions(+), 8 deletions(-)

diff --git a/alex_trainL.py b/alex_trainL.py
index c9316d3..b0aa425 100644
--- a/alex_trainL.py
+++ b/alex_trainL.py
@@ -108,8 +108,8 @@ def train_model(args, model_id, dataset_id):
         print(f"train: {seq_datasets[:14]}")
         print(f"val: {seq_datasets[14:]}")
 
-        train_loader = DataLoader(train_data, batch_size=args.batch_size, num_workers=5, shuffle= False, pin_memory = True)
-        val_loader = DataLoader(val_data, batch_size=args.batch_size, num_workers=5, shuffle = False, pin_memory = True)
+        train_loader = DataLoader(train_data, batch_size=args.batch_size, num_workers=1, shuffle= False, pin_memory = True)
+        val_loader = DataLoader(val_data, batch_size=args.batch_size, num_workers=1, shuffle = False, pin_memory = True)
 
         #print(f"train sequences: {seq_names[:14]}")
         #print(f"validation sequences: {seq_names[14:]}")
diff --git a/alex_train_onesample.py b/alex_train_onesample.py
index 993678d..453ef84 100644
--- a/alex_train_onesample.py
+++ b/alex_train_onesample.py
@@ -76,12 +76,12 @@ def train_model(args, model_id, dataset_id):
         for sample in seq_names:
             print(sample)
             # Take a smaller spatial crop to get more granularity in the temporal dimension. (same crop as in E-RAFT)
-            seq_datasets.append(SequenceEDENN(Path(root_folder),"zurich_city_05_a",num_bins=args.n_bins, crop_window=(288, 384), mode="val"))
+            seq_datasets.append(SequenceEDENN(Path(root_folder),sample,num_bins=args.n_bins, crop_window=(288, 384), mode="val"))
 
         # Naive 80/20 (change this I guess)
         print(len(seq_datasets[0]))
-        train_data = ConcatDataset(seq_datasets[:1])
-        val_data = ConcatDataset(seq_datasets[:1])
+        train_data = ConcatDataset(seq_datasets[:14])
+        val_data = ConcatDataset(seq_datasets[14:])
 
         print(f"train sequences: {seq_names[:14]}")
         print(f"validation sequences: {seq_names[14:]}")
diff --git a/edenn/datasets/dsec.py b/edenn/datasets/dsec.py
index 523690a..570fd5e 100644
--- a/edenn/datasets/dsec.py
+++ b/edenn/datasets/dsec.py
@@ -197,9 +197,6 @@ class SequenceEDENN(Dataset):
     def get_data_sample(self, index, crop_window=None, flip=None):
         # Get full event stream in between flow GT readings.
 
-        print(f"index: {index}")
-        print(self.flow_names[index])
-        print(self.flow_root)
         ts_start = self.flow_timestamps[index][0]
         ts_end =  self.flow_timestamps[index][1]
 
-- 
GitLab