diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py
index 089f80a9f8479c0a26352a4bdcf02dcd559694c3..8b24b128c7f77db60d92efec51cb06809bcf6906 100644
--- a/vitookit/evaluation/eval_cls.py
+++ b/vitookit/evaluation/eval_cls.py
@@ -230,7 +230,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
 
 
 @torch.no_grad()
-def evaluate(data_loader, model, device):
+def evaluate(data_loader, model, device,return_preds=False ):
     criterion = torch.nn.CrossEntropyLoss()
 
     metric_logger = misc.MetricLogger(delimiter="  ")
@@ -260,7 +260,10 @@ def evaluate(data_loader, model, device):
     print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
           .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
     preds = torch.cat(preds)
-    return {k: meter.global_avg for k, meter in metric_logger.meters.items()},preds
+    if return_preds:
+        return {k: meter.global_avg for k, meter in metric_logger.meters.items()},preds
+    else:
+    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
 
 
 def main(args):
@@ -393,7 +396,7 @@ def main(args):
 
     if args.eval:
         assert args.world_size == 1
-        test_stats, preds = evaluate(data_loader_val, model_without_ddp, device)
+        test_stats, preds = evaluate(data_loader_val, model_without_ddp, device, True)
         print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
         if args.output_dir and misc.is_main_process():
             with (output_dir / "log.txt").open("a") as f:
diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py
index 3f7de32c816ed474a13e1597a506b112a5a06dbd..42991232c1e74782e26365a14346436227d4ad54 100644
--- a/vitookit/evaluation/eval_cls_ffcv.py
+++ b/vitookit/evaluation/eval_cls_ffcv.py
@@ -317,7 +317,7 @@ def main(args):
 
     if args.eval:
         assert args.world_size == 1
-        test_stats, preds = evaluate(data_loader_val, model_without_ddp, device)
+        test_stats, preds = evaluate(data_loader_val, model_without_ddp, device,True)
         print(f"Accuracy of the network on the test images: {test_stats['acc1']:.1f}%")
         if args.output_dir and misc.is_main_process():
             with (output_dir / "log.txt").open("a") as f: