From e09a7015d604054b425ce2fd440025e077f3d965 Mon Sep 17 00:00:00 2001 From: gent <jw02425@surrey.ac.uk> Date: Tue, 23 Jan 2024 21:13:14 +0000 Subject: [PATCH] add RandDownSampling for compression shift --- vitookit/datasets/ffcv_transform.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py index bc93642..46aaecf 100644 --- a/vitookit/datasets/ffcv_transform.py +++ b/vitookit/datasets/ffcv_transform.py @@ -474,8 +474,22 @@ class RandomChoice(nn.Module): def __repr__(self) -> str: return f"{super().__repr__()}(p={self.p})" +from torchvision.transforms import InterpolationMode + +class RandDownSampling(nn.Module): + def __init__(self, r=(3/4,1)) -> None: + super().__init__() + self.r = r + def forward(self,x): + h, w = x.shape[-2:] + r = random.uniform(*self.r) + nh,hw = int(h*r),int(w*r) + down = F.resize(x,(nh,hw),interpolation=InterpolationMode.BICUBIC) + up = F.resize(down,(h,w),interpolation=InterpolationMode.BICUBIC) + return up + @gin.configurable -def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None): +def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsampling=False): """ ThreeAugmentPipeline """ @@ -496,9 +510,10 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None): RandomChoice([ tfms.RandomSolarize(0,1), tfms.RandomGrayscale(1), - tfms.GaussianBlur(7,sigma=(0.1,2)), + tfms.GaussianBlur(5,sigma=(0.1,2)), ]) - ]) + ] + + [RandDownSampling((3/4,1))] if downsampling else []) label_pipeline = [IntDecoder(), ToTensor(),View(-1)] # Pipeline for each data field pipelines = { -- GitLab