From 7cb28a451fa37d20055d4afc9c175e0df943c899 Mon Sep 17 00:00:00 2001
From: "Wu, Jiantao (PG/R - Comp Sci & Elec Eng)" <jiantao.wu@surrey.ac.uk>
Date: Fri, 23 Feb 2024 16:48:50 +0000
Subject: [PATCH] call gin.parse before launching the program

---
 vitookit/evaluation/eval_cls_ffcv.py    | 3 ++-
 vitookit/evaluation/eval_linear.py      | 6 ++++--
 vitookit/evaluation/eval_linear_ffcv.py | 6 ++++--
 vitookit/utils/submitit.py              | 4 ++++
 4 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py
index a3cf9f8..fa9e32a 100644
--- a/vitookit/evaluation/eval_cls_ffcv.py
+++ b/vitookit/evaluation/eval_cls_ffcv.py
@@ -221,7 +221,8 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
 def main(args):
     misc.init_distributed_mode(args)
 
-    print(args)
+    print("args: ", args)
+    print("configure: ", gin.config_str())
     import torch
     device = torch.device(args.device)
 
diff --git a/vitookit/evaluation/eval_linear.py b/vitookit/evaluation/eval_linear.py
index c2c5d75..a28c365 100644
--- a/vitookit/evaluation/eval_linear.py
+++ b/vitookit/evaluation/eval_linear.py
@@ -32,7 +32,7 @@ from vitookit.utils.helper import aug_parse, load_pretrained_weights, log_metric
 from timm.models.layers import trunc_normal_
 
 from vitookit.utils.lars import LARS
-
+import gin
 
 
 
@@ -120,7 +120,9 @@ def get_args_parser():
 
 def main(args):
     misc.init_distributed_mode(args)
-
+    print("args: ", args)
+    print("configure: ", gin.config_str())
+    
     print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
     print("{}".format(args).replace(', ', ',\n'))
 
diff --git a/vitookit/evaluation/eval_linear_ffcv.py b/vitookit/evaluation/eval_linear_ffcv.py
index 5c9e1bd..b28277e 100644
--- a/vitookit/evaluation/eval_linear_ffcv.py
+++ b/vitookit/evaluation/eval_linear_ffcv.py
@@ -34,7 +34,7 @@ from vitookit.utils.lars import LARS
 from vitookit.datasets.ffcv_transform import SimplePipeline, ValPipeline
 from ffcv import Loader
 from ffcv.loader import OrderOption
-
+import gin
 
 
 def get_args_parser():
@@ -120,7 +120,9 @@ def get_args_parser():
 
 def main(args):
     misc.init_distributed_mode(args)
-    args.distributed = True
+    print("args: ", args)
+    print("configure: ", gin.config_str())
+    
     print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
     print("{}".format(args).replace(', ', ',\n'))
 
diff --git a/vitookit/utils/submitit.py b/vitookit/utils/submitit.py
index 3dc5233..4c73dd0 100644
--- a/vitookit/utils/submitit.py
+++ b/vitookit/utils/submitit.py
@@ -115,6 +115,10 @@ class Trainer(object):
         module_args.world_size = job_env.num_tasks
         
         module_args.comment = f"Job {job_env.job_id} on {job_env.num_tasks} GPUs"
+        
+        import gin
+        if not gin.config_is_locked():
+            gin.parse_config_files_and_bindings(module_args.cfgs,module_args.gin)
         print("Setting up GPU args", module_args)
         print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
 
-- 
GitLab