From 7f5a9212aff0b2176d7da4115d386080a2c686f4 Mon Sep 17 00:00:00 2001
From: gent <jw02425@surrey.ac.uk>
Date: Sat, 27 Jan 2024 20:29:03 +0000
Subject: [PATCH] Update evaluate function to return predictions

---
 vitookit/evaluation/eval_cls.py      | 9 ++++++---
 vitookit/evaluation/eval_cls_ffcv.py | 2 +-
 2 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py
index 089f80a..8b24b12 100644
--- a/vitookit/evaluation/eval_cls.py
+++ b/vitookit/evaluation/eval_cls.py
@@ -230,7 +230,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
 
 
 @torch.no_grad()
-def evaluate(data_loader, model, device):
+def evaluate(data_loader, model, device,return_preds=False ):
     criterion = torch.nn.CrossEntropyLoss()
 
     metric_logger = misc.MetricLogger(delimiter="  ")
@@ -260,7 +260,10 @@ def evaluate(data_loader, model, device):
     print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
           .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
     preds = torch.cat(preds)
-    return {k: meter.global_avg for k, meter in metric_logger.meters.items()},preds
+    if return_preds:
+        return {k: meter.global_avg for k, meter in metric_logger.meters.items()},preds
+    else:
+    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
 
 
 def main(args):
@@ -393,7 +396,7 @@ def main(args):
 
     if args.eval:
         assert args.world_size == 1
-        test_stats, preds = evaluate(data_loader_val, model_without_ddp, device)
+        test_stats, preds = evaluate(data_loader_val, model_without_ddp, device, True)
         print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
         if args.output_dir and misc.is_main_process():
             with (output_dir / "log.txt").open("a") as f:
diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py
index 3f7de32..4299123 100644
--- a/vitookit/evaluation/eval_cls_ffcv.py
+++ b/vitookit/evaluation/eval_cls_ffcv.py
@@ -317,7 +317,7 @@ def main(args):
 
     if args.eval:
         assert args.world_size == 1
-        test_stats, preds = evaluate(data_loader_val, model_without_ddp, device)
+        test_stats, preds = evaluate(data_loader_val, model_without_ddp, device,True)
         print(f"Accuracy of the network on the test images: {test_stats['acc1']:.1f}%")
         if args.output_dir and misc.is_main_process():
             with (output_dir / "log.txt").open("a") as f:
-- 
GitLab