From b9d6f375d28a63d6b6e137c2c6ab4abcbebe4ec0 Mon Sep 17 00:00:00 2001
From: "Wu, Jiantao (PG/R - Comp Sci & Elec Eng)" <jiantao.wu@surrey.ac.uk>
Date: Sat, 10 Feb 2024 23:23:05 +0000
Subject: [PATCH] support submitit

---
 README.MD                            |   5 +
 bin/submitit                         | 132 +++++++++++++++++++++++++++
 setup.py                             |   1 +
 vitookit/evaluation/eval_cls.py      |   7 +-
 vitookit/evaluation/eval_cls_ffcv.py |   4 +-
 vitookit/utils/helper.py             |   2 +-
 6 files changed, 144 insertions(+), 7 deletions(-)
 create mode 100755 bin/submitit

diff --git a/README.MD b/README.MD
index 615884b..c6e201b 100644
--- a/README.MD
+++ b/README.MD
@@ -47,6 +47,11 @@ condor_submit condor/eval_weka_cls.submit model_dir=outputs/dinosara/base ARCH=v
 condor_submit condor/eval_weka_seg.submit model_dir=outputs/dinosara/base
 ```
 
+## Slurm
+
+```bash
+bin/submitit  --module vitookit.evaluation.eval_cls_ffcv   --train_path  ~/data/ffcv/IN1K_train_500_95.ffcv --val_path  ~/data/ffcv/IN1K_val_500_95.ffcv --gin VisionTransformer.global_pool='\"avg\"' -w wandb:dlib/EfficientSSL/lsx2qmys 
+```
 
 ## Test examples
 
diff --git a/bin/submitit b/bin/submitit
new file mode 100755
index 0000000..2b7aabf
--- /dev/null
+++ b/bin/submitit
@@ -0,0 +1,132 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# A script to run multinode training with submitit.
+# --------------------------------------------------------
+from PIL import Image
+import argparse
+import os
+import uuid
+from pathlib import Path
+
+import submitit
+import importlib
+
+from vitookit.utils.helper import aug_parse
+
+def parse_args():
+    # trainer_parser = trainer.get_args_parser()
+    parser = argparse.ArgumentParser("Submitit for evaluation",)
+    parser.add_argument("--module", default="vitookit.evaluation.eval_cls", type=str, help="Module to run")
+    parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
+    parser.add_argument("--nodes", default=1, type=int, help="Number of nodes to request")
+    parser.add_argument("-t", "--timeout", default=1440, type=int, help="Duration of the job")
+    parser.add_argument("--mem", default=400, type=float, help="Memory to request")
+
+    parser.add_argument("--partition", default="big", type=str, help="Partition where to submit")
+    parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
+    parser.add_argument( "--job_dir", default='',type=str,)
+    
+    args, known= parser.parse_known_args()
+    return args
+
+
+def get_shared_folder(root) -> Path:
+    p = Path(f"{root}/experiments")
+    os.makedirs(str(p), exist_ok=True)
+    if Path(root).is_dir():
+        return p
+    raise RuntimeError("No shared folder available")
+
+
+def get_init_file(root):
+    # Init file must not exist, but it's parent dir must exist.
+    os.makedirs(str(get_shared_folder(root)), exist_ok=True)
+    init_file = get_shared_folder(root) / f"{uuid.uuid4().hex}_init"
+    if init_file.exists():
+        os.remove(str(init_file))
+    return init_file
+
+
+class Trainer(object):
+    def __init__(self, args):
+        self.args = args
+        self.module = importlib.import_module(args.module)
+        
+        ## reassing args
+        parser = self.module.get_args_parser()
+        module_args = aug_parse(parser)
+        module_args.output_dir = args.job_dir
+        module_args.dist_url = args.dist_url
+        self.module_args = module_args
+
+    def __call__(self):
+        self._setup_gpu_args()
+        print("passing args", self.module_args)
+        self.module.main(self.module_args)
+
+    def checkpoint(self):
+        import os
+        import submitit
+        
+        checkpoint_file = os.path.join(self.module_args.output_dir, "checkpoint.pth")
+        if os.path.exists(checkpoint_file):
+            self.args.resume = checkpoint_file
+        print("Requeuing ", self.module_args)
+        empty_trainer = type(self)(self.module_args)
+        return submitit.helpers.DelayedSubmission(empty_trainer)
+
+    def _setup_gpu_args(self):
+        import submitit
+        module_args = self.module_args
+        job_env = submitit.JobEnvironment()
+        module_args.gpu = job_env.local_rank
+        module_args.rank = job_env.global_rank
+        module_args.world_size = job_env.num_tasks
+        print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
+
+
+def main():
+    args = parse_args()
+    if args.job_dir=='':
+        args.job_dir = f"outputs/experiments/%j"
+    args.job_dir = os.path.abspath(args.job_dir)
+    # Note that the folder will depend on the job_id, to easily track experiments
+    executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
+
+    num_gpus_per_node = args.ngpus
+    nodes = args.nodes
+    timeout_min = args.timeout
+
+    partition = args.partition
+    kwargs = {}
+    
+    if args.comment:
+        kwargs['slurm_comment'] = args.comment
+
+    executor.update_parameters(
+        mem_gb=args.mem,
+        gpus_per_node=num_gpus_per_node,
+        tasks_per_node=num_gpus_per_node,  # one task per GPU
+        cpus_per_task=10,
+        nodes=nodes,
+        timeout_min=timeout_min,  # max is 60 * 72
+        # Below are cluster dependent parameters
+        slurm_partition=partition,
+        slurm_signal_delay_s=120,
+        **kwargs
+    )
+
+    executor.update_parameters(name="eval")
+    args.dist_url = get_init_file(args.job_dir).as_uri()
+    trainer = Trainer(args)
+    job = executor.submit(trainer)
+    
+    print("Submitted job_id:", job.job_id)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/setup.py b/setup.py
index f065d7c..0ef5dde 100644
--- a/setup.py
+++ b/setup.py
@@ -15,5 +15,6 @@ setup(
     },
     scripts=[
         'bin/vitrun',
+        'bin/submitit',
     ]
 )
\ No newline at end of file
diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py
index 7ef3a16..beb6297 100644
--- a/vitookit/evaluation/eval_cls.py
+++ b/vitookit/evaluation/eval_cls.py
@@ -9,7 +9,7 @@
 Mostly copy-paste from DEiT library:
 https://github.com/facebookresearch/deit/blob/main/main.py
 """
-from PIL import Image # hack to avoid `CXXABI_1.3.9' not found error
+# from PIL import Image # hack to avoid `CXXABI_1.3.9' not found error
 
 import argparse
 import datetime
@@ -178,7 +178,7 @@ def get_args_parser():
 def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                     data_loader: Iterable, optimizer: torch.optim.Optimizer,
                     device: torch.device, epoch: int, loss_scaler,lr_scheduler, max_norm: float = 0,
-                     mixup_fn: Optional[Mixup] = None,
+                     mixup_fn: Optional[Mixup] = None, accum_iter=1
                     ):
     model.train(True)
     metric_logger = misc.MetricLogger(delimiter="  ")
@@ -186,7 +186,6 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
     header = 'Epoch: [{}]'.format(epoch)
     print_freq = max(len(data_loader)//20,20)
     
-    accum_iter = args.accum_iter
     for itr,(samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
         samples = samples.to(device, non_blocking=True)
         targets = targets.to(device, non_blocking=True)
@@ -421,7 +420,7 @@ def main(args):
         train_stats = train_one_epoch(
             model, criterion, data_loader_train,
             optimizer, device, epoch, loss_scaler,lr_scheduler,
-            args.clip_grad,  mixup_fn,
+            args.clip_grad,  mixup_fn, accum_iter=args.accum_iter
         )
         
         checkpoint_paths = ['checkpoint.pth']
diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py
index 0db7702..82d7b54 100644
--- a/vitookit/evaluation/eval_cls_ffcv.py
+++ b/vitookit/evaluation/eval_cls_ffcv.py
@@ -173,6 +173,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                     data_loader: Iterable, optimizer: torch.optim.Optimizer,
                     device: torch.device, epoch: int, loss_scaler,lr_scheduler, max_norm: float = 0,
                      mixup_fn: Optional[Mixup] = None,
+                     accum_iter = 1
                     ):
     model.train(True)
     metric_logger = misc.MetricLogger(delimiter="  ")
@@ -180,7 +181,6 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
     header = 'Epoch: [{}]'.format(epoch)
     print_freq = max(len(data_loader)//20,20)
     
-    accum_iter = args.accum_iter
     for itr,(samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
         samples = samples.to(device, non_blocking=True)
         targets = targets.to(device, non_blocking=True)
@@ -357,7 +357,7 @@ def main(args):
         train_stats = train_one_epoch(
             model, criterion, data_loader_train,
             optimizer, device, epoch, loss_scaler,lr_scheduler,
-            args.clip_grad,  mixup_fn,
+            args.clip_grad,  mixup_fn, accum_iter=args.accum_iter
         )
         
         checkpoint_paths = ['checkpoint.pth']
diff --git a/vitookit/utils/helper.py b/vitookit/utils/helper.py
index d16aefa..24026e9 100644
--- a/vitookit/utils/helper.py
+++ b/vitookit/utils/helper.py
@@ -29,7 +29,7 @@ def aug_parse(parser: argparse.ArgumentParser):
     parser.add_argument('--gin', nargs='+', 
                         help='Overrides config values. e.g. --gin "section.option=value"')
    
-    args = parser.parse_args()
+    args, _ = parser.parse_known_args()
     if args.output_dir:
         output_dir=Path(args.output_dir)
         output_dir.mkdir(parents=True, exist_ok=True)
-- 
GitLab