From 813e9fb5b7f7d430b2caeba6ad67767b84f4cbac Mon Sep 17 00:00:00 2001
From: gent <jw02425@surrey.ac.uk>
Date: Tue, 23 Jan 2024 21:36:21 +0000
Subject: [PATCH] Add RandDownSampling configurable class

---
 vitookit/datasets/ffcv_transform.py | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py
index 46aaecf..34a8564 100644
--- a/vitookit/datasets/ffcv_transform.py
+++ b/vitookit/datasets/ffcv_transform.py
@@ -476,6 +476,7 @@ class RandomChoice(nn.Module):
 
 from torchvision.transforms import InterpolationMode
 
+@gin.configurable
 class RandDownSampling(nn.Module):
     def __init__(self, r=(3/4,1)) -> None:
         super().__init__()
@@ -483,13 +484,15 @@ class RandDownSampling(nn.Module):
     def forward(self,x):
         h, w = x.shape[-2:]
         r = random.uniform(*self.r)
+        if r == 1:
+            return x
         nh,hw = int(h*r),int(w*r)
         down = F.resize(x,(nh,hw),interpolation=InterpolationMode.BICUBIC)
         up = F.resize(down,(h,w),interpolation=InterpolationMode.BICUBIC)
         return up
     
 @gin.configurable
-def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsampling=False):
+def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downscale=1):
     """
     ThreeAugmentPipeline
     """
@@ -498,7 +501,6 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsamp
             [   RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,),
                 RandomHorizontalFlip(),]+
             # second_tfl
-            # [   ThreeAugmentation(),] +
             (   [RandomColorJitter(jitter_prob=1, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter, hue=0, seed=None)] if color_jitter else []) +
             # final_tfl
             [
@@ -511,9 +513,9 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsamp
                     tfms.RandomSolarize(0,1),
                     tfms.RandomGrayscale(1),
                     tfms.GaussianBlur(5,sigma=(0.1,2)),
-                ])
-            ] + 
-            [RandDownSampling((3/4,1))] if downsampling else [])
+                ]),
+                RandDownSampling((downscale,1)),
+            ])
     label_pipeline = [IntDecoder(), ToTensor(),View(-1)]
     # Pipeline for each data field
     pipelines = {
-- 
GitLab