diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py index 46aaecfae896e283e367d005dc01107dcd666dfe..34a85648fe0ae89e705c1567e0f278eed120bdf8 100644 --- a/vitookit/datasets/ffcv_transform.py +++ b/vitookit/datasets/ffcv_transform.py @@ -476,6 +476,7 @@ class RandomChoice(nn.Module): from torchvision.transforms import InterpolationMode +@gin.configurable class RandDownSampling(nn.Module): def __init__(self, r=(3/4,1)) -> None: super().__init__() @@ -483,13 +484,15 @@ class RandDownSampling(nn.Module): def forward(self,x): h, w = x.shape[-2:] r = random.uniform(*self.r) + if r == 1: + return x nh,hw = int(h*r),int(w*r) down = F.resize(x,(nh,hw),interpolation=InterpolationMode.BICUBIC) up = F.resize(down,(h,w),interpolation=InterpolationMode.BICUBIC) return up @gin.configurable -def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsampling=False): +def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downscale=1): """ ThreeAugmentPipeline """ @@ -498,7 +501,6 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsamp [ RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,), RandomHorizontalFlip(),]+ # second_tfl - # [ ThreeAugmentation(),] + ( [RandomColorJitter(jitter_prob=1, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter, hue=0, seed=None)] if color_jitter else []) + # final_tfl [ @@ -511,9 +513,9 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsamp tfms.RandomSolarize(0,1), tfms.RandomGrayscale(1), tfms.GaussianBlur(5,sigma=(0.1,2)), - ]) - ] + - [RandDownSampling((3/4,1))] if downsampling else []) + ]), + RandDownSampling((downscale,1)), + ]) label_pipeline = [IntDecoder(), ToTensor(),View(-1)] # Pipeline for each data field pipelines = {