From 4dfeb51d604119d6a54b5a36044f88609f81c215 Mon Sep 17 00:00:00 2001
From: "Wu, Jiantao (PG/R - Comp Sci & Elec Eng)" <jiantao.wu@surrey.ac.uk>
Date: Tue, 13 Feb 2024 13:37:04 +0000
Subject: [PATCH] optimize submitit and vitrun

---
 bin/submitit                         | 38 +++++++++++++++++++++++-----
 bin/vitrun                           | 21 ++++++++++-----
 vitookit/evaluation/eval_cls_ffcv.py |  7 +----
 vitookit/utils/helper.py             |  1 +
 4 files changed, 47 insertions(+), 20 deletions(-)

diff --git a/bin/submitit b/bin/submitit
index 2b7aabf..ea85887 100755
--- a/bin/submitit
+++ b/bin/submitit
@@ -6,7 +6,7 @@
 # --------------------------------------------------------
 # A script to run multinode training with submitit.
 # --------------------------------------------------------
-from PIL import Image
+
 import argparse
 import os
 import uuid
@@ -29,13 +29,15 @@ def parse_args():
     parser.add_argument("--partition", default="big", type=str, help="Partition where to submit")
     parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
     parser.add_argument( "--job_dir", default='',type=str,)
+    parser.add_argument('--fast_dir',default='', help="The dictory of fast disk to load the datasets")
     
     args, known= parser.parse_known_args()
     return args
 
 
 def get_shared_folder(root) -> Path:
-    p = Path(f"{root}/experiments")
+    root = root.replace("%j", "shared")
+    p = Path(root)
     os.makedirs(str(p), exist_ok=True)
     if Path(root).is_dir():
         return p
@@ -65,27 +67,48 @@ class Trainer(object):
 
     def __call__(self):
         self._setup_gpu_args()
-        print("passing args", self.module_args)
+        
+        # move the dataset to fast_dir
+        fast_dir = self.args.fast_dir
+        if fast_dir:
+            import shutil
+            for key,value in self.module_args.__dict__.items():
+                if isinstance(value,str) and '.ffcv' in value:
+                    os.makedirs(fast_dir, exist_ok=True)
+                    # Copy the file
+                    new_path = shutil.copy(value, fast_dir)
+                    self.module_args.__dict__[key] = new_path
         self.module.main(self.module_args)
 
     def checkpoint(self):
+        print("Checkpointing")
         import os
         import submitit
+        job_env = submitit.JobEnvironment()
+        print("Requeuing ", self.args)
         
-        checkpoint_file = os.path.join(self.module_args.output_dir, "checkpoint.pth")
+        output_dir = Path(str(self.args.job_dir))
+        
+        checkpoint_file = os.path.join(output_dir, "checkpoint.pth")  
+        self.args.dist_url = get_init_file(output_dir).as_uri()
+        empty_trainer = type(self)(self.args)      
         if os.path.exists(checkpoint_file):
-            self.args.resume = checkpoint_file
-        print("Requeuing ", self.module_args)
-        empty_trainer = type(self)(self.module_args)
+            empty_trainer.module_args.resume = checkpoint_file
+        
+        print("Requeueing with ", empty_trainer.module_args)
         return submitit.helpers.DelayedSubmission(empty_trainer)
 
     def _setup_gpu_args(self):
         import submitit
         module_args = self.module_args
         job_env = submitit.JobEnvironment()
+        output_dir = Path(str(self.args.job_dir).replace("%j", str(job_env.job_id)))
+        module_args.output_dir = output_dir
+        
         module_args.gpu = job_env.local_rank
         module_args.rank = job_env.global_rank
         module_args.world_size = job_env.num_tasks
+        print("Setting up GPU args", module_args)
         print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
 
 
@@ -122,6 +145,7 @@ def main():
 
     executor.update_parameters(name="eval")
     args.dist_url = get_init_file(args.job_dir).as_uri()
+    print("args:", args)
     trainer = Trainer(args)
     job = executor.submit(trainer)
     
diff --git a/bin/vitrun b/bin/vitrun
index b1843c6..261f64c 100644
--- a/bin/vitrun
+++ b/bin/vitrun
@@ -1,9 +1,7 @@
 #!/usr/bin/env python
 
-import os
-import distutils.sysconfig as sysconfig
 import pkg_resources
-import sys
+# import sys
 
 def get_install_location(package_name):
     try:
@@ -15,7 +13,16 @@ def get_install_location(package_name):
 # Replace 'numpy' with the name of the package you're interested in
 pack_path = pkg_resources.get_distribution('vitookit').location
 
-# replace the program name with the absolute path of the script you want to run
-argv = [ os.path.join(pack_path,'vitookit','evaluation',i) if ".py" in i else i for i in sys.argv[1:]]
-# Pass all script arguments to eval_cls.py
-os.system(f"torchrun {' '.join(argv)}")
\ No newline at end of file
+# # replace the program name with the absolute path of the script you want to run
+# argv = [ os.path.join(pack_path,'vitookit','evaluation',i) if ".py" in i else i for i in sys.argv[1:]]
+# # Pass all script arguments to eval_cls.py
+# os.system(f"torchrun {' '.join(argv)}")
+
+import re
+import sys, os
+from torch.distributed.run import main
+if __name__ == '__main__':
+    sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
+    sys.argv[1] = os.path.join(pack_path,'vitookit','evaluation',sys.argv[1])
+    print(sys.argv)
+    sys.exit(main())
diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py
index 82d7b54..6f97c52 100644
--- a/vitookit/evaluation/eval_cls_ffcv.py
+++ b/vitookit/evaluation/eval_cls_ffcv.py
@@ -207,11 +207,6 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
         if not math.isfinite(loss_value):
             print("Loss is {}, stopping training".format(loss_value))
             sys.exit(1)
-        # this attribute is added by timm on one optimizer (adahessian)
-        # is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
-
-        # if model_ema is not None:
-        #     model_ema.update(model)
             
         if wandb.run: 
             wandb.log({'train/loss':loss})
@@ -349,7 +344,7 @@ def main(args):
             if epoch == dres.end_ramp:
                 print("enhance augmentation!", epoch, dres.end_ramp)
                 ## enhance augmentation, see efficientnetv2
-                data_loader_train = Loader(args.train_path, pipelines=ThreeAugmentPipeline(color_jitter=0.4),
+                data_loader_train = Loader(args.train_path, pipelines=ThreeAugmentPipeline(color_jitter=0.3),
                         batch_size=args.batch_size, num_workers=args.num_workers, 
                         order=order, distributed=args.distributed,seed=args.seed)
             dres(data_loader_train,epoch,True)
diff --git a/vitookit/utils/helper.py b/vitookit/utils/helper.py
index 24026e9..5580362 100644
--- a/vitookit/utils/helper.py
+++ b/vitookit/utils/helper.py
@@ -41,6 +41,7 @@ def aug_parse(parser: argparse.ArgumentParser):
             yaml.dump(vars(args), f)
             
         open(output_dir/"config.gin",'w').write(gin.config_str(),)
+    
     return args
 
 
-- 
GitLab