From 7cb28a451fa37d20055d4afc9c175e0df943c899 Mon Sep 17 00:00:00 2001 From: "Wu, Jiantao (PG/R - Comp Sci & Elec Eng)" <jiantao.wu@surrey.ac.uk> Date: Fri, 23 Feb 2024 16:48:50 +0000 Subject: [PATCH] call gin.parse before launching the program --- vitookit/evaluation/eval_cls_ffcv.py | 3 ++- vitookit/evaluation/eval_linear.py | 6 ++++-- vitookit/evaluation/eval_linear_ffcv.py | 6 ++++-- vitookit/utils/submitit.py | 4 ++++ 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py index a3cf9f8..fa9e32a 100644 --- a/vitookit/evaluation/eval_cls_ffcv.py +++ b/vitookit/evaluation/eval_cls_ffcv.py @@ -221,7 +221,8 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, def main(args): misc.init_distributed_mode(args) - print(args) + print("args: ", args) + print("configure: ", gin.config_str()) import torch device = torch.device(args.device) diff --git a/vitookit/evaluation/eval_linear.py b/vitookit/evaluation/eval_linear.py index c2c5d75..a28c365 100644 --- a/vitookit/evaluation/eval_linear.py +++ b/vitookit/evaluation/eval_linear.py @@ -32,7 +32,7 @@ from vitookit.utils.helper import aug_parse, load_pretrained_weights, log_metric from timm.models.layers import trunc_normal_ from vitookit.utils.lars import LARS - +import gin @@ -120,7 +120,9 @@ def get_args_parser(): def main(args): misc.init_distributed_mode(args) - + print("args: ", args) + print("configure: ", gin.config_str()) + print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) print("{}".format(args).replace(', ', ',\n')) diff --git a/vitookit/evaluation/eval_linear_ffcv.py b/vitookit/evaluation/eval_linear_ffcv.py index 5c9e1bd..b28277e 100644 --- a/vitookit/evaluation/eval_linear_ffcv.py +++ b/vitookit/evaluation/eval_linear_ffcv.py @@ -34,7 +34,7 @@ from vitookit.utils.lars import LARS from vitookit.datasets.ffcv_transform import SimplePipeline, ValPipeline from ffcv import Loader from ffcv.loader import OrderOption - +import gin def get_args_parser(): @@ -120,7 +120,9 @@ def get_args_parser(): def main(args): misc.init_distributed_mode(args) - args.distributed = True + print("args: ", args) + print("configure: ", gin.config_str()) + print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) print("{}".format(args).replace(', ', ',\n')) diff --git a/vitookit/utils/submitit.py b/vitookit/utils/submitit.py index 3dc5233..4c73dd0 100644 --- a/vitookit/utils/submitit.py +++ b/vitookit/utils/submitit.py @@ -115,6 +115,10 @@ class Trainer(object): module_args.world_size = job_env.num_tasks module_args.comment = f"Job {job_env.job_id} on {job_env.num_tasks} GPUs" + + import gin + if not gin.config_is_locked(): + gin.parse_config_files_and_bindings(module_args.cfgs,module_args.gin) print("Setting up GPU args", module_args) print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") -- GitLab