diff --git a/.idea/misc.xml b/.idea/misc.xml
index c4d39259202f7675b57dbaaad9a079a6daba4e40..971b53d78c8b5de31aaa8b89a37f31bcd8a95a94 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -1,7 +1,7 @@
 <?xml version="1.0" encoding="UTF-8"?>
 <project version="4">
-  <component name="GithubDefaultAccount">
-    <option name="defaultAccountId" value="93306623-ad79-4b24-b96e-9f1ce5d3c2a0" />
+  <component name="Black">
+    <option name="sdkName" value="video_classification" />
   </component>
   <component name="ProjectRootManager" version="2" project-jdk-name="video_classification" project-jdk-type="Python SDK" />
 </project>
\ No newline at end of file
diff --git a/notebooks/experimentation.ipynb b/notebooks/experimentation.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..edc749cc73f83f11b47f109da5b0eb20a35f6cff
--- /dev/null
+++ b/notebooks/experimentation.ipynb
@@ -0,0 +1,2261 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "provenance": [],
+      "gpuType": "A100",
+      "machine_shape": "hm",
+      "collapsed_sections": [
+        "S8cK8UUk16O6"
+      ]
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    },
+    "language_info": {
+      "name": "python"
+    },
+    "accelerator": "GPU",
+    "widgets": {
+      "application/vnd.jupyter.widget-state+json": {
+        "5445dbf6343740f9b89f1c2133762370": {
+          "model_module": "@jupyter-widgets/controls",
+          "model_name": "VBoxModel",
+          "model_module_version": "1.5.0",
+          "state": {
+            "_dom_classes": [],
+            "_model_module": "@jupyter-widgets/controls",
+            "_model_module_version": "1.5.0",
+            "_model_name": "VBoxModel",
+            "_view_count": null,
+            "_view_module": "@jupyter-widgets/controls",
+            "_view_module_version": "1.5.0",
+            "_view_name": "VBoxView",
+            "box_style": "",
+            "children": [
+              "IPY_MODEL_4f31997fd8704ba3b8a6ce1c0b7988a4",
+              "IPY_MODEL_b369d7a80aae4d59aebbe5c0a0673bd4"
+            ],
+            "layout": "IPY_MODEL_e0509e15a96e496d96df8d6a82df3782"
+          }
+        },
+        "4f31997fd8704ba3b8a6ce1c0b7988a4": {
+          "model_module": "@jupyter-widgets/controls",
+          "model_name": "LabelModel",
+          "model_module_version": "1.5.0",
+          "state": {
+            "_dom_classes": [],
+            "_model_module": "@jupyter-widgets/controls",
+            "_model_module_version": "1.5.0",
+            "_model_name": "LabelModel",
+            "_view_count": null,
+            "_view_module": "@jupyter-widgets/controls",
+            "_view_module_version": "1.5.0",
+            "_view_name": "LabelView",
+            "description": "",
+            "description_tooltip": null,
+            "layout": "IPY_MODEL_06a3702b5df94291bc87a20f7b1ef6b5",
+            "placeholder": "​",
+            "style": "IPY_MODEL_be79c9ec0adb4d2699b58354ad8493b7",
+            "value": "0.030 MB of 0.030 MB uploaded\r"
+          }
+        },
+        "b369d7a80aae4d59aebbe5c0a0673bd4": {
+          "model_module": "@jupyter-widgets/controls",
+          "model_name": "FloatProgressModel",
+          "model_module_version": "1.5.0",
+          "state": {
+            "_dom_classes": [],
+            "_model_module": "@jupyter-widgets/controls",
+            "_model_module_version": "1.5.0",
+            "_model_name": "FloatProgressModel",
+            "_view_count": null,
+            "_view_module": "@jupyter-widgets/controls",
+            "_view_module_version": "1.5.0",
+            "_view_name": "ProgressView",
+            "bar_style": "",
+            "description": "",
+            "description_tooltip": null,
+            "layout": "IPY_MODEL_8e26e03388da4aea8dfbfb93e7a54f87",
+            "max": 1,
+            "min": 0,
+            "orientation": "horizontal",
+            "style": "IPY_MODEL_4af18f329a83473783f0f32151b79476",
+            "value": 1
+          }
+        },
+        "e0509e15a96e496d96df8d6a82df3782": {
+          "model_module": "@jupyter-widgets/base",
+          "model_name": "LayoutModel",
+          "model_module_version": "1.2.0",
+          "state": {
+            "_model_module": "@jupyter-widgets/base",
+            "_model_module_version": "1.2.0",
+            "_model_name": "LayoutModel",
+            "_view_count": null,
+            "_view_module": "@jupyter-widgets/base",
+            "_view_module_version": "1.2.0",
+            "_view_name": "LayoutView",
+            "align_content": null,
+            "align_items": null,
+            "align_self": null,
+            "border": null,
+            "bottom": null,
+            "display": null,
+            "flex": null,
+            "flex_flow": null,
+            "grid_area": null,
+            "grid_auto_columns": null,
+            "grid_auto_flow": null,
+            "grid_auto_rows": null,
+            "grid_column": null,
+            "grid_gap": null,
+            "grid_row": null,
+            "grid_template_areas": null,
+            "grid_template_columns": null,
+            "grid_template_rows": null,
+            "height": null,
+            "justify_content": null,
+            "justify_items": null,
+            "left": null,
+            "margin": null,
+            "max_height": null,
+            "max_width": null,
+            "min_height": null,
+            "min_width": null,
+            "object_fit": null,
+            "object_position": null,
+            "order": null,
+            "overflow": null,
+            "overflow_x": null,
+            "overflow_y": null,
+            "padding": null,
+            "right": null,
+            "top": null,
+            "visibility": null,
+            "width": null
+          }
+        },
+        "06a3702b5df94291bc87a20f7b1ef6b5": {
+          "model_module": "@jupyter-widgets/base",
+          "model_name": "LayoutModel",
+          "model_module_version": "1.2.0",
+          "state": {
+            "_model_module": "@jupyter-widgets/base",
+            "_model_module_version": "1.2.0",
+            "_model_name": "LayoutModel",
+            "_view_count": null,
+            "_view_module": "@jupyter-widgets/base",
+            "_view_module_version": "1.2.0",
+            "_view_name": "LayoutView",
+            "align_content": null,
+            "align_items": null,
+            "align_self": null,
+            "border": null,
+            "bottom": null,
+            "display": null,
+            "flex": null,
+            "flex_flow": null,
+            "grid_area": null,
+            "grid_auto_columns": null,
+            "grid_auto_flow": null,
+            "grid_auto_rows": null,
+            "grid_column": null,
+            "grid_gap": null,
+            "grid_row": null,
+            "grid_template_areas": null,
+            "grid_template_columns": null,
+            "grid_template_rows": null,
+            "height": null,
+            "justify_content": null,
+            "justify_items": null,
+            "left": null,
+            "margin": null,
+            "max_height": null,
+            "max_width": null,
+            "min_height": null,
+            "min_width": null,
+            "object_fit": null,
+            "object_position": null,
+            "order": null,
+            "overflow": null,
+            "overflow_x": null,
+            "overflow_y": null,
+            "padding": null,
+            "right": null,
+            "top": null,
+            "visibility": null,
+            "width": null
+          }
+        },
+        "be79c9ec0adb4d2699b58354ad8493b7": {
+          "model_module": "@jupyter-widgets/controls",
+          "model_name": "DescriptionStyleModel",
+          "model_module_version": "1.5.0",
+          "state": {
+            "_model_module": "@jupyter-widgets/controls",
+            "_model_module_version": "1.5.0",
+            "_model_name": "DescriptionStyleModel",
+            "_view_count": null,
+            "_view_module": "@jupyter-widgets/base",
+            "_view_module_version": "1.2.0",
+            "_view_name": "StyleView",
+            "description_width": ""
+          }
+        },
+        "8e26e03388da4aea8dfbfb93e7a54f87": {
+          "model_module": "@jupyter-widgets/base",
+          "model_name": "LayoutModel",
+          "model_module_version": "1.2.0",
+          "state": {
+            "_model_module": "@jupyter-widgets/base",
+            "_model_module_version": "1.2.0",
+            "_model_name": "LayoutModel",
+            "_view_count": null,
+            "_view_module": "@jupyter-widgets/base",
+            "_view_module_version": "1.2.0",
+            "_view_name": "LayoutView",
+            "align_content": null,
+            "align_items": null,
+            "align_self": null,
+            "border": null,
+            "bottom": null,
+            "display": null,
+            "flex": null,
+            "flex_flow": null,
+            "grid_area": null,
+            "grid_auto_columns": null,
+            "grid_auto_flow": null,
+            "grid_auto_rows": null,
+            "grid_column": null,
+            "grid_gap": null,
+            "grid_row": null,
+            "grid_template_areas": null,
+            "grid_template_columns": null,
+            "grid_template_rows": null,
+            "height": null,
+            "justify_content": null,
+            "justify_items": null,
+            "left": null,
+            "margin": null,
+            "max_height": null,
+            "max_width": null,
+            "min_height": null,
+            "min_width": null,
+            "object_fit": null,
+            "object_position": null,
+            "order": null,
+            "overflow": null,
+            "overflow_x": null,
+            "overflow_y": null,
+            "padding": null,
+            "right": null,
+            "top": null,
+            "visibility": null,
+            "width": null
+          }
+        },
+        "4af18f329a83473783f0f32151b79476": {
+          "model_module": "@jupyter-widgets/controls",
+          "model_name": "ProgressStyleModel",
+          "model_module_version": "1.5.0",
+          "state": {
+            "_model_module": "@jupyter-widgets/controls",
+            "_model_module_version": "1.5.0",
+            "_model_name": "ProgressStyleModel",
+            "_view_count": null,
+            "_view_module": "@jupyter-widgets/base",
+            "_view_module_version": "1.2.0",
+            "_view_name": "StyleView",
+            "bar_color": null,
+            "description_width": ""
+          }
+        }
+      }
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# Setup"
+      ],
+      "metadata": {
+        "id": "S8cK8UUk16O6"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "!pip install wandb\n",
+        "\n",
+        "import wandb\n",
+        "wandb.login()"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "collapsed": true,
+        "id": "--B8rqMm_83Y",
+        "outputId": "a65d1ee5-12e7-410b-c1f0-be59421b6db3"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Requirement already satisfied: wandb in /usr/local/lib/python3.10/dist-packages (0.17.8)\n",
+            "Requirement already satisfied: click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.7)\n",
+            "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (0.4.0)\n",
+            "Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.1.43)\n",
+            "Requirement already satisfied: platformdirs in /usr/local/lib/python3.10/dist-packages (from wandb) (4.2.2)\n",
+            "Requirement already satisfied: protobuf!=4.21.0,<6,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n",
+            "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (5.9.5)\n",
+            "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from wandb) (6.0.2)\n",
+            "Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (2.32.3)\n",
+            "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (2.13.0)\n",
+            "Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb) (1.3.3)\n",
+            "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (71.0.4)\n",
+            "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
+            "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.11)\n",
+            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.3.2)\n",
+            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.8)\n",
+            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2.0.7)\n",
+            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2024.7.4)\n",
+            "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.1)\n"
+          ]
+        },
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33msanchitv7\u001b[0m (\u001b[33msanchitv7-university-of-surrey\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
+          ]
+        },
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "True"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 4
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "%env WANDB_PROJECT=video_classification\n",
+        "%env WANDB_LOG_MODEL=\"checkpoint\""
+      ],
+      "metadata": {
+        "id": "BQf0IsFozEB_",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "3410c5d3-1f0f-4624-d923-3c1252ff720d"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "env: WANDB_PROJECT=video_classification\n",
+            "env: WANDB_LOG_MODEL=\"checkpoint\"\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "%%capture\n",
+        "import importlib\n",
+        "\n",
+        "def install_if_missing(package_name):\n",
+        "    try:\n",
+        "        importlib.import_module(package_name)\n",
+        "        print(f\"'{package_name}' is already installed.\")\n",
+        "    except ImportError:\n",
+        "        print(f\"'{package_name}' not found. Installing...\")\n",
+        "        !pip install {package_name}\n",
+        "\n",
+        "\n",
+        "install_if_missing('evaluate')\n",
+        "install_if_missing('datasets')\n",
+        "install_if_missing('torchsummary')"
+      ],
+      "metadata": {
+        "collapsed": true,
+        "id": "FpugRyrUzJDC"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "dataset_zip = 'BVIArtefact_8_crops_all_videos'"
+      ],
+      "metadata": {
+        "id": "nHwmAod7BIek"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import os\n",
+        "\n",
+        "if os.getcwd() != '/content':\n",
+        "    %cd /content\n",
+        "\n",
+        "if not os.path.isdir('/content/data'):\n",
+        "    !mkdir data\n",
+        "\n",
+        "if not os.path.isdir(f'/content/data/{dataset_zip}'):\n",
+        "    !scp /content/drive/MyDrive/MSProjectMisc/{dataset_zip}.zip .\n",
+        "    !unzip -q {dataset_zip}.zip\n",
+        "\n",
+        "    !rm -rf /content/__MACOSX\n",
+        "    !mv {dataset_zip}/ data/\n",
+        "    !rm -rf {dataset_zip}.zip\n",
+        "    %cd /content/data/{dataset_zip}\n",
+        "    !pwd\n",
+        "    !find ./train ./val ./test -name \"._*\" -type f -delete\n",
+        "\n",
+        "    %cd ../\n",
+        "    %cd ../"
+      ],
+      "metadata": {
+        "id": "YCCFP89gle4J",
+        "collapsed": true
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# @title # Set Seed\n",
+        "import random\n",
+        "import numpy as np\n",
+        "import torch\n",
+        "import transformers\n",
+        "\n",
+        "def set_seed(seed: int = 42):\n",
+        "    \"\"\"\n",
+        "    Set seeds for reproducibility.\n",
+        "\n",
+        "    Args:\n",
+        "        seed (int): The seed to set for all random number generators.\n",
+        "    \"\"\"\n",
+        "    random.seed(seed)\n",
+        "    np.random.seed(seed)\n",
+        "    torch.manual_seed(seed)\n",
+        "    torch.cuda.manual_seed(seed)\n",
+        "    torch.cuda.manual_seed_all(seed)  # if using multi-GPU\n",
+        "    torch.backends.cudnn.deterministic = True\n",
+        "    torch.backends.cudnn.benchmark = False\n",
+        "\n",
+        "    # Set seed for Hugging Face transformers\n",
+        "    transformers.set_seed(seed)\n",
+        "\n",
+        "# At the start of your script\n",
+        "set_seed(42)"
+      ],
+      "metadata": {
+        "id": "b5Hdnlzr3jnh"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# @title #Dataset Declaration\n",
+        "import torch\n",
+        "import os\n",
+        "import json\n",
+        "from torch.utils.data import Dataset\n",
+        "from typing import Dict, List\n",
+        "from torchvision import transforms\n",
+        "\n",
+        "class BVIArtefactDataset(Dataset):\n",
+        "    def __init__(self, root_dir: str, split: str = 'train'):\n",
+        "        self.root_dir = os.path.join(root_dir, split)\n",
+        "        self.labels_file = os.path.join(self.root_dir, 'labels.json')\n",
+        "        self.crop_files = [f for f in os.listdir(self.root_dir) if f.endswith('.pt')]\n",
+        "        with open(self.labels_file, 'r') as f:\n",
+        "            self.labels = json.load(f)\n",
+        "        self.label_names = list(next(iter(self.labels.values())).keys())\n",
+        "        self.id2label = {i: label for i, label in enumerate(self.label_names)}\n",
+        "        self.label2id = {label: i for i, label in self.id2label.items()}\n",
+        "        self.transform = transforms.Compose([\n",
+        "            transforms.RandomHorizontalFlip(),\n",
+        "            transforms.RandomRotation(10)\n",
+        "        ]) if split == 'train' else None\n",
+        "\n",
+        "    def __len__(self):\n",
+        "        return len(self.crop_files)\n",
+        "\n",
+        "    def __getitem__(self, idx):\n",
+        "        crop_file = self.crop_files[idx]\n",
+        "        crop_path = os.path.join(self.root_dir, crop_file)\n",
+        "        # load video crops\n",
+        "        video = torch.load(crop_path,\n",
+        "                           map_location='cpu',\n",
+        "                           weights_only=False)\n",
+        "        # Get labels\n",
+        "        label_dict = self.labels[crop_file]\n",
+        "        labels = torch.tensor([label_dict[name] for name in self.label_names], dtype=torch.float32)\n",
+        "        return {\"pixel_values\": video, \"labels\": labels}\n",
+        "\n",
+        "    def get_label_mappings(self):\n",
+        "        return self.id2label, self.label2id"
+      ],
+      "metadata": {
+        "id": "W3egs-vYwGWi",
+        "cellView": "form"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# Inspect Dataset"
+      ],
+      "metadata": {
+        "id": "OyjCVZr3zovr"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from pprint import pprint\n",
+        "# Create datasets\n",
+        "root_dir = f'data/{dataset_zip}'\n",
+        "train_dataset = BVIArtefactDataset(root_dir, split='train')\n",
+        "val_dataset = BVIArtefactDataset(root_dir, split='val')\n",
+        "test_dataset = BVIArtefactDataset(root_dir, split='test')\n",
+        "\n",
+        "# Check the shape of an item\n",
+        "print('Shape of one patch of a video:')\n",
+        "print(train_dataset[0]['pixel_values'].shape)  # Should print torch.Size([8, 3, 224, 224])\n",
+        "print()\n",
+        "print('Shape of the label:')\n",
+        "print(train_dataset[0]['labels'].shape)  # Should print torch.Size([num_labe\n",
+        "\n",
+        "id2label, label2id = BVIArtefactDataset(root_dir, split='train').get_label_mappings()\n",
+        "\n",
+        "num_labels = len(id2label)\n",
+        "\n",
+        "print(f'\\nindex to label mapping:')\n",
+        "pprint(id2label)\n",
+        "print()\n",
+        "print(f'label to index mapping:')\n",
+        "pprint(label2id)"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "NEpCmHAgkhp8",
+        "outputId": "5574f74c-f77d-499f-c9d2-4e9c10361587"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Shape of one patch of a video:\n",
+            "torch.Size([8, 3, 224, 224])\n",
+            "\n",
+            "Shape of the label:\n",
+            "torch.Size([9])\n",
+            "\n",
+            "index to label mapping:\n",
+            "{0: 'black_screen',\n",
+            " 1: 'frame_drop',\n",
+            " 2: 'spatial_blur',\n",
+            " 3: 'transmission_error',\n",
+            " 4: 'aliasing',\n",
+            " 5: 'banding',\n",
+            " 6: 'dark_scenes',\n",
+            " 7: 'graininess',\n",
+            " 8: 'motion_blur'}\n",
+            "\n",
+            "label to index mapping:\n",
+            "{'aliasing': 4,\n",
+            " 'banding': 5,\n",
+            " 'black_screen': 0,\n",
+            " 'dark_scenes': 6,\n",
+            " 'frame_drop': 1,\n",
+            " 'graininess': 7,\n",
+            " 'motion_blur': 8,\n",
+            " 'spatial_blur': 2,\n",
+            " 'transmission_error': 3}\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "train_dataset[0]['labels']"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "FEsAgMicnwMI",
+        "outputId": "35f13d94-6386-48e8-9ee0-1227bdd0f386"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "tensor([0., 1., 0., 0., 0., 0., 0., 1., 1.])"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 12
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import imageio\n",
+        "import numpy as np\n",
+        "from IPython.display import Image, display\n",
+        "import torch\n",
+        "import random\n",
+        "\n",
+        "def unnormalize_img(img):\n",
+        "    \"\"\"Un-normalizes the image pixels.\"\"\"\n",
+        "    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)\n",
+        "    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)\n",
+        "    img = (img * std) + mean\n",
+        "    img = (img * 255).clamp(0, 255).byte()\n",
+        "    return img\n",
+        "\n",
+        "def create_gif(video_tensor, filename=\"sample.gif\"):\n",
+        "    \"\"\"Prepares a GIF from a video tensor.\n",
+        "\n",
+        "    The video tensor is expected to have the following shape:\n",
+        "    (num_frames, num_channels, height, width).\n",
+        "    \"\"\"\n",
+        "    frames = []\n",
+        "    for video_frame in video_tensor:\n",
+        "        frame_unnormalized = unnormalize_img(video_frame).permute(1, 2, 0).numpy()\n",
+        "        frames.append(frame_unnormalized)\n",
+        "    kargs = {\"duration\": 0.25}\n",
+        "    imageio.mimsave(filename, frames, \"GIF\", **kargs)\n",
+        "    return filename\n",
+        "\n",
+        "def display_gif(video_tensor, gif_name=\"sample.gif\"):\n",
+        "    \"\"\"Prepares and displays a GIF from a video tensor.\"\"\"\n",
+        "    gif_filename = create_gif(video_tensor, gif_name)\n",
+        "    return Image(filename=gif_filename)\n",
+        "\n",
+        "# Get a sample from your dataset\n",
+        "random_sample = random.randint(0, len(train_dataset))\n",
+        "\n",
+        "sample = train_dataset[random_sample]\n",
+        "video_tensor = sample['pixel_values']\n",
+        "\n",
+        "# Display the GIF\n",
+        "display(display_gif(video_tensor, \"sample_crop.gif\"))\n",
+        "pprint({id2label[i]: sample['labels'][i] for i in range(len(sample['labels']))})"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 402
+        },
+        "id": "ovOL5Rtqwb3_",
+        "outputId": "8503307d-1326-42d2-fb7f-2f190f284da4"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "display_data",
+          "data": {
+            "image/gif": "R0lGODlh4ADgAIEAAAAAAAAAAAAAAAAAACwAAAAA4ADgAEAI/wABCBxIsKDBgwgTKlzIsKHDhxAjSpxIsaLFixgzatzIsaPHjyBDihxJsqTJkyhTqlzJsqXLlzBjypxJs6bNmzhz6tzJs6fPn0CDCh1KtKjRo0iTKl3KtKnTp1CjSp1KtarVq1izat3KtavXr2DDih1LtqzZs2jTql3Ltq3bt3Djyp1Lt67du3jz6t3Lt6/fv4ADCx5MuLDhw4gTK17MuLHjx5AjS55MubLly5gza97MubPnz6BDix5NurTp06hTq17NurXr17Bjy55Nu7bt27hz697Nu7fv38CDCx9OvLjx48iTK1/OvLnz59CjS59Ovbr169iza9/Ovbv37+DDi3AfT768+fPo06tfz769+/fw48ufT7++/fv48+vfz7+///8ABijggAQWaOCBCCao4IIMNujggxBGKOGEFFZo4YUYZqjhhhx26OGHIIYo4ogklmjiiSimqOKKLLbo4oswxijjjDTWaOONOOao4448ahUQADs=\n",
+            "text/plain": [
+              "<IPython.core.display.Image object>"
+            ]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "{'aliasing': tensor(0.),\n",
+            " 'banding': tensor(1.),\n",
+            " 'black_screen': tensor(1.),\n",
+            " 'dark_scenes': tensor(0.),\n",
+            " 'frame_drop': tensor(1.),\n",
+            " 'graininess': tensor(0.),\n",
+            " 'motion_blur': tensor(0.),\n",
+            " 'spatial_blur': tensor(0.),\n",
+            " 'transmission_error': tensor(1.)}\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# @title # Utility Functions\n",
+        "import datetime\n",
+        "import os\n",
+        "import json\n",
+        "from typing import Dict, Any\n",
+        "import pytz\n",
+        "\n",
+        "# Set the timezone to UTC+1\n",
+        "timezone = pytz.timezone('Europe/London')  # This is an example; choose a timezone with UTC+1\n",
+        "\n",
+        "\n",
+        "def get_run_name():\n",
+        "    # Generate a timestamp\n",
+        "    timestamp = datetime.datetime.now(timezone).strftime(\"%Y-%m-%d--%H-%M-%S\")\n",
+        "\n",
+        "    # Create a unique run name\n",
+        "    run_name = f\"run-{timestamp}\"\n",
+        "\n",
+        "    # Create directories for this run\n",
+        "    os.makedirs(f\"./logs/{run_name}\", exist_ok=True)\n",
+        "    os.makedirs(f\"./results/{run_name}\", exist_ok=True)\n",
+        "\n",
+        "    return run_name\n",
+        "\n",
+        "\n",
+        "def log_run_info(run_name, model, training_args):\n",
+        "\n",
+        "    run_info_path = os.path.join(\"./logs\", run_name, \"run_info.txt\")\n",
+        "    with open(run_info_path, \"w\") as f:\n",
+        "        f.write(f\"Run Name: {run_name}\\n\")\n",
+        "        f.write(f\"Model: {type(model).__name__}\\n\")\n",
+        "        f.write(\"Training Arguments:\\n\")\n",
+        "        for arg, value in vars(training_args).items():\n",
+        "            f.write(f\"  {arg}: {value}\\n\")\n",
+        "\n",
+        "\n",
+        "def save_test_results_to_json(test_results: Dict[str, Any], artifacts: list):\n",
+        "    \"\"\"\n",
+        "    Format test results and save them to a JSON file.\n",
+        "\n",
+        "    Args:\n",
+        "    test_results (Dict[str, Any]): Dictionary containing test results\n",
+        "    artifacts (list): List of artifact names\n",
+        "    output_file (str): Path to the output JSON file\n",
+        "    \"\"\"\n",
+        "\n",
+        "    output_file = os.path.join(\"./logs\", run_name, \"test_set_results.json\")\n",
+        "\n",
+        "    # formatted_results = {\n",
+        "    #     \"overall_metrics\": {\n",
+        "    #         \"loss\": test_results['eval_loss'],\n",
+        "    #         \"accuracy\": test_results['eval_accuracy'],\n",
+        "    #         \"f1_score\": test_results['eval_f1'],\n",
+        "    #         \"auc\": test_results['eval_auc'],\n",
+        "    #         \"combined_auc_focus\": test_results['source_artefacts_avg_auc']\n",
+        "    #     },\n",
+        "    #     \"artifact_metrics\": {},\n",
+        "    #     \"additional_info\": {\n",
+        "    #         \"runtime\": test_results['eval_runtime'],\n",
+        "    #         \"samples_per_second\": test_results['eval_samples_per_second'],\n",
+        "    #         \"steps_per_second\": test_results['eval_steps_per_second'],\n",
+        "    #         \"epoch\": test_results['epoch']\n",
+        "    #     }\n",
+        "    # }\n",
+        "\n",
+        "    # for artifact in artifacts:\n",
+        "    #     formatted_results[\"artifact_metrics\"][artifact] = {\n",
+        "    #         \"accuracy\": test_results[f'eval_accuracy_{artifact}'],\n",
+        "    #         \"f1_score\": test_results[f'eval_f1_{artifact}'],\n",
+        "    #         \"auc\": test_results[f'eval_auc_{artifact}']\n",
+        "    #     }\n",
+        "\n",
+        "    with open(output_file, 'w') as f:\n",
+        "        json.dump(test_results, f, indent=2)\n",
+        "\n",
+        "    print(f\"Test results saved to {output_file}\")\n",
+        "\n",
+        "# Function to print metrics for a specific artifact\n",
+        "def print_artifact_metrics(artifact, metrics):\n",
+        "    print(f\"\\n{artifact.capitalize()} Metrics:\")\n",
+        "    print(f\"  Accuracy: {metrics[f'eval_accuracy_{artifact}']:.4f}\")\n",
+        "    print(f\"  F1 Score: {metrics[f'eval_f1_{artifact}']:.4f}\")\n",
+        "    print(f\"  AUC:      {metrics[f'eval_auc_{artifact}']:.4f}\")"
+      ],
+      "metadata": {
+        "id": "KrAU5_dt1p1N"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# @title # Model Declaration\n",
+        "\n",
+        "import torch\n",
+        "import torch.nn as nn\n",
+        "from transformers import TimesformerModel\n",
+        "from torch.nn import functional as F\n",
+        "\n",
+        "\n",
+        "def focal_loss(logits, labels, alpha=0.25, gamma=2.0):\n",
+        "    \"\"\"\n",
+        "    Focal loss for multi-label classification.\n",
+        "    \"\"\"\n",
+        "    BCE_loss = F.binary_cross_entropy_with_logits(logits, labels, reduction='none')\n",
+        "    pt = torch.exp(-BCE_loss)\n",
+        "    focal_loss = alpha * (1-pt)**gamma * BCE_loss\n",
+        "    return focal_loss.mean()\n",
+        "\n",
+        "\n",
+        "class MultiTaskTimeSformer(nn.Module):\n",
+        "    def __init__(self, num_labels, id2label, label2id, dataset, num_layers_to_finetune=4):\n",
+        "        \"\"\"\n",
+        "        Multi-task TimeSformer model for video artifact detection.\n",
+        "\n",
+        "        Args:\n",
+        "        num_labels: Number of artifact classes\n",
+        "        id2label: Mapping from label index to label name\n",
+        "        label2id: Mapping from label name to label index\n",
+        "        dataset: Dataset used for calculating initial weights\n",
+        "        num_layers_to_finetune: Number of layers to fine-tune in the TimeSformer backbone\n",
+        "        \"\"\"\n",
+        "        super().__init__()\n",
+        "        self.num_labels = num_labels\n",
+        "        self.id2label = id2label\n",
+        "        self.label2id = label2id\n",
+        "\n",
+        "        # TimeSformer backbone\n",
+        "        # Using a pre-trained model for transfer learning\n",
+        "        self.timesformer = TimesformerModel.from_pretrained(\n",
+        "            \"facebook/timesformer-base-finetuned-k400\",\n",
+        "            ignore_mismatched_sizes=True\n",
+        "        )\n",
+        "\n",
+        "        # Task-specific adaptation layers\n",
+        "        # These layers adapt the TimeSformer's output for each specific artifact detection task\n",
+        "        # self.task_adapters = nn.ModuleDict({\n",
+        "        #     label: nn.Sequential(\n",
+        "        #         nn.Linear(768, 512),\n",
+        "        #         nn.ReLU(),\n",
+        "        #         nn.Dropout(0.1),  # Dropout for regularization\n",
+        "        #         nn.Linear(512, 256),\n",
+        "        #         nn.ReLU(),\n",
+        "        #         nn.Dropout(0.1),\n",
+        "        #         nn.Linear(256, 1)\n",
+        "        #     ) for label in id2label.values()\n",
+        "        # })\n",
+        "\n",
+        "        self.task_adapters = nn.ModuleDict({\n",
+        "            label: nn.Sequential(\n",
+        "                nn.Linear(768, 512),\n",
+        "                nn.LayerNorm(512),\n",
+        "                nn.ReLU(),\n",
+        "                nn.Dropout(0.3),\n",
+        "                nn.Linear(512, 256),\n",
+        "                nn.LayerNorm(256),\n",
+        "                nn.ReLU(),\n",
+        "                nn.Dropout(0.2),\n",
+        "                nn.Linear(256, 128),\n",
+        "                nn.LayerNorm(128),\n",
+        "                nn.ReLU(),\n",
+        "                nn.Dropout(0.1),\n",
+        "                nn.Linear(128, 1)\n",
+        "            ) for label in id2label.values()\n",
+        "        })\n",
+        "\n",
+        "        # self.task_adapters = nn.ModuleDict({\n",
+        "        #     label: nn.Sequential(\n",
+        "        #         nn.Linear(768, 256),\n",
+        "        #         nn.LayerNorm(256), # Layer Normalisatoin\n",
+        "        #         nn.ReLU(),\n",
+        "        #         nn.Dropout(0.3),  # Dropout for regularization\n",
+        "        #         nn.Linear(256, 1)\n",
+        "        #     ) for label in id2label.values()\n",
+        "        # })\n",
+        "\n",
+        "        # Calculate and set initial task weights\n",
+        "        # This helps address class imbalance from the start of training\n",
+        "        initial_weights = self.calculate_initial_weights(dataset)\n",
+        "        # self.class_weights = nn.Parameter(initial_weights, requires_grad=False)\n",
+        "        self.task_weights = nn.Parameter(initial_weights)\n",
+        "\n",
+        "        # Freeze all layers except the last few\n",
+        "        # This allows for fine-tuning of the model while preserving pre-trained knowledge\n",
+        "        self.freeze_layers(num_layers_to_finetune)\n",
+        "\n",
+        "    def calculate_initial_weights(self, dataset):\n",
+        "        \"\"\"\n",
+        "        Calculate class weights based on class frequencies and positive/negative ratios.\n",
+        "        This helps address class imbalance in multi-label classification.\n",
+        "\n",
+        "        Args:\n",
+        "        dataset: The dataset containing labels\n",
+        "\n",
+        "        Returns:\n",
+        "        combined_weights: Tensor of weights for each class\n",
+        "        \"\"\"\n",
+        "        num_samples = len(dataset)\n",
+        "        num_labels = len(dataset.id2label)\n",
+        "\n",
+        "        class_frequencies = torch.zeros(num_labels)\n",
+        "        pos_neg_ratios = torch.zeros(num_labels)\n",
+        "\n",
+        "        # Calculate class frequencies\n",
+        "        for i in range(num_samples):\n",
+        "            labels = dataset[i]['labels']\n",
+        "            class_frequencies += labels\n",
+        "\n",
+        "        # Calculate positive/negative ratios\n",
+        "        for i in range(num_labels):\n",
+        "            pos_count = class_frequencies[i]\n",
+        "            neg_count = num_samples - pos_count\n",
+        "            pos_neg_ratios[i] = neg_count / pos_count if pos_count > 0 else 1.0\n",
+        "\n",
+        "        # Inverse class frequencies (rare classes get higher weights)\n",
+        "        class_weights = 1.0 / class_frequencies\n",
+        "        class_weights[class_weights == float('inf')] = 0  # Handle division by zero\n",
+        "\n",
+        "        # Normalize weights\n",
+        "        class_weights = class_weights / class_weights.sum()\n",
+        "        pos_neg_ratios = pos_neg_ratios / pos_neg_ratios.sum()\n",
+        "\n",
+        "        # Combine class frequencies and pos/neg ratios\n",
+        "        # This balances between overall class rarity and class imbalance within each label\n",
+        "        combined_weights = (class_weights + pos_neg_ratios) / 2\n",
+        "\n",
+        "        return combined_weights\n",
+        "\n",
+        "\n",
+        "    # def calculate_class_weights(self, dataset):\n",
+        "    #     num_samples = len(dataset)\n",
+        "    #     class_counts = torch.zeros(self.num_labels)\n",
+        "\n",
+        "    #     for i in range(num_samples):\n",
+        "    #         labels = dataset[i]['labels']\n",
+        "    #         class_counts += labels\n",
+        "\n",
+        "    #     class_weights = num_samples / (self.num_labels * class_counts)\n",
+        "    #     class_weights = torch.clamp(class_weights, min=0.1, max=10.0)  # Prevent extreme values\n",
+        "\n",
+        "    #     return nn.Parameter(class_weights, requires_grad=False)\n",
+        "\n",
+        "    def freeze_layers(self, num_layers_to_finetune):\n",
+        "        \"\"\"\n",
+        "        Freeze layers of the TimeSformer backbone except for the last few.\n",
+        "        This is a form of transfer learning that preserves most of the pre-trained weights.\n",
+        "        \"\"\"\n",
+        "        # Freeze all layers\n",
+        "        for param in self.timesformer.parameters():\n",
+        "            param.requires_grad = False\n",
+        "\n",
+        "        # Unfreeze the last few layers\n",
+        "        for i, layer in enumerate(reversed(list(self.timesformer.encoder.layer))):\n",
+        "            if i < num_layers_to_finetune:\n",
+        "                for param in layer.parameters():\n",
+        "                    param.requires_grad = True\n",
+        "            else:\n",
+        "                break\n",
+        "\n",
+        "        # Always unfreeze the task adapters and task weights\n",
+        "        # This allows these task-specific components to be fully trained\n",
+        "        for adapter in self.task_adapters.values():\n",
+        "            for param in adapter.parameters():\n",
+        "                param.requires_grad = True\n",
+        "\n",
+        "        self.task_weights.requires_grad = True\n",
+        "\n",
+        "\n",
+        "    def forward(self, pixel_values, labels=None):\n",
+        "        \"\"\"\n",
+        "        Forward pass of the model.\n",
+        "\n",
+        "        Args:\n",
+        "        pixel_values: Input video tensor\n",
+        "        labels: Ground truth labels (optional)\n",
+        "\n",
+        "        Returns:\n",
+        "        Dictionary containing loss and logits (if labels provided), or just logits\n",
+        "        \"\"\"\n",
+        "        outputs = self.timesformer(pixel_values=pixel_values)\n",
+        "        pooled_output = outputs.last_hidden_state[:, 0]  # Use CLS token for classification\n",
+        "\n",
+        "        # Task-specific adaptation and prediction\n",
+        "        logits = torch.cat([self.task_adapters[label](pooled_output)\n",
+        "                            for label in self.id2label.values()], dim=1)\n",
+        "\n",
+        "        if labels is not None:\n",
+        "            # Binary Cross Entropy loss for multi-label classification\n",
+        "            # loss_fct = nn.BCEWithLogitsLoss(reduction='none')\n",
+        "            loss_fct = focal_loss\n",
+        "            loss = loss_fct(logits, labels)\n",
+        "\n",
+        "            # # use raw weights\n",
+        "            # class_weights = self.class_weights.to(loss.device)\n",
+        "            # task_weights = self.task_weights.to(loss.device)\n",
+        "\n",
+        "            # OR\n",
+        "            # Use sigmoid\n",
+        "            # task_weights = torch.sigmoid(self.task_weights)\n",
+        "            # class_weights = torch.sigmoid(self.class_weights)\n",
+        "\n",
+        "            # # OR\n",
+        "            # # Use softmax\n",
+        "            task_weights = F.softmax(self.task_weights, dim=0)\n",
+        "            # class_weights = F.softmax(self.class_weights, dim=0)\n",
+        "\n",
+        "            # weighted_loss = (loss * task_weights * class_weights).mean()\n",
+        "            weighted_loss = (loss * task_weights).mean()\n",
+        "\n",
+        "            return {\"loss\": weighted_loss, \"logits\": logits}\n",
+        "        return {\"logits\": logits}"
+      ],
+      "metadata": {
+        "id": "3-Pi5RSawGNd"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "source": [
+        "# @title Compute Metrics on specific or all artefacts\n",
+        "\n",
+        "import numpy as np\n",
+        "import torch\n",
+        "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score\n",
+        "\n",
+        "\n",
+        "artefacts_in_focus = ['motion_blur',\n",
+        "                      'dark_scenes',\n",
+        "                      'graininess',\n",
+        "                      'aliasing',\n",
+        "                      'banding']\n",
+        "USE_ALL_ARTEFACTS = False\n",
+        "\n",
+        "if USE_ALL_ARTEFACTS:\n",
+        "    artefacts_in_focus = list(label2id.keys())\n",
+        "\n",
+        "def compute_metrics(eval_pred):\n",
+        "    logits, labels = eval_pred\n",
+        "    probabilities = torch.sigmoid(torch.tensor(logits))\n",
+        "    predictions = (probabilities > 0.5).numpy()\n",
+        "\n",
+        "    results = {}\n",
+        "    auc_scores = []\n",
+        "    f1_scores = []\n",
+        "    accuracy_scores = []\n",
+        "    for artifact in artefacts_in_focus:\n",
+        "        i = label2id[artifact]  # Get the index of the artifact\n",
+        "        task_accuracy = accuracy_score(labels[:, i], predictions[:, i])\n",
+        "        task_f1 = f1_score(labels[:, i], predictions[:, i], average='binary', zero_division=0)\n",
+        "        task_auc = roc_auc_score(labels[:, i], probabilities[:, i])\n",
+        "        auc_scores.append(task_auc)\n",
+        "        f1_scores.append(task_f1)\n",
+        "        accuracy_scores.append(task_accuracy)\n",
+        "\n",
+        "        results.update({\n",
+        "            f\"accuracy_{artifact}\": task_accuracy,\n",
+        "            f\"f1_{artifact}\": task_f1,\n",
+        "            f\"auc_{artifact}\": task_auc\n",
+        "        })\n",
+        "\n",
+        "    # Compute overall metrics\n",
+        "    results.update({\n",
+        "        \"accuracy\": accuracy_score(labels.flatten(), predictions.flatten()),\n",
+        "        \"f1\": f1_score(labels, predictions, average='samples', zero_division=0),\n",
+        "        \"auc\": roc_auc_score(labels, probabilities, average='samples')\n",
+        "    })\n",
+        "\n",
+        "    # Add combined AUC for artifacts in focus\n",
+        "    results[\"source_artefacts_avg_auc\"] = np.mean(auc_scores)\n",
+        "    results[\"source_artefacts_avg_f1\"] = np.mean(f1_scores)\n",
+        "    results[\"source_artefacts_avg_accuracy\"] = np.mean(accuracy_scores)\n",
+        "\n",
+        "    return results\n",
+        "\n",
+        "# def compute_metrics(eval_pred):\n",
+        "#     logits, labels = eval_pred\n",
+        "#     probabilities = torch.sigmoid(torch.tensor(logits))\n",
+        "#     predictions = (probabilities > 0.5).numpy()\n",
+        "\n",
+        "#     results = {}\n",
+        "#     for i in range(len(id2label)):\n",
+        "#         task_accuracy = accuracy_score(labels[:, i], predictions[:, i])\n",
+        "#         task_f1 = f1_score(labels[:, i], predictions[:, i], average='binary', zero_division=0)\n",
+        "#         task_auc = roc_auc_score(labels[:, i], probabilities[:, i])\n",
+        "#         # task_precision = precision_score(labels[:, i], predictions[:, i], average='binary', zero_division=0)\n",
+        "#         # task_recall = recall_score(labels[:, i], predictions[:, i], average='binary', zero_division=0)\n",
+        "#         # Compute ROC AUC for each task\n",
+        "\n",
+        "#         results.update({\n",
+        "#             f\"accuracy_{id2label[i]}\": task_accuracy,\n",
+        "#             f\"f1_{id2label[i]}\": task_f1,\n",
+        "#             f\"auc_{id2label[i]}\": task_auc\n",
+        "#             # f\"precision_{id2label[i]}\": task_precision,\n",
+        "#             # f\"recall_{id2label[i]}\": task_recall,\n",
+        "#         })\n",
+        "\n",
+        "#     # Compute overall metrics\n",
+        "#     results.update({\n",
+        "#         \"accuracy\": accuracy_score(labels.flatten(), predictions.flatten()),\n",
+        "#         \"f1\": f1_score(labels, predictions, average='samples', zero_division=0),\n",
+        "#         \"auc\": roc_auc_score(labels, probabilities, average='samples')\n",
+        "#         # \"precision\": precision_score(labels, predictions, average='samples', zero_division=0),\n",
+        "#         # \"recall\": recall_score(labels, predictions, average='samples', zero_division=0),\n",
+        "#     })\n",
+        "#    return results"
+      ],
+      "cell_type": "code",
+      "metadata": {
+        "id": "DYeBpurb3szb"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# import torchsummary\n",
+        "\n",
+        "model = MultiTaskTimeSformer(num_labels,\n",
+        "                             id2label,\n",
+        "                             label2id,\n",
+        "                             train_dataset,\n",
+        "                             num_layers_to_finetune=8)\n",
+        "\n",
+        "\n",
+        "# You need to define input size to calcualte parameters\n",
+        "# torchsummary.summary(model, input_size=(8, 3, 224, 224))"
+      ],
+      "metadata": {
+        "id": "NqyI88XTwYXM",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "32d39e67-c168-44c4-9abe-f02e2f841ab9"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n",
+            "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
+            "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
+            "You will be able to reuse this secret in all of your notebooks.\n",
+            "Please note that authentication is recommended but still optional to access public models or datasets.\n",
+            "  warnings.warn(\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "model(train_dataset[0]['pixel_values'].unsqueeze(0), train_dataset[0]['labels'].unsqueeze(0))"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "pZR6e7jTq1bM",
+        "outputId": "4919cf51-0ecf-4d1b-ca67-7bf0e6874b77"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "{'loss': tensor(0.0048, grad_fn=<MeanBackward0>),\n",
+              " 'logits': tensor([[-0.0173,  0.1177,  0.0437, -0.4189, -0.0314,  0.1233,  0.2152, -0.0879,\n",
+              "           0.0947]], grad_fn=<CatBackward0>)}"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 18
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments\n",
+        "\n",
+        "# Early stopping callback\n",
+        "early_stopping_callback = EarlyStoppingCallback(\n",
+        "    early_stopping_patience=3,\n",
+        "    early_stopping_threshold=0.01\n",
+        ")\n",
+        "\n",
+        "# Get run name\n",
+        "run_name = get_run_name()\n",
+        "\n",
+        "# Set training parameters\n",
+        "batch_size = 32\n",
+        "gradient_accumulation_steps = 3\n",
+        "max_steps = 500\n",
+        "\n",
+        "# Print run parameters\n",
+        "print(f\"Run Name: {run_name}\")\n",
+        "print(f\"Batch Size: {batch_size}\")\n",
+        "print(f\"Gradient Accumulation Steps: {gradient_accumulation_steps}\")\n",
+        "print(f\"Max Steps: {max_steps}\")"
+      ],
+      "metadata": {
+        "id": "xRnjq61dXKNB",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "df22cd27-0c8f-4e2c-a832-c52dd44efc8b"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Run Name: run-2024-09-03--04-44-54\n",
+            "Batch Size: 32\n",
+            "Gradient Accumulation Steps: 3\n",
+            "Max Steps: 500\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Define the training arguments\n",
+        "training_args = TrainingArguments(\n",
+        "    output_dir=f\"./results/{run_name}\",\n",
+        "    max_steps=max_steps,  # Set maximum number of training steps\n",
+        "    per_device_train_batch_size=batch_size,\n",
+        "    per_device_eval_batch_size=batch_size,\n",
+        "    gradient_accumulation_steps=gradient_accumulation_steps,\n",
+        "    # logging_first_step=True,\n",
+        "    warmup_ratio=0.05,  # Warmup for 10% of total steps\n",
+        "    weight_decay=0.01,  # Weight decay\n",
+        "    learning_rate=1e-5,  # Learning rate\n",
+        "    logging_dir=f\"./logs/{run_name}\",\n",
+        "    logging_steps=10,  # Log every X steps\n",
+        "    eval_strategy=\"steps\",  # Evaluate based on steps\n",
+        "    eval_steps=20,  # Evaluate every X steps\n",
+        "    save_strategy=\"steps\",\n",
+        "    save_steps=60,  # Save every X steps\n",
+        "    load_best_model_at_end=True,\n",
+        "    metric_for_best_model=\"source_artefacts_avg_auc\",\n",
+        "    greater_is_better=True,  # Maximize the metric\n",
+        "    save_total_limit=2,  # Keep the 2 best checkpoints\n",
+        "    seed=42,  # Set a seed for reproducibility\n",
+        "    fp16=True,  # Enable mixed precision training\n",
+        "    report_to=\"wandb\",\n",
+        "    run_name=run_name,\n",
+        "    dataloader_num_workers=12  # Adjust as needed,\n",
+        ")\n",
+        "\n",
+        "# Create a trainer\n",
+        "trainer = Trainer(\n",
+        "    model=model,\n",
+        "    args=training_args,\n",
+        "    train_dataset=train_dataset,\n",
+        "    eval_dataset=val_dataset,\n",
+        "    compute_metrics=compute_metrics,\n",
+        "    callbacks=[early_stopping_callback],\n",
+        ")\n",
+        "\n",
+        "# Start training\n",
+        "trainer.train()"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 1000
+        },
+        "id": "wM5JRu1hx75X",
+        "outputId": "032a9f7c-6ffc-4f21-f9ec-744e1077578b",
+        "collapsed": true
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py:482: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n",
+            "  self.scaler = torch.cuda.amp.GradScaler(**kwargs)\n",
+            "max_steps is given, it will override any value given in num_train_epochs\n"
+          ]
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<IPython.core.display.HTML object>"
+            ],
+            "text/html": [
+              "Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to <a href='https://wandb.me/wandb-init' target=\"_blank\">the W&B docs</a>."
+            ]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<IPython.core.display.HTML object>"
+            ],
+            "text/html": [
+              "Tracking run with wandb version 0.17.8"
+            ]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<IPython.core.display.HTML object>"
+            ],
+            "text/html": [
+              "Run data is saved locally in <code>/content/wandb/run-20240903_034455-suem3kz6</code>"
+            ]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<IPython.core.display.HTML object>"
+            ],
+            "text/html": [
+              "Syncing run <strong><a href='https://wandb.ai/sanchitv7-university-of-surrey/video_classification/runs/suem3kz6' target=\"_blank\">run-2024-09-03--04-44-54</a></strong> to <a href='https://wandb.ai/sanchitv7-university-of-surrey/video_classification' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
+            ]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<IPython.core.display.HTML object>"
+            ],
+            "text/html": [
+              " View project at <a href='https://wandb.ai/sanchitv7-university-of-surrey/video_classification' target=\"_blank\">https://wandb.ai/sanchitv7-university-of-surrey/video_classification</a>"
+            ]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<IPython.core.display.HTML object>"
+            ],
+            "text/html": [
+              " View run at <a href='https://wandb.ai/sanchitv7-university-of-surrey/video_classification/runs/suem3kz6' target=\"_blank\">https://wandb.ai/sanchitv7-university-of-surrey/video_classification/runs/suem3kz6</a>"
+            ]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n"
+          ]
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<IPython.core.display.HTML object>"
+            ],
+            "text/html": [
+              "\n",
+              "    <div>\n",
+              "      \n",
+              "      <progress value='300' max='500' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
+              "      [300/500 08:32 < 05:43, 0.58 it/s, Epoch 10/18]\n",
+              "    </div>\n",
+              "    <table border=\"1\" class=\"dataframe\">\n",
+              "  <thead>\n",
+              " <tr style=\"text-align: left;\">\n",
+              "      <th>Step</th>\n",
+              "      <th>Training Loss</th>\n",
+              "      <th>Validation Loss</th>\n",
+              "      <th>Accuracy Motion Blur</th>\n",
+              "      <th>F1 Motion Blur</th>\n",
+              "      <th>Auc Motion Blur</th>\n",
+              "      <th>Accuracy Dark Scenes</th>\n",
+              "      <th>F1 Dark Scenes</th>\n",
+              "      <th>Auc Dark Scenes</th>\n",
+              "      <th>Accuracy Graininess</th>\n",
+              "      <th>F1 Graininess</th>\n",
+              "      <th>Auc Graininess</th>\n",
+              "      <th>Accuracy Aliasing</th>\n",
+              "      <th>F1 Aliasing</th>\n",
+              "      <th>Auc Aliasing</th>\n",
+              "      <th>Accuracy Banding</th>\n",
+              "      <th>F1 Banding</th>\n",
+              "      <th>Auc Banding</th>\n",
+              "      <th>Accuracy</th>\n",
+              "      <th>F1</th>\n",
+              "      <th>Auc</th>\n",
+              "      <th>Source Artefacts Avg Auc</th>\n",
+              "      <th>Source Artefacts Avg F1</th>\n",
+              "      <th>Source Artefacts Avg Accuracy</th>\n",
+              "    </tr>\n",
+              "  </thead>\n",
+              "  <tbody>\n",
+              "    <tr>\n",
+              "      <td>20</td>\n",
+              "      <td>0.005400</td>\n",
+              "      <td>0.004893</td>\n",
+              "      <td>0.536972</td>\n",
+              "      <td>0.263305</td>\n",
+              "      <td>0.497133</td>\n",
+              "      <td>0.625000</td>\n",
+              "      <td>0.419619</td>\n",
+              "      <td>0.602496</td>\n",
+              "      <td>0.718310</td>\n",
+              "      <td>0.200000</td>\n",
+              "      <td>0.537953</td>\n",
+              "      <td>0.654930</td>\n",
+              "      <td>0.155172</td>\n",
+              "      <td>0.513067</td>\n",
+              "      <td>0.554577</td>\n",
+              "      <td>0.216718</td>\n",
+              "      <td>0.466600</td>\n",
+              "      <td>0.567488</td>\n",
+              "      <td>0.363818</td>\n",
+              "      <td>0.543714</td>\n",
+              "      <td>0.523450</td>\n",
+              "      <td>0.250963</td>\n",
+              "      <td>0.617958</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <td>40</td>\n",
+              "      <td>0.004900</td>\n",
+              "      <td>0.004784</td>\n",
+              "      <td>0.589789</td>\n",
+              "      <td>0.016878</td>\n",
+              "      <td>0.536157</td>\n",
+              "      <td>0.811620</td>\n",
+              "      <td>0.467662</td>\n",
+              "      <td>0.727021</td>\n",
+              "      <td>0.746479</td>\n",
+              "      <td>0.132530</td>\n",
+              "      <td>0.556286</td>\n",
+              "      <td>0.686620</td>\n",
+              "      <td>0.135922</td>\n",
+              "      <td>0.566906</td>\n",
+              "      <td>0.691901</td>\n",
+              "      <td>0.033149</td>\n",
+              "      <td>0.533720</td>\n",
+              "      <td>0.614241</td>\n",
+              "      <td>0.292834</td>\n",
+              "      <td>0.595156</td>\n",
+              "      <td>0.584018</td>\n",
+              "      <td>0.157228</td>\n",
+              "      <td>0.705282</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <td>60</td>\n",
+              "      <td>0.004700</td>\n",
+              "      <td>0.004639</td>\n",
+              "      <td>0.595070</td>\n",
+              "      <td>0.087302</td>\n",
+              "      <td>0.559896</td>\n",
+              "      <td>0.830986</td>\n",
+              "      <td>0.589744</td>\n",
+              "      <td>0.758148</td>\n",
+              "      <td>0.744718</td>\n",
+              "      <td>0.189944</td>\n",
+              "      <td>0.565980</td>\n",
+              "      <td>0.700704</td>\n",
+              "      <td>0.267241</td>\n",
+              "      <td>0.588996</td>\n",
+              "      <td>0.700704</td>\n",
+              "      <td>0.212963</td>\n",
+              "      <td>0.583549</td>\n",
+              "      <td>0.617567</td>\n",
+              "      <td>0.345792</td>\n",
+              "      <td>0.615963</td>\n",
+              "      <td>0.611314</td>\n",
+              "      <td>0.269439</td>\n",
+              "      <td>0.714437</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <td>80</td>\n",
+              "      <td>0.004600</td>\n",
+              "      <td>0.004614</td>\n",
+              "      <td>0.610915</td>\n",
+              "      <td>0.159696</td>\n",
+              "      <td>0.577484</td>\n",
+              "      <td>0.830986</td>\n",
+              "      <td>0.606557</td>\n",
+              "      <td>0.769114</td>\n",
+              "      <td>0.750000</td>\n",
+              "      <td>0.236559</td>\n",
+              "      <td>0.585041</td>\n",
+              "      <td>0.705986</td>\n",
+              "      <td>0.283262</td>\n",
+              "      <td>0.595453</td>\n",
+              "      <td>0.734155</td>\n",
+              "      <td>0.334802</td>\n",
+              "      <td>0.602284</td>\n",
+              "      <td>0.620892</td>\n",
+              "      <td>0.352606</td>\n",
+              "      <td>0.624883</td>\n",
+              "      <td>0.625875</td>\n",
+              "      <td>0.324175</td>\n",
+              "      <td>0.726408</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <td>100</td>\n",
+              "      <td>0.004500</td>\n",
+              "      <td>0.004587</td>\n",
+              "      <td>0.625000</td>\n",
+              "      <td>0.202247</td>\n",
+              "      <td>0.599330</td>\n",
+              "      <td>0.825704</td>\n",
+              "      <td>0.605578</td>\n",
+              "      <td>0.781578</td>\n",
+              "      <td>0.757042</td>\n",
+              "      <td>0.233333</td>\n",
+              "      <td>0.582422</td>\n",
+              "      <td>0.713028</td>\n",
+              "      <td>0.345382</td>\n",
+              "      <td>0.611274</td>\n",
+              "      <td>0.748239</td>\n",
+              "      <td>0.401674</td>\n",
+              "      <td>0.616644</td>\n",
+              "      <td>0.634585</td>\n",
+              "      <td>0.378922</td>\n",
+              "      <td>0.637860</td>\n",
+              "      <td>0.638250</td>\n",
+              "      <td>0.357643</td>\n",
+              "      <td>0.733803</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <td>120</td>\n",
+              "      <td>0.004400</td>\n",
+              "      <td>0.004570</td>\n",
+              "      <td>0.626761</td>\n",
+              "      <td>0.237410</td>\n",
+              "      <td>0.605180</td>\n",
+              "      <td>0.836268</td>\n",
+              "      <td>0.620408</td>\n",
+              "      <td>0.782437</td>\n",
+              "      <td>0.753521</td>\n",
+              "      <td>0.239130</td>\n",
+              "      <td>0.585964</td>\n",
+              "      <td>0.709507</td>\n",
+              "      <td>0.331984</td>\n",
+              "      <td>0.623427</td>\n",
+              "      <td>0.762324</td>\n",
+              "      <td>0.478764</td>\n",
+              "      <td>0.626481</td>\n",
+              "      <td>0.629304</td>\n",
+              "      <td>0.385952</td>\n",
+              "      <td>0.646639</td>\n",
+              "      <td>0.644698</td>\n",
+              "      <td>0.381539</td>\n",
+              "      <td>0.737676</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <td>140</td>\n",
+              "      <td>0.004300</td>\n",
+              "      <td>0.004569</td>\n",
+              "      <td>0.632042</td>\n",
+              "      <td>0.256228</td>\n",
+              "      <td>0.609035</td>\n",
+              "      <td>0.836268</td>\n",
+              "      <td>0.623482</td>\n",
+              "      <td>0.784984</td>\n",
+              "      <td>0.755282</td>\n",
+              "      <td>0.240437</td>\n",
+              "      <td>0.599663</td>\n",
+              "      <td>0.713028</td>\n",
+              "      <td>0.340081</td>\n",
+              "      <td>0.633747</td>\n",
+              "      <td>0.774648</td>\n",
+              "      <td>0.492063</td>\n",
+              "      <td>0.634085</td>\n",
+              "      <td>0.631651</td>\n",
+              "      <td>0.389567</td>\n",
+              "      <td>0.653623</td>\n",
+              "      <td>0.652303</td>\n",
+              "      <td>0.390458</td>\n",
+              "      <td>0.742254</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <td>160</td>\n",
+              "      <td>0.004300</td>\n",
+              "      <td>0.004562</td>\n",
+              "      <td>0.642606</td>\n",
+              "      <td>0.302405</td>\n",
+              "      <td>0.619516</td>\n",
+              "      <td>0.839789</td>\n",
+              "      <td>0.628571</td>\n",
+              "      <td>0.787539</td>\n",
+              "      <td>0.751761</td>\n",
+              "      <td>0.261780</td>\n",
+              "      <td>0.609606</td>\n",
+              "      <td>0.714789</td>\n",
+              "      <td>0.386364</td>\n",
+              "      <td>0.637617</td>\n",
+              "      <td>0.776408</td>\n",
+              "      <td>0.520755</td>\n",
+              "      <td>0.637143</td>\n",
+              "      <td>0.635172</td>\n",
+              "      <td>0.389261</td>\n",
+              "      <td>0.654359</td>\n",
+              "      <td>0.658284</td>\n",
+              "      <td>0.419975</td>\n",
+              "      <td>0.745070</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <td>180</td>\n",
+              "      <td>0.004300</td>\n",
+              "      <td>0.004550</td>\n",
+              "      <td>0.658451</td>\n",
+              "      <td>0.370130</td>\n",
+              "      <td>0.630811</td>\n",
+              "      <td>0.839789</td>\n",
+              "      <td>0.634538</td>\n",
+              "      <td>0.786483</td>\n",
+              "      <td>0.750000</td>\n",
+              "      <td>0.275510</td>\n",
+              "      <td>0.623828</td>\n",
+              "      <td>0.709507</td>\n",
+              "      <td>0.395604</td>\n",
+              "      <td>0.641669</td>\n",
+              "      <td>0.776408</td>\n",
+              "      <td>0.538182</td>\n",
+              "      <td>0.643549</td>\n",
+              "      <td>0.632825</td>\n",
+              "      <td>0.410094</td>\n",
+              "      <td>0.661050</td>\n",
+              "      <td>0.665268</td>\n",
+              "      <td>0.442793</td>\n",
+              "      <td>0.746831</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <td>200</td>\n",
+              "      <td>0.004200</td>\n",
+              "      <td>0.004557</td>\n",
+              "      <td>0.661972</td>\n",
+              "      <td>0.380645</td>\n",
+              "      <td>0.633685</td>\n",
+              "      <td>0.841549</td>\n",
+              "      <td>0.640000</td>\n",
+              "      <td>0.793321</td>\n",
+              "      <td>0.755282</td>\n",
+              "      <td>0.315271</td>\n",
+              "      <td>0.631605</td>\n",
+              "      <td>0.716549</td>\n",
+              "      <td>0.383142</td>\n",
+              "      <td>0.650061</td>\n",
+              "      <td>0.779930</td>\n",
+              "      <td>0.542125</td>\n",
+              "      <td>0.643854</td>\n",
+              "      <td>0.635759</td>\n",
+              "      <td>0.403682</td>\n",
+              "      <td>0.666880</td>\n",
+              "      <td>0.670505</td>\n",
+              "      <td>0.452236</td>\n",
+              "      <td>0.751056</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <td>220</td>\n",
+              "      <td>0.004200</td>\n",
+              "      <td>0.004563</td>\n",
+              "      <td>0.665493</td>\n",
+              "      <td>0.383117</td>\n",
+              "      <td>0.635577</td>\n",
+              "      <td>0.845070</td>\n",
+              "      <td>0.645161</td>\n",
+              "      <td>0.797415</td>\n",
+              "      <td>0.758803</td>\n",
+              "      <td>0.311558</td>\n",
+              "      <td>0.633070</td>\n",
+              "      <td>0.721831</td>\n",
+              "      <td>0.423358</td>\n",
+              "      <td>0.659874</td>\n",
+              "      <td>0.786972</td>\n",
+              "      <td>0.546816</td>\n",
+              "      <td>0.646406</td>\n",
+              "      <td>0.639085</td>\n",
+              "      <td>0.414961</td>\n",
+              "      <td>0.668390</td>\n",
+              "      <td>0.674468</td>\n",
+              "      <td>0.462002</td>\n",
+              "      <td>0.755634</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <td>240</td>\n",
+              "      <td>0.004100</td>\n",
+              "      <td>0.004577</td>\n",
+              "      <td>0.670775</td>\n",
+              "      <td>0.398714</td>\n",
+              "      <td>0.636090</td>\n",
+              "      <td>0.841549</td>\n",
+              "      <td>0.640000</td>\n",
+              "      <td>0.795393</td>\n",
+              "      <td>0.755282</td>\n",
+              "      <td>0.328502</td>\n",
+              "      <td>0.635547</td>\n",
+              "      <td>0.723592</td>\n",
+              "      <td>0.407547</td>\n",
+              "      <td>0.665070</td>\n",
+              "      <td>0.779930</td>\n",
+              "      <td>0.548736</td>\n",
+              "      <td>0.652865</td>\n",
+              "      <td>0.641041</td>\n",
+              "      <td>0.421542</td>\n",
+              "      <td>0.671348</td>\n",
+              "      <td>0.676993</td>\n",
+              "      <td>0.464700</td>\n",
+              "      <td>0.754225</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <td>260</td>\n",
+              "      <td>0.004000</td>\n",
+              "      <td>0.004589</td>\n",
+              "      <td>0.674296</td>\n",
+              "      <td>0.416404</td>\n",
+              "      <td>0.640381</td>\n",
+              "      <td>0.839789</td>\n",
+              "      <td>0.637450</td>\n",
+              "      <td>0.794770</td>\n",
+              "      <td>0.765845</td>\n",
+              "      <td>0.338308</td>\n",
+              "      <td>0.642702</td>\n",
+              "      <td>0.725352</td>\n",
+              "      <td>0.430657</td>\n",
+              "      <td>0.669078</td>\n",
+              "      <td>0.771127</td>\n",
+              "      <td>0.542254</td>\n",
+              "      <td>0.653966</td>\n",
+              "      <td>0.642410</td>\n",
+              "      <td>0.428757</td>\n",
+              "      <td>0.672502</td>\n",
+              "      <td>0.680179</td>\n",
+              "      <td>0.473015</td>\n",
+              "      <td>0.755282</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <td>280</td>\n",
+              "      <td>0.004000</td>\n",
+              "      <td>0.004598</td>\n",
+              "      <td>0.679577</td>\n",
+              "      <td>0.431250</td>\n",
+              "      <td>0.642402</td>\n",
+              "      <td>0.843310</td>\n",
+              "      <td>0.639676</td>\n",
+              "      <td>0.794082</td>\n",
+              "      <td>0.769366</td>\n",
+              "      <td>0.360976</td>\n",
+              "      <td>0.646245</td>\n",
+              "      <td>0.725352</td>\n",
+              "      <td>0.438849</td>\n",
+              "      <td>0.674354</td>\n",
+              "      <td>0.772887</td>\n",
+              "      <td>0.544170</td>\n",
+              "      <td>0.653787</td>\n",
+              "      <td>0.641432</td>\n",
+              "      <td>0.428405</td>\n",
+              "      <td>0.675171</td>\n",
+              "      <td>0.682174</td>\n",
+              "      <td>0.482984</td>\n",
+              "      <td>0.758099</td>\n",
+              "    </tr>\n",
+              "    <tr>\n",
+              "      <td>300</td>\n",
+              "      <td>0.003900</td>\n",
+              "      <td>0.004605</td>\n",
+              "      <td>0.683099</td>\n",
+              "      <td>0.437500</td>\n",
+              "      <td>0.642299</td>\n",
+              "      <td>0.839789</td>\n",
+              "      <td>0.637450</td>\n",
+              "      <td>0.796408</td>\n",
+              "      <td>0.765845</td>\n",
+              "      <td>0.369668</td>\n",
+              "      <td>0.642294</td>\n",
+              "      <td>0.721831</td>\n",
+              "      <td>0.443662</td>\n",
+              "      <td>0.680224</td>\n",
+              "      <td>0.774648</td>\n",
+              "      <td>0.552448</td>\n",
+              "      <td>0.655878</td>\n",
+              "      <td>0.640845</td>\n",
+              "      <td>0.426792</td>\n",
+              "      <td>0.675000</td>\n",
+              "      <td>0.683421</td>\n",
+              "      <td>0.488146</td>\n",
+              "      <td>0.757042</td>\n",
+              "    </tr>\n",
+              "  </tbody>\n",
+              "</table><p>"
+            ]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n",
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n"
+          ]
+        },
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "TrainOutput(global_step=300, training_loss=0.004406587146222591, metrics={'train_runtime': 524.3153, 'train_samples_per_second': 91.548, 'train_steps_per_second': 0.954, 'total_flos': 0.0, 'train_loss': 0.004406587146222591, 'epoch': 10.714285714285714})"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 20
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# wandb.finish()"
+      ],
+      "metadata": {
+        "id": "BVTEmrXU21zV",
+        "collapsed": true
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# After training\n",
+        "test_results = trainer.evaluate(eval_dataset=test_dataset)\n",
+        "\n",
+        "log_run_info(run_name, model, training_args)\n",
+        "save_test_results_to_json(test_results, artefacts_in_focus)\n",
+        "\n",
+        "# run_info_path = os.path.join(\"./logs\", run_name, \"run_info.txt\")\n",
+        "\n",
+        "wandb.save(os.path.join(\"./logs\", run_name, '*'))\n",
+        "\n",
+        "wandb.finish()\n",
+        "\n",
+        "# Print the metrics\n",
+        "print(\"Test set metrics:\")\n",
+        "# for key, value in test_results.items():\n",
+        "#     print(f\"{key}: {value}\")\n",
+        "\n",
+        "\n",
+        "# # List of artifacts\n",
+        "# artifacts = [\n",
+        "#     \"motion_blur\",\n",
+        "#     \"dark_scenes\",\n",
+        "#     \"graininess\",\n",
+        "#     \"aliasing\",\n",
+        "#     \"banding\",\n",
+        "#     \"transmission_error\",\n",
+        "#     \"spatial_blur\n",
+        "#     \"black_screen\",\n",
+        "#     \"frame_drop\"\n",
+        "#     ]\n",
+        "\n",
+        "\n",
+        "# Print overall metrics\n",
+        "print(\"Overall Metrics:\")\n",
+        "print(f\"Loss:     {test_results['eval_loss']:.4f}\")\n",
+        "print(f\"Accuracy: {test_results['eval_accuracy']:.4f}\")\n",
+        "print(f\"F1 Score: {test_results['eval_f1']:.4f}\")\n",
+        "print(f\"AUC:      {test_results['eval_auc']:.4f}\")\n",
+        "\n",
+        "# Print metrics for each artifact\n",
+        "for artefact in artefacts_in_focus:\n",
+        "    print_artifact_metrics(artefact, test_results)\n",
+        "\n",
+        "# Print additional information\n",
+        "print(\"\\nAdditional Information:\")\n",
+        "print(f\"Runtime:                 {test_results['eval_runtime']:.4f} seconds\")\n",
+        "print(f\"Samples per second:      {test_results['eval_samples_per_second']:.2f}\")\n",
+        "print(f\"Steps per second:        {test_results['eval_steps_per_second']:.2f}\")\n",
+        "print(f\"Epoch:                   {test_results['epoch']:.2f}\")"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 1000,
+          "referenced_widgets": [
+            "5445dbf6343740f9b89f1c2133762370",
+            "4f31997fd8704ba3b8a6ce1c0b7988a4",
+            "b369d7a80aae4d59aebbe5c0a0673bd4",
+            "e0509e15a96e496d96df8d6a82df3782",
+            "06a3702b5df94291bc87a20f7b1ef6b5",
+            "be79c9ec0adb4d2699b58354ad8493b7",
+            "8e26e03388da4aea8dfbfb93e7a54f87",
+            "4af18f329a83473783f0f32151b79476"
+          ]
+        },
+        "id": "_PjDoIcNdDNl",
+        "outputId": "c723de83-c878-4cf7-a067-fc7c5ff3c66a"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+            "  self.pid = os.fork()\n"
+          ]
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<IPython.core.display.HTML object>"
+            ],
+            "text/html": [
+              "\n",
+              "    <div>\n",
+              "      \n",
+              "      <progress value='19' max='19' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
+              "      [19/19 00:03]\n",
+              "    </div>\n",
+              "    "
+            ]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Symlinked 2 files into the W&B run directory, call wandb.save again to sync new files.\n"
+          ]
+        },
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Test results saved to ./logs/run-2024-09-03--04-44-54/test_set_results.json\n"
+          ]
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "VBox(children=(Label(value='0.007 MB of 0.007 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
+            ],
+            "application/vnd.jupyter.widget-view+json": {
+              "version_major": 2,
+              "version_minor": 0,
+              "model_id": "5445dbf6343740f9b89f1c2133762370"
+            }
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<IPython.core.display.HTML object>"
+            ],
+            "text/html": [
+              "<style>\n",
+              "    table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
+              "    .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
+              "    .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
+              "    </style>\n",
+              "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy</td><td>▁▅▆▆▇▇▇▇▇▇██████</td></tr><tr><td>eval/accuracy_aliasing</td><td>▁▄▆▆▇▆▇▇▆▇██████</td></tr><tr><td>eval/accuracy_banding</td><td>▁▅▅▆▇▇█████████▆</td></tr><tr><td>eval/accuracy_dark_scenes</td><td>▁▇██▇██████████▄</td></tr><tr><td>eval/accuracy_graininess</td><td>▁▅▅▅▆▆▆▆▅▆▇▆███▁</td></tr><tr><td>eval/accuracy_motion_blur</td><td>▁▂▂▃▃▃▃▄▄▄▄▄▄▄▄█</td></tr><tr><td>eval/auc</td><td>▁▃▄▅▆▆▆▆▇▇▇▇▇▇▇█</td></tr><tr><td>eval/auc_aliasing</td><td>▁▃▄▄▅▆▆▆▆▇▇▇███▇</td></tr><tr><td>eval/auc_banding</td><td>▁▃▅▆▇▇▇▇███████▇</td></tr><tr><td>eval/auc_dark_scenes</td><td>▁▅▇▇▇▇█████████▆</td></tr><tr><td>eval/auc_graininess</td><td>▁▂▂▃▃▃▄▄▅▅▅▅▆▆▆█</td></tr><tr><td>eval/auc_motion_blur</td><td>▁▂▂▃▃▃▄▄▄▄▄▄▄▄▄█</td></tr><tr><td>eval/f1</td><td>▄▁▃▄▅▅▅▅▆▆▆▇▇▇▇█</td></tr><tr><td>eval/f1_aliasing</td><td>▁▁▄▄▅▅▅▆▆▆▇▇▇▇▇█</td></tr><tr><td>eval/f1_banding</td><td>▃▁▃▅▆▇▇████████▇</td></tr><tr><td>eval/f1_dark_scenes</td><td>▁▂▆▇▇▇▇▇███████▅</td></tr><tr><td>eval/f1_graininess</td><td>▂▁▂▃▃▃▃▃▄▄▄▅▅▅▅█</td></tr><tr><td>eval/f1_motion_blur</td><td>▄▁▂▃▃▄▄▅▅▆▆▆▆▆▆█</td></tr><tr><td>eval/loss</td><td>█▆▃▂▂▁▁▁▁▁▁▂▂▂▂▂</td></tr><tr><td>eval/runtime</td><td>█▃▃▁▃▂▃▂▃▂▂▂▂▂▂▄</td></tr><tr><td>eval/samples_per_second</td><td>▁▆▆█▆▆▆▇▆▆▇▆▇▇▇▆</td></tr><tr><td>eval/source_artefacts_avg_accuracy</td><td>▁▅▆▆▇▇▇▇▇███████</td></tr><tr><td>eval/source_artefacts_avg_auc</td><td>▁▃▄▅▅▆▆▆▆▇▇▇▇▇▇█</td></tr><tr><td>eval/source_artefacts_avg_f1</td><td>▃▁▃▄▅▅▆▆▇▇▇▇▇▇▇█</td></tr><tr><td>eval/steps_per_second</td><td>▁▆▆█▆▆▆▇▆▆▇▆▇▇▇▇</td></tr><tr><td>train/epoch</td><td>▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇██████</td></tr><tr><td>train/global_step</td><td>▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇██████</td></tr><tr><td>train/grad_norm</td><td>█▅▃▄▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▂▂▁▁▁▁</td></tr><tr><td>train/learning_rate</td><td>▁▆███▇▇▇▆▆▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▁</td></tr><tr><td>train/loss</td><td>██▆▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy</td><td>0.64346</td></tr><tr><td>eval/accuracy_aliasing</td><td>0.72432</td></tr><tr><td>eval/accuracy_banding</td><td>0.73459</td></tr><tr><td>eval/accuracy_dark_scenes</td><td>0.73288</td></tr><tr><td>eval/accuracy_graininess</td><td>0.71918</td></tr><tr><td>eval/accuracy_motion_blur</td><td>0.83048</td></tr><tr><td>eval/auc</td><td>0.68896</td></tr><tr><td>eval/auc_aliasing</td><td>0.66432</td></tr><tr><td>eval/auc_banding</td><td>0.63605</td></tr><tr><td>eval/auc_dark_scenes</td><td>0.72998</td></tr><tr><td>eval/auc_graininess</td><td>0.69241</td></tr><tr><td>eval/auc_motion_blur</td><td>0.80945</td></tr><tr><td>eval/f1</td><td>0.45053</td></tr><tr><td>eval/f1_aliasing</td><td>0.47213</td></tr><tr><td>eval/f1_banding</td><td>0.47458</td></tr><tr><td>eval/f1_dark_scenes</td><td>0.55172</td></tr><tr><td>eval/f1_graininess</td><td>0.51479</td></tr><tr><td>eval/f1_motion_blur</td><td>0.57511</td></tr><tr><td>eval/loss</td><td>0.0046</td></tr><tr><td>eval/runtime</td><td>6.1607</td></tr><tr><td>eval/samples_per_second</td><td>94.795</td></tr><tr><td>eval/source_artefacts_avg_accuracy</td><td>0.74829</td></tr><tr><td>eval/source_artefacts_avg_auc</td><td>0.70644</td></tr><tr><td>eval/source_artefacts_avg_f1</td><td>0.51767</td></tr><tr><td>eval/steps_per_second</td><td>3.084</td></tr><tr><td>total_flos</td><td>0.0</td></tr><tr><td>train/epoch</td><td>10.71429</td></tr><tr><td>train/global_step</td><td>300</td></tr><tr><td>train/grad_norm</td><td>0.00885</td></tr><tr><td>train/learning_rate</td><td>0.0</td></tr><tr><td>train/loss</td><td>0.0039</td></tr><tr><td>train_loss</td><td>0.00441</td></tr><tr><td>train_runtime</td><td>524.3153</td></tr><tr><td>train_samples_per_second</td><td>91.548</td></tr><tr><td>train_steps_per_second</td><td>0.954</td></tr></table><br/></div></div>"
+            ]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<IPython.core.display.HTML object>"
+            ],
+            "text/html": [
+              " View run <strong style=\"color:#cdcd00\">run-2024-09-03--04-44-54</strong> at: <a href='https://wandb.ai/sanchitv7-university-of-surrey/video_classification/runs/suem3kz6' target=\"_blank\">https://wandb.ai/sanchitv7-university-of-surrey/video_classification/runs/suem3kz6</a><br/> View project at: <a href='https://wandb.ai/sanchitv7-university-of-surrey/video_classification' target=\"_blank\">https://wandb.ai/sanchitv7-university-of-surrey/video_classification</a><br/>Synced 5 W&B file(s), 0 media file(s), 1 artifact file(s) and 2 other file(s)"
+            ]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<IPython.core.display.HTML object>"
+            ],
+            "text/html": [
+              "Find logs at: <code>./wandb/run-20240903_034455-suem3kz6/logs</code>"
+            ]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<IPython.core.display.HTML object>"
+            ],
+            "text/html": [
+              "The new W&B backend becomes opt-out in version 0.18.0; try it out with `wandb.require(\"core\")`! See https://wandb.me/wandb-core for more information."
+            ]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Test set metrics:\n",
+            "Overall Metrics:\n",
+            "Loss:     0.0046\n",
+            "Accuracy: 0.6435\n",
+            "F1 Score: 0.4505\n",
+            "AUC:      0.6890\n",
+            "\n",
+            "Motion_blur Metrics:\n",
+            "  Accuracy: 0.8305\n",
+            "  F1 Score: 0.5751\n",
+            "  AUC:      0.8094\n",
+            "\n",
+            "Dark_scenes Metrics:\n",
+            "  Accuracy: 0.7329\n",
+            "  F1 Score: 0.5517\n",
+            "  AUC:      0.7300\n",
+            "\n",
+            "Graininess Metrics:\n",
+            "  Accuracy: 0.7192\n",
+            "  F1 Score: 0.5148\n",
+            "  AUC:      0.6924\n",
+            "\n",
+            "Aliasing Metrics:\n",
+            "  Accuracy: 0.7243\n",
+            "  F1 Score: 0.4721\n",
+            "  AUC:      0.6643\n",
+            "\n",
+            "Banding Metrics:\n",
+            "  Accuracy: 0.7346\n",
+            "  F1 Score: 0.4746\n",
+            "  AUC:      0.6361\n",
+            "\n",
+            "Additional Information:\n",
+            "Runtime:                 6.1607 seconds\n",
+            "Samples per second:      94.80\n",
+            "Steps per second:        3.08\n",
+            "Epoch:                   10.71\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "run_name"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 35
+        },
+        "id": "FrFfu6P6vT5g",
+        "outputId": "38dc4f60-8cf9-4203-ed84-a87aa18d88e3"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "'run-2024-09-03--01-02-23'"
+            ],
+            "application/vnd.google.colaboratory.intrinsic+json": {
+              "type": "string"
+            }
+          },
+          "metadata": {},
+          "execution_count": 168
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# @title Save or Not\n",
+        "# @markdown **Save the model or not**\n",
+        "\n",
+        "# form input for saving or not\n",
+        "save_logs_or_not = True # @param {type: \"boolean\"}\n",
+        "save_model_or_not = False # @param {type: \"boolean\"}\n",
+        "\n",
+        "if save_logs_or_not:\n",
+        "    # Move logs to drive with today's data\n",
+        "    !scp -r logs/{run_name}/ /content/drive/MyDrive/MSProjectMisc\n",
+        "\n",
+        "if save_model_or_not:\n",
+        "    # Save the final model\n",
+        "    final_model_path = os.path.join(\"./results\", run_name)\n",
+        "\n",
+        "    trainer.save_model(final_model_path)\n",
+        "    print(f\"Final model saved to {final_model_path}\")\n",
+        "\n",
+        "    # Move the final model to drive\n",
+        "    !scp -r {final_model_path} /content/drive/MyDrive/MSProjectMisc/{run_name}/"
+      ],
+      "metadata": {
+        "id": "3haYagRvmtMj",
+        "cellView": "form"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# !rm -rf logs results\n",
+        "\n",
+        "# !scp -r logs_successful_run_2/ /content/drive/MyDrive/MSProjectMisc"
+      ],
+      "metadata": {
+        "id": "U-e7LLQeJOFC"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# from safetensors import safe_open\n",
+        "# model = MultiTaskTimeSformer(num_labels, id2label, label2id, train_dataset)\n",
+        "\n",
+        "# # Load the saved weights\n",
+        "# checkpoint_path = \"/content/results/run-2024-08-30--00-49-32/checkpoint-400/model.safetensors\"\n",
+        "# with safe_open(checkpoint_path, framework=\"pt\", device=\"cpu\") as f:\n",
+        "#     state_dict = {key: f.get_tensor(key) for key in f.keys()}\n",
+        "\n",
+        "# model.load_state_dict(state_dict)\n",
+        "\n",
+        "# # Set the model to evaluation mode\n",
+        "# model.eval()\n",
+        "\n",
+        "# print(\"Model loaded successfully!\")"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "uuLHkPhNzp4R",
+        "outputId": "c2702569-133b-4b3c-b0bb-ae730ea0e9c8"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Model loaded successfully!\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [],
+      "metadata": {
+        "id": "_I9XUa4nz-yl"
+      },
+      "execution_count": null,
+      "outputs": []
+    }
+  ]
+}
\ No newline at end of file