diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py
index e6d5b97f2c11f133d1dc5a824cefba13198115c1..37d5d68d9122742e29ae25dcddf409eeb560bde6 100644
--- a/vitookit/evaluation/eval_cls.py
+++ b/vitookit/evaluation/eval_cls.py
@@ -72,11 +72,6 @@ def get_args_parser():
     parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
                         help='Drop path rate (default: 0.1)')
 
-    parser.add_argument('--model_ema', action='store_true')
-    parser.add_argument('--no_model_ema', action='store_false', dest='model_ema')
-    parser.set_defaults(model_ema=False)
-    parser.add_argument('--model_ema_decay', type=float, default=0.99996, help='')
-    parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')
 
     # Optimizer parameters
     parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
@@ -99,23 +94,12 @@ def get_args_parser():
                         help='learning rate (default: 5e-4)')
     parser.add_argument('--lr_noise', type=float, nargs='+', default=None, metavar='pct, pct',
                         help='learning rate noise on/off epoch percentages')
-    parser.add_argument('--lr_noise_pct', type=float, default=0.67, metavar='PERCENT',
-                        help='learning rate noise limit percent (default: 0.67)')
-    parser.add_argument('--lr_noise_std', type=float, default=1.0, metavar='STDDEV',
-                        help='learning rate noise std-dev (default: 1.0)')
-    parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
-                        help='warmup learning rate (default: 1e-6)')
-    parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',
+    
+    parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
                         help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
 
-    parser.add_argument('--decay_epochs', type=float, default=30, metavar='N',
-                        help='epoch interval to decay LR')
     parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
                         help='epochs to warmup LR, if scheduler supports')
-    parser.add_argument('--cooldown_epochs', type=int, default=10, metavar='N',
-                        help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
-    parser.add_argument('--patience_epochs', type=int, default=10, metavar='N',
-                        help='patience epochs for Plateau LR scheduler (default: 10')
     parser.add_argument('--decay_rate', '--dr', type=float, default=0.1, metavar='RATE',
                         help='LR decay rate (default: 0.1)')
 
@@ -123,18 +107,11 @@ def get_args_parser():
     parser.add_argument('--ThreeAugment', action='store_true', default=False) #3augment
     parser.add_argument('--src',action='store_true', default=False, 
                         help="Use Simple Random Crop (SRC) or Random Resized Crop (RRC). Use SRC when there is less risk of overfitting, such as on ImageNet-21k.")
-    parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
-                        help='Color jitter factor (default: 0.4)')
+    parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
+                        help='Color jitter factor (enabled only when not using Auto/RandAug)')
     parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
-                        help='Use AutoAugment policy. "v0" or "original". " + \
-                             "(default: rand-m9-mstd0.5-inc1)'),
+                        help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
     parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
-    parser.add_argument('--train_interpolation', type=str, default='bicubic',
-                        help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
-
-    parser.add_argument('--repeated_aug', action='store_true')
-    parser.add_argument('--no_repeated_aug', action='store_false', dest='repeated_aug')
-    parser.set_defaults(repeated_aug=True)
 
     # * Random Erase params
     parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
@@ -160,14 +137,6 @@ def get_args_parser():
     parser.add_argument('--mixup_mode', type=str, default='batch',
                         help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
 
-    # Distillation parameters
-    parser.add_argument('--teacher_model', default='regnety_160', type=str, metavar='MODEL',
-                        help='Name of teacher model to train (default: "regnety_160"')
-    parser.add_argument('--teacher_path', type=str, default='')
-    parser.add_argument('--distillation_type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
-    parser.add_argument('--distillation_alpha', default=0.5, type=float, help="")
-    parser.add_argument('--distillation_tau', default=1.0, type=float, help="")
-
     # * Finetuning params
     parser.add_argument('--disable_weight_decay_on_bias_norm', action='store_true', default=False)
     parser.add_argument('--init_scale', default=1.0, type=float)
@@ -205,120 +174,10 @@ def get_args_parser():
     return parser
 
 
-
-class RASampler(torch.utils.data.Sampler):
-    """Sampler that restricts data loading to a subset of the dataset for distributed,
-    with repeated augmentation.
-    It ensures that different each augmented version of a sample will be visible to a
-    different process (GPU)
-    Heavily based on torch.utils.data.DistributedSampler
-    """
-
-    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
-        if num_replicas is None:
-            if not dist.is_available():
-                raise RuntimeError("Requires distributed package to be available")
-            num_replicas = dist.get_world_size()
-        if rank is None:
-            if not dist.is_available():
-                raise RuntimeError("Requires distributed package to be available")
-            rank = dist.get_rank()
-        self.dataset = dataset
-        self.num_replicas = num_replicas
-        self.rank = rank
-        self.epoch = 0
-        self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
-        self.total_size = self.num_samples * self.num_replicas
-        # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
-        self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
-        self.shuffle = shuffle
-
-    def __iter__(self):
-        # deterministically shuffle based on epoch
-        g = torch.Generator()
-        g.manual_seed(self.epoch)
-        if self.shuffle:
-            indices = torch.randperm(len(self.dataset), generator=g).tolist()
-        else:
-            indices = list(range(len(self.dataset)))
-
-        # add extra samples to make it evenly divisible
-        indices = [ele for ele in indices for i in range(3)]
-        indices += indices[:(self.total_size - len(indices))]
-        assert len(indices) == self.total_size
-
-        # subsample
-        indices = indices[self.rank:self.total_size:self.num_replicas]
-        assert len(indices) == self.num_samples
-
-        return iter(indices[:self.num_selected_samples])
-
-    def __len__(self):
-        return self.num_selected_samples
-
-    def set_epoch(self, epoch):
-        self.epoch = epoch
-
-class DistillationLoss(torch.nn.Module):
-    """
-    This module wraps a standard criterion and adds an extra knowledge distillation loss by
-    taking a teacher model prediction and using it as additional supervision.
-    """
-    def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
-                 distillation_type: str, alpha: float, tau: float):
-        super().__init__()
-        self.base_criterion = base_criterion
-        self.teacher_model = teacher_model
-        assert distillation_type in ['none', 'soft', 'hard']
-        self.distillation_type = distillation_type
-        self.alpha = alpha
-        self.tau = tau
-
-    def forward(self, inputs, outputs, labels):
-        """
-        Args:
-            inputs: The original inputs that are feed to the teacher model
-            outputs: the outputs of the model to be trained. It is expected to be
-                either a Tensor, or a Tuple[Tensor, Tensor], with the original output
-                in the first position and the distillation predictions as the second output
-            labels: the labels for the base criterion
-        """
-        outputs_kd = None
-        if not isinstance(outputs, torch.Tensor):
-            # assume that the model outputs a tuple of [outputs, outputs_kd]
-            outputs, outputs_kd = outputs
-        base_loss = self.base_criterion(outputs, labels)
-        if self.distillation_type == 'none':
-            return base_loss
-
-        if outputs_kd is None:
-            raise ValueError("When knowledge distillation is enabled, the model is "
-                             "expected to return a Tuple[Tensor, Tensor] with the output of the "
-                             "class_token and the dist_token")
-        # don't backprop throught the teacher
-        with torch.no_grad():
-            teacher_outputs = self.teacher_model(inputs)
-
-        if self.distillation_type == 'soft':
-            T = self.tau
-            # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
-            # with slight modifications
-            distillation_loss = F.kl_div(
-                F.log_softmax(outputs_kd / T, dim=1),
-                F.log_softmax(teacher_outputs / T, dim=1),
-                reduction='sum',
-                log_target=True
-            ) * (T * T) / outputs_kd.numel()
-        elif self.distillation_type == 'hard':
-            distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
-
-        loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
-        return loss
-
-def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
+def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                     data_loader: Iterable, optimizer: torch.optim.Optimizer,
                     device: torch.device, epoch: int, loss_scaler,lr_scheduler, max_norm: float = 0,
-                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
+                     mixup_fn: Optional[Mixup] = None,
                     ):
     model.train(True)
     metric_logger = misc.MetricLogger(delimiter="  ")
@@ -337,7 +196,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
 
         with torch.cuda.amp.autocast():
             outputs = model(samples)
-            loss = criterion(outputs, targets)
+            loss = criterion(samples, outputs, targets)
         
         loss /= accum_iter
         loss_scaler(loss, optimizer, clip_grad=max_norm,
@@ -401,14 +260,12 @@ def evaluate(data_loader, model, device):
 
     return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
 
+
 def main(args):
     misc.init_distributed_mode(args)
 
     print(args)
 
-    if args.distillation_type != 'none' and args.pretrained_weights and not args.eval:
-        raise NotImplementedError("Finetuning with distillation not yet supported")
-
     device = torch.device(args.device)
 
     # fix the seed for reproducibility
@@ -429,21 +286,16 @@ def main(args):
     if True:  # args.distributed:
         num_tasks = misc.get_world_size()
         global_rank = misc.get_rank()
-        if args.repeated_aug:
-            sampler_train = RASampler(
-                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
-            )
-        else:
-            sampler_train = torch.utils.data.DistributedSampler(
-                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
-            )
+        sampler_train = torch.utils.data.DistributedSampler(
+            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+        )
         if args.dist_eval:
             if len(dataset_val) % num_tasks != 0:
                 print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                       'This will slightly alter validation results as extra duplicate entries are added to achieve '
                       'equal num of samples per-process.')
             sampler_val = torch.utils.data.DistributedSampler(
-                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
+                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True)
         else:
             sampler_val = torch.utils.data.SequentialSampler(dataset_val)
     else:
@@ -488,19 +340,7 @@ def main(args):
     
     model.to(device)
 
-    model_ema = None
-    if args.model_ema:
-        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
-        model_ema = ModelEma(
-            model,
-            decay=args.model_ema_decay,
-            device='cpu' if args.model_ema_force_cpu else '',
-            resume='')
-
     model_without_ddp = model
-    if True:
-        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
-        model_without_ddp = model.module
     n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
     print('number of params:', n_parameters)
     
@@ -515,6 +355,9 @@ def main(args):
     print("accumulate grad iterations: %d" % args.accum_iter)
     print("effective batch size: %d" % eff_batch_size)
     
+    if args.distributed:
+        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+        model_without_ddp = model.module
     optimizer = create_optimizer(args, model_without_ddp)
     loss_scaler = NativeScaler()
 
@@ -548,7 +391,7 @@ def main(args):
         if args.output_dir and misc.is_main_process():
             with (output_dir / "log.txt").open("a") as f:
                 f.write(json.dumps(test_stats) + "\n")
-        return
+        exit(0)
 
     print(f"Start training for {args.epochs} epochs from {args.start_epoch}")
     start_time = time.time()
@@ -566,18 +409,18 @@ def main(args):
         train_stats = train_one_epoch(
             model, criterion, data_loader_train,
             optimizer, device, epoch, loss_scaler,lr_scheduler,
-            args.clip_grad, model_ema, mixup_fn,
+            args.clip_grad,  mixup_fn,
         )
         
-        checkpoint_paths = [output_dir / 'checkpoint.pth']
+        checkpoint_paths = ['checkpoint.pth']
         
-        if epoch%10==0 or epoch==args.epochs-1:
+        if epoch%args.ckpt_freq==0 or epoch==args.epochs-1:
             test_stats = evaluate(data_loader_val, model, device)
             print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
             
             if (test_stats["acc1"] >= max_accuracy):
                 # always only save best checkpoint till now
-                checkpoint_paths += [output_dir / 'checkpoint_best.pth']
+                checkpoint_paths += [ 'checkpoint_best.pth']
                 
         
             max_accuracy = max(max_accuracy, test_stats["acc1"])
@@ -593,7 +436,7 @@ def main(args):
                         'n_parameters': n_parameters}
             
         # only save checkpoint on rank 0
-        if args.output_dir and misc.is_main_process():
+        if output_dir and misc.is_main_process():
             if epoch%args.ckpt_freq==0 or epoch==args.epochs-1:
                 for checkpoint_path in checkpoint_paths:
                     misc.save_on_master({
@@ -603,7 +446,7 @@ def main(args):
                         'epoch': epoch,
                         'scaler': loss_scaler.state_dict(),
                         'args': args,
-                    }, checkpoint_path)
+                    }, output_dir / checkpoint_path)
                 
             if wandb.run: wandb.log(log_stats)
             with (output_dir / "log.txt").open("a") as f: