From 107b54da6af83ace73d1ce6d8f46909876076bde Mon Sep 17 00:00:00 2001 From: "Wu, Jiantao (PG/R - Comp Sci & Elec Eng)" <jiantao.wu@surrey.ac.uk> Date: Wed, 28 Feb 2024 13:57:53 +0000 Subject: [PATCH] change dres --- vitookit/datasets/dres.py | 70 +++++++++++++++++++++++++ vitookit/datasets/ffcv_transform.py | 78 ---------------------------- vitookit/datasets/sa.py | 6 +-- vitookit/datasets/transform.py | 5 +- vitookit/evaluation/eval_cls_ffcv.py | 8 +-- 5 files changed, 78 insertions(+), 89 deletions(-) create mode 100644 vitookit/datasets/dres.py diff --git a/vitookit/datasets/dres.py b/vitookit/datasets/dres.py new file mode 100644 index 0000000..9d6d149 --- /dev/null +++ b/vitookit/datasets/dres.py @@ -0,0 +1,70 @@ + +from ffcv import Loader +import gin +from .ffcv_transform import ThreeAugmentPipeline + +@gin.configurable +class DynamicResolution: + def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED, + scheme=1): + schemes ={ + 1: [ + dict(res=160,lower_scale=0.08, upper_scale=1,color_jitter=False), + dict(res=192,lower_scale=0.08, upper_scale=1,color_jitter=False), + dict(res=224,lower_scale=0.08, upper_scale=1,color_jitter=True), + ], + 2:[ + dict(res=160,lower_scale=0.96, upper_scale=1,color_jitter=False), + dict(res=192,lower_scale=0.38, upper_scale=1,color_jitter=False), + dict(res=224,lower_scale=0.08, upper_scale=1,color_jitter=True), + ], + 3:[ + dict(res=160,lower_scale=0.08, upper_scale=0.4624,color_jitter=False), + dict(res=192,lower_scale=0.08, upper_scale=0.7056,color_jitter=False), + dict(res=224,lower_scale=0.08, upper_scale=1,color_jitter=True), + ], + 4:[ + dict(res=160,lower_scale=0.20, upper_scale=0.634,color_jitter=False), + dict(res=192,lower_scale=0.137, upper_scale=0.81,color_jitter=False), + dict(res=224,lower_scale=0.08, upper_scale=1,color_jitter=True), + ], + 5: [ + dict(res=160,lower_scale=0.08, upper_scale=1,color_jitter=True), + dict(res=192,lower_scale=0.08, upper_scale=1,color_jitter=True), + dict(res=224,lower_scale=0.08, upper_scale=1,color_jitter=True), + ], + + } + self.scheme = schemes[scheme] + self.start_ramp = start_ramp + self.end_ramp = end_ramp + + + def get_config(self, epoch): + if epoch <= self.start_ramp: + return self.scheme[0] + elif epoch>=self.end_ramp: + return self.scheme[-1] + else: + i = (epoch-self.start_ramp) * (len(self.scheme)-1) // (self.end_ramp-self.start_ramp) + return self.scheme[i] + + def __call__(self, loader: Loader, epoch,is_ffcv=False): + config = self.get_config(epoch) + print(", ".join([f"{k}={v}" for k,v in config.items()])) + + img_size = config['res'] + lower_scale = config['lower_scale'] + upper_scale = config['upper_scale'] + color_jitter = config['color_jitter'] + + if is_ffcv: + pipelines = ThreeAugmentPipeline(img_size,scale=(lower_scale,upper_scale),color_jitter=color_jitter) + loader.compile_pipeline(pipelines) + else: + # todo: change dres + pipelines = loader.dataset.transforms + + decoder = loader.dataset.transforms.transform.transforms[0] + decoder.size=(img_size,img_size) + decoder.scale = (lower_scale,upper_scale) diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py index ae25390..db86abc 100644 --- a/vitookit/datasets/ffcv_transform.py +++ b/vitookit/datasets/ffcv_transform.py @@ -36,84 +36,6 @@ IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255 IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255 -@gin.configurable -class DynamicResolution: - def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED, - scheme=1): - schemes ={ - 1: [ - dict(res=160,lower_scale=0.08, upper_scale=1), - dict(res=192,lower_scale=0.08, upper_scale=1), - dict(res=224,lower_scale=0.08, upper_scale=1), - ], - 2:[ - dict(res=160,lower_scale=0.96, upper_scale=1), - dict(res=192,lower_scale=0.38, upper_scale=1), - dict(res=224,lower_scale=0.08, upper_scale=1), - ], - 3:[ - dict(res=160,lower_scale=0.08, upper_scale=0.4624), - dict(res=192,lower_scale=0.08, upper_scale=0.7056), - dict(res=224,lower_scale=0.08, upper_scale=1), - ], - 4:[ - dict(res=160,lower_scale=0.20, upper_scale=0.634), - dict(res=192,lower_scale=0.137, upper_scale=0.81), - dict(res=224,lower_scale=0.08, upper_scale=1), - ], - 5: [ - dict(res=160,lower_scale=0.20, upper_scale=1), - dict(res=192,lower_scale=0.137, upper_scale=1), - dict(res=224,lower_scale=0.08, upper_scale=1), - ], - - } - self.scheme = schemes[scheme] - self.start_ramp = start_ramp - self.end_ramp = end_ramp - - - def get_resolution(self, epoch): - if epoch<=(self.start_ramp+self.end_ramp)/2: - return self.scheme[0]['res'] - elif epoch<=self.end_ramp: - return self.scheme[1]['res'] - else: - return self.scheme[2]['res'] - - def get_lower_scale(self, epoch): - if epoch<=(self.start_ramp+self.end_ramp)/2: - return self.scheme[0]['lower_scale'] - elif epoch<=self.end_ramp: - return self.scheme[1]['lower_scale'] - else: - return self.scheme[2]['lower_scale'] - - def get_upper_scale(self, epoch): - if epoch<=(self.start_ramp+self.end_ramp)/2: - return self.scheme[0]['upper_scale'] - elif epoch<=self.end_ramp: - return self.scheme[1]['upper_scale'] - else: - return self.scheme[2]['upper_scale'] - - def __call__(self, loader, epoch,is_ffcv=False): - img_size = self.get_resolution(epoch) - lower_scale = self.get_lower_scale(epoch) - upper_scale = self.get_upper_scale(epoch) - print(f"resolution = {img_size}, scale = ({lower_scale:.2f},{upper_scale:.2f}) ") - - if is_ffcv: - pipeline=loader.pipeline_specs['image'] - if pipeline.decoder.output_size[0] != img_size: - pipeline.decoder.scale = (lower_scale,upper_scale) - pipeline.decoder.output_size = (img_size,img_size) - loader.generate_code() - else: - decoder = loader.dataset.transforms.transform.transforms[0] - decoder.size=(img_size,img_size) - decoder.scale = (lower_scale,upper_scale) - @gin.configurable def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0)): image_pipeline = [ diff --git a/vitookit/datasets/sa.py b/vitookit/datasets/sa.py index 38e7aed..a5e5ff7 100644 --- a/vitookit/datasets/sa.py +++ b/vitookit/datasets/sa.py @@ -117,6 +117,6 @@ class SA1BDataset: m = maskUtils.decode(rle) return m - -data=SA1BDataset('/vol/research/datasets/still/SA-1B',['sa_258805','sa_2601447','sa_6313499','sa_2592075','sa_258406','sa_2601844','sa_6353015','sa_10958302','sa_6344773']) # if you want a subset of the dataset -img, mask, class_ids = data[0] \ No newline at end of file +if __name__ == '__main__': + data=SA1BDataset('/vol/research/datasets/still/SA-1B',['sa_258805','sa_2601447','sa_6313499','sa_2592075','sa_258406','sa_2601844','sa_6353015','sa_10958302','sa_6344773']) # if you want a subset of the dataset + img, mask, class_ids = data[0] \ No newline at end of file diff --git a/vitookit/datasets/transform.py b/vitookit/datasets/transform.py index 5d017f7..a628005 100644 --- a/vitookit/datasets/transform.py +++ b/vitookit/datasets/transform.py @@ -5,7 +5,7 @@ import numpy as np from torchvision import transforms import torch import gin - +from torch import nn class PermutePatch(object): """ @@ -103,6 +103,9 @@ class GrayScale(object): else: return img +class ThreeAugmentation(nn.Module): + def __init__(self,): + super().__init__() def three_augmentation(args = None): img_size = args.input_size diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py index fa9e32a..2450174 100644 --- a/vitookit/evaluation/eval_cls_ffcv.py +++ b/vitookit/evaluation/eval_cls_ffcv.py @@ -43,7 +43,7 @@ from timm.layers import trunc_normal_ from ffcv import Loader from ffcv.loader import OrderOption - +from datasets.dres import DynamicResolution def get_args_parser(): parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) @@ -342,12 +342,6 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if dres: - if epoch == dres.end_ramp: - print("enhance augmentation!", epoch, dres.end_ramp) - ## enhance augmentation, see efficientnetv2 - data_loader_train = Loader(args.train_path, pipelines=ThreeAugmentPipeline(color_jitter=0.3), - batch_size=args.batch_size, num_workers=args.num_workers, - order=order, distributed=args.distributed,seed=args.seed) dres(data_loader_train,epoch,True) train_stats = train_one_epoch( -- GitLab