From c46b1316b8430344028386cd8131b0c6f304412c Mon Sep 17 00:00:00 2001 From: gent <jw02425@surrey.ac.uk> Date: Sun, 28 Jan 2024 00:56:37 +0000 Subject: [PATCH] Add DynamicResolution class for dynamic image resolution and scaling --- vitookit/datasets/ffcv_transform.py | 70 ++++++++++++++++++++++++++++ vitookit/evaluation/eval_cls_ffcv.py | 3 +- 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py index 1ad463d..bee28cc 100644 --- a/vitookit/datasets/ffcv_transform.py +++ b/vitookit/datasets/ffcv_transform.py @@ -35,8 +35,78 @@ from torch import nn IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255 IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255 + @gin.configurable class DynamicResolution: + def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED, + scheme=1): + schemes ={ + 1: [ + dict(res=160,lower_scale=0.08, upper_scale=1), + dict(res=192,lower_scale=0.08, upper_scale=1), + dict(res=224,lower_scale=0.08, upper_scale=1), + ], + 2:[ + dict(res=160,lower_scale=0.96, upper_scale=1), + dict(res=192,lower_scale=0.38, upper_scale=1), + dict(res=224,lower_scale=0.08, upper_scale=1), + ], + 3:[ + dict(res=160,lower_scale=0.08, upper_scale=0.4624), + dict(res=192,lower_scale=0.08, upper_scale=0.7056), + dict(res=224,lower_scale=0.08, upper_scale=1), + ] + + } + self.scheme = schemes[scheme] + self.start_ramp = start_ramp + self.end_ramp = end_ramp + + + def get_resolution(self, epoch): + if epoch<=(self.start_ramp+self.end_ramp)/2: + return self.scheme[0]['res'] + elif epoch<=self.end_ramp: + return self.scheme[1]['res'] + else: + return self.scheme[2]['res'] + + def get_lower_scale(self, epoch): + if epoch<=(self.start_ramp+self.end_ramp)/2: + return self.scheme[0]['lower_scale'] + elif epoch<=self.end_ramp: + return self.scheme[1]['lower_scale'] + else: + return self.scheme[2]['lower_scale'] + + def get_upper_scale(self, epoch): + if epoch<=(self.start_ramp+self.end_ramp)/2: + return self.scheme[0]['upper_scale'] + elif epoch<=self.end_ramp: + return self.scheme[1]['upper_scale'] + else: + return self.scheme[2]['upper_scale'] + + def __call__(self, loader, epoch,is_ffcv=False): + img_size = self.get_resolution(epoch) + lower_scale = self.get_lower_scale(epoch) + upper_scale = self.get_upper_scale(epoch) + print(f"resolution = {img_size}, scale = ({lower_scale:.2f},{upper_scale:.2f}) ") + + if is_ffcv: + pipeline=loader.pipeline_specs['image'] + if pipeline.decoder.output_size[0] != img_size: + pipeline.decoder.scale = (lower_scale,upper_scale) + pipeline.decoder.output_size = (img_size,img_size) + loader.generate_code() + else: + decoder = loader.dataset.transforms.transform.transforms[0] + decoder.size=(img_size,img_size) + decoder.scale = (lower_scale,upper_scale) + + +@gin.configurable +class DynamicResolution1: def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED, res=(224,224,0),lower_scale=(0.08,0.08,0), upper_scale=(1,1,0)): self.start_ramp = start_ramp diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py index 4299123..0db7702 100644 --- a/vitookit/evaluation/eval_cls_ffcv.py +++ b/vitookit/evaluation/eval_cls_ffcv.py @@ -347,12 +347,13 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if dres: if epoch == dres.end_ramp: + print("enhance augmentation!", epoch, dres.end_ramp) ## enhance augmentation, see efficientnetv2 - data_loader_train = Loader(args.train_path, pipelines=ThreeAugmentPipeline(color_jitter=0.4), batch_size=args.batch_size, num_workers=args.num_workers, order=order, distributed=args.distributed,seed=args.seed) dres(data_loader_train,epoch,True) + train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler,lr_scheduler, -- GitLab