diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py index bc93642b350feea76691e44d51f72bbdebd23a89..46aaecfae896e283e367d005dc01107dcd666dfe 100644 --- a/vitookit/datasets/ffcv_transform.py +++ b/vitookit/datasets/ffcv_transform.py @@ -474,8 +474,22 @@ class RandomChoice(nn.Module): def __repr__(self) -> str: return f"{super().__repr__()}(p={self.p})" +from torchvision.transforms import InterpolationMode + +class RandDownSampling(nn.Module): + def __init__(self, r=(3/4,1)) -> None: + super().__init__() + self.r = r + def forward(self,x): + h, w = x.shape[-2:] + r = random.uniform(*self.r) + nh,hw = int(h*r),int(w*r) + down = F.resize(x,(nh,hw),interpolation=InterpolationMode.BICUBIC) + up = F.resize(down,(h,w),interpolation=InterpolationMode.BICUBIC) + return up + @gin.configurable -def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None): +def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsampling=False): """ ThreeAugmentPipeline """ @@ -496,9 +510,10 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None): RandomChoice([ tfms.RandomSolarize(0,1), tfms.RandomGrayscale(1), - tfms.GaussianBlur(7,sigma=(0.1,2)), + tfms.GaussianBlur(5,sigma=(0.1,2)), ]) - ]) + ] + + [RandDownSampling((3/4,1))] if downsampling else []) label_pipeline = [IntDecoder(), ToTensor(),View(-1)] # Pipeline for each data field pipelines = {