From f53030bf2ba1b54862f7af576b847bc2720b2292 Mon Sep 17 00:00:00 2001
From: gent <jw02425@surrey.ac.uk>
Date: Thu, 18 Jan 2024 21:11:03 +0000
Subject: [PATCH] Fix shape assignment in convolution function and import torch
 in evaluation scripts

---
 vitookit/datasets/ffcv_transform.py     | 2 +-
 vitookit/evaluation/eval_cls_ffcv.py    | 2 +-
 vitookit/evaluation/eval_linear_ffcv.py | 2 +-
 3 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/vitookit/datasets/ffcv_transform.py b/vitookit/datasets/ffcv_transform.py
index a8ea589..036b593 100644
--- a/vitookit/datasets/ffcv_transform.py
+++ b/vitookit/datasets/ffcv_transform.py
@@ -356,7 +356,7 @@ def convolution(image: np.ndarray, kernel: list | tuple, output: np.ndarray) ->
     else:
         raise Exception('Shape of image not supported')
 
-    m_k, n_k, _ = kernel.shape
+    m_k, n_k = kernel.shape
 
     y_strides = m_i - m_k + 1  # possible number of strides in y direction
     x_strides = n_i - n_k + 1  # possible number of strides in x direction
diff --git a/vitookit/evaluation/eval_cls_ffcv.py b/vitookit/evaluation/eval_cls_ffcv.py
index b71e3aa..70a79f0 100644
--- a/vitookit/evaluation/eval_cls_ffcv.py
+++ b/vitookit/evaluation/eval_cls_ffcv.py
@@ -263,7 +263,7 @@ def main(args):
     misc.init_distributed_mode(args)
 
     print(args)
-
+    import torch
     device = torch.device(args.device)
 
     # fix the seed for reproducibility
diff --git a/vitookit/evaluation/eval_linear_ffcv.py b/vitookit/evaluation/eval_linear_ffcv.py
index 26cabe0..143b6a8 100644
--- a/vitookit/evaluation/eval_linear_ffcv.py
+++ b/vitookit/evaluation/eval_linear_ffcv.py
@@ -126,7 +126,7 @@ def main(args):
     print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
     print("{}".format(args).replace(', ', ',\n'))
 
-    import torch
+    
     device = torch.device(args.device)
 
     # fix the seed for reproducibility
-- 
GitLab