From 1fa1f666391cb209111cb6641b778229ec91ce81 Mon Sep 17 00:00:00 2001 From: gent <jw02425@surrey.ac.uk> Date: Sun, 3 Mar 2024 07:53:44 +0000 Subject: [PATCH] build_model with timm --- README.MD | 2 +- vitookit/evaluation/eval_cls.py | 5 +---- vitookit/models/build_model.py | 8 ++++---- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/README.MD b/README.MD index e23fd4f..2399924 100644 --- a/README.MD +++ b/README.MD @@ -13,7 +13,7 @@ See the available evaluations in [evaluation protocols](EVALUATION.md). ## commands ```bash -vitrun train_cls.py --data_location=../data/IMNET --gin VisionTransformer.global_pool='"avg"' -w wandb:dlib/EfficientSSL/lsx2qmys +vitrun train_cls.py --data_location=../data/IMNET --gin build_model.model_name='"vit_tiny_patch16_224"' build_model.global_pool='"avg"' -w wandb:dlib/EfficientSSL/ezuz0x4u --layer_decay=0.75 ``` ## condor diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py index dbca91f..75a5b4e 100644 --- a/vitookit/evaluation/eval_cls.py +++ b/vitookit/evaluation/eval_cls.py @@ -147,9 +147,6 @@ def get_args_parser(): help='dataset path') parser.add_argument('--data_set', default='IN1K', type=str, help='ImageNet dataset path') - parser.add_argument('--inat_category', default='name', - choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], - type=str, help='semantic granularity') parser.add_argument('--output_dir', default=None, type=str, help='path where to save, empty for no saving') @@ -336,7 +333,7 @@ def main(args): print(f"Model built.") # load weights to evaluate - model = build_model(num_classes=args.nb_classes,drop_path_rate=args.drop_path,) + model = build_model(num_classes=args.nb_classes,drop_path_rate=args.drop_path) if args.pretrained_weights: load_pretrained_weights(model, args.pretrained_weights, checkpoint_key=args.checkpoint_key, prefix=args.prefix) if args.compile: diff --git a/vitookit/models/build_model.py b/vitookit/models/build_model.py index a821e61..2af15db 100644 --- a/vitookit/models/build_model.py +++ b/vitookit/models/build_model.py @@ -1,7 +1,7 @@ import gin -from .vision_transformer import vit_base,vit_small, vit_tiny +import timm @gin.configurable() -def build_model(*args,model_fn=vit_tiny,**kwargs): - model = model_fn(*args,**kwargs) - return model +def build_model(model_name='vit_base_patch16_224', **kwargs): + model = timm.create_model(model_name, **kwargs) + return model \ No newline at end of file -- GitLab