From 5bc839afae95a8f44cb6dc10dd14627262094570 Mon Sep 17 00:00:00 2001
From: gent <jw02425@surrey.ac.uk>
Date: Sat, 20 Jan 2024 15:47:51 +0000
Subject: [PATCH] import torch

---
 vitookit/evaluation/eval_cls.py      |  2 +-
 vitookit/evaluation/eval_cls_ffcv.py | 11 ++++-------
 vitookit/utils/helper.py             |  1 +
 3 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/vitookit/evaluation/eval_cls.py b/vitookit/evaluation/eval_cls.py
index e0896cf..de4568b 100644
--- a/vitookit/evaluation/eval_cls.py
+++ b/vitookit/evaluation/eval_cls.py
@@ -265,7 +265,7 @@ def main(args):
     misc.init_distributed_mode(args)
 
     print(args)
-
+    import torch
     device = torch.device(args.device)
 
     # fix the seed for reproducibility
diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py
index e9d399e..0ee6485 100644
--- a/vitookit/evaluation/eval_cls_ffcv.py
+++ b/vitookit/evaluation/eval_cls_ffcv.py
@@ -1,13 +1,10 @@
 #!/usr/bin/env python
-# Copyright (c) ByteDance, Inc. and its affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
+
 
 """
-Mostly copy-paste from DEiT library:
-https://github.com/facebookresearch/deit/blob/main/main.py
+Example:
+vitrun  --nproc_per_node=3 eval_cls_ffcv.py --train_path $train_path --val_path $val_path  --gin VisionTransformer.global_pool='\"avg\"'  -w wandb:dlib/EfficientSSL/xsa4wubh  --batch_size 360 --output_dir outputs/cls
+
 """
 from PIL import Image # hack to avoid `CXXABI_1.3.9' not found error
 
diff --git a/vitookit/utils/helper.py b/vitookit/utils/helper.py
index 6eb89d3..d16aefa 100644
--- a/vitookit/utils/helper.py
+++ b/vitookit/utils/helper.py
@@ -106,6 +106,7 @@ def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
     Re-start from checkpoint
     """
     if not os.path.isfile(ckp_path):
+        print("the file doesn't exist")
         return
     print("Found checkpoint at {}".format(ckp_path))
     if ckp_path.startswith('https'):
-- 
GitLab