diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..89d749e18474bf6bb8deeb490582271d6d415bc9
--- /dev/null
+++ b/vitookit/datasets/ffcv_transform.py
@@ -0,0 +1,396 @@
+"""
+Copyright (c) Meta Platforms, Inc. and affiliates.
+All rights reserved.
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+Copy from https://github.com/facebookresearch/FFCV-SSL/blob/6458e33f0753e7a35bc639517a763350a0fc2177/ffcv/transforms/colorjitter.py
+"""
+
+
+import numpy as np
+from collections.abc import Sequence
+from typing import Callable, Optional, Tuple
+from dataclasses import replace
+from ffcv.pipeline.allocation_query import AllocationQuery
+from ffcv.pipeline.operation import Operation
+from ffcv.pipeline.state import State
+from ffcv.pipeline.compiler import Compiler
+import numba as nb
+import numbers
+import math
+import random
+from numba import njit, jit
+import gin
+from cv2 import GaussianBlur
+from scipy.ndimage import gaussian_filter
+
+from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip, View
+from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder, SimpleRGBImageDecoder, CenterCropRGBImageDecoder
+
+import torch
+import torchvision.transforms as tfms
+from torch import nn
+
+IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
+IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
+
+@njit(parallel=False, fastmath=True, inline="always")
+def apply_cj(
+    im,
+    apply_bri,
+    bri_ratio,
+    apply_cont,
+    cont_ratio,
+    apply_sat,
+    sat_ratio,
+    apply_hue,
+    hue_factor,
+):
+
+    gray = (
+        np.float32(0.2989) * im[..., 0]
+        + np.float32(0.5870) * im[..., 1]
+        + np.float32(0.1140) * im[..., 2]
+    )
+    one = np.float32(1)
+    # Brightness
+    if apply_bri:
+        im = im * bri_ratio
+
+    # Contrast
+    if apply_cont:
+        im = cont_ratio * im + (one - cont_ratio) * np.float32(gray.mean())
+
+    # Saturation
+    if apply_sat:
+        im[..., 0] = sat_ratio * im[..., 0] + (one - sat_ratio) * gray
+        im[..., 1] = sat_ratio * im[..., 1] + (one - sat_ratio) * gray
+        im[..., 2] = sat_ratio * im[..., 2] + (one - sat_ratio) * gray
+
+    # Hue
+    if apply_hue:
+        hue_factor_radians = hue_factor * 2.0 * np.pi
+        cosA = np.cos(hue_factor_radians)
+        sinA = np.sin(hue_factor_radians)
+        v1, v2, v3 = 1.0 / 3.0, np.sqrt(1.0 / 3.0), (1.0 - cosA)
+        hue_matrix = [
+            [
+                cosA + v3 / 3.0,
+                v1 * v3 - v2 * sinA,
+                v1 * v3 + v2 * sinA,
+            ],
+            [
+                v1 * v3 + v2 * sinA,
+                cosA + v1 * v3,
+                v1 * v3 - v2 * sinA,
+            ],
+            [
+                v1 * v3 - v2 * sinA,
+                v1 * v3 + v2 * sinA,
+                cosA + v1 * v3,
+            ],
+        ]
+        hue_matrix = np.array(hue_matrix, dtype=np.float64).T
+        for row in nb.prange(im.shape[0]):
+            im[row] = im[row] @ hue_matrix
+    return np.clip(im, 0, 255).astype(np.uint8)
+
+
+class RandomColorJitter(Operation):
+    """Add ColorJitter with probability jitter_prob.
+    Operates on raw arrays (not tensors), ranging from 0 to 255.
+
+    see https://github.com/pytorch/vision/blob/28557e0cfe9113a5285330542264f03e4ba74535/torchvision/transforms/functional_tensor.py#L165
+     and https://sanje2v.wordpress.com/2021/01/11/accelerating-data-transforms/
+    Parameters
+    ----------
+    jitter_prob : float, The probability with which to apply ColorJitter.
+    brightness (float or tuple of float (min, max)): How much to jitter brightness.
+        brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
+        or the given [min, max]. Should be non negative numbers.
+    contrast (float or tuple of float (min, max)): How much to jitter contrast.
+        contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
+        or the given [min, max]. Should be non negative numbers.
+    saturation (float or tuple of float (min, max)): How much to jitter saturation.
+        saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
+        or the given [min, max]. Should be non negative numbers.
+    hue (float or tuple of float (min, max)): How much to jitter hue.
+        hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
+        Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
+    """
+
+    def __init__(
+        self,
+        jitter_prob=0.5,
+        brightness=0.8,
+        contrast=0.4,
+        saturation=0.4,
+        hue=0.2,
+        seed=None,
+    ):
+        super().__init__()
+        self.jitter_prob = jitter_prob
+
+        self.brightness = self._check_input(brightness, "brightness")
+        self.contrast = self._check_input(contrast, "contrast")
+        self.saturation = self._check_input(saturation, "saturation")
+        self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5))
+        self.seed = seed
+        assert self.jitter_prob >= 0 and self.jitter_prob <= 1
+
+    def _check_input(
+        self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True
+    ):
+        if isinstance(value, numbers.Number):
+            if value < 0:
+                raise ValueError(
+                    f"If {name} is a single number, it must be non negative."
+                )
+            value = [center - float(value), center + float(value)]
+            if clip_first_on_zero:
+                value[0] = max(value[0], 0.0)
+        elif isinstance(value, (tuple, list)) and len(value) == 2:
+            if not bound[0] <= value[0] <= value[1] <= bound[1]:
+                raise ValueError(f"{name} values should be between {bound}")
+        else:
+            raise TypeError(
+                f"{name} should be a single number or a list/tuple with length 2."
+            )
+
+        # if value is 0 or (1., 1.) for brightness/contrast/saturation
+        # or (0., 0.) for hue, do nothing
+        if value[0] == value[1] == center:
+            setattr(self, f"apply_{name}", False)
+        else:
+            setattr(self, f"apply_{name}", True)
+        return tuple(value)
+
+    def generate_code(self) -> Callable:
+        my_range = Compiler.get_iterator()
+
+        jitter_prob = self.jitter_prob
+
+        apply_bri = self.apply_brightness
+        bri = self.brightness
+
+        apply_cont = self.apply_contrast
+        cont = self.contrast
+
+        apply_sat = self.apply_saturation
+        sat = self.saturation
+
+        apply_hue = self.apply_hue
+        hue = self.hue
+
+        seed = self.seed
+        if seed is None:
+
+            def color_jitter(images, _):
+                for i in my_range(images.shape[0]):
+                    if np.random.rand() > jitter_prob:
+                        continue
+
+                    images[i] = apply_cj(
+                        images[i].astype("float64"),
+                        apply_bri,
+                        np.random.uniform(bri[0], bri[1]),
+                        apply_cont,
+                        np.random.uniform(cont[0], cont[1]),
+                        apply_sat,
+                        np.random.uniform(sat[0], sat[1]),
+                        apply_hue,
+                        np.random.uniform(hue[0], hue[1]),
+                    )
+                return images
+
+            color_jitter.is_parallel = True
+            return color_jitter
+
+        def color_jitter(images, _, counter):
+
+            random.seed(seed + counter)
+            N = images.shape[0]
+            values = np.zeros(N)
+            bris = np.zeros(N)
+            conts = np.zeros(N)
+            sats = np.zeros(N)
+            hues = np.zeros(N)
+            for i in range(N):
+                values[i] = np.float32(random.uniform(0, 1))
+                bris[i] = np.float32(random.uniform(bri[0], bri[1]))
+                conts[i] = np.float32(random.uniform(cont[0], cont[1]))
+                sats[i] = np.float32(random.uniform(sat[0], sat[1]))
+                hues[i] = np.float32(random.uniform(hue[0], hue[1]))
+            for i in my_range(N):
+                if values[i] > jitter_prob:
+                    continue
+                images[i] = apply_cj(
+                    images[i].astype("float64"),
+                    apply_bri,
+                    bris[i],
+                    apply_cont,
+                    conts[i],
+                    apply_sat,
+                    sats[i],
+                    apply_hue,
+                    hues[i],
+                )
+            return images
+
+        color_jitter.is_parallel = True
+        color_jitter.with_counter = True
+        return color_jitter
+
+    def declare_state_and_memory(
+        self, previous_state: State
+    ) -> Tuple[State, Optional[AllocationQuery]]:
+        return (replace(previous_state, jit_mode=True), None)
+
+class Grayscale(Operation):
+    """Add Gaussian Blur with probability blur_prob.
+    Operates on raw arrays (not tensors).
+
+    Parameters
+    ----------
+    blur_prob : float
+        The probability with which to flip each image in the batch
+        horizontally.
+    """
+
+    def __init__(self):
+        super().__init__()
+
+    def generate_code(self) -> Callable:
+        my_range = Compiler.get_iterator()
+        
+        def grayscale(images, _):
+            for i in my_range(images.shape[0]):
+                images[i] = (
+                    0.2989 * images[i, ..., 0:1]
+                    + 0.5870 * images[i, ..., 1:2]
+                    + 0.1140 * images[i, ..., 2:3]
+                )
+            return images
+        grayscale.is_parallel = True
+        return grayscale
+
+    def declare_state_and_memory(
+        self, previous_state: State
+    ) -> Tuple[State, Optional[AllocationQuery]]:
+        return (previous_state, None)
+
+
+
+class Solarization(Operation):
+    """Solarize the image randomly with a given probability by inverting all pixel
+    values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
+    where ... means it can have an arbitrary number of leading dimensions.
+    If img is PIL Image, it is expected to be in mode "L" or "RGB".
+    Parameters
+    ----------
+        solarization_prob (float): probability of the image being solarized. Default value is 0.5
+        threshold (float): all pixels equal or above this value are inverted.
+    """
+
+    def __init__(
+        self, threshold: float = 128,
+    ):
+        super().__init__()
+        self.threshold = threshold
+
+    def generate_code(self) -> Callable:
+        my_range = Compiler.get_iterator()
+        threshold = self.threshold
+        def solarize(images, _):
+            for i in my_range(images.shape[0]):
+                mask = images[i] >= threshold
+                images[i] = np.where(mask, 255 - images[i], images[i])
+            return images
+
+        solarize.is_parallel = True
+        return solarize
+
+    def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
+        # No updates to state or extra memory necessary!
+        return previous_state, None
+  
+
+class ThreeAugmentation(Operation):
+    def __init__(
+        self, threshold=128, radius_min=0.1, radius_max=2.
+    ):
+        super().__init__()
+        self.threshold = threshold
+        self.radius_min = radius_min
+        self.radius_max = radius_max
+
+    def generate_code(self) -> Callable:
+        my_range = Compiler.get_iterator()
+        threshold = self.threshold
+        radius_min = self.radius_min
+        radius_max = self.radius_max
+        
+        def randchoice(images, _):
+            for i in my_range(images.shape[0]):
+                idx = random.randint(0, 2)
+                if idx == 0:       
+                    # solarize             
+                    mask = images[i] >= threshold
+                    images[i] = np.where(mask, 255 - images[i], images[i])
+                elif idx == 1:
+                    # grayscale
+                    images[i] = (
+                        0.2989 * images[i, ..., 0:1]
+                        + 0.5870 * images[i, ..., 1:2]
+                        + 0.1140 * images[i, ..., 2:3]
+                    )
+                else:
+                    # TODO: GaussianBlur
+                    radius = np.random.uniform(radius_min, radius_max)
+                    # images[i] = gaussian_filter(images[i], radius)
+            return images
+        # randchoice.is_parallel = True
+        
+        return randchoice
+
+    def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
+        # No updates to state or extra memory necessary!
+        return previous_state, None
+
+@gin.configurable
+def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None):
+    """
+    ThreeAugmentPipeline
+    """
+    image_pipeline = (
+            # first_tfl 
+            [   RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,),
+                RandomHorizontalFlip(),]+
+            # second_tfl
+            [   ThreeAugmentation(),] +
+            (   [RandomColorJitter(jitter_prob=1, brightness=color_jitter, contrast=color_jitter, saturation=color_jitter, hue=0, seed=None)] if color_jitter else []) +
+            # final_tfl
+            [
+                NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16),
+                ToTensor(), 
+                # ToDevice(torch.device('cuda')),        
+                ToTorchImage(),
+            ])
+    label_pipeline = [IntDecoder(), ToTensor(),ToDevice(torch.device('cuda')),View(-1)]
+    # Pipeline for each data field
+    pipelines = {
+        'image': image_pipeline,
+        'label': label_pipeline
+    } 
+    return pipelines   
+
+if __name__ == "__main__":
+    from ffcv import Loader
+    from ffcv.pipeline import Pipeline
+    path = "/users/jw02425/projects/EfficientSSL/data/IN1K_val_1000.ffcv"
+    loader = Loader(path,4,
+                    pipelines=ThreeAugmentPipeline(),)
+    
+    for batch in loader:
+        print(batch)
+        break
\ No newline at end of file
diff --git a/vitookit/datasets/transform.py b/vitookit/datasets/transform.py
index 33750245324bf46bd19c875bf7db3ec5e7f9800a..4fdfb0f384d9e218e1fd047bc1a7da72f2697084 100644
--- a/vitookit/datasets/transform.py
+++ b/vitookit/datasets/transform.py
@@ -6,7 +6,7 @@ from torchvision import transforms
 import torch
 import gin
 
-from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip
+from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip,View
 from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder, SimpleRGBImageDecoder, CenterCropRGBImageDecoder
 
 IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
@@ -39,7 +39,7 @@ def ValPipeline(img_size=224,ratio= 224/256):
             ToDevice(torch.device('cuda')),        
             ToTorchImage(),
             ]
-    label_pipeline = [IntDecoder(), ToTensor(),ToDevice(torch.device('cuda'))]
+    label_pipeline = [IntDecoder(), ToTensor(),ToDevice(torch.device('cuda')),View(-1)]
     # Pipeline for each data field
     pipelines = {
         'image': image_pipeline,
@@ -143,35 +143,6 @@ class GrayScale(object):
         else:
             return img
 
-from torchvision.transforms import functional as F
-class RandomResizedCrop(transforms.RandomResizedCrop):
-    """
-    RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
-    This may lead to results different with torchvision's version.
-    Following BYOL's TF code:
-    https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
-    """
-    @staticmethod
-    def get_params(img, scale, ratio):
-        width, height = F._get_image_size(img)
-        area = height * width
-
-        target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
-        log_ratio = torch.log(torch.tensor(ratio))
-        aspect_ratio = torch.exp(
-            torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
-        ).item()
-
-        w = int(round(math.sqrt(target_area * aspect_ratio)))
-        h = int(round(math.sqrt(target_area / aspect_ratio)))
-
-        w = min(w, width)
-        h = min(h, height)
-
-        i = torch.randint(0, height - h + 1, size=(1,)).item()
-        j = torch.randint(0, width - w + 1, size=(1,)).item()
-
-        return i, j, h, w
     
 def three_augmentation(args = None):
     img_size = args.input_size
@@ -188,7 +159,7 @@ def three_augmentation(args = None):
         ]
     else:
         primary_tfl = [
-            RandomResizedCrop(
+            transforms.RandomResizedCrop(
                 img_size, scale=scale, interpolation=transforms.InterpolationMode.BICUBIC),
             transforms.RandomHorizontalFlip()
         ]
@@ -206,4 +177,5 @@ def three_augmentation(args = None):
                 mean=torch.tensor(mean),
                 std=torch.tensor(std))
         ]
-    return transforms.Compose(primary_tfl+secondary_tfl+final_tfl)
\ No newline at end of file
+    return transforms.Compose(primary_tfl+secondary_tfl+final_tfl)
+
diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py
index b5cf20eab66a45738d975132fe3c93465f56e5c6..d06c688632b723d5ea4deadf2613218a5006904d 100644
--- a/vitookit/evaluation/eval_cls.py
+++ b/vitookit/evaluation/eval_cls.py
@@ -196,7 +196,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
 
         with torch.cuda.amp.autocast():
             outputs = model(samples)
-            loss = criterion(samples, outputs, targets)
+            loss = criterion( outputs, targets)
         
         loss /= accum_iter
         loss_scaler(loss, optimizer, clip_grad=max_norm,
diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a477d16fb5ec0ac310b2b0396d0e87f8ac607f5
--- /dev/null
+++ b/vitookit/evaluation/eval_cls_ffcv.py
@@ -0,0 +1,427 @@
+#!/usr/bin/env python
+# Copyright (c) ByteDance, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Mostly copy-paste from DEiT library:
+https://github.com/facebookresearch/deit/blob/main/main.py
+"""
+from PIL import Image # hack to avoid `CXXABI_1.3.9' not found error
+
+import argparse
+import datetime
+import numpy as np
+import time
+import torch
+import torch.nn as nn
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import json
+import os
+import math
+import sys
+import copy
+import scipy.io as scio
+from vitookit.datasets.ffcv_transform import ThreeAugmentPipeline
+from vitookit.datasets.transform import ValPipeline, three_augmentation
+from vitookit.utils.helper import *
+from vitookit.utils import misc
+from vitookit.models.build_model import build_model
+from vitookit.datasets import build_dataset
+import wandb
+
+from pathlib import Path
+from typing import Iterable, Optional
+from torch.nn import functional as F
+
+from timm.models import create_model
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
+from timm.scheduler import create_scheduler
+from timm.optim import create_optimizer
+from timm.utils import NativeScaler, get_state_dict, ModelEma, accuracy
+from timm.data import Mixup, create_transform
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import trunc_normal_
+
+from ffcv import Loader
+from ffcv.loader import OrderOption
+
+def get_args_parser():
+    parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
+    parser.add_argument('--batch_size', default=128, type=int,
+                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
+    parser.add_argument('--accum_iter', default=1, type=int,
+                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
+    parser.add_argument('--epochs', default=100, type=int)
+    parser.add_argument('--ckpt_freq', default=5, type=int)
+
+    # Model parameters
+    parser.add_argument("--compile", action='store_true', default=False, help="compile model with PyTorch 2.0")
+    parser.add_argument("--prefix", default=None, type=str, help="prefix of the model name")
+    
+    parser.add_argument('--input_size', default=224, type=int, help='images input size')
+    parser.add_argument('-w', '--pretrained_weights', default='', type=str, help="""Path to pretrained 
+        weights to evaluate. Set to `download` to automatically load the pretrained DINO from url.
+        Otherwise the model is randomly initialized""")
+    parser.add_argument("--checkpoint_key", default=None, type=str, help='Key to use in the checkpoint (example: "teacher")')
+    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
+                        help='Dropout rate (default: 0.)')
+    parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
+                        help='Attention dropout rate (default: 0.)')
+    parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
+                        help='Drop path rate (default: 0.1)')
+
+
+    # Optimizer parameters
+    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
+                        help='Optimizer (default: "adamw"')
+    parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
+                        help='Optimizer Epsilon (default: 1e-8)')
+    parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
+                        help='Optimizer Betas (default: None, use opt default)')
+    parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
+                        help='Clip gradient norm (default: None, no clipping)')
+    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
+                        help='SGD momentum (default: 0.9)')
+    parser.add_argument('--weight_decay', type=float, default=0.05,
+                        help='weight decay (default: 0.05)')
+    parser.add_argument('--layer_decay', type=float, default=0.75)
+    # Learning rate schedule parameters
+    parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
+                        help='LR scheduler (default: "cosine"')
+    parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
+                        help='learning rate (default: 5e-4)')
+    parser.add_argument('--lr_noise', type=float, nargs='+', default=None, metavar='pct, pct',
+                        help='learning rate noise on/off epoch percentages')
+    
+    parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
+                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
+
+    parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
+                        help='epochs to warmup LR, if scheduler supports')
+    parser.add_argument('--decay_rate', '--dr', type=float, default=0.1, metavar='RATE',
+                        help='LR decay rate (default: 0.1)')
+
+    # Augmentation parameters
+    parser.add_argument('--ThreeAugment', action='store_true', default=False) #3augment
+    parser.add_argument('--src',action='store_true', default=False, 
+                        help="Use Simple Random Crop (SRC) or Random Resized Crop (RRC). Use SRC when there is less risk of overfitting, such as on ImageNet-21k.")
+    parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
+                        help='Color jitter factor (enabled only when not using Auto/RandAug)')
+    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
+                        help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
+    parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
+
+    # * Random Erase params
+    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
+                        help='Random erase prob (default: 0.25)')
+    parser.add_argument('--remode', type=str, default='pixel',
+                        help='Random erase mode (default: "pixel")')
+    parser.add_argument('--recount', type=int, default=1,
+                        help='Random erase count (default: 1)')
+    parser.add_argument('--resplit', action='store_true', default=False,
+                        help='Do not random erase first (clean) augmentation split')
+
+    # * Mixup params
+    parser.add_argument('--mixup', type=float, default=0.8,
+                        help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
+    parser.add_argument('--cutmix', type=float, default=1.0,
+                        help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
+    parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
+                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
+    parser.add_argument('--mixup_prob', type=float, default=1.0,
+                        help='Probability of performing mixup or cutmix when either/both is enabled')
+    parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
+                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
+    parser.add_argument('--mixup_mode', type=str, default='batch',
+                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
+
+    # * Finetuning params
+    parser.add_argument('--disable_weight_decay_on_bias_norm', action='store_true', default=False)
+    parser.add_argument('--init_scale', default=1.0, type=float)
+
+    # Dataset parameters
+    parser.add_argument('--train_path', type=str, required=True, help='path to train dataset')
+    parser.add_argument('--val_path', type=str, required=True, help='path to test dataset')
+    parser.add_argument('--nb_classes', type=int, default=1000, help='number of classes')
+
+    parser.add_argument('--output_dir', default=None, type=str,
+                        help='path where to save, empty for no saving')
+    parser.add_argument('--device', default='cuda',
+                        help='device to use for training / testing')
+    parser.add_argument('--seed', default=0, type=int)
+    parser.add_argument('--resume', default='', help='resume from checkpoint')
+    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
+                        help='start epoch')
+    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
+    parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation')
+    parser.add_argument('--num_workers', default=10, type=int)
+    parser.add_argument('--pin_mem', action='store_true',
+                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
+    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
+                        help='')
+    parser.set_defaults(pin_mem=True)
+
+    # distributed training parameters
+    parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
+    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+
+    return parser
+
+
+def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
+                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
+                    device: torch.device, epoch: int, loss_scaler,lr_scheduler, max_norm: float = 0,
+                     mixup_fn: Optional[Mixup] = None,
+                    ):
+    model.train(True)
+    metric_logger = misc.MetricLogger(delimiter="  ")
+    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+    header = 'Epoch: [{}]'.format(epoch)
+    print_freq = max(len(data_loader)//20,20)
+    
+    accum_iter = args.accum_iter
+    for itr,(samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+        samples = samples.to(device, non_blocking=True)
+        targets = targets.to(device, non_blocking=True)
+
+        lr_scheduler.step(epoch+itr/len(data_loader))
+        if mixup_fn is not None:
+            samples, targets = mixup_fn(samples, targets)
+
+        with torch.cuda.amp.autocast():
+            outputs = model(samples)
+            loss = criterion(outputs, targets)
+        
+        loss /= accum_iter
+        loss_scaler(loss, optimizer, clip_grad=max_norm,
+                    parameters=model.parameters(), create_graph=False,
+                    need_update=(itr + 1) % accum_iter == 0)
+        if (itr + 1) % accum_iter == 0:
+            optimizer.zero_grad()
+            
+        torch.cuda.synchronize()
+        # log metrics
+        loss_value = loss.item()
+
+        if not math.isfinite(loss_value):
+            print("Loss is {}, stopping training".format(loss_value))
+            sys.exit(1)
+        # this attribute is added by timm on one optimizer (adahessian)
+        # is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
+
+        # if model_ema is not None:
+        #     model_ema.update(model)
+            
+        if wandb.run: 
+            wandb.log({'train/loss':loss})
+        metric_logger.update(loss=loss_value)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+    # gather the stats from all processes
+    metric_logger.synchronize_between_processes()
+    print("Averaged stats:", metric_logger)
+    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluate(data_loader, model, device):
+    criterion = torch.nn.CrossEntropyLoss()
+
+    metric_logger = misc.MetricLogger(delimiter="  ")
+    header = 'Test:'
+
+    # switch to evaluation mode
+    model.eval()
+
+    for images, target in metric_logger.log_every(data_loader, 10, header):
+        images = images.to(device, non_blocking=True)
+        target = target.to(device, non_blocking=True, dtype=torch.long)
+
+        # compute output
+        with torch.cuda.amp.autocast():
+            output = model(images)
+            loss = criterion(output, target)
+
+        acc1, acc5 = accuracy(output, target, topk=(1, 5))
+
+        batch_size = images.shape[0]
+        metric_logger.update(loss=loss.item())
+        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
+        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
+    # gather the stats from all processes
+    metric_logger.synchronize_between_processes()
+    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
+          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
+
+    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+def main(args):
+    misc.init_distributed_mode(args)
+
+    print(args)
+
+    device = torch.device(args.device)
+
+    # fix the seed for reproducibility
+    misc.fix_random_seeds(args.seed)
+
+    cudnn.benchmark = True
+    
+    
+    order = OrderOption.RANDOM if args.distributed else OrderOption.QUASI_RANDOM
+    data_loader_train =  Loader(args.train_path, pipelines=ThreeAugmentPipeline(),
+                        batch_size=args.batch_size, num_workers=args.num_workers, 
+                        order=order, distributed=args.distributed,seed=args.seed)
+    
+
+    data_loader_val =  Loader(args.val_path, pipelines=ValPipeline(),
+                        batch_size=args.batch_size, num_workers=args.num_workers, 
+                        distributed=args.distributed,seed=args.seed)
+
+    mixup_fn = None
+    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
+    if mixup_active:
+        print("Mixup is activated!")
+        mixup_fn = Mixup(
+            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
+            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
+            label_smoothing=args.smoothing, num_classes=args.nb_classes)
+
+    
+    print(f"Model  built.")
+    # load weights to evaluate
+    
+    model = build_model(num_classes=args.nb_classes,drop_path_rate=args.drop_path,)
+    if args.pretrained_weights:
+        load_pretrained_weights(model, args.pretrained_weights, checkpoint_key=args.checkpoint_key, prefix=args.prefix)
+    if args.compile:
+        model = torch.compile(model)    
+    trunc_normal_(model.head.weight, std=2e-5)
+    
+    model.to(device)
+
+    model_without_ddp = model
+    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    print('number of params:', n_parameters)
+    
+    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
+
+    linear_scaled_lr = args.lr * eff_batch_size / 256.0
+    
+    print("base lr: %.2e" % args.lr )
+    print("actual lr: %.2e" % linear_scaled_lr)
+    args.lr = linear_scaled_lr
+    
+    print("accumulate grad iterations: %d" % args.accum_iter)
+    print("effective batch size: %d" % eff_batch_size)
+    
+    if args.distributed:
+        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+        model_without_ddp = model.module
+    optimizer = create_optimizer(args, model_without_ddp)
+    loss_scaler = NativeScaler()
+
+    lr_scheduler, _ = create_scheduler(args, optimizer)
+
+    if mixup_fn is not None:
+        # smoothing is handled with mixup label transform
+        criterion = SoftTargetCrossEntropy()
+    elif args.smoothing > 0.:
+        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
+    else:
+        criterion = torch.nn.CrossEntropyLoss()
+
+    print("criterion = %s" % str(criterion))
+    
+
+    output_dir = Path(args.output_dir) if args.output_dir else None
+    if args.resume:        
+        run_variables={"args":dict(),"epoch":0}
+        restart_from_checkpoint(args.resume,
+                                optimizer=optimizer,
+                                model=model_without_ddp,
+                                scaler=loss_scaler,
+                                run_variables=run_variables)
+        # args = run_variables['args']
+        args.start_epoch = run_variables["epoch"] + 1
+
+    if args.eval:
+        test_stats = evaluate(data_loader_val, model_without_ddp, device)
+        print(f"Accuracy of the network on the test images: {test_stats['acc1']:.1f}%")
+        if args.output_dir and misc.is_main_process():
+            with (output_dir / "log.txt").open("a") as f:
+                f.write(json.dumps(test_stats) + "\n")
+        exit(0)
+
+    print(f"Start training for {args.epochs} epochs from {args.start_epoch}")
+    start_time = time.time()
+    max_accuracy = 0.0
+    if args.output_dir and misc.is_main_process():
+        try:
+            wandb.init(job_type='finetune',dir=args.output_dir,resume=True, 
+                   config=args.__dict__)
+        except:
+            pass
+        
+    for epoch in range(args.start_epoch, args.epochs):
+
+        train_stats = train_one_epoch(
+            model, criterion, data_loader_train,
+            optimizer, device, epoch, loss_scaler,lr_scheduler,
+            args.clip_grad,  mixup_fn,
+        )
+        
+        checkpoint_paths = ['checkpoint.pth']
+        
+        if epoch%args.ckpt_freq==0 or epoch==args.epochs-1:
+            test_stats = evaluate(data_loader_val, model, device)
+            print(f"Accuracy of the network on test images: {test_stats['acc1']:.1f}%")
+            
+            if (test_stats["acc1"] >= max_accuracy):
+                # always only save best checkpoint till now
+                checkpoint_paths += [ 'checkpoint_best.pth']
+                
+        
+            max_accuracy = max(max_accuracy, test_stats["acc1"])
+            print(f'Max accuracy: {max_accuracy:.2f}%')
+
+            log_stats = {**{f'train/{k}': v for k, v in train_stats.items()},
+                        **{f'test/{k}': v for k, v in test_stats.items()},
+                        'epoch': epoch,
+                        'n_parameters': n_parameters}
+        else:
+            log_stats = {**{f'train/{k}': v for k, v in train_stats.items()},
+                        'epoch': epoch,
+                        'n_parameters': n_parameters}
+            
+        # only save checkpoint on rank 0
+        if output_dir and misc.is_main_process():
+            if epoch%args.ckpt_freq==0 or epoch==args.epochs-1:
+                for checkpoint_path in checkpoint_paths:
+                    misc.save_on_master({
+                        'model': model_without_ddp.state_dict(),
+                        'optimizer': optimizer.state_dict(),
+                        'lr_scheduler': lr_scheduler.state_dict(),
+                        'epoch': epoch,
+                        'scaler': loss_scaler.state_dict(),
+                        'args': args,
+                    }, output_dir / checkpoint_path)
+                
+            if wandb.run: wandb.log(log_stats)
+            with (output_dir / "log.txt").open("a") as f:
+                f.write(json.dumps(log_stats) + "\n")
+            
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
+    args = aug_parse(parser)
+    main(args)