diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py index 089f80a9f8479c0a26352a4bdcf02dcd559694c3..8b24b128c7f77db60d92efec51cb06809bcf6906 100644 --- a/vitookit/evaluation/eval_cls.py +++ b/vitookit/evaluation/eval_cls.py @@ -230,7 +230,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, @torch.no_grad() -def evaluate(data_loader, model, device): +def evaluate(data_loader, model, device,return_preds=False ): criterion = torch.nn.CrossEntropyLoss() metric_logger = misc.MetricLogger(delimiter=" ") @@ -260,7 +260,10 @@ def evaluate(data_loader, model, device): print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) preds = torch.cat(preds) - return {k: meter.global_avg for k, meter in metric_logger.meters.items()},preds + if return_preds: + return {k: meter.global_avg for k, meter in metric_logger.meters.items()},preds + else: + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} def main(args): @@ -393,7 +396,7 @@ def main(args): if args.eval: assert args.world_size == 1 - test_stats, preds = evaluate(data_loader_val, model_without_ddp, device) + test_stats, preds = evaluate(data_loader_val, model_without_ddp, device, True) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") if args.output_dir and misc.is_main_process(): with (output_dir / "log.txt").open("a") as f: diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py index 3f7de32c816ed474a13e1597a506b112a5a06dbd..42991232c1e74782e26365a14346436227d4ad54 100644 --- a/vitookit/evaluation/eval_cls_ffcv.py +++ b/vitookit/evaluation/eval_cls_ffcv.py @@ -317,7 +317,7 @@ def main(args): if args.eval: assert args.world_size == 1 - test_stats, preds = evaluate(data_loader_val, model_without_ddp, device) + test_stats, preds = evaluate(data_loader_val, model_without_ddp, device,True) print(f"Accuracy of the network on the test images: {test_stats['acc1']:.1f}%") if args.output_dir and misc.is_main_process(): with (output_dir / "log.txt").open("a") as f: