diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py
index 1d2a722ed187050abe050c35baee589cc0db47d5..ffb51ac1045d71895cf1db4c998d5b39227ad322 100644
--- a/vitookit/datasets/ffcv_transform.py
+++ b/vitookit/datasets/ffcv_transform.py
@@ -575,7 +575,7 @@ def convolution(image: np.ndarray, kernel: list | tuple, output: np.ndarray) ->
     return output
 
 
-class ThreeAugmentation(Operation):
+class FFCVThreeAugmentation(Operation):
     def __init__(
         self, threshold=128, radius_min=0.1, radius_max=2.
     ):
@@ -621,20 +621,7 @@ class ThreeAugmentation(Operation):
         mem_alloc = AllocationQuery(previous_state.shape,dtype=previous_state.dtype)
         return replace(previous_state,jit_mode=True), mem_alloc
 
-class RandomChoice(nn.Module):
-    """Apply single transformation randomly picked from a list. This transform does not support torchscript."""
-
-    def __init__(self, transforms, p=None):
-        super().__init__()
-        self.transforms = transforms
-        self.p = p
 
-    def __call__(self, *args):
-        t = random.choices(self.transforms, weights=self.p)[0]
-        return t(*args)
-
-    def __repr__(self) -> str:
-        return f"{super().__repr__()}(p={self.p})"
 
 from torchvision.transforms import InterpolationMode
 
@@ -652,32 +639,50 @@ class RandDownSampling(nn.Module):
         down = F.resize(x,(nh,hw),interpolation=InterpolationMode.BICUBIC)
         up = F.resize(down,(h,w),interpolation=InterpolationMode.BICUBIC)
         return up
-    
+
+class ThreeAugmentation(nn.Module):
+    """Apply single transformation randomly picked from a list. This transform does not support torchscript."""
+
+    def __init__(self, ):
+        super().__init__()
+        self.guassian_blur = tfms.GaussianBlur(5,sigma=(0.1,2))
+        
+
+    def __call__(self, x):
+        mb = math.ceil(len(x)/3)
+        perm = torch.randperm(3)
+        for i in range(3):
+            if perm[i] == 0:
+                x[i*mb:(i+1)*mb] = self.guassian_blur(x[i*mb:(i+1)*mb])
+            elif perm[i] == 1:
+                x[i*mb:(i+1)*mb] = F.solarize(x[i*mb:(i+1)*mb], 0)
+            else:
+                x[i*mb:(i+1)*mb] = F.rgb_to_grayscale(x[i*mb:(i+1)*mb],1)
+        return x
+
+    def __repr__(self) -> str:
+        return f"{super().__repr__()}(p={self.p})"
+        
 @gin.configurable
 def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,downscale=1):
     """
     ThreeAugmentPipeline
     """
     image_pipeline = (
-            # first_tfl 
-            [   RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,),
-                RandomHorizontalFlip(),]+
-            # second_tfl
-            (   [RandomColorJitter(jitter_prob=0.8, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter, hue=0, seed=None)] if color_jitter else []) +
-            # final_tfl
-            [
-                NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32),
-                ToTensor(), 
-                # ToDevice(torch.device('cuda')),        
-                ToTorchImage(),
-                # ThreeAugmentation()                
-                RandomChoice([
-                    tfms.RandomSolarize(0,1),
-                    tfms.RandomGrayscale(1),
-                    tfms.GaussianBlur(5,sigma=(0.1,2)),
-                ]),
-                RandDownSampling((downscale,1)),
-            ])
+        # first_tfl 
+        [   RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,),
+            RandomHorizontalFlip(),]+
+        # second_tfl
+        (   [RandomColorJitter(jitter_prob=1, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter,)] if color_jitter else []) + 
+        # final_tfl
+        [
+            NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32),
+            ToTensor(), 
+            # ToDevice(torch.device('cuda')),        
+            ToTorchImage(),
+            ThreeAugmentation(),
+        ]) 
+        
     label_pipeline = [IntDecoder(), ToTensor(),View(-1)]
     # Pipeline for each data field
     pipelines = {