diff --git a/bin/viz_ffcv.py b/bin/viz_ffcv.py
index 314ead8d69907f3de7e3ed604cca8c31c587cd96..d979f31177dd66dddeb6353e0200fe242c6e1739 100644
--- a/bin/viz_ffcv.py
+++ b/bin/viz_ffcv.py
@@ -6,6 +6,9 @@ from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,Ran
 from ffcv.fields.decoders import CenterCropRGBImageDecoder, IntDecoder
 from torchvision.utils import save_image
 import numpy as np
+import time
+from tqdm import tqdm
+
 
 IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
 IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
@@ -18,7 +21,7 @@ parser.add_argument('--img_size', type=int, default=224)
 
 args = parser.parse_args()
 image_pipeline = [
-        CenterCropRGBImageDecoder((args.img_size, args.img_size), 1),
+        CenterCropRGBImageDecoder((args.img_size, args.img_size), 224/256),
         ToTensor(),       
         ToTorchImage(),
         ]
@@ -29,8 +32,14 @@ pipelines = {
     'label': label_pipeline,
 } 
 loader = Loader(args.data_path, pipelines=pipelines,
-                        batch_size=args.num_samples, num_workers=10, 
+                batch_size=args.num_samples, num_workers=10, 
                         )
+
+start = time.time()
+for batch in tqdm(loader):
+    pass
+print('Time taken', time.time()-start, 's to load', len(loader.indices), 'samples')
+
 nrow=int(np.sqrt(args.num_samples))
 for batch in loader:
     x, y  = batch
diff --git a/vitookit/datasets/transform.py b/vitookit/datasets/transform.py
index 2532c1d8531e9fb2aa87b8998b7119a307c7ad54..33750245324bf46bd19c875bf7db3ec5e7f9800a 100644
--- a/vitookit/datasets/transform.py
+++ b/vitookit/datasets/transform.py
@@ -31,7 +31,7 @@ def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0)):
     return pipelines    
 
 @gin.configurable
-def ValPipeline(img_size=224,ratio= 4.0/3.0):
+def ValPipeline(img_size=224,ratio= 224/256):
     image_pipeline = [
             CenterCropRGBImageDecoder((img_size, img_size), ratio),
             NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16),
diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py
index a3775db7029aca7c683b3f43972fe3ccddac1da6..e6d5b97f2c11f133d1dc5a824cefba13198115c1 100644
--- a/vitookit/evaluation/eval_cls.py
+++ b/vitookit/evaluation/eval_cls.py
@@ -54,7 +54,7 @@ def get_args_parser():
     parser.add_argument('--accum_iter', default=1, type=int,
                         help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
     parser.add_argument('--epochs', default=100, type=int)
-    parser.add_argument('--ckpt_freq', default=10, type=int)
+    parser.add_argument('--ckpt_freq', default=5, type=int)
 
     # Model parameters
     parser.add_argument("--compile", action='store_true', default=False, help="compile model with PyTorch 2.0")
@@ -337,12 +337,12 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
 
         with torch.cuda.amp.autocast():
             outputs = model(samples)
-            loss = criterion(samples, outputs, targets)
+            loss = criterion(outputs, targets)
         
         loss /= accum_iter
         loss_scaler(loss, optimizer, clip_grad=max_norm,
                     parameters=model.parameters(), create_graph=False,
-                    update_grad=(itr + 1) % accum_iter == 0)
+                    need_update=(itr + 1) % accum_iter == 0)
         if (itr + 1) % accum_iter == 0:
             optimizer.zero_grad()
             
@@ -529,31 +529,6 @@ def main(args):
         criterion = torch.nn.CrossEntropyLoss()
 
     print("criterion = %s" % str(criterion))
-
-    teacher_model = None
-    # if args.distillation_type != 'none':
-    #     assert args.teacher_path, 'need to specify teacher-path when using distillation'
-    #     print(f"Creating teacher model: {args.teacher_model}")
-    #     teacher_model = create_model(
-    #         args.teacher_model,
-    #         pretrained=False,
-    #         num_classes=args.nb_classes,
-    #         head_type=args.head_type,
-    #     )
-    #     if args.teacher_path.startswith('https'):
-    #         checkpoint = torch.hub.load_state_dict_from_url(
-    #             args.teacher_path, map_location='cpu', check_hash=True)
-    #     else:
-    #         checkpoint = torch.load(args.teacher_path, map_location='cpu')
-    #     teacher_model.load_state_dict(checkpoint['model'])
-    #     teacher_model.to(device)
-    #     teacher_model.eval()
-
-    # wrap the criterion in our custom DistillationLoss, which
-    # just dispatches to the original criterion if args.distillation_type is 'none'
-    # criterion = DistillationLoss(
-    #     criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
-    # )
     
 
     output_dir = Path(args.output_dir) if args.output_dir else None
@@ -563,9 +538,8 @@ def main(args):
                                 optimizer=optimizer,
                                 model=model_without_ddp,
                                 scaler=loss_scaler,
-                                model_ema=model_ema,
                                 run_variables=run_variables)
-        args = run_variables['args']
+        # args = run_variables['args']
         args.start_epoch = run_variables["epoch"] + 1
 
     if args.eval:
@@ -627,7 +601,6 @@ def main(args):
                         'optimizer': optimizer.state_dict(),
                         'lr_scheduler': lr_scheduler.state_dict(),
                         'epoch': epoch,
-                        'model_ema': get_state_dict(model_ema),
                         'scaler': loss_scaler.state_dict(),
                         'args': args,
                     }, checkpoint_path)
diff --git a/vitookit/evaluation/eval_distill.py b/vitookit/evaluation/eval_distill.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4f163caf4567d9bc0981ee3ea7b5710eb2fc6b5
--- /dev/null
+++ b/vitookit/evaluation/eval_distill.py
@@ -0,0 +1,651 @@
+#!/usr/bin/env python
+# Copyright (c) ByteDance, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Mostly copy-paste from DEiT library:
+https://github.com/facebookresearch/deit/blob/main/main.py
+"""
+from PIL import Image # hack to avoid `CXXABI_1.3.9' not found error
+
+import argparse
+import datetime
+import numpy as np
+import time
+import torch
+import torch.nn as nn
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import json
+import os
+import math
+import sys
+import copy
+import scipy.io as scio
+from vitookit.datasets.transform import three_augmentation
+from vitookit.utils.helper import *
+from vitookit.utils import misc
+from vitookit.models.build_model import build_model
+from vitookit.datasets import build_dataset
+import wandb
+
+
+from pathlib import Path
+from typing import Iterable, Optional
+from torch.nn import functional as F
+
+from timm.models import create_model
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
+from timm.scheduler import create_scheduler
+from timm.optim import create_optimizer
+from timm.utils import NativeScaler, get_state_dict, ModelEma, accuracy
+from timm.data import Mixup, create_transform
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import trunc_normal_
+
+
+def get_args_parser():
+    parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
+    parser.add_argument('--batch_size', default=128, type=int,
+                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
+    parser.add_argument('--accum_iter', default=1, type=int,
+                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
+    parser.add_argument('--epochs', default=100, type=int)
+    parser.add_argument('--ckpt_freq', default=5, type=int)
+
+    # Model parameters
+    parser.add_argument("--compile", action='store_true', default=False, help="compile model with PyTorch 2.0")
+    parser.add_argument("--prefix", default=None, type=str, help="prefix of the model name")
+    
+    parser.add_argument('--input_size', default=224, type=int, help='images input size')
+    parser.add_argument('-w', '--pretrained_weights', default='', type=str, help="""Path to pretrained 
+        weights to evaluate. Set to `download` to automatically load the pretrained DINO from url.
+        Otherwise the model is randomly initialized""")
+    parser.add_argument("--checkpoint_key", default=None, type=str, help='Key to use in the checkpoint (example: "teacher")')
+    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
+                        help='Dropout rate (default: 0.)')
+    parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
+                        help='Attention dropout rate (default: 0.)')
+    parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
+                        help='Drop path rate (default: 0.1)')
+
+    parser.add_argument('--model_ema', action='store_true')
+    parser.add_argument('--no_model_ema', action='store_false', dest='model_ema')
+    parser.set_defaults(model_ema=False)
+    parser.add_argument('--model_ema_decay', type=float, default=0.99996, help='')
+    parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')
+
+    # Optimizer parameters
+    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
+                        help='Optimizer (default: "adamw"')
+    parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
+                        help='Optimizer Epsilon (default: 1e-8)')
+    parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
+                        help='Optimizer Betas (default: None, use opt default)')
+    parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
+                        help='Clip gradient norm (default: None, no clipping)')
+    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
+                        help='SGD momentum (default: 0.9)')
+    parser.add_argument('--weight_decay', type=float, default=0.05,
+                        help='weight decay (default: 0.05)')
+    parser.add_argument('--layer_decay', type=float, default=0.75)
+    # Learning rate schedule parameters
+    parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
+                        help='LR scheduler (default: "cosine"')
+    parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
+                        help='learning rate (default: 5e-4)')
+    parser.add_argument('--lr_noise', type=float, nargs='+', default=None, metavar='pct, pct',
+                        help='learning rate noise on/off epoch percentages')
+    parser.add_argument('--lr_noise_pct', type=float, default=0.67, metavar='PERCENT',
+                        help='learning rate noise limit percent (default: 0.67)')
+    parser.add_argument('--lr_noise_std', type=float, default=1.0, metavar='STDDEV',
+                        help='learning rate noise std-dev (default: 1.0)')
+    parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
+                        help='warmup learning rate (default: 1e-6)')
+    parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',
+                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
+
+    parser.add_argument('--decay_epochs', type=float, default=30, metavar='N',
+                        help='epoch interval to decay LR')
+    parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
+                        help='epochs to warmup LR, if scheduler supports')
+    parser.add_argument('--cooldown_epochs', type=int, default=10, metavar='N',
+                        help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
+    parser.add_argument('--patience_epochs', type=int, default=10, metavar='N',
+                        help='patience epochs for Plateau LR scheduler (default: 10')
+    parser.add_argument('--decay_rate', '--dr', type=float, default=0.1, metavar='RATE',
+                        help='LR decay rate (default: 0.1)')
+
+    # Augmentation parameters
+    parser.add_argument('--ThreeAugment', action='store_true', default=False) #3augment
+    parser.add_argument('--src',action='store_true', default=False, 
+                        help="Use Simple Random Crop (SRC) or Random Resized Crop (RRC). Use SRC when there is less risk of overfitting, such as on ImageNet-21k.")
+    parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
+                        help='Color jitter factor (default: 0.4)')
+    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
+                        help='Use AutoAugment policy. "v0" or "original". " + \
+                             "(default: rand-m9-mstd0.5-inc1)'),
+    parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
+    parser.add_argument('--train_interpolation', type=str, default='bicubic',
+                        help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
+
+    parser.add_argument('--repeated_aug', action='store_true')
+    parser.add_argument('--no_repeated_aug', action='store_false', dest='repeated_aug')
+    parser.set_defaults(repeated_aug=True)
+
+    # * Random Erase params
+    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
+                        help='Random erase prob (default: 0.25)')
+    parser.add_argument('--remode', type=str, default='pixel',
+                        help='Random erase mode (default: "pixel")')
+    parser.add_argument('--recount', type=int, default=1,
+                        help='Random erase count (default: 1)')
+    parser.add_argument('--resplit', action='store_true', default=False,
+                        help='Do not random erase first (clean) augmentation split')
+
+    # * Mixup params
+    parser.add_argument('--mixup', type=float, default=0.8,
+                        help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
+    parser.add_argument('--cutmix', type=float, default=1.0,
+                        help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
+    parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
+                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
+    parser.add_argument('--mixup_prob', type=float, default=1.0,
+                        help='Probability of performing mixup or cutmix when either/both is enabled')
+    parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
+                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
+    parser.add_argument('--mixup_mode', type=str, default='batch',
+                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
+
+    # Distillation parameters
+    parser.add_argument('--teacher_model', default='regnety_160', type=str, metavar='MODEL',
+                        help='Name of teacher model to train (default: "regnety_160"')
+    parser.add_argument('--teacher_path', type=str, default='')
+    parser.add_argument('--distillation_type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
+    parser.add_argument('--distillation_alpha', default=0.5, type=float, help="")
+    parser.add_argument('--distillation_tau', default=1.0, type=float, help="")
+
+    # * Finetuning params
+    parser.add_argument('--disable_weight_decay_on_bias_norm', action='store_true', default=False)
+    parser.add_argument('--init_scale', default=1.0, type=float)
+
+    # Dataset parameters
+    parser.add_argument('--data_location', default='/datasets01/imagenet_full_size/061417/', type=str,
+                        help='dataset path')
+    parser.add_argument('--data_set', default='IN1K', 
+                            type=str, help='ImageNet dataset path')
+    parser.add_argument('--inat_category', default='name',
+                        choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
+                        type=str, help='semantic granularity')
+
+    parser.add_argument('--output_dir', default=None, type=str,
+                        help='path where to save, empty for no saving')
+    parser.add_argument('--device', default='cuda',
+                        help='device to use for training / testing')
+    parser.add_argument('--seed', default=0, type=int)
+    parser.add_argument('--resume', default='', help='resume from checkpoint')
+    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
+                        help='start epoch')
+    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
+    parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation')
+    parser.add_argument('--num_workers', default=10, type=int)
+    parser.add_argument('--pin_mem', action='store_true',
+                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
+    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
+                        help='')
+    parser.set_defaults(pin_mem=True)
+
+    # distributed training parameters
+    parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
+    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+
+    return parser
+
+
+
+class RASampler(torch.utils.data.Sampler):
+    """Sampler that restricts data loading to a subset of the dataset for distributed,
+    with repeated augmentation.
+    It ensures that different each augmented version of a sample will be visible to a
+    different process (GPU)
+    Heavily based on torch.utils.data.DistributedSampler
+    """
+
+    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
+        if num_replicas is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            num_replicas = dist.get_world_size()
+        if rank is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            rank = dist.get_rank()
+        self.dataset = dataset
+        self.num_replicas = num_replicas
+        self.rank = rank
+        self.epoch = 0
+        self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
+        self.total_size = self.num_samples * self.num_replicas
+        # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
+        self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
+        self.shuffle = shuffle
+
+    def __iter__(self):
+        # deterministically shuffle based on epoch
+        g = torch.Generator()
+        g.manual_seed(self.epoch)
+        if self.shuffle:
+            indices = torch.randperm(len(self.dataset), generator=g).tolist()
+        else:
+            indices = list(range(len(self.dataset)))
+
+        # add extra samples to make it evenly divisible
+        indices = [ele for ele in indices for i in range(3)]
+        indices += indices[:(self.total_size - len(indices))]
+        assert len(indices) == self.total_size
+
+        # subsample
+        indices = indices[self.rank:self.total_size:self.num_replicas]
+        assert len(indices) == self.num_samples
+
+        return iter(indices[:self.num_selected_samples])
+
+    def __len__(self):
+        return self.num_selected_samples
+
+    def set_epoch(self, epoch):
+        self.epoch = epoch
+
+class DistillationLoss(torch.nn.Module):
+    """
+    This module wraps a standard criterion and adds an extra knowledge distillation loss by
+    taking a teacher model prediction and using it as additional supervision.
+    """
+    def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
+                 distillation_type: str, alpha: float, tau: float):
+        super().__init__()
+        self.base_criterion = base_criterion
+        self.teacher_model = teacher_model
+        assert distillation_type in ['none', 'soft', 'hard']
+        self.distillation_type = distillation_type
+        self.alpha = alpha
+        self.tau = tau
+
+    def forward(self, inputs, outputs, labels):
+        """
+        Args:
+            inputs: The original inputs that are feed to the teacher model
+            outputs: the outputs of the model to be trained. It is expected to be
+                either a Tensor, or a Tuple[Tensor, Tensor], with the original output
+                in the first position and the distillation predictions as the second output
+            labels: the labels for the base criterion
+        """
+        outputs_kd = None
+        if not isinstance(outputs, torch.Tensor):
+            # assume that the model outputs a tuple of [outputs, outputs_kd]
+            outputs, outputs_kd = outputs
+        base_loss = self.base_criterion(outputs, labels)
+        if self.distillation_type == 'none':
+            return base_loss
+
+        if outputs_kd is None:
+            raise ValueError("When knowledge distillation is enabled, the model is "
+                             "expected to return a Tuple[Tensor, Tensor] with the output of the "
+                             "class_token and the dist_token")
+        # don't backprop throught the teacher
+        with torch.no_grad():
+            teacher_outputs = self.teacher_model(inputs)
+
+        if self.distillation_type == 'soft':
+            T = self.tau
+            # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
+            # with slight modifications
+            distillation_loss = F.kl_div(
+                F.log_softmax(outputs_kd / T, dim=1),
+                F.log_softmax(teacher_outputs / T, dim=1),
+                reduction='sum',
+                log_target=True
+            ) * (T * T) / outputs_kd.numel()
+        elif self.distillation_type == 'hard':
+            distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
+
+        loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
+        return loss
+
+def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
+                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
+                    device: torch.device, epoch: int, loss_scaler,lr_scheduler, max_norm: float = 0,
+                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
+                    ):
+    model.train(True)
+    metric_logger = misc.MetricLogger(delimiter="  ")
+    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+    header = 'Epoch: [{}]'.format(epoch)
+    print_freq = max(len(data_loader)//20,20)
+    
+    accum_iter = args.accum_iter
+    for itr,(samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+        samples = samples.to(device, non_blocking=True)
+        targets = targets.to(device, non_blocking=True)
+
+        lr_scheduler.step(epoch+itr/len(data_loader))
+        if mixup_fn is not None:
+            samples, targets = mixup_fn(samples, targets)
+
+        with torch.cuda.amp.autocast():
+            outputs = model(samples)
+            loss = criterion(samples, outputs, targets)
+        
+        loss /= accum_iter
+        loss_scaler(loss, optimizer, clip_grad=max_norm,
+                    parameters=model.parameters(), create_graph=False,
+                    need_update=(itr + 1) % accum_iter == 0)
+        if (itr + 1) % accum_iter == 0:
+            optimizer.zero_grad()
+            
+        torch.cuda.synchronize()
+        # log metrics
+        loss_value = loss.item()
+
+        if not math.isfinite(loss_value):
+            print("Loss is {}, stopping training".format(loss_value))
+            sys.exit(1)
+        # this attribute is added by timm on one optimizer (adahessian)
+        # is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
+
+        # if model_ema is not None:
+        #     model_ema.update(model)
+            
+        if wandb.run: 
+            wandb.log({'train/loss':loss})
+        metric_logger.update(loss=loss_value)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+    # gather the stats from all processes
+    metric_logger.synchronize_between_processes()
+    print("Averaged stats:", metric_logger)
+    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluate(data_loader, model, device):
+    criterion = torch.nn.CrossEntropyLoss()
+
+    metric_logger = misc.MetricLogger(delimiter="  ")
+    header = 'Test:'
+
+    # switch to evaluation mode
+    model.eval()
+
+    for images, target in metric_logger.log_every(data_loader, 10, header):
+        images = images.to(device, non_blocking=True)
+        target = target.to(device, non_blocking=True, dtype=torch.long)
+
+        # compute output
+        with torch.cuda.amp.autocast():
+            output = model(images)
+            loss = criterion(output, target)
+
+        acc1, acc5 = accuracy(output, target, topk=(1, 5))
+
+        batch_size = images.shape[0]
+        metric_logger.update(loss=loss.item())
+        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
+        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
+    # gather the stats from all processes
+    metric_logger.synchronize_between_processes()
+    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
+          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
+
+    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+def main(args):
+    misc.init_distributed_mode(args)
+
+    print(args)
+
+    if args.distillation_type != 'none' and args.pretrained_weights and not args.eval:
+        raise NotImplementedError("Finetuning with distillation not yet supported")
+
+    device = torch.device(args.device)
+
+    # fix the seed for reproducibility
+    misc.fix_random_seeds(args.seed)
+
+    cudnn.benchmark = True
+    
+    if args.ThreeAugment:
+        transform = three_augmentation(args)
+    else:
+        transform = None
+    dataset_train, args.nb_classes = build_dataset(args=args, is_train=True, trnsfrm=transform)
+    dataset_val, _ = build_dataset(is_train=False, args=args)
+    
+        
+    print("Load dataset:", dataset_train)
+
+    if True:  # args.distributed:
+        num_tasks = misc.get_world_size()
+        global_rank = misc.get_rank()
+        if args.repeated_aug:
+            sampler_train = RASampler(
+                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+            )
+        else:
+            sampler_train = torch.utils.data.DistributedSampler(
+                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+            )
+        if args.dist_eval:
+            if len(dataset_val) % num_tasks != 0:
+                print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
+                      'This will slightly alter validation results as extra duplicate entries are added to achieve '
+                      'equal num of samples per-process.')
+            sampler_val = torch.utils.data.DistributedSampler(
+                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
+        else:
+            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+    else:
+        sampler_train = torch.utils.data.RandomSampler(dataset_train)
+        sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+
+    data_loader_train = torch.utils.data.DataLoader(
+        dataset_train, sampler=sampler_train,
+        batch_size=args.batch_size,
+        num_workers=args.num_workers,
+        pin_memory=args.pin_mem,
+        drop_last=True,
+    )
+
+    data_loader_val = torch.utils.data.DataLoader(
+        dataset_val, sampler=sampler_val,
+        batch_size=int(1.5 * args.batch_size),
+        num_workers=args.num_workers,
+        pin_memory=args.pin_mem,
+        drop_last=False
+    )
+
+    mixup_fn = None
+    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
+    if mixup_active:
+        print("Mixup is activated!")
+        mixup_fn = Mixup(
+            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
+            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
+            label_smoothing=args.smoothing, num_classes=args.nb_classes)
+
+    
+    print(f"Model  built.")
+    # load weights to evaluate
+    
+    model = build_model(num_classes=args.nb_classes,drop_path_rate=args.drop_path,)
+    if args.pretrained_weights:
+        load_pretrained_weights(model, args.pretrained_weights, checkpoint_key=args.checkpoint_key, prefix=args.prefix)
+    if args.compile:
+        return torch.compile(model)    
+    trunc_normal_(model.head.weight, std=2e-5)
+    
+    model.to(device)
+
+    model_ema = None
+    if args.model_ema:
+        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
+        model_ema = ModelEma(
+            model,
+            decay=args.model_ema_decay,
+            device='cpu' if args.model_ema_force_cpu else '',
+            resume='')
+
+    model_without_ddp = model
+    if True:
+        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+        model_without_ddp = model.module
+    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    print('number of params:', n_parameters)
+    
+    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
+
+    linear_scaled_lr = args.lr * eff_batch_size / 256.0
+    
+    print("base lr: %.2e" % args.lr )
+    print("actual lr: %.2e" % linear_scaled_lr)
+    args.lr = linear_scaled_lr
+    
+    print("accumulate grad iterations: %d" % args.accum_iter)
+    print("effective batch size: %d" % eff_batch_size)
+    
+    optimizer = create_optimizer(args, model_without_ddp)
+    loss_scaler = NativeScaler()
+
+    lr_scheduler, _ = create_scheduler(args, optimizer)
+
+    criterion = LabelSmoothingCrossEntropy()
+
+    if mixup_active:
+        # smoothing is handled with mixup label transform
+        criterion = SoftTargetCrossEntropy()
+    elif args.smoothing:
+        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
+    else:
+        criterion = torch.nn.CrossEntropyLoss()
+        
+    if args.bce_loss:
+        criterion = torch.nn.BCEWithLogitsLoss()
+    print("criterion = %s" % str(criterion))
+
+    teacher_model = None
+    if args.distillation_type != 'none':
+        assert args.teacher_path, 'need to specify teacher-path when using distillation'
+        print(f"Creating teacher model: {args.teacher_model}")
+        teacher_model = create_model(
+            args.teacher_model,
+            pretrained=False,
+            num_classes=args.nb_classes,
+            head_type=args.head_type,
+        )
+        if args.teacher_path.startswith('https'):
+            checkpoint = torch.hub.load_state_dict_from_url(
+                args.teacher_path, map_location='cpu', check_hash=True)
+        else:
+            checkpoint = torch.load(args.teacher_path, map_location='cpu')
+        teacher_model.load_state_dict(checkpoint['model'])
+        teacher_model.to(device)
+        teacher_model.eval()
+
+    # wrap the criterion in our custom DistillationLoss, which
+    # just dispatches to the original criterion if args.distillation_type is 'none'
+    criterion = DistillationLoss(
+        criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
+    )
+    
+    output_dir = Path(args.output_dir) if args.output_dir else None
+    if args.resume:        
+        run_variables={"args":dict(),"epoch":0}
+        restart_from_checkpoint(args.resume,
+                                optimizer=optimizer,
+                                model=model_without_ddp,
+                                scaler=loss_scaler,
+                                model_ema=model_ema,
+                                run_variables=run_variables)
+        # args = run_variables['args']
+        args.start_epoch = run_variables["epoch"] + 1
+
+    if args.eval:
+        test_stats = evaluate(data_loader_val, model_without_ddp, device)
+        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
+        if args.output_dir and misc.is_main_process():
+            with (output_dir / "log.txt").open("a") as f:
+                f.write(json.dumps(test_stats) + "\n")
+        return
+
+    print(f"Start training for {args.epochs} epochs from {args.start_epoch}")
+    start_time = time.time()
+    max_accuracy = 0.0
+    if args.output_dir and misc.is_main_process():
+        try:
+            wandb.init(job_type='distill',dir=args.output_dir,resume=True, 
+                   config=args.__dict__)
+        except:
+            pass
+        
+    for epoch in range(args.start_epoch, args.epochs):
+        data_loader_train.sampler.set_epoch(epoch)
+
+        train_stats = train_one_epoch(
+            model, criterion, data_loader_train,
+            optimizer, device, epoch, loss_scaler,lr_scheduler,
+            args.clip_grad, model_ema, mixup_fn,
+        )
+        
+        checkpoint_paths = [output_dir / 'checkpoint.pth']
+        
+        if epoch%10==0 or epoch==args.epochs-1:
+            test_stats = evaluate(data_loader_val, model, device)
+            print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
+            
+            if (test_stats["acc1"] >= max_accuracy):
+                # always only save best checkpoint till now
+                checkpoint_paths += [output_dir / 'checkpoint_best.pth']
+                
+        
+            max_accuracy = max(max_accuracy, test_stats["acc1"])
+            print(f'Max accuracy: {max_accuracy:.2f}%')
+
+            log_stats = {**{f'train/{k}': v for k, v in train_stats.items()},
+                        **{f'test/{k}': v for k, v in test_stats.items()},
+                        'epoch': epoch,
+                        'n_parameters': n_parameters}
+        else:
+            log_stats = {**{f'train/{k}': v for k, v in train_stats.items()},
+                        'epoch': epoch,
+                        'n_parameters': n_parameters}
+            
+        # only save checkpoint on rank 0
+        if args.output_dir and misc.is_main_process():
+            if epoch%args.ckpt_freq==0 or epoch==args.epochs-1:
+                for checkpoint_path in checkpoint_paths:
+                    misc.save_on_master({
+                        'model': model_without_ddp.state_dict(),
+                        'optimizer': optimizer.state_dict(),
+                        'lr_scheduler': lr_scheduler.state_dict(),
+                        'epoch': epoch,
+                        'model_ema': get_state_dict(model_ema),
+                        'scaler': loss_scaler.state_dict(),
+                        'args': args,
+                    }, checkpoint_path)
+                
+            if wandb.run: wandb.log(log_stats)
+            with (output_dir / "log.txt").open("a") as f:
+                f.write(json.dumps(log_stats) + "\n")
+            
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
+    args = aug_parse(parser)
+    main(args)
diff --git a/vitookit/evaluation/eval_linear.py b/vitookit/evaluation/eval_linear.py
index 598f8dd8b263464526143928216870bddb096ef8..8a2ec0eea121f5ce162e83c417b0aa80b563d6e1 100644
--- a/vitookit/evaluation/eval_linear.py
+++ b/vitookit/evaluation/eval_linear.py
@@ -41,7 +41,7 @@ def get_args_parser():
     parser.add_argument('--batch_size', default=512, type=int,
                         help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
     parser.add_argument('--epochs', default=90, type=int)
-    parser.add_argument('--ckpt_freq', default=10, type=int)
+    parser.add_argument('--ckpt_freq', default=5, type=int)
     parser.add_argument('--accum_iter', default=1, type=int,
                         help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
 
@@ -251,7 +251,7 @@ def main(args):
                                 model=model_without_ddp,
                                 scaler=loss_scaler,
                                 run_variables=run_variables)
-        args = run_variables['args']
+        # args = run_variables['args']
         args.start_epoch = run_variables["epoch"] + 1
         print("Resuming from epoch %d" % args.start_epoch)
 
diff --git a/vitookit/evaluation/eval_linear_ffcv.py b/vitookit/evaluation/eval_linear_ffcv.py
index 3e50d031f57b601e7a08ade85ec4b49c888e41b9..5368cc36fa1a4b30cbdeb968844c936fbe2c036b 100644
--- a/vitookit/evaluation/eval_linear_ffcv.py
+++ b/vitookit/evaluation/eval_linear_ffcv.py
@@ -43,7 +43,7 @@ def get_args_parser():
     parser.add_argument('--batch_size', default=512, type=int,
                         help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
     parser.add_argument('--epochs', default=90, type=int)
-    parser.add_argument('--ckpt_freq', default=10, type=int)
+    parser.add_argument('--ckpt_freq', default=5, type=int)
     parser.add_argument('--accum_iter', default=1, type=int,
                         help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
 
@@ -136,15 +136,8 @@ def main(args):
     np.random.seed(seed)
 
     cudnn.benchmark = True
-    order = OrderOption.RANDOM if args.distributed else OrderOption.QUASI_RANDOM
-        
     
-
-    data_loader_train =  Loader(args.train_path, pipelines=SimplePipeline(),
-                        batch_size=args.batch_size, num_workers=args.num_workers, 
-                        order=order, distributed=args.distributed,seed=args.seed)
-    
-    data_loader_val =  Loader(args.val_path, pipelines=ValPipeline(ratio=1),
+    data_loader_val =  Loader(args.val_path, pipelines=ValPipeline(),
                         batch_size=args.batch_size, num_workers=args.num_workers, 
                         distributed=args.distributed,seed=args.seed)
     
@@ -206,6 +199,7 @@ def main(args):
 
     print("criterion = %s" % str(criterion))
     
+    
     if args.resume:        
         run_variables={"args":dict(),"epoch":0}
         restart_from_checkpoint(args.resume,
@@ -213,9 +207,10 @@ def main(args):
                                 model=model_without_ddp,
                                 scaler=loss_scaler,
                                 run_variables=run_variables)
-        args = run_variables['args']
+        # args = run_variables['args']
         args.start_epoch = run_variables["epoch"] + 1
-
+        print("resume from epoch %d" % args.start_epoch)
+        
     if args.eval:
         test_stats = evaluate(data_loader_val, model, device)
         print(f"Accuracy of the network on the test images: {test_stats['acc1']:.1f}%")
@@ -226,7 +221,13 @@ def main(args):
     max_accuracy = 0.0
     
     output_dir = Path(args.output_dir) if args.output_dir else None
-        
+    
+    
+    order = OrderOption.RANDOM if args.distributed else OrderOption.QUASI_RANDOM
+    data_loader_train =  Loader(args.train_path, pipelines=SimplePipeline(),
+                        batch_size=args.batch_size, num_workers=args.num_workers, 
+                        order=order, distributed=args.distributed,seed=args.seed)
+    
     for epoch in range(args.start_epoch, args.epochs):
         
         train_stats = train_one_epoch(
@@ -238,19 +239,19 @@ def main(args):
                         'epoch': epoch,
                         'n_parameters': n_parameters}
         
-        ckpt_path = output_dir / 'checkpoint.pth'
+        ckpt_path =  'checkpoint.pth'
         if epoch%args.ckpt_freq==0 or epoch == args.epochs-1:
             test_stats = evaluate(data_loader_val, model, device)
             print(f"Accuracy of the network on the test images: {test_stats['acc1']:.1f}%")
             print(f'Max accuracy: {max_accuracy:.2f}%')
             log_stats.update({f'test/{k}': v for k, v in test_stats.items()})
             if max_accuracy < test_stats["acc1"]:
-                ckpt_path = output_dir / 'best_checkpoint.pth'
+                ckpt_path =  'best_checkpoint.pth'
             max_accuracy = max(max_accuracy, test_stats["acc1"])
         else:
             test_stats={}
 
-        if args.output_dir and misc.is_main_process():
+        if output_dir and misc.is_main_process():
             if epoch % args.ckpt_freq == 0 or epoch == args.epochs-1:
                 misc.save_on_master({
                         'model': model_without_ddp.state_dict(),
@@ -258,7 +259,7 @@ def main(args):
                         'epoch': epoch,
                         'scaler': loss_scaler.state_dict(),
                         'args': args,
-                    }, ckpt_path) 
+                    }, output_dir / ckpt_path) 
             
             if wandb.run:
                 wandb.log(log_stats)