diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py index 37d5d68d9122742e29ae25dcddf409eeb560bde6..b5cf20eab66a45738d975132fe3c93465f56e5c6 100644 --- a/vitookit/evaluation/eval_cls.py +++ b/vitookit/evaluation/eval_cls.py @@ -335,7 +335,7 @@ def main(args): if args.pretrained_weights: load_pretrained_weights(model, args.pretrained_weights, checkpoint_key=args.checkpoint_key, prefix=args.prefix) if args.compile: - return torch.compile(model) + model = torch.compile(model) trunc_normal_(model.head.weight, std=2e-5) model.to(device)