diff --git a/bin/submitit b/bin/submitit
index 2b7aabf3c0aa4528dd73ae4f1324d9f471d939f9..ea8588705760a6fd939b62f5a032afacab72df4d 100755
--- a/bin/submitit
+++ b/bin/submitit
@@ -6,7 +6,7 @@
 # --------------------------------------------------------
 # A script to run multinode training with submitit.
 # --------------------------------------------------------
-from PIL import Image
+
 import argparse
 import os
 import uuid
@@ -29,13 +29,15 @@ def parse_args():
     parser.add_argument("--partition", default="big", type=str, help="Partition where to submit")
     parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
     parser.add_argument( "--job_dir", default='',type=str,)
+    parser.add_argument('--fast_dir',default='', help="The dictory of fast disk to load the datasets")
     
     args, known= parser.parse_known_args()
     return args
 
 
 def get_shared_folder(root) -> Path:
-    p = Path(f"{root}/experiments")
+    root = root.replace("%j", "shared")
+    p = Path(root)
     os.makedirs(str(p), exist_ok=True)
     if Path(root).is_dir():
         return p
@@ -65,27 +67,48 @@ class Trainer(object):
 
     def __call__(self):
         self._setup_gpu_args()
-        print("passing args", self.module_args)
+        
+        # move the dataset to fast_dir
+        fast_dir = self.args.fast_dir
+        if fast_dir:
+            import shutil
+            for key,value in self.module_args.__dict__.items():
+                if isinstance(value,str) and '.ffcv' in value:
+                    os.makedirs(fast_dir, exist_ok=True)
+                    # Copy the file
+                    new_path = shutil.copy(value, fast_dir)
+                    self.module_args.__dict__[key] = new_path
         self.module.main(self.module_args)
 
     def checkpoint(self):
+        print("Checkpointing")
         import os
         import submitit
+        job_env = submitit.JobEnvironment()
+        print("Requeuing ", self.args)
         
-        checkpoint_file = os.path.join(self.module_args.output_dir, "checkpoint.pth")
+        output_dir = Path(str(self.args.job_dir))
+        
+        checkpoint_file = os.path.join(output_dir, "checkpoint.pth")  
+        self.args.dist_url = get_init_file(output_dir).as_uri()
+        empty_trainer = type(self)(self.args)      
         if os.path.exists(checkpoint_file):
-            self.args.resume = checkpoint_file
-        print("Requeuing ", self.module_args)
-        empty_trainer = type(self)(self.module_args)
+            empty_trainer.module_args.resume = checkpoint_file
+        
+        print("Requeueing with ", empty_trainer.module_args)
         return submitit.helpers.DelayedSubmission(empty_trainer)
 
     def _setup_gpu_args(self):
         import submitit
         module_args = self.module_args
         job_env = submitit.JobEnvironment()
+        output_dir = Path(str(self.args.job_dir).replace("%j", str(job_env.job_id)))
+        module_args.output_dir = output_dir
+        
         module_args.gpu = job_env.local_rank
         module_args.rank = job_env.global_rank
         module_args.world_size = job_env.num_tasks
+        print("Setting up GPU args", module_args)
         print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
 
 
@@ -122,6 +145,7 @@ def main():
 
     executor.update_parameters(name="eval")
     args.dist_url = get_init_file(args.job_dir).as_uri()
+    print("args:", args)
     trainer = Trainer(args)
     job = executor.submit(trainer)
     
diff --git a/bin/vitrun b/bin/vitrun
index b1843c6fb537715b9c1c030324c892857bdca2d6..261f64cbfe47ae25c2bbfe406f202f894445475a 100644
--- a/bin/vitrun
+++ b/bin/vitrun
@@ -1,9 +1,7 @@
 #!/usr/bin/env python
 
-import os
-import distutils.sysconfig as sysconfig
 import pkg_resources
-import sys
+# import sys
 
 def get_install_location(package_name):
     try:
@@ -15,7 +13,16 @@ def get_install_location(package_name):
 # Replace 'numpy' with the name of the package you're interested in
 pack_path = pkg_resources.get_distribution('vitookit').location
 
-# replace the program name with the absolute path of the script you want to run
-argv = [ os.path.join(pack_path,'vitookit','evaluation',i) if ".py" in i else i for i in sys.argv[1:]]
-# Pass all script arguments to eval_cls.py
-os.system(f"torchrun {' '.join(argv)}")
\ No newline at end of file
+# # replace the program name with the absolute path of the script you want to run
+# argv = [ os.path.join(pack_path,'vitookit','evaluation',i) if ".py" in i else i for i in sys.argv[1:]]
+# # Pass all script arguments to eval_cls.py
+# os.system(f"torchrun {' '.join(argv)}")
+
+import re
+import sys, os
+from torch.distributed.run import main
+if __name__ == '__main__':
+    sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
+    sys.argv[1] = os.path.join(pack_path,'vitookit','evaluation',sys.argv[1])
+    print(sys.argv)
+    sys.exit(main())
diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py
index 82d7b54e79e74221cfdcc2df787b48b40eb2f3b6..6f97c52ab1fe04384defc1518628e38438e45a69 100644
--- a/vitookit/evaluation/eval_cls_ffcv.py
+++ b/vitookit/evaluation/eval_cls_ffcv.py
@@ -207,11 +207,6 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
         if not math.isfinite(loss_value):
             print("Loss is {}, stopping training".format(loss_value))
             sys.exit(1)
-        # this attribute is added by timm on one optimizer (adahessian)
-        # is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
-
-        # if model_ema is not None:
-        #     model_ema.update(model)
             
         if wandb.run: 
             wandb.log({'train/loss':loss})
@@ -349,7 +344,7 @@ def main(args):
             if epoch == dres.end_ramp:
                 print("enhance augmentation!", epoch, dres.end_ramp)
                 ## enhance augmentation, see efficientnetv2
-                data_loader_train = Loader(args.train_path, pipelines=ThreeAugmentPipeline(color_jitter=0.4),
+                data_loader_train = Loader(args.train_path, pipelines=ThreeAugmentPipeline(color_jitter=0.3),
                         batch_size=args.batch_size, num_workers=args.num_workers, 
                         order=order, distributed=args.distributed,seed=args.seed)
             dres(data_loader_train,epoch,True)
diff --git a/vitookit/utils/helper.py b/vitookit/utils/helper.py
index 24026e97004211f17c5ff5a3ab5f2e5ed04f023f..5580362b4bde2e71c5e24dc6b549eade29ae2399 100644
--- a/vitookit/utils/helper.py
+++ b/vitookit/utils/helper.py
@@ -41,6 +41,7 @@ def aug_parse(parser: argparse.ArgumentParser):
             yaml.dump(vars(args), f)
             
         open(output_dir/"config.gin",'w').write(gin.config_str(),)
+    
     return args