Skip to content
Snippets Groups Projects
Commit 7f5a9212 authored by gent's avatar gent
Browse files

Update evaluate function to return predictions

parent 421621d9
No related branches found
No related tags found
No related merge requests found
...@@ -230,7 +230,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, ...@@ -230,7 +230,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
@torch.no_grad() @torch.no_grad()
def evaluate(data_loader, model, device): def evaluate(data_loader, model, device,return_preds=False ):
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
metric_logger = misc.MetricLogger(delimiter=" ") metric_logger = misc.MetricLogger(delimiter=" ")
...@@ -260,7 +260,10 @@ def evaluate(data_loader, model, device): ...@@ -260,7 +260,10 @@ def evaluate(data_loader, model, device):
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
preds = torch.cat(preds) preds = torch.cat(preds)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()},preds if return_preds:
return {k: meter.global_avg for k, meter in metric_logger.meters.items()},preds
else:
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def main(args): def main(args):
...@@ -393,7 +396,7 @@ def main(args): ...@@ -393,7 +396,7 @@ def main(args):
if args.eval: if args.eval:
assert args.world_size == 1 assert args.world_size == 1
test_stats, preds = evaluate(data_loader_val, model_without_ddp, device) test_stats, preds = evaluate(data_loader_val, model_without_ddp, device, True)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
if args.output_dir and misc.is_main_process(): if args.output_dir and misc.is_main_process():
with (output_dir / "log.txt").open("a") as f: with (output_dir / "log.txt").open("a") as f:
......
...@@ -317,7 +317,7 @@ def main(args): ...@@ -317,7 +317,7 @@ def main(args):
if args.eval: if args.eval:
assert args.world_size == 1 assert args.world_size == 1
test_stats, preds = evaluate(data_loader_val, model_without_ddp, device) test_stats, preds = evaluate(data_loader_val, model_without_ddp, device,True)
print(f"Accuracy of the network on the test images: {test_stats['acc1']:.1f}%") print(f"Accuracy of the network on the test images: {test_stats['acc1']:.1f}%")
if args.output_dir and misc.is_main_process(): if args.output_dir and misc.is_main_process():
with (output_dir / "log.txt").open("a") as f: with (output_dir / "log.txt").open("a") as f:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment