From 107b54da6af83ace73d1ce6d8f46909876076bde Mon Sep 17 00:00:00 2001
From: "Wu, Jiantao (PG/R - Comp Sci & Elec Eng)" <jiantao.wu@surrey.ac.uk>
Date: Wed, 28 Feb 2024 13:57:53 +0000
Subject: [PATCH] change dres

---
 vitookit/datasets/dres.py            | 70 +++++++++++++++++++++++++
 vitookit/datasets/ffcv_transform.py  | 78 ----------------------------
 vitookit/datasets/sa.py              |  6 +--
 vitookit/datasets/transform.py       |  5 +-
 vitookit/evaluation/eval_cls_ffcv.py |  8 +--
 5 files changed, 78 insertions(+), 89 deletions(-)
 create mode 100644 vitookit/datasets/dres.py

diff --git a/vitookit/datasets/dres.py b/vitookit/datasets/dres.py
new file mode 100644
index 0000000..9d6d149
--- /dev/null
+++ b/vitookit/datasets/dres.py
@@ -0,0 +1,70 @@
+
+from ffcv import Loader
+import gin
+from .ffcv_transform import ThreeAugmentPipeline
+
+@gin.configurable
+class DynamicResolution:
+    def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED,  
+                    scheme=1):
+        schemes ={
+            1: [
+                dict(res=160,lower_scale=0.08, upper_scale=1,color_jitter=False),
+                dict(res=192,lower_scale=0.08, upper_scale=1,color_jitter=False),
+                dict(res=224,lower_scale=0.08, upper_scale=1,color_jitter=True),
+            ],
+            2:[
+                dict(res=160,lower_scale=0.96, upper_scale=1,color_jitter=False),
+                dict(res=192,lower_scale=0.38, upper_scale=1,color_jitter=False),
+                dict(res=224,lower_scale=0.08, upper_scale=1,color_jitter=True),
+            ],
+            3:[
+                dict(res=160,lower_scale=0.08, upper_scale=0.4624,color_jitter=False),
+                dict(res=192,lower_scale=0.08, upper_scale=0.7056,color_jitter=False),
+                dict(res=224,lower_scale=0.08, upper_scale=1,color_jitter=True),
+            ],
+            4:[
+                dict(res=160,lower_scale=0.20, upper_scale=0.634,color_jitter=False),
+                dict(res=192,lower_scale=0.137, upper_scale=0.81,color_jitter=False),
+                dict(res=224,lower_scale=0.08, upper_scale=1,color_jitter=True),
+            ],
+            5: [
+                dict(res=160,lower_scale=0.08, upper_scale=1,color_jitter=True),
+                dict(res=192,lower_scale=0.08, upper_scale=1,color_jitter=True),
+                dict(res=224,lower_scale=0.08, upper_scale=1,color_jitter=True),
+            ],
+            
+        }
+        self.scheme = schemes[scheme]
+        self.start_ramp = start_ramp
+        self.end_ramp = end_ramp
+        
+    
+    def get_config(self, epoch):
+        if epoch <= self.start_ramp:
+            return self.scheme[0]
+        elif epoch>=self.end_ramp:
+            return self.scheme[-1]
+        else:
+            i = (epoch-self.start_ramp) * (len(self.scheme)-1) // (self.end_ramp-self.start_ramp)
+            return self.scheme[i]
+    
+    def __call__(self, loader: Loader, epoch,is_ffcv=False):
+        config = self.get_config(epoch)
+        print(", ".join([f"{k}={v}" for k,v in config.items()]))
+        
+        img_size = config['res']
+        lower_scale = config['lower_scale']
+        upper_scale = config['upper_scale']
+        color_jitter = config['color_jitter']
+
+        if is_ffcv:
+            pipelines = ThreeAugmentPipeline(img_size,scale=(lower_scale,upper_scale),color_jitter=color_jitter)
+            loader.compile_pipeline(pipelines)           
+        else:
+            # todo: change dres
+            pipelines = loader.dataset.transforms
+            
+            decoder = loader.dataset.transforms.transform.transforms[0]
+            decoder.size=(img_size,img_size)
+            decoder.scale = (lower_scale,upper_scale)
diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py
index ae25390..db86abc 100644
--- a/vitookit/datasets/ffcv_transform.py
+++ b/vitookit/datasets/ffcv_transform.py
@@ -36,84 +36,6 @@ IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
 IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
 
 
-@gin.configurable
-class DynamicResolution:
-    def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED,  
-                    scheme=1):
-        schemes ={
-            1: [
-                dict(res=160,lower_scale=0.08, upper_scale=1),
-                dict(res=192,lower_scale=0.08, upper_scale=1),
-                dict(res=224,lower_scale=0.08, upper_scale=1),
-            ],
-            2:[
-                dict(res=160,lower_scale=0.96, upper_scale=1),
-                dict(res=192,lower_scale=0.38, upper_scale=1),
-                dict(res=224,lower_scale=0.08, upper_scale=1),
-            ],
-            3:[
-                dict(res=160,lower_scale=0.08, upper_scale=0.4624),
-                dict(res=192,lower_scale=0.08, upper_scale=0.7056),
-                dict(res=224,lower_scale=0.08, upper_scale=1),
-            ],
-            4:[
-                dict(res=160,lower_scale=0.20, upper_scale=0.634),
-                dict(res=192,lower_scale=0.137, upper_scale=0.81),
-                dict(res=224,lower_scale=0.08, upper_scale=1),
-            ],
-            5: [
-                dict(res=160,lower_scale=0.20, upper_scale=1),
-                dict(res=192,lower_scale=0.137, upper_scale=1),
-                dict(res=224,lower_scale=0.08, upper_scale=1),
-            ],
-            
-        }
-        self.scheme = schemes[scheme]
-        self.start_ramp = start_ramp
-        self.end_ramp = end_ramp
-        
-    
-    def get_resolution(self, epoch):
-        if epoch<=(self.start_ramp+self.end_ramp)/2:
-            return self.scheme[0]['res']
-        elif epoch<=self.end_ramp:
-            return self.scheme[1]['res']
-        else:
-            return self.scheme[2]['res']
-    
-    def get_lower_scale(self, epoch):
-        if epoch<=(self.start_ramp+self.end_ramp)/2:
-            return self.scheme[0]['lower_scale']
-        elif epoch<=self.end_ramp:
-            return self.scheme[1]['lower_scale']
-        else:
-            return self.scheme[2]['lower_scale']
-    
-    def get_upper_scale(self, epoch):
-        if epoch<=(self.start_ramp+self.end_ramp)/2:
-            return self.scheme[0]['upper_scale']
-        elif epoch<=self.end_ramp:
-            return self.scheme[1]['upper_scale']
-        else:
-            return self.scheme[2]['upper_scale']
-    
-    def __call__(self, loader, epoch,is_ffcv=False):
-        img_size = self.get_resolution(epoch)
-        lower_scale = self.get_lower_scale(epoch)
-        upper_scale = self.get_upper_scale(epoch)
-        print(f"resolution = {img_size}, scale = ({lower_scale:.2f},{upper_scale:.2f}) ")
-
-        if is_ffcv:
-            pipeline=loader.pipeline_specs['image']
-            if pipeline.decoder.output_size[0] != img_size:
-                pipeline.decoder.scale = (lower_scale,upper_scale)
-                pipeline.decoder.output_size = (img_size,img_size)
-                loader.generate_code()
-        else:
-            decoder = loader.dataset.transforms.transform.transforms[0]
-            decoder.size=(img_size,img_size)
-            decoder.scale = (lower_scale,upper_scale)
-                
 @gin.configurable
 def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0)):
     image_pipeline = [
diff --git a/vitookit/datasets/sa.py b/vitookit/datasets/sa.py
index 38e7aed..a5e5ff7 100644
--- a/vitookit/datasets/sa.py
+++ b/vitookit/datasets/sa.py
@@ -117,6 +117,6 @@ class SA1BDataset:
         m = maskUtils.decode(rle)
         return m
 
-    
-data=SA1BDataset('/vol/research/datasets/still/SA-1B',['sa_258805','sa_2601447','sa_6313499','sa_2592075','sa_258406','sa_2601844','sa_6353015','sa_10958302','sa_6344773']) # if you want a subset of the dataset
-img, mask, class_ids = data[0]
\ No newline at end of file
+if __name__ == '__main__':
+    data=SA1BDataset('/vol/research/datasets/still/SA-1B',['sa_258805','sa_2601447','sa_6313499','sa_2592075','sa_258406','sa_2601844','sa_6353015','sa_10958302','sa_6344773']) # if you want a subset of the dataset
+    img, mask, class_ids = data[0]
\ No newline at end of file
diff --git a/vitookit/datasets/transform.py b/vitookit/datasets/transform.py
index 5d017f7..a628005 100644
--- a/vitookit/datasets/transform.py
+++ b/vitookit/datasets/transform.py
@@ -5,7 +5,7 @@ import numpy as np
 from torchvision import transforms
 import torch
 import gin
-
+from torch import nn
 
 class PermutePatch(object):
     """
@@ -103,6 +103,9 @@ class GrayScale(object):
         else:
             return img
 
+class ThreeAugmentation(nn.Module):
+    def __init__(self,):
+        super().__init__()
     
 def three_augmentation(args = None):
     img_size = args.input_size
diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py
index fa9e32a..2450174 100644
--- a/vitookit/evaluation/eval_cls_ffcv.py
+++ b/vitookit/evaluation/eval_cls_ffcv.py
@@ -43,7 +43,7 @@ from timm.layers import trunc_normal_
 
 from ffcv import Loader
 from ffcv.loader import OrderOption
-
+from datasets.dres import DynamicResolution
 
 def get_args_parser():
     parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
@@ -342,12 +342,6 @@ def main(args):
         
     for epoch in range(args.start_epoch, args.epochs):
         if dres:
-            if epoch == dres.end_ramp:
-                print("enhance augmentation!", epoch, dres.end_ramp)
-                ## enhance augmentation, see efficientnetv2
-                data_loader_train = Loader(args.train_path, pipelines=ThreeAugmentPipeline(color_jitter=0.3),
-                        batch_size=args.batch_size, num_workers=args.num_workers, 
-                        order=order, distributed=args.distributed,seed=args.seed)
             dres(data_loader_train,epoch,True)
         
         train_stats = train_one_epoch(
-- 
GitLab