diff --git a/README.MD b/README.MD index e23fd4f8b47901e06ef37f0fda83c3366c09fd49..2399924641076f57f1d6932b2734a0e12b8d06c0 100644 --- a/README.MD +++ b/README.MD @@ -13,7 +13,7 @@ See the available evaluations in [evaluation protocols](EVALUATION.md). ## commands ```bash -vitrun train_cls.py --data_location=../data/IMNET --gin VisionTransformer.global_pool='"avg"' -w wandb:dlib/EfficientSSL/lsx2qmys +vitrun train_cls.py --data_location=../data/IMNET --gin build_model.model_name='"vit_tiny_patch16_224"' build_model.global_pool='"avg"' -w wandb:dlib/EfficientSSL/ezuz0x4u --layer_decay=0.75 ``` ## condor diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py index dbca91f0d7c27b2779a7a3b46e9dec2f5cdccf25..75a5b4e7d483b36d1b92289a3648eba15e0f4cad 100644 --- a/vitookit/evaluation/eval_cls.py +++ b/vitookit/evaluation/eval_cls.py @@ -147,9 +147,6 @@ def get_args_parser(): help='dataset path') parser.add_argument('--data_set', default='IN1K', type=str, help='ImageNet dataset path') - parser.add_argument('--inat_category', default='name', - choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], - type=str, help='semantic granularity') parser.add_argument('--output_dir', default=None, type=str, help='path where to save, empty for no saving') @@ -336,7 +333,7 @@ def main(args): print(f"Model built.") # load weights to evaluate - model = build_model(num_classes=args.nb_classes,drop_path_rate=args.drop_path,) + model = build_model(num_classes=args.nb_classes,drop_path_rate=args.drop_path) if args.pretrained_weights: load_pretrained_weights(model, args.pretrained_weights, checkpoint_key=args.checkpoint_key, prefix=args.prefix) if args.compile: diff --git a/vitookit/models/build_model.py b/vitookit/models/build_model.py index a821e61d94a0193a00ec477dfb8440355dcabef3..2af15db6f2abf945e55a43674ee0ce32fc51c212 100644 --- a/vitookit/models/build_model.py +++ b/vitookit/models/build_model.py @@ -1,7 +1,7 @@ import gin -from .vision_transformer import vit_base,vit_small, vit_tiny +import timm @gin.configurable() -def build_model(*args,model_fn=vit_tiny,**kwargs): - model = model_fn(*args,**kwargs) - return model +def build_model(model_name='vit_base_patch16_224', **kwargs): + model = timm.create_model(model_name, **kwargs) + return model \ No newline at end of file