From e09a7015d604054b425ce2fd440025e077f3d965 Mon Sep 17 00:00:00 2001
From: gent <jw02425@surrey.ac.uk>
Date: Tue, 23 Jan 2024 21:13:14 +0000
Subject: [PATCH] add RandDownSampling for compression shift

---
 vitookit/datasets/ffcv_transform.py | 21 ++++++++++++++++++---
 1 file changed, 18 insertions(+), 3 deletions(-)

diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py
index bc93642..46aaecf 100644
--- a/vitookit/datasets/ffcv_transform.py
+++ b/vitookit/datasets/ffcv_transform.py
@@ -474,8 +474,22 @@ class RandomChoice(nn.Module):
     def __repr__(self) -> str:
         return f"{super().__repr__()}(p={self.p})"
 
+from torchvision.transforms import InterpolationMode
+
+class RandDownSampling(nn.Module):
+    def __init__(self, r=(3/4,1)) -> None:
+        super().__init__()
+        self.r = r
+    def forward(self,x):
+        h, w = x.shape[-2:]
+        r = random.uniform(*self.r)
+        nh,hw = int(h*r),int(w*r)
+        down = F.resize(x,(nh,hw),interpolation=InterpolationMode.BICUBIC)
+        up = F.resize(down,(h,w),interpolation=InterpolationMode.BICUBIC)
+        return up
+    
 @gin.configurable
-def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None):
+def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downsampling=False):
     """
     ThreeAugmentPipeline
     """
@@ -496,9 +510,10 @@ def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None):
                 RandomChoice([
                     tfms.RandomSolarize(0,1),
                     tfms.RandomGrayscale(1),
-                    tfms.GaussianBlur(7,sigma=(0.1,2)),
+                    tfms.GaussianBlur(5,sigma=(0.1,2)),
                 ])
-            ])
+            ] + 
+            [RandDownSampling((3/4,1))] if downsampling else [])
     label_pipeline = [IntDecoder(), ToTensor(),View(-1)]
     # Pipeline for each data field
     pipelines = {
-- 
GitLab