diff --git a/vitookit/evaluation/eval_cls1_ffcv.py b/vitookit/evaluation/eval_cls1_ffcv.py
new file mode 100644
index 0000000000000000000000000000000000000000..310a3c73f7aa870207de813854e49b324b072381
--- /dev/null
+++ b/vitookit/evaluation/eval_cls1_ffcv.py
@@ -0,0 +1,426 @@
+#!/usr/bin/env python
+
+
+"""
+Example:
+vitrun  --nproc_per_node=3 eval_cls_ffcv.py --train_path $train_path --val_path $val_path  --gin VisionTransformer.global_pool='\"avg\"'  -w wandb:dlib/EfficientSSL/xsa4wubh  --batch_size 360 --output_dir outputs/cls
+
+"""
+from PIL import Image # hack to avoid `CXXABI_1.3.9' not found error
+
+import argparse
+import datetime
+import numpy as np
+import time
+import torch
+import torch.nn as nn
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import json
+import os
+import math
+import sys
+import copy
+import scipy.io as scio
+from vitookit.datasets.ffcv_transform import *
+from vitookit.utils.helper import *
+from vitookit.utils import misc
+from vitookit.models.build_model import build_model
+import wandb
+
+from pathlib import Path
+from typing import Iterable, Optional
+from torch.nn import functional as F
+
+
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
+from timm.scheduler import create_scheduler
+from timm.optim import create_optimizer
+from timm.utils import NativeScaler, get_state_dict, ModelEma, accuracy
+from timm.data import Mixup
+from timm.layers import trunc_normal_
+
+from ffcv import Loader
+from ffcv.loader import OrderOption
+
+def get_args_parser():
+    parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
+    parser.add_argument('--batch_size', default=128, type=int,
+                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
+    parser.add_argument('--accum_iter', default=1, type=int,
+                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
+    parser.add_argument('--epochs', default=100, type=int)
+    parser.add_argument('--ckpt_freq', default=5, type=int)
+
+    # Model parameters
+    parser.add_argument("--compile", action='store_true', default=False, help="compile model with PyTorch 2.0")
+    parser.add_argument("--prefix", default=None, type=str, help="prefix of the model name")
+    
+    parser.add_argument('--input_size', default=224, type=int, help='images input size')
+    parser.add_argument('-w', '--pretrained_weights', default='', type=str, help="""Path to pretrained 
+        weights to evaluate. Set to `download` to automatically load the pretrained DINO from url.
+        Otherwise the model is randomly initialized""")
+    parser.add_argument("--checkpoint_key", default=None, type=str, help='Key to use in the checkpoint (example: "teacher")')
+    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
+                        help='Dropout rate (default: 0.)')
+    parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
+                        help='Attention dropout rate (default: 0.)')
+    parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
+                        help='Drop path rate (default: 0.1)')
+
+
+    # Optimizer parameters
+    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
+                        help='Optimizer (default: "adamw"')
+    parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
+                        help='Optimizer Epsilon (default: 1e-8)')
+    parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
+                        help='Optimizer Betas (default: None, use opt default)')
+    parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
+                        help='Clip gradient norm (default: None, no clipping)')
+    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
+                        help='SGD momentum (default: 0.9)')
+    parser.add_argument('--weight_decay', type=float, default=0.05,
+                        help='weight decay (default: 0.05)')
+    parser.add_argument('--layer_decay', type=float, default=0.75)
+    # Learning rate schedule parameters
+    parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
+                        help='LR scheduler (default: "cosine"')
+    parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
+                        help='learning rate (default: 5e-4)')
+    parser.add_argument('--lr_noise', type=float, nargs='+', default=None, metavar='pct, pct',
+                        help='learning rate noise on/off epoch percentages')
+    
+    parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
+                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
+
+    parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
+                        help='epochs to warmup LR, if scheduler supports')
+    parser.add_argument('--decay_rate', '--dr', type=float, default=0.1, metavar='RATE',
+                        help='LR decay rate (default: 0.1)')
+
+    # Augmentation parameters
+    parser.add_argument('--ThreeAugment', action='store_true', default=True) #3augment
+    parser.add_argument('--src',action='store_true', default=False, 
+                        help="Use Simple Random Crop (SRC) or Random Resized Crop (RRC). Use SRC when there is less risk of overfitting, such as on ImageNet-21k.")
+    parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
+                        help='Color jitter factor (enabled only when not using Auto/RandAug)')
+    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
+                        help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
+    parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
+
+    # * Random Erase params
+    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
+                        help='Random erase prob (default: 0.25)')
+    parser.add_argument('--remode', type=str, default='pixel',
+                        help='Random erase mode (default: "pixel")')
+    parser.add_argument('--recount', type=int, default=1,
+                        help='Random erase count (default: 1)')
+    parser.add_argument('--resplit', action='store_true', default=False,
+                        help='Do not random erase first (clean) augmentation split')
+
+    # * Mixup params
+    parser.add_argument('--mixup', type=float, default=0.8,
+                        help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
+    parser.add_argument('--cutmix', type=float, default=1.0,
+                        help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
+    parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
+                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
+    parser.add_argument('--mixup_prob', type=float, default=1.0,
+                        help='Probability of performing mixup or cutmix when either/both is enabled')
+    parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
+                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
+    parser.add_argument('--mixup_mode', type=str, default='batch',
+                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
+
+    # * Finetuning params
+    parser.add_argument('--disable_weight_decay_on_bias_norm', action='store_true', default=False)
+    parser.add_argument('--init_scale', default=1.0, type=float)
+
+    # Dataset parameters
+    parser.add_argument('--train_path', type=str, required=True, help='path to train dataset')
+    parser.add_argument('--val_path', type=str, required=True, help='path to test dataset')
+    parser.add_argument('--nb_classes', type=int, default=1000, help='number of classes')
+
+    parser.add_argument('--output_dir', default=None, type=str,
+                        help='path where to save, empty for no saving')
+    parser.add_argument('--device', default='cuda',
+                        help='device to use for training / testing')
+    parser.add_argument('--seed', default=0, type=int)
+    parser.add_argument('--resume', default='', help='resume from checkpoint')
+    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
+                        help='start epoch')
+    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
+    parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation')
+    parser.add_argument('--num_workers', default=10, type=int)
+    parser.add_argument('--pin_mem', action='store_true',
+                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
+    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
+                        help='')
+    parser.set_defaults(pin_mem=True)
+
+    # distributed training parameters
+    parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
+    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+
+    return parser
+
+
+def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
+                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
+                    device: torch.device, epoch: int, loss_scaler,lr_scheduler, max_norm: float = 0,
+                     mixup_fn: Optional[Mixup] = None,
+                    ):
+    model.train(True)
+    metric_logger = misc.MetricLogger(delimiter="  ")
+    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+    header = 'Epoch: [{}]'.format(epoch)
+    print_freq = max(len(data_loader)//20,20)
+    
+    accum_iter = args.accum_iter
+    for itr,(samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+        samples = samples.to(device, non_blocking=True)
+        targets = targets.to(device, non_blocking=True)
+
+        lr_scheduler.step(epoch+itr/len(data_loader))
+        if mixup_fn is not None:
+            samples, targets = mixup_fn(samples, targets)
+
+        with torch.cuda.amp.autocast():
+            outputs = model(samples)
+            loss = criterion(outputs, targets)
+        
+        loss /= accum_iter
+        loss_scaler(loss, optimizer, clip_grad=max_norm,
+                    parameters=model.parameters(), create_graph=False,
+                    need_update=(itr + 1) % accum_iter == 0)
+        if (itr + 1) % accum_iter == 0:
+            optimizer.zero_grad()
+            
+        torch.cuda.synchronize()
+        # log metrics
+        loss_value = loss.item()
+
+        if not math.isfinite(loss_value):
+            print("Loss is {}, stopping training".format(loss_value))
+            sys.exit(1)
+        # this attribute is added by timm on one optimizer (adahessian)
+        # is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
+
+        # if model_ema is not None:
+        #     model_ema.update(model)
+            
+        if wandb.run: 
+            wandb.log({'train/loss':loss})
+        metric_logger.update(loss=loss_value)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+    # gather the stats from all processes
+    metric_logger.synchronize_between_processes()
+    print("Averaged stats:", metric_logger)
+    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluate(data_loader, model, device):
+    criterion = torch.nn.CrossEntropyLoss()
+
+    metric_logger = misc.MetricLogger(delimiter="  ")
+    header = 'Test:'
+
+    # switch to evaluation mode
+    model.eval()
+
+    for images, target in metric_logger.log_every(data_loader, 10, header):
+        images = images.to(device, non_blocking=True)
+        target = target.to(device, non_blocking=True, dtype=torch.long)
+
+        # compute output
+        with torch.cuda.amp.autocast():
+            output = model(images)
+            loss = criterion(output, target)
+
+        acc1, acc5 = accuracy(output, target, topk=(1, 5))
+
+        batch_size = images.shape[0]
+        metric_logger.update(loss=loss.item())
+        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
+        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
+    # gather the stats from all processes
+    metric_logger.synchronize_between_processes()
+    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
+          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
+
+    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+def main(args):
+    misc.init_distributed_mode(args)
+
+    print(args)
+    import torch
+    device = torch.device(args.device)
+
+    # fix the seed for reproducibility
+    misc.fix_random_seeds(args.seed)
+
+    cudnn.benchmark = True
+    
+    
+    order = OrderOption.RANDOM if args.distributed else OrderOption.QUASI_RANDOM
+    data_loader_train =  Loader(args.train_path, pipelines=ThreeAugmentPipeline(),
+                        batch_size=args.batch_size, num_workers=args.num_workers, 
+                        order=order, distributed=args.distributed,seed=args.seed)
+    
+
+    data_loader_val =  Loader(args.val_path, pipelines=ValPipeline(),
+                        batch_size=args.batch_size, num_workers=args.num_workers, 
+                        distributed=args.distributed,seed=args.seed)
+
+    mixup_fn = None
+    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
+    if mixup_active:
+        print("Mixup is activated!")
+        mixup_fn = Mixup(
+            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
+            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
+            label_smoothing=args.smoothing, num_classes=args.nb_classes)
+
+    
+    print(f"Model  built.")
+    # load weights to evaluate
+    
+    model = build_model(num_classes=args.nb_classes,drop_path_rate=args.drop_path,)
+    if args.pretrained_weights:
+        load_pretrained_weights(model, args.pretrained_weights, checkpoint_key=args.checkpoint_key, prefix=args.prefix)
+    if args.compile:
+        model = torch.compile(model)    
+        import torch._dynamo
+        torch._dynamo.config.suppress_errors = True 
+    trunc_normal_(model.head.weight, std=2e-5)
+    
+    model.to(device)
+
+    model_without_ddp = model
+    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    print('number of params:', n_parameters)
+    
+    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
+
+    linear_scaled_lr = args.lr * eff_batch_size / 256.0
+    
+    print("base lr: %.2e" % args.lr )
+    print("actual lr: %.2e" % linear_scaled_lr)
+    args.lr = linear_scaled_lr
+    
+    print("accumulate grad iterations: %d" % args.accum_iter)
+    print("effective batch size: %d" % eff_batch_size)
+    
+    if args.distributed:
+        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+        model_without_ddp = model.module
+    optimizer = create_optimizer(args, model_without_ddp)
+    # hack to optimize patch embedding
+    print([ i.shape for i in optimizer.param_groups[1]['params']])
+    optimizer.param_groups[1]['lr_scale'] = 1.0
+    loss_scaler = NativeScaler()
+
+    lr_scheduler, _ = create_scheduler(args, optimizer)
+
+    if mixup_fn is not None:
+        # smoothing is handled with mixup label transform
+        criterion = SoftTargetCrossEntropy()
+    elif args.smoothing > 0.:
+        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
+    else:
+        criterion = torch.nn.CrossEntropyLoss()
+
+    print("criterion = %s" % str(criterion))
+    
+
+    output_dir = Path(args.output_dir) if args.output_dir else None
+    if args.resume:        
+        run_variables={"args":dict(),"epoch":0}
+        restart_from_checkpoint(args.resume,
+                                optimizer=optimizer,
+                                model=model_without_ddp,
+                                scaler=loss_scaler,
+                                run_variables=run_variables)
+        # args = run_variables['args']
+        args.start_epoch = run_variables["epoch"] + 1
+
+    if args.eval:
+        test_stats = evaluate(data_loader_val, model_without_ddp, device)
+        print(f"Accuracy of the network on the test images: {test_stats['acc1']:.1f}%")
+        if args.output_dir and misc.is_main_process():
+            with (output_dir / "log.txt").open("a") as f:
+                f.write(json.dumps(test_stats) + "\n")
+        exit(0)
+
+    print(f"Start training for {args.epochs} epochs from {args.start_epoch}")
+    start_time = time.time()
+    max_accuracy = 0.0
+    if args.output_dir and misc.is_main_process():
+        try:
+            wandb.init(job_type='finetune',dir=args.output_dir,resume=True, 
+                   config=args.__dict__)
+        except:
+            pass
+        
+    for epoch in range(args.start_epoch, args.epochs):
+
+        train_stats = train_one_epoch(
+            model, criterion, data_loader_train,
+            optimizer, device, epoch, loss_scaler,lr_scheduler,
+            args.clip_grad,  mixup_fn,
+        )
+        
+        checkpoint_paths = ['checkpoint.pth']
+        
+        if epoch%args.ckpt_freq==0 or epoch==args.epochs-1:
+            test_stats = evaluate(data_loader_val, model, device)
+            print(f"Accuracy of the network on test images: {test_stats['acc1']:.1f}%")
+            
+            if (test_stats["acc1"] >= max_accuracy):
+                # always only save best checkpoint till now
+                checkpoint_paths += [ 'checkpoint_best.pth']
+                
+        
+            max_accuracy = max(max_accuracy, test_stats["acc1"])
+            print(f'Max accuracy: {max_accuracy:.2f}%')
+
+            log_stats = {**{f'train/{k}': v for k, v in train_stats.items()},
+                        **{f'test/{k}': v for k, v in test_stats.items()},
+                        'epoch': epoch,
+                        'n_parameters': n_parameters}
+        else:
+            log_stats = {**{f'train/{k}': v for k, v in train_stats.items()},
+                        'epoch': epoch,
+                        'n_parameters': n_parameters}
+            
+        # only save checkpoint on rank 0
+        if output_dir and misc.is_main_process():
+            if epoch%args.ckpt_freq==0 or epoch==args.epochs-1:
+                for checkpoint_path in checkpoint_paths:
+                    misc.save_on_master({
+                        'model': model_without_ddp.state_dict(),
+                        'optimizer': optimizer.state_dict(),
+                        'lr_scheduler': lr_scheduler.state_dict(),
+                        'epoch': epoch,
+                        'scaler': loss_scaler.state_dict(),
+                        'args': args,
+                    }, output_dir / checkpoint_path)
+                
+            if wandb.run: wandb.log(log_stats)
+            with (output_dir / "log.txt").open("a") as f:
+                f.write(json.dumps(log_stats) + "\n")
+            
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
+    args = aug_parse(parser)
+    main(args)
diff --git a/vitookit/models/vision_transformer.py b/vitookit/models/vision_transformer.py
index 90c030f2f28fcab70803c8ef375fb073a4c3841f..c412514d54924682abb075f348467cc6fe957b20 100644
--- a/vitookit/models/vision_transformer.py
+++ b/vitookit/models/vision_transformer.py
@@ -15,7 +15,7 @@ class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
         super(VisionTransformer, self).__init__(global_pool=global_pool,dynamic_img_size = dynamic_img_size, **kwargs)
         self.global_pool = global_pool
         print("global_pool",global_pool,self.norm)
- 
+        
 
 def vit_small(**kwargs):
     model = VisionTransformer(