From 7f5a9212aff0b2176d7da4115d386080a2c686f4 Mon Sep 17 00:00:00 2001 From: gent <jw02425@surrey.ac.uk> Date: Sat, 27 Jan 2024 20:29:03 +0000 Subject: [PATCH] Update evaluate function to return predictions --- vitookit/evaluation/eval_cls.py | 9 ++++++--- vitookit/evaluation/eval_cls_ffcv.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py index 089f80a..8b24b12 100644 --- a/vitookit/evaluation/eval_cls.py +++ b/vitookit/evaluation/eval_cls.py @@ -230,7 +230,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, @torch.no_grad() -def evaluate(data_loader, model, device): +def evaluate(data_loader, model, device,return_preds=False ): criterion = torch.nn.CrossEntropyLoss() metric_logger = misc.MetricLogger(delimiter=" ") @@ -260,7 +260,10 @@ def evaluate(data_loader, model, device): print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) preds = torch.cat(preds) - return {k: meter.global_avg for k, meter in metric_logger.meters.items()},preds + if return_preds: + return {k: meter.global_avg for k, meter in metric_logger.meters.items()},preds + else: + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} def main(args): @@ -393,7 +396,7 @@ def main(args): if args.eval: assert args.world_size == 1 - test_stats, preds = evaluate(data_loader_val, model_without_ddp, device) + test_stats, preds = evaluate(data_loader_val, model_without_ddp, device, True) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") if args.output_dir and misc.is_main_process(): with (output_dir / "log.txt").open("a") as f: diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py index 3f7de32..4299123 100644 --- a/vitookit/evaluation/eval_cls_ffcv.py +++ b/vitookit/evaluation/eval_cls_ffcv.py @@ -317,7 +317,7 @@ def main(args): if args.eval: assert args.world_size == 1 - test_stats, preds = evaluate(data_loader_val, model_without_ddp, device) + test_stats, preds = evaluate(data_loader_val, model_without_ddp, device,True) print(f"Accuracy of the network on the test images: {test_stats['acc1']:.1f}%") if args.output_dir and misc.is_main_process(): with (output_dir / "log.txt").open("a") as f: -- GitLab