From 813e9fb5b7f7d430b2caeba6ad67767b84f4cbac Mon Sep 17 00:00:00 2001 From: gent <jw02425@surrey.ac.uk> Date: Tue, 23 Jan 2024 21:36:21 +0000 Subject: [PATCH] Add RandDownSampling configurable class --- vitookit/datasets/ffcv_transform.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py index 46aaecf..34a8564 100644 --- a/vitookit/datasets/ffcv_transform.py +++ b/vitookit/datasets/ffcv_transform.py @@ -476,6 +476,7 @@ class RandomChoice(nn.Module): from torchvision.transforms import InterpolationMode +@gin.configurable class RandDownSampling(nn.Module): def __init__(self, r=(3/4,1)) -> None: super().__init__() @@ -483,13 +484,15 @@ class RandDownSampling(nn.Module): def forward(self,x): h, w = x.shape[-2:] r = random.uniform(*self.r) + if r == 1: + return x nh,hw = int(h*r),int(w*r) down = F.resize(x,(nh,hw),interpolation=InterpolationMode.BICUBIC) up = F.resize(down,(h,w),interpolation=InterpolationMode.BICUBIC) return up @gin.configurable -def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsampling=False): +def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downscale=1): """ ThreeAugmentPipeline """ @@ -498,7 +501,6 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsamp [ RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,), RandomHorizontalFlip(),]+ # second_tfl - # [ ThreeAugmentation(),] + ( [RandomColorJitter(jitter_prob=1, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter, hue=0, seed=None)] if color_jitter else []) + # final_tfl [ @@ -511,9 +513,9 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsamp tfms.RandomSolarize(0,1), tfms.RandomGrayscale(1), tfms.GaussianBlur(5,sigma=(0.1,2)), - ]) - ] + - [RandDownSampling((3/4,1))] if downsampling else []) + ]), + RandDownSampling((downscale,1)), + ]) label_pipeline = [IntDecoder(), ToTensor(),View(-1)] # Pipeline for each data field pipelines = { -- GitLab