diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py
index 46aaecfae896e283e367d005dc01107dcd666dfe..34a85648fe0ae89e705c1567e0f278eed120bdf8 100644
--- a/vitookit/datasets/ffcv_transform.py
+++ b/vitookit/datasets/ffcv_transform.py
@@ -476,6 +476,7 @@ class RandomChoice(nn.Module):
 
 from torchvision.transforms import InterpolationMode
 
+@gin.configurable
 class RandDownSampling(nn.Module):
     def __init__(self, r=(3/4,1)) -> None:
         super().__init__()
@@ -483,13 +484,15 @@ class RandDownSampling(nn.Module):
     def forward(self,x):
         h, w = x.shape[-2:]
         r = random.uniform(*self.r)
+        if r == 1:
+            return x
         nh,hw = int(h*r),int(w*r)
         down = F.resize(x,(nh,hw),interpolation=InterpolationMode.BICUBIC)
         up = F.resize(down,(h,w),interpolation=InterpolationMode.BICUBIC)
         return up
     
 @gin.configurable
-def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsampling=False):
+def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downscale=1):
     """
     ThreeAugmentPipeline
     """
@@ -498,7 +501,6 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsamp
             [   RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,),
                 RandomHorizontalFlip(),]+
             # second_tfl
-            # [   ThreeAugmentation(),] +
             (   [RandomColorJitter(jitter_prob=1, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter, hue=0, seed=None)] if color_jitter else []) +
             # final_tfl
             [
@@ -511,9 +513,9 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsamp
                     tfms.RandomSolarize(0,1),
                     tfms.RandomGrayscale(1),
                     tfms.GaussianBlur(5,sigma=(0.1,2)),
-                ])
-            ] + 
-            [RandDownSampling((3/4,1))] if downsampling else [])
+                ]),
+                RandDownSampling((downscale,1)),
+            ])
     label_pipeline = [IntDecoder(), ToTensor(),View(-1)]
     # Pipeline for each data field
     pipelines = {