diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py
index ffb51ac1045d71895cf1db4c998d5b39227ad322..ae25390193a116313c31c1c03eb0b1d2c7a7a36a 100644
--- a/vitookit/datasets/ffcv_transform.py
+++ b/vitookit/datasets/ffcv_transform.py
@@ -24,7 +24,7 @@ import gin
 from cv2 import GaussianBlur
 from scipy.ndimage import gaussian_filter
 
-from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip, View
+from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip, View, Convert
 from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder, SimpleRGBImageDecoder, CenterCropRGBImageDecoder
 
 import torch
@@ -114,101 +114,16 @@ class DynamicResolution:
             decoder.size=(img_size,img_size)
             decoder.scale = (lower_scale,upper_scale)
                 
-
-@gin.configurable
-class DynamicResolution1:
-    def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED,  
-                    res=(224,224,0),lower_scale=(0.08,0.08,0), upper_scale=(1,1,0)):
-        self.start_ramp = start_ramp
-        self.end_ramp = end_ramp
-        self.resolution = res
-        self.lower_scale = lower_scale
-        self.upper_scale = upper_scale
-        
-        
-    def get_resolution(self, epoch):    
-        
-        lv,rv, step = self.resolution
-        
-        if step==0:
-            assert lv==rv
-            return rv
-        
-        if epoch <= self.start_ramp:
-            return lv
-
-        if epoch >= self.end_ramp:
-            return rv
-        
-        interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv])
-        num_steps = int(((interp[0]-lv) / step))
-        return num_steps*step + lv
-    
-    def get_lower_scale(self, epoch):    
-        
-        lv,rv, step = self.lower_scale
-        
-        if step==0:
-            assert lv==rv
-            return rv
-        
-        if epoch <= self.start_ramp:
-            return lv
-
-        if epoch >= self.end_ramp:
-            return rv
-
-        
-        interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv])
-        num_steps = int(((interp[0]-lv) / step))
-        return num_steps*step + lv
-    
-    def get_upper_scale(self, epoch):    
-        
-        lv,rv, step = self.upper_scale
-        
-        if step==0:
-            assert lv==rv
-            return rv
-        
-        if epoch <= self.start_ramp:
-            return lv
-
-        if epoch >= self.end_ramp:
-            return rv
-
-        interp = np.interp([epoch], [self.start_ramp, self.end_ramp], [lv, rv])
-        num_steps = int(((interp[0]-lv) / step))
-        return num_steps*step + lv
-    
-
-    def __call__(self, loader, epoch,is_ffcv=False):
-        img_size = self.get_resolution(epoch)
-        lower_scale = self.get_lower_scale(epoch)
-        upper_scale = self.get_upper_scale(epoch)
-        print(f"resolution = {img_size}, scale = ({lower_scale:.2f},{upper_scale:.2f}) ")
-
-        if is_ffcv:
-            pipeline=loader.pipeline_specs['image']
-            if pipeline.decoder.output_size[0] != img_size:
-                pipeline.decoder.scale = (lower_scale,upper_scale)
-                pipeline.decoder.output_size = (img_size,img_size)
-                loader.generate_code()
-        else:
-            decoder = loader.dataset.transforms.transform.transforms[0]
-            decoder.size=(img_size,img_size)
-            decoder.scale = (lower_scale,upper_scale)
-                
-                
 @gin.configurable
 def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0)):
     image_pipeline = [
-            RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,ratio=ratio),
-            RandomHorizontalFlip(),
-            NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16),
+            RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,ratio=ratio,),
+            RandomHorizontalFlip(),            
             ToTensor(), 
-            ToDevice(torch.device('cuda')),        
+            ToDevice(torch.device('cuda')),
             ToTorchImage(),
+            Convert(torch.float16),
+            tfms.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255], inplace=True),
             ]
     label_pipeline = [IntDecoder(), ToTensor(),ToDevice(torch.device('cuda'))]
     # Pipeline for each data field
@@ -222,10 +137,11 @@ def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0)):
 def ValPipeline(img_size=224,ratio= 224/256):
     image_pipeline = [
             CenterCropRGBImageDecoder((img_size, img_size), ratio),
-            NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16),
             ToTensor(), 
             ToDevice(torch.device('cuda')),        
             ToTorchImage(),
+            Convert(torch.float16),
+            tfms.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255], inplace=True),
             ]
     label_pipeline = [IntDecoder(), ToTensor(),ToDevice(torch.device('cuda')),View(-1)]
     # Pipeline for each data field
@@ -575,54 +491,6 @@ def convolution(image: np.ndarray, kernel: list | tuple, output: np.ndarray) ->
     return output
 
 
-class FFCVThreeAugmentation(Operation):
-    def __init__(
-        self, threshold=128, radius_min=0.1, radius_max=2.
-    ):
-        super().__init__()
-        self.threshold = threshold
-        self.radius_min = radius_min
-        self.radius_max = radius_max
-
-    def generate_code(self) -> Callable:
-        my_range = Compiler.get_iterator()
-        threshold = self.threshold
-        radius_min = self.radius_min
-        radius_max = self.radius_max
-        
-        def randchoice(images, dst):
-            for i in my_range(images.shape[0]):
-                idx = random.randint(0, 2)
-                if idx == 0:       
-                    # solarize             
-                    mask = images[i] >= threshold
-                    dst[i] = np.where(mask, 255 - images[i], images[i])
-                elif idx == 1:
-                    # grayscale
-                    dst[i] = (
-                        0.2989 * images[i, ..., 0:1]
-                        + 0.5870 * images[i, ..., 1:2]
-                        + 0.1140 * images[i, ..., 2:3]
-                    )
-                else:
-                    sigma = np.random.uniform(radius_min, radius_max)      
-                    kernel_width = int(3 * sigma)
-                    if kernel_width % 2 == 0:
-                        kernel_width = kernel_width + 1  # make sure kernel width only sth 3,5,7 etc
-                    kernel = generate_gaussian_filter(sigma,filter_shape=(kernel_width, kernel_width))
-                    convolution(images[i], kernel, dst[i])
-                    
-            return dst
-        randchoice.parallel = True
-        return randchoice
-
-    def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
-        # No updates to state or extra memory necessary!
-        mem_alloc = AllocationQuery(previous_state.shape,dtype=previous_state.dtype)
-        return replace(previous_state,jit_mode=True), mem_alloc
-
-
-
 from torchvision.transforms import InterpolationMode
 
 @gin.configurable
@@ -664,7 +532,7 @@ class ThreeAugmentation(nn.Module):
         return f"{super().__repr__()}(p={self.p})"
         
 @gin.configurable
-def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downscale=1):
+def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None):
     """
     ThreeAugmentPipeline
     """