diff --git a/test_dl.ipynb b/test_dl.ipynb deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py index d63f894818405b129cbfb86b01ef98e6a7feb065..c5144c1645bfb5a0af6e145ec56c0b8daf3e8bea 100644 --- a/vitookit/datasets/ffcv_transform.py +++ b/vitookit/datasets/ffcv_transform.py @@ -37,43 +37,85 @@ IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255 @gin.configurable class DynamicResolution: - def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED, min_res=160, max_res=224, step=32): + def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED, + res=(224,224,0),lower_scale=(0.2,0.2,0), upper_scale=(1,1,0)): self.start_ramp = start_ramp self.end_ramp = end_ramp - self.min_res = min_res - self.max_res = max_res - self.step = step - - assert min_res <= max_res + self.resolution = res + self.lower_scale = lower_scale + self.upper_scale = upper_scale def get_resolution(self, epoch): + + lv,rv, step = self.resolution + + if step==0: + assert lv==rv + return rv + + if epoch <= self.start_ramp: + return lv + + if epoch >= self.end_ramp: + return rv + + interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv]) + num_steps = int(np.round((interp[0]-lv) / step)) + return num_steps*step + lv + + def get_lower_scale(self, epoch): + + lv,rv, step = self.lower_scale + + if step==0: + assert lv==rv + return rv + + if epoch <= self.start_ramp: + return lv + + if epoch >= self.end_ramp: + return rv + + + interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv]) + num_steps = int(np.round((interp[0]-lv) / step)) + return num_steps*step + lv + + def get_upper_scale(self, epoch): + + lv,rv, step = self.upper_scale + + if step==0: + assert lv==rv + return rv + if epoch <= self.start_ramp: - return self.min_res + return lv if epoch >= self.end_ramp: - return self.max_res + return rv - # otherwise, linearly interpolate to the nearest multiple of 32 - interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [self.min_res, self.max_res]) - final_res = int(np.round(interp[0] / self.step)) * self.step - return final_res + interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv]) + num_steps = int(np.round((interp[0]-lv) / step)) + return num_steps*step + lv def __call__(self, loader, epoch,is_ffcv=False): img_size = self.get_resolution(epoch) - print("resolution = %s" % str(img_size)) + lower_scale = self.get_lower_scale(epoch) + upper_scale = self.get_upper_scale(epoch) + print(f"resolution = {img_size}, scale = ({lower_scale:.2f},{upper_scale:.2f}) ") if is_ffcv: pipeline=loader.pipeline_specs['image'] if pipeline.decoder.output_size[0] != img_size: - k = img_size/self.max_res - pipeline.decoder.scale = (0.2*k,1) + pipeline.decoder.scale = (lower_scale,upper_scale) pipeline.decoder.output_size = (img_size,img_size) loader.generate_code() else: decoder = loader.dataset.transforms.transform.transforms[0] - k = img_size/self.max_res decoder.size=(img_size,img_size) - decoder.scale = (0.2*k,1) + decoder.scale = (lower_scale,upper_scale) @gin.configurable diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py index 3c2863a3b2d98d840db75a3328329db57d5154ae..607c6ad64e2e5a6ec13e26583a6d133122752b5b 100644 --- a/vitookit/evaluation/eval_cls.py +++ b/vitookit/evaluation/eval_cls.py @@ -29,7 +29,7 @@ from vitookit.datasets.transform import three_augmentation from vitookit.utils.helper import * from vitookit.utils import misc from vitookit.models.build_model import build_model -from vitookit.datasets.build_dataset import build_dataset +from vitookit.datasets.build_dataset import build_dataset, build_transform import wandb @@ -104,6 +104,7 @@ def get_args_parser(): help='LR decay rate (default: 0.1)') # Augmentation parameters + parser.add_argument('--ThreeAugment', action='store_true', default=False) #3augment parser.add_argument('--src',action='store_true', default=False, help="Use Simple Random Crop (SRC) or Random Resized Crop (RRC). Use SRC when there is less risk of overfitting, such as on ImageNet-21k.") @@ -276,7 +277,8 @@ def main(args): if args.ThreeAugment: transform = three_augmentation(args) else: - transform = None + transform = build_transform(is_train=True, args=args) + dataset_train, args.nb_classes = build_dataset(args=args, is_train=True, trnsfrm=transform) dataset_val, _ = build_dataset(is_train=False, args=args)