diff --git a/.idea/misc.xml b/.idea/misc.xml index c4d39259202f7675b57dbaaad9a079a6daba4e40..971b53d78c8b5de31aaa8b89a37f31bcd8a95a94 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,7 +1,7 @@ <?xml version="1.0" encoding="UTF-8"?> <project version="4"> - <component name="GithubDefaultAccount"> - <option name="defaultAccountId" value="93306623-ad79-4b24-b96e-9f1ce5d3c2a0" /> + <component name="Black"> + <option name="sdkName" value="video_classification" /> </component> <component name="ProjectRootManager" version="2" project-jdk-name="video_classification" project-jdk-type="Python SDK" /> </project> \ No newline at end of file diff --git a/notebooks/experimentation.ipynb b/notebooks/experimentation.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..edc749cc73f83f11b47f109da5b0eb20a35f6cff --- /dev/null +++ b/notebooks/experimentation.ipynb @@ -0,0 +1,2261 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "A100", + "machine_shape": "hm", + "collapsed_sections": [ + "S8cK8UUk16O6" + ] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "5445dbf6343740f9b89f1c2133762370": { + "model_module": "@jupyter-widgets/controls", + "model_name": "VBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "VBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "VBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_4f31997fd8704ba3b8a6ce1c0b7988a4", + "IPY_MODEL_b369d7a80aae4d59aebbe5c0a0673bd4" + ], + "layout": "IPY_MODEL_e0509e15a96e496d96df8d6a82df3782" + } + }, + "4f31997fd8704ba3b8a6ce1c0b7988a4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "LabelModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "LabelModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "LabelView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_06a3702b5df94291bc87a20f7b1ef6b5", + "placeholder": "​", + "style": "IPY_MODEL_be79c9ec0adb4d2699b58354ad8493b7", + "value": "0.030 MB of 0.030 MB uploaded\r" + } + }, + "b369d7a80aae4d59aebbe5c0a0673bd4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8e26e03388da4aea8dfbfb93e7a54f87", + "max": 1, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_4af18f329a83473783f0f32151b79476", + "value": 1 + } + }, + "e0509e15a96e496d96df8d6a82df3782": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "06a3702b5df94291bc87a20f7b1ef6b5": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "be79c9ec0adb4d2699b58354ad8493b7": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "8e26e03388da4aea8dfbfb93e7a54f87": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4af18f329a83473783f0f32151b79476": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + } + } + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Setup" + ], + "metadata": { + "id": "S8cK8UUk16O6" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install wandb\n", + "\n", + "import wandb\n", + "wandb.login()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "collapsed": true, + "id": "--B8rqMm_83Y", + "outputId": "a65d1ee5-12e7-410b-c1f0-be59421b6db3" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: wandb in /usr/local/lib/python3.10/dist-packages (0.17.8)\n", + "Requirement already satisfied: click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.7)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (0.4.0)\n", + "Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.1.43)\n", + "Requirement already satisfied: platformdirs in /usr/local/lib/python3.10/dist-packages (from wandb) (4.2.2)\n", + "Requirement already satisfied: protobuf!=4.21.0,<6,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n", + "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (5.9.5)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from wandb) (6.0.2)\n", + "Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (2.32.3)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (2.13.0)\n", + "Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb) (1.3.3)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (71.0.4)\n", + "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.11)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.8)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2024.7.4)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.1)\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33msanchitv7\u001b[0m (\u001b[33msanchitv7-university-of-surrey\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "True" + ] + }, + "metadata": {}, + "execution_count": 4 + } + ] + }, + { + "cell_type": "code", + "source": [ + "%env WANDB_PROJECT=video_classification\n", + "%env WANDB_LOG_MODEL=\"checkpoint\"" + ], + "metadata": { + "id": "BQf0IsFozEB_", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "3410c5d3-1f0f-4624-d923-3c1252ff720d" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "env: WANDB_PROJECT=video_classification\n", + "env: WANDB_LOG_MODEL=\"checkpoint\"\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "%%capture\n", + "import importlib\n", + "\n", + "def install_if_missing(package_name):\n", + " try:\n", + " importlib.import_module(package_name)\n", + " print(f\"'{package_name}' is already installed.\")\n", + " except ImportError:\n", + " print(f\"'{package_name}' not found. Installing...\")\n", + " !pip install {package_name}\n", + "\n", + "\n", + "install_if_missing('evaluate')\n", + "install_if_missing('datasets')\n", + "install_if_missing('torchsummary')" + ], + "metadata": { + "collapsed": true, + "id": "FpugRyrUzJDC" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "dataset_zip = 'BVIArtefact_8_crops_all_videos'" + ], + "metadata": { + "id": "nHwmAod7BIek" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "\n", + "if os.getcwd() != '/content':\n", + " %cd /content\n", + "\n", + "if not os.path.isdir('/content/data'):\n", + " !mkdir data\n", + "\n", + "if not os.path.isdir(f'/content/data/{dataset_zip}'):\n", + " !scp /content/drive/MyDrive/MSProjectMisc/{dataset_zip}.zip .\n", + " !unzip -q {dataset_zip}.zip\n", + "\n", + " !rm -rf /content/__MACOSX\n", + " !mv {dataset_zip}/ data/\n", + " !rm -rf {dataset_zip}.zip\n", + " %cd /content/data/{dataset_zip}\n", + " !pwd\n", + " !find ./train ./val ./test -name \"._*\" -type f -delete\n", + "\n", + " %cd ../\n", + " %cd ../" + ], + "metadata": { + "id": "YCCFP89gle4J", + "collapsed": true + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# @title # Set Seed\n", + "import random\n", + "import numpy as np\n", + "import torch\n", + "import transformers\n", + "\n", + "def set_seed(seed: int = 42):\n", + " \"\"\"\n", + " Set seeds for reproducibility.\n", + "\n", + " Args:\n", + " seed (int): The seed to set for all random number generators.\n", + " \"\"\"\n", + " random.seed(seed)\n", + " np.random.seed(seed)\n", + " torch.manual_seed(seed)\n", + " torch.cuda.manual_seed(seed)\n", + " torch.cuda.manual_seed_all(seed) # if using multi-GPU\n", + " torch.backends.cudnn.deterministic = True\n", + " torch.backends.cudnn.benchmark = False\n", + "\n", + " # Set seed for Hugging Face transformers\n", + " transformers.set_seed(seed)\n", + "\n", + "# At the start of your script\n", + "set_seed(42)" + ], + "metadata": { + "id": "b5Hdnlzr3jnh" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# @title #Dataset Declaration\n", + "import torch\n", + "import os\n", + "import json\n", + "from torch.utils.data import Dataset\n", + "from typing import Dict, List\n", + "from torchvision import transforms\n", + "\n", + "class BVIArtefactDataset(Dataset):\n", + " def __init__(self, root_dir: str, split: str = 'train'):\n", + " self.root_dir = os.path.join(root_dir, split)\n", + " self.labels_file = os.path.join(self.root_dir, 'labels.json')\n", + " self.crop_files = [f for f in os.listdir(self.root_dir) if f.endswith('.pt')]\n", + " with open(self.labels_file, 'r') as f:\n", + " self.labels = json.load(f)\n", + " self.label_names = list(next(iter(self.labels.values())).keys())\n", + " self.id2label = {i: label for i, label in enumerate(self.label_names)}\n", + " self.label2id = {label: i for i, label in self.id2label.items()}\n", + " self.transform = transforms.Compose([\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.RandomRotation(10)\n", + " ]) if split == 'train' else None\n", + "\n", + " def __len__(self):\n", + " return len(self.crop_files)\n", + "\n", + " def __getitem__(self, idx):\n", + " crop_file = self.crop_files[idx]\n", + " crop_path = os.path.join(self.root_dir, crop_file)\n", + " # load video crops\n", + " video = torch.load(crop_path,\n", + " map_location='cpu',\n", + " weights_only=False)\n", + " # Get labels\n", + " label_dict = self.labels[crop_file]\n", + " labels = torch.tensor([label_dict[name] for name in self.label_names], dtype=torch.float32)\n", + " return {\"pixel_values\": video, \"labels\": labels}\n", + "\n", + " def get_label_mappings(self):\n", + " return self.id2label, self.label2id" + ], + "metadata": { + "id": "W3egs-vYwGWi", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Inspect Dataset" + ], + "metadata": { + "id": "OyjCVZr3zovr" + } + }, + { + "cell_type": "code", + "source": [ + "from pprint import pprint\n", + "# Create datasets\n", + "root_dir = f'data/{dataset_zip}'\n", + "train_dataset = BVIArtefactDataset(root_dir, split='train')\n", + "val_dataset = BVIArtefactDataset(root_dir, split='val')\n", + "test_dataset = BVIArtefactDataset(root_dir, split='test')\n", + "\n", + "# Check the shape of an item\n", + "print('Shape of one patch of a video:')\n", + "print(train_dataset[0]['pixel_values'].shape) # Should print torch.Size([8, 3, 224, 224])\n", + "print()\n", + "print('Shape of the label:')\n", + "print(train_dataset[0]['labels'].shape) # Should print torch.Size([num_labe\n", + "\n", + "id2label, label2id = BVIArtefactDataset(root_dir, split='train').get_label_mappings()\n", + "\n", + "num_labels = len(id2label)\n", + "\n", + "print(f'\\nindex to label mapping:')\n", + "pprint(id2label)\n", + "print()\n", + "print(f'label to index mapping:')\n", + "pprint(label2id)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NEpCmHAgkhp8", + "outputId": "5574f74c-f77d-499f-c9d2-4e9c10361587" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Shape of one patch of a video:\n", + "torch.Size([8, 3, 224, 224])\n", + "\n", + "Shape of the label:\n", + "torch.Size([9])\n", + "\n", + "index to label mapping:\n", + "{0: 'black_screen',\n", + " 1: 'frame_drop',\n", + " 2: 'spatial_blur',\n", + " 3: 'transmission_error',\n", + " 4: 'aliasing',\n", + " 5: 'banding',\n", + " 6: 'dark_scenes',\n", + " 7: 'graininess',\n", + " 8: 'motion_blur'}\n", + "\n", + "label to index mapping:\n", + "{'aliasing': 4,\n", + " 'banding': 5,\n", + " 'black_screen': 0,\n", + " 'dark_scenes': 6,\n", + " 'frame_drop': 1,\n", + " 'graininess': 7,\n", + " 'motion_blur': 8,\n", + " 'spatial_blur': 2,\n", + " 'transmission_error': 3}\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "train_dataset[0]['labels']" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FEsAgMicnwMI", + "outputId": "35f13d94-6386-48e8-9ee0-1227bdd0f386" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([0., 1., 0., 0., 0., 0., 0., 1., 1.])" + ] + }, + "metadata": {}, + "execution_count": 12 + } + ] + }, + { + "cell_type": "code", + "source": [ + "import imageio\n", + "import numpy as np\n", + "from IPython.display import Image, display\n", + "import torch\n", + "import random\n", + "\n", + "def unnormalize_img(img):\n", + " \"\"\"Un-normalizes the image pixels.\"\"\"\n", + " mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)\n", + " std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)\n", + " img = (img * std) + mean\n", + " img = (img * 255).clamp(0, 255).byte()\n", + " return img\n", + "\n", + "def create_gif(video_tensor, filename=\"sample.gif\"):\n", + " \"\"\"Prepares a GIF from a video tensor.\n", + "\n", + " The video tensor is expected to have the following shape:\n", + " (num_frames, num_channels, height, width).\n", + " \"\"\"\n", + " frames = []\n", + " for video_frame in video_tensor:\n", + " frame_unnormalized = unnormalize_img(video_frame).permute(1, 2, 0).numpy()\n", + " frames.append(frame_unnormalized)\n", + " kargs = {\"duration\": 0.25}\n", + " imageio.mimsave(filename, frames, \"GIF\", **kargs)\n", + " return filename\n", + "\n", + "def display_gif(video_tensor, gif_name=\"sample.gif\"):\n", + " \"\"\"Prepares and displays a GIF from a video tensor.\"\"\"\n", + " gif_filename = create_gif(video_tensor, gif_name)\n", + " return Image(filename=gif_filename)\n", + "\n", + "# Get a sample from your dataset\n", + "random_sample = random.randint(0, len(train_dataset))\n", + "\n", + "sample = train_dataset[random_sample]\n", + "video_tensor = sample['pixel_values']\n", + "\n", + "# Display the GIF\n", + "display(display_gif(video_tensor, \"sample_crop.gif\"))\n", + "pprint({id2label[i]: sample['labels'][i] for i in range(len(sample['labels']))})" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 402 + }, + "id": "ovOL5Rtqwb3_", + "outputId": "8503307d-1326-42d2-fb7f-2f190f284da4" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/gif": "R0lGODlh4ADgAIEAAAAAAAAAAAAAAAAAACwAAAAA4ADgAEAI/wABCBxIsKDBgwgTKlzIsKHDhxAjSpxIsaLFixgzatzIsaPHjyBDihxJsqTJkyhTqlzJsqXLlzBjypxJs6bNmzhz6tzJs6fPn0CDCh1KtKjRo0iTKl3KtKnTp1CjSp1KtarVq1izat3KtavXr2DDih1LtqzZs2jTql3Ltq3bt3Djyp1Lt67du3jz6t3Lt6/fv4ADCx5MuLDhw4gTK17MuLHjx5AjS55MubLly5gza97MubPnz6BDix5NurTp06hTq17NurXr17Bjy55Nu7bt27hz697Nu7fv38CDCx9OvLjx48iTK1/OvLnz59CjS59Ovbr169iza9/Ovbv37+DDi3AfT768+fPo06tfz769+/fw48ufT7++/fv48+vfz7+///8ABijggAQWaOCBCCao4IIMNujggxBGKOGEFFZo4YUYZqjhhhx26OGHIIYo4ogklmjiiSimqOKKLLbo4oswxijjjDTWaOONOOao4448ahUQADs=\n", + "text/plain": [ + "<IPython.core.display.Image object>" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{'aliasing': tensor(0.),\n", + " 'banding': tensor(1.),\n", + " 'black_screen': tensor(1.),\n", + " 'dark_scenes': tensor(0.),\n", + " 'frame_drop': tensor(1.),\n", + " 'graininess': tensor(0.),\n", + " 'motion_blur': tensor(0.),\n", + " 'spatial_blur': tensor(0.),\n", + " 'transmission_error': tensor(1.)}\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# @title # Utility Functions\n", + "import datetime\n", + "import os\n", + "import json\n", + "from typing import Dict, Any\n", + "import pytz\n", + "\n", + "# Set the timezone to UTC+1\n", + "timezone = pytz.timezone('Europe/London') # This is an example; choose a timezone with UTC+1\n", + "\n", + "\n", + "def get_run_name():\n", + " # Generate a timestamp\n", + " timestamp = datetime.datetime.now(timezone).strftime(\"%Y-%m-%d--%H-%M-%S\")\n", + "\n", + " # Create a unique run name\n", + " run_name = f\"run-{timestamp}\"\n", + "\n", + " # Create directories for this run\n", + " os.makedirs(f\"./logs/{run_name}\", exist_ok=True)\n", + " os.makedirs(f\"./results/{run_name}\", exist_ok=True)\n", + "\n", + " return run_name\n", + "\n", + "\n", + "def log_run_info(run_name, model, training_args):\n", + "\n", + " run_info_path = os.path.join(\"./logs\", run_name, \"run_info.txt\")\n", + " with open(run_info_path, \"w\") as f:\n", + " f.write(f\"Run Name: {run_name}\\n\")\n", + " f.write(f\"Model: {type(model).__name__}\\n\")\n", + " f.write(\"Training Arguments:\\n\")\n", + " for arg, value in vars(training_args).items():\n", + " f.write(f\" {arg}: {value}\\n\")\n", + "\n", + "\n", + "def save_test_results_to_json(test_results: Dict[str, Any], artifacts: list):\n", + " \"\"\"\n", + " Format test results and save them to a JSON file.\n", + "\n", + " Args:\n", + " test_results (Dict[str, Any]): Dictionary containing test results\n", + " artifacts (list): List of artifact names\n", + " output_file (str): Path to the output JSON file\n", + " \"\"\"\n", + "\n", + " output_file = os.path.join(\"./logs\", run_name, \"test_set_results.json\")\n", + "\n", + " # formatted_results = {\n", + " # \"overall_metrics\": {\n", + " # \"loss\": test_results['eval_loss'],\n", + " # \"accuracy\": test_results['eval_accuracy'],\n", + " # \"f1_score\": test_results['eval_f1'],\n", + " # \"auc\": test_results['eval_auc'],\n", + " # \"combined_auc_focus\": test_results['source_artefacts_avg_auc']\n", + " # },\n", + " # \"artifact_metrics\": {},\n", + " # \"additional_info\": {\n", + " # \"runtime\": test_results['eval_runtime'],\n", + " # \"samples_per_second\": test_results['eval_samples_per_second'],\n", + " # \"steps_per_second\": test_results['eval_steps_per_second'],\n", + " # \"epoch\": test_results['epoch']\n", + " # }\n", + " # }\n", + "\n", + " # for artifact in artifacts:\n", + " # formatted_results[\"artifact_metrics\"][artifact] = {\n", + " # \"accuracy\": test_results[f'eval_accuracy_{artifact}'],\n", + " # \"f1_score\": test_results[f'eval_f1_{artifact}'],\n", + " # \"auc\": test_results[f'eval_auc_{artifact}']\n", + " # }\n", + "\n", + " with open(output_file, 'w') as f:\n", + " json.dump(test_results, f, indent=2)\n", + "\n", + " print(f\"Test results saved to {output_file}\")\n", + "\n", + "# Function to print metrics for a specific artifact\n", + "def print_artifact_metrics(artifact, metrics):\n", + " print(f\"\\n{artifact.capitalize()} Metrics:\")\n", + " print(f\" Accuracy: {metrics[f'eval_accuracy_{artifact}']:.4f}\")\n", + " print(f\" F1 Score: {metrics[f'eval_f1_{artifact}']:.4f}\")\n", + " print(f\" AUC: {metrics[f'eval_auc_{artifact}']:.4f}\")" + ], + "metadata": { + "id": "KrAU5_dt1p1N" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# @title # Model Declaration\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from transformers import TimesformerModel\n", + "from torch.nn import functional as F\n", + "\n", + "\n", + "def focal_loss(logits, labels, alpha=0.25, gamma=2.0):\n", + " \"\"\"\n", + " Focal loss for multi-label classification.\n", + " \"\"\"\n", + " BCE_loss = F.binary_cross_entropy_with_logits(logits, labels, reduction='none')\n", + " pt = torch.exp(-BCE_loss)\n", + " focal_loss = alpha * (1-pt)**gamma * BCE_loss\n", + " return focal_loss.mean()\n", + "\n", + "\n", + "class MultiTaskTimeSformer(nn.Module):\n", + " def __init__(self, num_labels, id2label, label2id, dataset, num_layers_to_finetune=4):\n", + " \"\"\"\n", + " Multi-task TimeSformer model for video artifact detection.\n", + "\n", + " Args:\n", + " num_labels: Number of artifact classes\n", + " id2label: Mapping from label index to label name\n", + " label2id: Mapping from label name to label index\n", + " dataset: Dataset used for calculating initial weights\n", + " num_layers_to_finetune: Number of layers to fine-tune in the TimeSformer backbone\n", + " \"\"\"\n", + " super().__init__()\n", + " self.num_labels = num_labels\n", + " self.id2label = id2label\n", + " self.label2id = label2id\n", + "\n", + " # TimeSformer backbone\n", + " # Using a pre-trained model for transfer learning\n", + " self.timesformer = TimesformerModel.from_pretrained(\n", + " \"facebook/timesformer-base-finetuned-k400\",\n", + " ignore_mismatched_sizes=True\n", + " )\n", + "\n", + " # Task-specific adaptation layers\n", + " # These layers adapt the TimeSformer's output for each specific artifact detection task\n", + " # self.task_adapters = nn.ModuleDict({\n", + " # label: nn.Sequential(\n", + " # nn.Linear(768, 512),\n", + " # nn.ReLU(),\n", + " # nn.Dropout(0.1), # Dropout for regularization\n", + " # nn.Linear(512, 256),\n", + " # nn.ReLU(),\n", + " # nn.Dropout(0.1),\n", + " # nn.Linear(256, 1)\n", + " # ) for label in id2label.values()\n", + " # })\n", + "\n", + " self.task_adapters = nn.ModuleDict({\n", + " label: nn.Sequential(\n", + " nn.Linear(768, 512),\n", + " nn.LayerNorm(512),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.3),\n", + " nn.Linear(512, 256),\n", + " nn.LayerNorm(256),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.2),\n", + " nn.Linear(256, 128),\n", + " nn.LayerNorm(128),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(128, 1)\n", + " ) for label in id2label.values()\n", + " })\n", + "\n", + " # self.task_adapters = nn.ModuleDict({\n", + " # label: nn.Sequential(\n", + " # nn.Linear(768, 256),\n", + " # nn.LayerNorm(256), # Layer Normalisatoin\n", + " # nn.ReLU(),\n", + " # nn.Dropout(0.3), # Dropout for regularization\n", + " # nn.Linear(256, 1)\n", + " # ) for label in id2label.values()\n", + " # })\n", + "\n", + " # Calculate and set initial task weights\n", + " # This helps address class imbalance from the start of training\n", + " initial_weights = self.calculate_initial_weights(dataset)\n", + " # self.class_weights = nn.Parameter(initial_weights, requires_grad=False)\n", + " self.task_weights = nn.Parameter(initial_weights)\n", + "\n", + " # Freeze all layers except the last few\n", + " # This allows for fine-tuning of the model while preserving pre-trained knowledge\n", + " self.freeze_layers(num_layers_to_finetune)\n", + "\n", + " def calculate_initial_weights(self, dataset):\n", + " \"\"\"\n", + " Calculate class weights based on class frequencies and positive/negative ratios.\n", + " This helps address class imbalance in multi-label classification.\n", + "\n", + " Args:\n", + " dataset: The dataset containing labels\n", + "\n", + " Returns:\n", + " combined_weights: Tensor of weights for each class\n", + " \"\"\"\n", + " num_samples = len(dataset)\n", + " num_labels = len(dataset.id2label)\n", + "\n", + " class_frequencies = torch.zeros(num_labels)\n", + " pos_neg_ratios = torch.zeros(num_labels)\n", + "\n", + " # Calculate class frequencies\n", + " for i in range(num_samples):\n", + " labels = dataset[i]['labels']\n", + " class_frequencies += labels\n", + "\n", + " # Calculate positive/negative ratios\n", + " for i in range(num_labels):\n", + " pos_count = class_frequencies[i]\n", + " neg_count = num_samples - pos_count\n", + " pos_neg_ratios[i] = neg_count / pos_count if pos_count > 0 else 1.0\n", + "\n", + " # Inverse class frequencies (rare classes get higher weights)\n", + " class_weights = 1.0 / class_frequencies\n", + " class_weights[class_weights == float('inf')] = 0 # Handle division by zero\n", + "\n", + " # Normalize weights\n", + " class_weights = class_weights / class_weights.sum()\n", + " pos_neg_ratios = pos_neg_ratios / pos_neg_ratios.sum()\n", + "\n", + " # Combine class frequencies and pos/neg ratios\n", + " # This balances between overall class rarity and class imbalance within each label\n", + " combined_weights = (class_weights + pos_neg_ratios) / 2\n", + "\n", + " return combined_weights\n", + "\n", + "\n", + " # def calculate_class_weights(self, dataset):\n", + " # num_samples = len(dataset)\n", + " # class_counts = torch.zeros(self.num_labels)\n", + "\n", + " # for i in range(num_samples):\n", + " # labels = dataset[i]['labels']\n", + " # class_counts += labels\n", + "\n", + " # class_weights = num_samples / (self.num_labels * class_counts)\n", + " # class_weights = torch.clamp(class_weights, min=0.1, max=10.0) # Prevent extreme values\n", + "\n", + " # return nn.Parameter(class_weights, requires_grad=False)\n", + "\n", + " def freeze_layers(self, num_layers_to_finetune):\n", + " \"\"\"\n", + " Freeze layers of the TimeSformer backbone except for the last few.\n", + " This is a form of transfer learning that preserves most of the pre-trained weights.\n", + " \"\"\"\n", + " # Freeze all layers\n", + " for param in self.timesformer.parameters():\n", + " param.requires_grad = False\n", + "\n", + " # Unfreeze the last few layers\n", + " for i, layer in enumerate(reversed(list(self.timesformer.encoder.layer))):\n", + " if i < num_layers_to_finetune:\n", + " for param in layer.parameters():\n", + " param.requires_grad = True\n", + " else:\n", + " break\n", + "\n", + " # Always unfreeze the task adapters and task weights\n", + " # This allows these task-specific components to be fully trained\n", + " for adapter in self.task_adapters.values():\n", + " for param in adapter.parameters():\n", + " param.requires_grad = True\n", + "\n", + " self.task_weights.requires_grad = True\n", + "\n", + "\n", + " def forward(self, pixel_values, labels=None):\n", + " \"\"\"\n", + " Forward pass of the model.\n", + "\n", + " Args:\n", + " pixel_values: Input video tensor\n", + " labels: Ground truth labels (optional)\n", + "\n", + " Returns:\n", + " Dictionary containing loss and logits (if labels provided), or just logits\n", + " \"\"\"\n", + " outputs = self.timesformer(pixel_values=pixel_values)\n", + " pooled_output = outputs.last_hidden_state[:, 0] # Use CLS token for classification\n", + "\n", + " # Task-specific adaptation and prediction\n", + " logits = torch.cat([self.task_adapters[label](pooled_output)\n", + " for label in self.id2label.values()], dim=1)\n", + "\n", + " if labels is not None:\n", + " # Binary Cross Entropy loss for multi-label classification\n", + " # loss_fct = nn.BCEWithLogitsLoss(reduction='none')\n", + " loss_fct = focal_loss\n", + " loss = loss_fct(logits, labels)\n", + "\n", + " # # use raw weights\n", + " # class_weights = self.class_weights.to(loss.device)\n", + " # task_weights = self.task_weights.to(loss.device)\n", + "\n", + " # OR\n", + " # Use sigmoid\n", + " # task_weights = torch.sigmoid(self.task_weights)\n", + " # class_weights = torch.sigmoid(self.class_weights)\n", + "\n", + " # # OR\n", + " # # Use softmax\n", + " task_weights = F.softmax(self.task_weights, dim=0)\n", + " # class_weights = F.softmax(self.class_weights, dim=0)\n", + "\n", + " # weighted_loss = (loss * task_weights * class_weights).mean()\n", + " weighted_loss = (loss * task_weights).mean()\n", + "\n", + " return {\"loss\": weighted_loss, \"logits\": logits}\n", + " return {\"logits\": logits}" + ], + "metadata": { + "id": "3-Pi5RSawGNd" + }, + "execution_count": null, + "outputs": [] + }, + { + "source": [ + "# @title Compute Metrics on specific or all artefacts\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score\n", + "\n", + "\n", + "artefacts_in_focus = ['motion_blur',\n", + " 'dark_scenes',\n", + " 'graininess',\n", + " 'aliasing',\n", + " 'banding']\n", + "USE_ALL_ARTEFACTS = False\n", + "\n", + "if USE_ALL_ARTEFACTS:\n", + " artefacts_in_focus = list(label2id.keys())\n", + "\n", + "def compute_metrics(eval_pred):\n", + " logits, labels = eval_pred\n", + " probabilities = torch.sigmoid(torch.tensor(logits))\n", + " predictions = (probabilities > 0.5).numpy()\n", + "\n", + " results = {}\n", + " auc_scores = []\n", + " f1_scores = []\n", + " accuracy_scores = []\n", + " for artifact in artefacts_in_focus:\n", + " i = label2id[artifact] # Get the index of the artifact\n", + " task_accuracy = accuracy_score(labels[:, i], predictions[:, i])\n", + " task_f1 = f1_score(labels[:, i], predictions[:, i], average='binary', zero_division=0)\n", + " task_auc = roc_auc_score(labels[:, i], probabilities[:, i])\n", + " auc_scores.append(task_auc)\n", + " f1_scores.append(task_f1)\n", + " accuracy_scores.append(task_accuracy)\n", + "\n", + " results.update({\n", + " f\"accuracy_{artifact}\": task_accuracy,\n", + " f\"f1_{artifact}\": task_f1,\n", + " f\"auc_{artifact}\": task_auc\n", + " })\n", + "\n", + " # Compute overall metrics\n", + " results.update({\n", + " \"accuracy\": accuracy_score(labels.flatten(), predictions.flatten()),\n", + " \"f1\": f1_score(labels, predictions, average='samples', zero_division=0),\n", + " \"auc\": roc_auc_score(labels, probabilities, average='samples')\n", + " })\n", + "\n", + " # Add combined AUC for artifacts in focus\n", + " results[\"source_artefacts_avg_auc\"] = np.mean(auc_scores)\n", + " results[\"source_artefacts_avg_f1\"] = np.mean(f1_scores)\n", + " results[\"source_artefacts_avg_accuracy\"] = np.mean(accuracy_scores)\n", + "\n", + " return results\n", + "\n", + "# def compute_metrics(eval_pred):\n", + "# logits, labels = eval_pred\n", + "# probabilities = torch.sigmoid(torch.tensor(logits))\n", + "# predictions = (probabilities > 0.5).numpy()\n", + "\n", + "# results = {}\n", + "# for i in range(len(id2label)):\n", + "# task_accuracy = accuracy_score(labels[:, i], predictions[:, i])\n", + "# task_f1 = f1_score(labels[:, i], predictions[:, i], average='binary', zero_division=0)\n", + "# task_auc = roc_auc_score(labels[:, i], probabilities[:, i])\n", + "# # task_precision = precision_score(labels[:, i], predictions[:, i], average='binary', zero_division=0)\n", + "# # task_recall = recall_score(labels[:, i], predictions[:, i], average='binary', zero_division=0)\n", + "# # Compute ROC AUC for each task\n", + "\n", + "# results.update({\n", + "# f\"accuracy_{id2label[i]}\": task_accuracy,\n", + "# f\"f1_{id2label[i]}\": task_f1,\n", + "# f\"auc_{id2label[i]}\": task_auc\n", + "# # f\"precision_{id2label[i]}\": task_precision,\n", + "# # f\"recall_{id2label[i]}\": task_recall,\n", + "# })\n", + "\n", + "# # Compute overall metrics\n", + "# results.update({\n", + "# \"accuracy\": accuracy_score(labels.flatten(), predictions.flatten()),\n", + "# \"f1\": f1_score(labels, predictions, average='samples', zero_division=0),\n", + "# \"auc\": roc_auc_score(labels, probabilities, average='samples')\n", + "# # \"precision\": precision_score(labels, predictions, average='samples', zero_division=0),\n", + "# # \"recall\": recall_score(labels, predictions, average='samples', zero_division=0),\n", + "# })\n", + "# return results" + ], + "cell_type": "code", + "metadata": { + "id": "DYeBpurb3szb" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# import torchsummary\n", + "\n", + "model = MultiTaskTimeSformer(num_labels,\n", + " id2label,\n", + " label2id,\n", + " train_dataset,\n", + " num_layers_to_finetune=8)\n", + "\n", + "\n", + "# You need to define input size to calcualte parameters\n", + "# torchsummary.summary(model, input_size=(8, 3, 224, 224))" + ], + "metadata": { + "id": "NqyI88XTwYXM", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "32d39e67-c168-44c4-9abe-f02e2f841ab9" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "model(train_dataset[0]['pixel_values'].unsqueeze(0), train_dataset[0]['labels'].unsqueeze(0))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pZR6e7jTq1bM", + "outputId": "4919cf51-0ecf-4d1b-ca67-7bf0e6874b77" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'loss': tensor(0.0048, grad_fn=<MeanBackward0>),\n", + " 'logits': tensor([[-0.0173, 0.1177, 0.0437, -0.4189, -0.0314, 0.1233, 0.2152, -0.0879,\n", + " 0.0947]], grad_fn=<CatBackward0>)}" + ] + }, + "metadata": {}, + "execution_count": 18 + } + ] + }, + { + "cell_type": "code", + "source": [ + "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments\n", + "\n", + "# Early stopping callback\n", + "early_stopping_callback = EarlyStoppingCallback(\n", + " early_stopping_patience=3,\n", + " early_stopping_threshold=0.01\n", + ")\n", + "\n", + "# Get run name\n", + "run_name = get_run_name()\n", + "\n", + "# Set training parameters\n", + "batch_size = 32\n", + "gradient_accumulation_steps = 3\n", + "max_steps = 500\n", + "\n", + "# Print run parameters\n", + "print(f\"Run Name: {run_name}\")\n", + "print(f\"Batch Size: {batch_size}\")\n", + "print(f\"Gradient Accumulation Steps: {gradient_accumulation_steps}\")\n", + "print(f\"Max Steps: {max_steps}\")" + ], + "metadata": { + "id": "xRnjq61dXKNB", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "df22cd27-0c8f-4e2c-a832-c52dd44efc8b" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Run Name: run-2024-09-03--04-44-54\n", + "Batch Size: 32\n", + "Gradient Accumulation Steps: 3\n", + "Max Steps: 500\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Define the training arguments\n", + "training_args = TrainingArguments(\n", + " output_dir=f\"./results/{run_name}\",\n", + " max_steps=max_steps, # Set maximum number of training steps\n", + " per_device_train_batch_size=batch_size,\n", + " per_device_eval_batch_size=batch_size,\n", + " gradient_accumulation_steps=gradient_accumulation_steps,\n", + " # logging_first_step=True,\n", + " warmup_ratio=0.05, # Warmup for 10% of total steps\n", + " weight_decay=0.01, # Weight decay\n", + " learning_rate=1e-5, # Learning rate\n", + " logging_dir=f\"./logs/{run_name}\",\n", + " logging_steps=10, # Log every X steps\n", + " eval_strategy=\"steps\", # Evaluate based on steps\n", + " eval_steps=20, # Evaluate every X steps\n", + " save_strategy=\"steps\",\n", + " save_steps=60, # Save every X steps\n", + " load_best_model_at_end=True,\n", + " metric_for_best_model=\"source_artefacts_avg_auc\",\n", + " greater_is_better=True, # Maximize the metric\n", + " save_total_limit=2, # Keep the 2 best checkpoints\n", + " seed=42, # Set a seed for reproducibility\n", + " fp16=True, # Enable mixed precision training\n", + " report_to=\"wandb\",\n", + " run_name=run_name,\n", + " dataloader_num_workers=12 # Adjust as needed,\n", + ")\n", + "\n", + "# Create a trainer\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=val_dataset,\n", + " compute_metrics=compute_metrics,\n", + " callbacks=[early_stopping_callback],\n", + ")\n", + "\n", + "# Start training\n", + "trainer.train()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "wM5JRu1hx75X", + "outputId": "032a9f7c-6ffc-4f21-f9ec-744e1077578b", + "collapsed": true + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py:482: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n", + " self.scaler = torch.cuda.amp.GradScaler(**kwargs)\n", + "max_steps is given, it will override any value given in num_train_epochs\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "<IPython.core.display.HTML object>" + ], + "text/html": [ + "Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to <a href='https://wandb.me/wandb-init' target=\"_blank\">the W&B docs</a>." + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "<IPython.core.display.HTML object>" + ], + "text/html": [ + "Tracking run with wandb version 0.17.8" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "<IPython.core.display.HTML object>" + ], + "text/html": [ + "Run data is saved locally in <code>/content/wandb/run-20240903_034455-suem3kz6</code>" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "<IPython.core.display.HTML object>" + ], + "text/html": [ + "Syncing run <strong><a href='https://wandb.ai/sanchitv7-university-of-surrey/video_classification/runs/suem3kz6' target=\"_blank\">run-2024-09-03--04-44-54</a></strong> to <a href='https://wandb.ai/sanchitv7-university-of-surrey/video_classification' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "<IPython.core.display.HTML object>" + ], + "text/html": [ + " View project at <a href='https://wandb.ai/sanchitv7-university-of-surrey/video_classification' target=\"_blank\">https://wandb.ai/sanchitv7-university-of-surrey/video_classification</a>" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "<IPython.core.display.HTML object>" + ], + "text/html": [ + " View run at <a href='https://wandb.ai/sanchitv7-university-of-surrey/video_classification/runs/suem3kz6' target=\"_blank\">https://wandb.ai/sanchitv7-university-of-surrey/video_classification/runs/suem3kz6</a>" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "<IPython.core.display.HTML object>" + ], + "text/html": [ + "\n", + " <div>\n", + " \n", + " <progress value='300' max='500' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", + " [300/500 08:32 < 05:43, 0.58 it/s, Epoch 10/18]\n", + " </div>\n", + " <table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: left;\">\n", + " <th>Step</th>\n", + " <th>Training Loss</th>\n", + " <th>Validation Loss</th>\n", + " <th>Accuracy Motion Blur</th>\n", + " <th>F1 Motion Blur</th>\n", + " <th>Auc Motion Blur</th>\n", + " <th>Accuracy Dark Scenes</th>\n", + " <th>F1 Dark Scenes</th>\n", + " <th>Auc Dark Scenes</th>\n", + " <th>Accuracy Graininess</th>\n", + " <th>F1 Graininess</th>\n", + " <th>Auc Graininess</th>\n", + " <th>Accuracy Aliasing</th>\n", + " <th>F1 Aliasing</th>\n", + " <th>Auc Aliasing</th>\n", + " <th>Accuracy Banding</th>\n", + " <th>F1 Banding</th>\n", + " <th>Auc Banding</th>\n", + " <th>Accuracy</th>\n", + " <th>F1</th>\n", + " <th>Auc</th>\n", + " <th>Source Artefacts Avg Auc</th>\n", + " <th>Source Artefacts Avg F1</th>\n", + " <th>Source Artefacts Avg Accuracy</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <td>20</td>\n", + " <td>0.005400</td>\n", + " <td>0.004893</td>\n", + " <td>0.536972</td>\n", + " <td>0.263305</td>\n", + " <td>0.497133</td>\n", + " <td>0.625000</td>\n", + " <td>0.419619</td>\n", + " <td>0.602496</td>\n", + " <td>0.718310</td>\n", + " <td>0.200000</td>\n", + " <td>0.537953</td>\n", + " <td>0.654930</td>\n", + " <td>0.155172</td>\n", + " <td>0.513067</td>\n", + " <td>0.554577</td>\n", + " <td>0.216718</td>\n", + " <td>0.466600</td>\n", + " <td>0.567488</td>\n", + " <td>0.363818</td>\n", + " <td>0.543714</td>\n", + " <td>0.523450</td>\n", + " <td>0.250963</td>\n", + " <td>0.617958</td>\n", + " </tr>\n", + " <tr>\n", + " <td>40</td>\n", + " <td>0.004900</td>\n", + " <td>0.004784</td>\n", + " <td>0.589789</td>\n", + " <td>0.016878</td>\n", + " <td>0.536157</td>\n", + " <td>0.811620</td>\n", + " <td>0.467662</td>\n", + " <td>0.727021</td>\n", + " <td>0.746479</td>\n", + " <td>0.132530</td>\n", + " <td>0.556286</td>\n", + " <td>0.686620</td>\n", + " <td>0.135922</td>\n", + " <td>0.566906</td>\n", + " <td>0.691901</td>\n", + " <td>0.033149</td>\n", + " <td>0.533720</td>\n", + " <td>0.614241</td>\n", + " <td>0.292834</td>\n", + " <td>0.595156</td>\n", + " <td>0.584018</td>\n", + " <td>0.157228</td>\n", + " <td>0.705282</td>\n", + " </tr>\n", + " <tr>\n", + " <td>60</td>\n", + " <td>0.004700</td>\n", + " <td>0.004639</td>\n", + " <td>0.595070</td>\n", + " <td>0.087302</td>\n", + " <td>0.559896</td>\n", + " <td>0.830986</td>\n", + " <td>0.589744</td>\n", + " <td>0.758148</td>\n", + " <td>0.744718</td>\n", + " <td>0.189944</td>\n", + " <td>0.565980</td>\n", + " <td>0.700704</td>\n", + " <td>0.267241</td>\n", + " <td>0.588996</td>\n", + " <td>0.700704</td>\n", + " <td>0.212963</td>\n", + " <td>0.583549</td>\n", + " <td>0.617567</td>\n", + " <td>0.345792</td>\n", + " <td>0.615963</td>\n", + " <td>0.611314</td>\n", + " <td>0.269439</td>\n", + " <td>0.714437</td>\n", + " </tr>\n", + " <tr>\n", + " <td>80</td>\n", + " <td>0.004600</td>\n", + " <td>0.004614</td>\n", + " <td>0.610915</td>\n", + " <td>0.159696</td>\n", + " <td>0.577484</td>\n", + " <td>0.830986</td>\n", + " <td>0.606557</td>\n", + " <td>0.769114</td>\n", + " <td>0.750000</td>\n", + " <td>0.236559</td>\n", + " <td>0.585041</td>\n", + " <td>0.705986</td>\n", + " <td>0.283262</td>\n", + " <td>0.595453</td>\n", + " <td>0.734155</td>\n", + " <td>0.334802</td>\n", + " <td>0.602284</td>\n", + " <td>0.620892</td>\n", + " <td>0.352606</td>\n", + " <td>0.624883</td>\n", + " <td>0.625875</td>\n", + " <td>0.324175</td>\n", + " <td>0.726408</td>\n", + " </tr>\n", + " <tr>\n", + " <td>100</td>\n", + " <td>0.004500</td>\n", + " <td>0.004587</td>\n", + " <td>0.625000</td>\n", + " <td>0.202247</td>\n", + " <td>0.599330</td>\n", + " <td>0.825704</td>\n", + " <td>0.605578</td>\n", + " <td>0.781578</td>\n", + " <td>0.757042</td>\n", + " <td>0.233333</td>\n", + " <td>0.582422</td>\n", + " <td>0.713028</td>\n", + " <td>0.345382</td>\n", + " <td>0.611274</td>\n", + " <td>0.748239</td>\n", + " <td>0.401674</td>\n", + " <td>0.616644</td>\n", + " <td>0.634585</td>\n", + " <td>0.378922</td>\n", + " <td>0.637860</td>\n", + " <td>0.638250</td>\n", + " <td>0.357643</td>\n", + " <td>0.733803</td>\n", + " </tr>\n", + " <tr>\n", + " <td>120</td>\n", + " <td>0.004400</td>\n", + " <td>0.004570</td>\n", + " <td>0.626761</td>\n", + " <td>0.237410</td>\n", + " <td>0.605180</td>\n", + " <td>0.836268</td>\n", + " <td>0.620408</td>\n", + " <td>0.782437</td>\n", + " <td>0.753521</td>\n", + " <td>0.239130</td>\n", + " <td>0.585964</td>\n", + " <td>0.709507</td>\n", + " <td>0.331984</td>\n", + " <td>0.623427</td>\n", + " <td>0.762324</td>\n", + " <td>0.478764</td>\n", + " <td>0.626481</td>\n", + " <td>0.629304</td>\n", + " <td>0.385952</td>\n", + " <td>0.646639</td>\n", + " <td>0.644698</td>\n", + " <td>0.381539</td>\n", + " <td>0.737676</td>\n", + " </tr>\n", + " <tr>\n", + " <td>140</td>\n", + " <td>0.004300</td>\n", + " <td>0.004569</td>\n", + " <td>0.632042</td>\n", + " <td>0.256228</td>\n", + " <td>0.609035</td>\n", + " <td>0.836268</td>\n", + " <td>0.623482</td>\n", + " <td>0.784984</td>\n", + " <td>0.755282</td>\n", + " <td>0.240437</td>\n", + " <td>0.599663</td>\n", + " <td>0.713028</td>\n", + " <td>0.340081</td>\n", + " <td>0.633747</td>\n", + " <td>0.774648</td>\n", + " <td>0.492063</td>\n", + " <td>0.634085</td>\n", + " <td>0.631651</td>\n", + " <td>0.389567</td>\n", + " <td>0.653623</td>\n", + " <td>0.652303</td>\n", + " <td>0.390458</td>\n", + " <td>0.742254</td>\n", + " </tr>\n", + " <tr>\n", + " <td>160</td>\n", + " <td>0.004300</td>\n", + " <td>0.004562</td>\n", + " <td>0.642606</td>\n", + " <td>0.302405</td>\n", + " <td>0.619516</td>\n", + " <td>0.839789</td>\n", + " <td>0.628571</td>\n", + " <td>0.787539</td>\n", + " <td>0.751761</td>\n", + " <td>0.261780</td>\n", + " <td>0.609606</td>\n", + " <td>0.714789</td>\n", + " <td>0.386364</td>\n", + " <td>0.637617</td>\n", + " <td>0.776408</td>\n", + " <td>0.520755</td>\n", + " <td>0.637143</td>\n", + " <td>0.635172</td>\n", + " <td>0.389261</td>\n", + " <td>0.654359</td>\n", + " <td>0.658284</td>\n", + " <td>0.419975</td>\n", + " <td>0.745070</td>\n", + " </tr>\n", + " <tr>\n", + " <td>180</td>\n", + " <td>0.004300</td>\n", + " <td>0.004550</td>\n", + " <td>0.658451</td>\n", + " <td>0.370130</td>\n", + " <td>0.630811</td>\n", + " <td>0.839789</td>\n", + " <td>0.634538</td>\n", + " <td>0.786483</td>\n", + " <td>0.750000</td>\n", + " <td>0.275510</td>\n", + " <td>0.623828</td>\n", + " <td>0.709507</td>\n", + " <td>0.395604</td>\n", + " <td>0.641669</td>\n", + " <td>0.776408</td>\n", + " <td>0.538182</td>\n", + " <td>0.643549</td>\n", + " <td>0.632825</td>\n", + " <td>0.410094</td>\n", + " <td>0.661050</td>\n", + " <td>0.665268</td>\n", + " <td>0.442793</td>\n", + " <td>0.746831</td>\n", + " </tr>\n", + " <tr>\n", + " <td>200</td>\n", + " <td>0.004200</td>\n", + " <td>0.004557</td>\n", + " <td>0.661972</td>\n", + " <td>0.380645</td>\n", + " <td>0.633685</td>\n", + " <td>0.841549</td>\n", + " <td>0.640000</td>\n", + " <td>0.793321</td>\n", + " <td>0.755282</td>\n", + " <td>0.315271</td>\n", + " <td>0.631605</td>\n", + " <td>0.716549</td>\n", + " <td>0.383142</td>\n", + " <td>0.650061</td>\n", + " <td>0.779930</td>\n", + " <td>0.542125</td>\n", + " <td>0.643854</td>\n", + " <td>0.635759</td>\n", + " <td>0.403682</td>\n", + " <td>0.666880</td>\n", + " <td>0.670505</td>\n", + " <td>0.452236</td>\n", + " <td>0.751056</td>\n", + " </tr>\n", + " <tr>\n", + " <td>220</td>\n", + " <td>0.004200</td>\n", + " <td>0.004563</td>\n", + " <td>0.665493</td>\n", + " <td>0.383117</td>\n", + " <td>0.635577</td>\n", + " <td>0.845070</td>\n", + " <td>0.645161</td>\n", + " <td>0.797415</td>\n", + " <td>0.758803</td>\n", + " <td>0.311558</td>\n", + " <td>0.633070</td>\n", + " <td>0.721831</td>\n", + " <td>0.423358</td>\n", + " <td>0.659874</td>\n", + " <td>0.786972</td>\n", + " <td>0.546816</td>\n", + " <td>0.646406</td>\n", + " <td>0.639085</td>\n", + " <td>0.414961</td>\n", + " <td>0.668390</td>\n", + " <td>0.674468</td>\n", + " <td>0.462002</td>\n", + " <td>0.755634</td>\n", + " </tr>\n", + " <tr>\n", + " <td>240</td>\n", + " <td>0.004100</td>\n", + " <td>0.004577</td>\n", + " <td>0.670775</td>\n", + " <td>0.398714</td>\n", + " <td>0.636090</td>\n", + " <td>0.841549</td>\n", + " <td>0.640000</td>\n", + " <td>0.795393</td>\n", + " <td>0.755282</td>\n", + " <td>0.328502</td>\n", + " <td>0.635547</td>\n", + " <td>0.723592</td>\n", + " <td>0.407547</td>\n", + " <td>0.665070</td>\n", + " <td>0.779930</td>\n", + " <td>0.548736</td>\n", + " <td>0.652865</td>\n", + " <td>0.641041</td>\n", + " <td>0.421542</td>\n", + " <td>0.671348</td>\n", + " <td>0.676993</td>\n", + " <td>0.464700</td>\n", + " <td>0.754225</td>\n", + " </tr>\n", + " <tr>\n", + " <td>260</td>\n", + " <td>0.004000</td>\n", + " <td>0.004589</td>\n", + " <td>0.674296</td>\n", + " <td>0.416404</td>\n", + " <td>0.640381</td>\n", + " <td>0.839789</td>\n", + " <td>0.637450</td>\n", + " <td>0.794770</td>\n", + " <td>0.765845</td>\n", + " <td>0.338308</td>\n", + " <td>0.642702</td>\n", + " <td>0.725352</td>\n", + " <td>0.430657</td>\n", + " <td>0.669078</td>\n", + " <td>0.771127</td>\n", + " <td>0.542254</td>\n", + " <td>0.653966</td>\n", + " <td>0.642410</td>\n", + " <td>0.428757</td>\n", + " <td>0.672502</td>\n", + " <td>0.680179</td>\n", + " <td>0.473015</td>\n", + " <td>0.755282</td>\n", + " </tr>\n", + " <tr>\n", + " <td>280</td>\n", + " <td>0.004000</td>\n", + " <td>0.004598</td>\n", + " <td>0.679577</td>\n", + " <td>0.431250</td>\n", + " <td>0.642402</td>\n", + " <td>0.843310</td>\n", + " <td>0.639676</td>\n", + " <td>0.794082</td>\n", + " <td>0.769366</td>\n", + " <td>0.360976</td>\n", + " <td>0.646245</td>\n", + " <td>0.725352</td>\n", + " <td>0.438849</td>\n", + " <td>0.674354</td>\n", + " <td>0.772887</td>\n", + " <td>0.544170</td>\n", + " <td>0.653787</td>\n", + " <td>0.641432</td>\n", + " <td>0.428405</td>\n", + " <td>0.675171</td>\n", + " <td>0.682174</td>\n", + " <td>0.482984</td>\n", + " <td>0.758099</td>\n", + " </tr>\n", + " <tr>\n", + " <td>300</td>\n", + " <td>0.003900</td>\n", + " <td>0.004605</td>\n", + " <td>0.683099</td>\n", + " <td>0.437500</td>\n", + " <td>0.642299</td>\n", + " <td>0.839789</td>\n", + " <td>0.637450</td>\n", + " <td>0.796408</td>\n", + " <td>0.765845</td>\n", + " <td>0.369668</td>\n", + " <td>0.642294</td>\n", + " <td>0.721831</td>\n", + " <td>0.443662</td>\n", + " <td>0.680224</td>\n", + " <td>0.774648</td>\n", + " <td>0.552448</td>\n", + " <td>0.655878</td>\n", + " <td>0.640845</td>\n", + " <td>0.426792</td>\n", + " <td>0.675000</td>\n", + " <td>0.683421</td>\n", + " <td>0.488146</td>\n", + " <td>0.757042</td>\n", + " </tr>\n", + " </tbody>\n", + "</table><p>" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "TrainOutput(global_step=300, training_loss=0.004406587146222591, metrics={'train_runtime': 524.3153, 'train_samples_per_second': 91.548, 'train_steps_per_second': 0.954, 'total_flos': 0.0, 'train_loss': 0.004406587146222591, 'epoch': 10.714285714285714})" + ] + }, + "metadata": {}, + "execution_count": 20 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# wandb.finish()" + ], + "metadata": { + "id": "BVTEmrXU21zV", + "collapsed": true + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# After training\n", + "test_results = trainer.evaluate(eval_dataset=test_dataset)\n", + "\n", + "log_run_info(run_name, model, training_args)\n", + "save_test_results_to_json(test_results, artefacts_in_focus)\n", + "\n", + "# run_info_path = os.path.join(\"./logs\", run_name, \"run_info.txt\")\n", + "\n", + "wandb.save(os.path.join(\"./logs\", run_name, '*'))\n", + "\n", + "wandb.finish()\n", + "\n", + "# Print the metrics\n", + "print(\"Test set metrics:\")\n", + "# for key, value in test_results.items():\n", + "# print(f\"{key}: {value}\")\n", + "\n", + "\n", + "# # List of artifacts\n", + "# artifacts = [\n", + "# \"motion_blur\",\n", + "# \"dark_scenes\",\n", + "# \"graininess\",\n", + "# \"aliasing\",\n", + "# \"banding\",\n", + "# \"transmission_error\",\n", + "# \"spatial_blur\n", + "# \"black_screen\",\n", + "# \"frame_drop\"\n", + "# ]\n", + "\n", + "\n", + "# Print overall metrics\n", + "print(\"Overall Metrics:\")\n", + "print(f\"Loss: {test_results['eval_loss']:.4f}\")\n", + "print(f\"Accuracy: {test_results['eval_accuracy']:.4f}\")\n", + "print(f\"F1 Score: {test_results['eval_f1']:.4f}\")\n", + "print(f\"AUC: {test_results['eval_auc']:.4f}\")\n", + "\n", + "# Print metrics for each artifact\n", + "for artefact in artefacts_in_focus:\n", + " print_artifact_metrics(artefact, test_results)\n", + "\n", + "# Print additional information\n", + "print(\"\\nAdditional Information:\")\n", + "print(f\"Runtime: {test_results['eval_runtime']:.4f} seconds\")\n", + "print(f\"Samples per second: {test_results['eval_samples_per_second']:.2f}\")\n", + "print(f\"Steps per second: {test_results['eval_steps_per_second']:.2f}\")\n", + "print(f\"Epoch: {test_results['epoch']:.2f}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "5445dbf6343740f9b89f1c2133762370", + "4f31997fd8704ba3b8a6ce1c0b7988a4", + "b369d7a80aae4d59aebbe5c0a0673bd4", + "e0509e15a96e496d96df8d6a82df3782", + "06a3702b5df94291bc87a20f7b1ef6b5", + "be79c9ec0adb4d2699b58354ad8493b7", + "8e26e03388da4aea8dfbfb93e7a54f87", + "4af18f329a83473783f0f32151b79476" + ] + }, + "id": "_PjDoIcNdDNl", + "outputId": "c723de83-c878-4cf7-a067-fc7c5ff3c66a" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "<IPython.core.display.HTML object>" + ], + "text/html": [ + "\n", + " <div>\n", + " \n", + " <progress value='19' max='19' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", + " [19/19 00:03]\n", + " </div>\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Symlinked 2 files into the W&B run directory, call wandb.save again to sync new files.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Test results saved to ./logs/run-2024-09-03--04-44-54/test_set_results.json\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "VBox(children=(Label(value='0.007 MB of 0.007 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "5445dbf6343740f9b89f1c2133762370" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "<IPython.core.display.HTML object>" + ], + "text/html": [ + "<style>\n", + " table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n", + " .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n", + " .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n", + " </style>\n", + "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy</td><td>â–▅▆▆▇▇▇▇▇▇██████</td></tr><tr><td>eval/accuracy_aliasing</td><td>â–▄▆▆▇▆▇▇▆▇██████</td></tr><tr><td>eval/accuracy_banding</td><td>â–▅▅▆▇▇█████████▆</td></tr><tr><td>eval/accuracy_dark_scenes</td><td>â–▇██▇██████████▄</td></tr><tr><td>eval/accuracy_graininess</td><td>â–▅▅▅▆▆▆▆▅▆▇▆███â–</td></tr><tr><td>eval/accuracy_motion_blur</td><td>â–▂▂▃▃▃▃▄▄▄▄▄▄▄▄█</td></tr><tr><td>eval/auc</td><td>â–▃▄▅▆▆▆▆▇▇▇▇▇▇▇█</td></tr><tr><td>eval/auc_aliasing</td><td>â–▃▄▄▅▆▆▆▆▇▇▇███▇</td></tr><tr><td>eval/auc_banding</td><td>â–▃▅▆▇▇▇▇███████▇</td></tr><tr><td>eval/auc_dark_scenes</td><td>â–▅▇▇▇▇█████████▆</td></tr><tr><td>eval/auc_graininess</td><td>â–▂▂▃▃▃▄▄▅▅▅▅▆▆▆█</td></tr><tr><td>eval/auc_motion_blur</td><td>â–▂▂▃▃▃▄▄▄▄▄▄▄▄▄█</td></tr><tr><td>eval/f1</td><td>â–„â–▃▄▅▅▅▅▆▆▆▇▇▇▇█</td></tr><tr><td>eval/f1_aliasing</td><td>â–â–▄▄▅▅▅▆▆▆▇▇▇▇▇█</td></tr><tr><td>eval/f1_banding</td><td>â–ƒâ–▃▅▆▇▇████████▇</td></tr><tr><td>eval/f1_dark_scenes</td><td>â–▂▆▇▇▇▇▇███████▅</td></tr><tr><td>eval/f1_graininess</td><td>â–‚â–▂▃▃▃▃▃▄▄▄▅▅▅▅█</td></tr><tr><td>eval/f1_motion_blur</td><td>â–„â–▂▃▃▄▄▅▅▆▆▆▆▆▆█</td></tr><tr><td>eval/loss</td><td>█▆▃▂▂â–â–â–â–â–â–â–‚â–‚â–‚â–‚â–‚</td></tr><tr><td>eval/runtime</td><td>█▃▃â–▃▂▃▂▃▂▂▂▂▂▂▄</td></tr><tr><td>eval/samples_per_second</td><td>â–▆▆█▆▆▆▇▆▆▇▆▇▇▇▆</td></tr><tr><td>eval/source_artefacts_avg_accuracy</td><td>â–▅▆▆▇▇▇▇▇███████</td></tr><tr><td>eval/source_artefacts_avg_auc</td><td>â–▃▄▅▅▆▆▆▆▇▇▇▇▇▇█</td></tr><tr><td>eval/source_artefacts_avg_f1</td><td>â–ƒâ–▃▄▅▅▆▆▇▇▇▇▇▇▇█</td></tr><tr><td>eval/steps_per_second</td><td>â–▆▆█▆▆▆▇▆▆▇▆▇▇▇▇</td></tr><tr><td>train/epoch</td><td>â–â–â–â–▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇██████</td></tr><tr><td>train/global_step</td><td>â–â–â–â–▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇██████</td></tr><tr><td>train/grad_norm</td><td>█▅▃▄▂▂▂▂▂▂▂▂▂▂▂â–â–â–‚â–â–â–â–â–â–â–‚â–‚â–â–â–â–</td></tr><tr><td>train/learning_rate</td><td>â–▆███▇▇▇▆▆▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂â–</td></tr><tr><td>train/loss</td><td>██▆▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂â–â–â–â–â–</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy</td><td>0.64346</td></tr><tr><td>eval/accuracy_aliasing</td><td>0.72432</td></tr><tr><td>eval/accuracy_banding</td><td>0.73459</td></tr><tr><td>eval/accuracy_dark_scenes</td><td>0.73288</td></tr><tr><td>eval/accuracy_graininess</td><td>0.71918</td></tr><tr><td>eval/accuracy_motion_blur</td><td>0.83048</td></tr><tr><td>eval/auc</td><td>0.68896</td></tr><tr><td>eval/auc_aliasing</td><td>0.66432</td></tr><tr><td>eval/auc_banding</td><td>0.63605</td></tr><tr><td>eval/auc_dark_scenes</td><td>0.72998</td></tr><tr><td>eval/auc_graininess</td><td>0.69241</td></tr><tr><td>eval/auc_motion_blur</td><td>0.80945</td></tr><tr><td>eval/f1</td><td>0.45053</td></tr><tr><td>eval/f1_aliasing</td><td>0.47213</td></tr><tr><td>eval/f1_banding</td><td>0.47458</td></tr><tr><td>eval/f1_dark_scenes</td><td>0.55172</td></tr><tr><td>eval/f1_graininess</td><td>0.51479</td></tr><tr><td>eval/f1_motion_blur</td><td>0.57511</td></tr><tr><td>eval/loss</td><td>0.0046</td></tr><tr><td>eval/runtime</td><td>6.1607</td></tr><tr><td>eval/samples_per_second</td><td>94.795</td></tr><tr><td>eval/source_artefacts_avg_accuracy</td><td>0.74829</td></tr><tr><td>eval/source_artefacts_avg_auc</td><td>0.70644</td></tr><tr><td>eval/source_artefacts_avg_f1</td><td>0.51767</td></tr><tr><td>eval/steps_per_second</td><td>3.084</td></tr><tr><td>total_flos</td><td>0.0</td></tr><tr><td>train/epoch</td><td>10.71429</td></tr><tr><td>train/global_step</td><td>300</td></tr><tr><td>train/grad_norm</td><td>0.00885</td></tr><tr><td>train/learning_rate</td><td>0.0</td></tr><tr><td>train/loss</td><td>0.0039</td></tr><tr><td>train_loss</td><td>0.00441</td></tr><tr><td>train_runtime</td><td>524.3153</td></tr><tr><td>train_samples_per_second</td><td>91.548</td></tr><tr><td>train_steps_per_second</td><td>0.954</td></tr></table><br/></div></div>" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "<IPython.core.display.HTML object>" + ], + "text/html": [ + " View run <strong style=\"color:#cdcd00\">run-2024-09-03--04-44-54</strong> at: <a href='https://wandb.ai/sanchitv7-university-of-surrey/video_classification/runs/suem3kz6' target=\"_blank\">https://wandb.ai/sanchitv7-university-of-surrey/video_classification/runs/suem3kz6</a><br/> View project at: <a href='https://wandb.ai/sanchitv7-university-of-surrey/video_classification' target=\"_blank\">https://wandb.ai/sanchitv7-university-of-surrey/video_classification</a><br/>Synced 5 W&B file(s), 0 media file(s), 1 artifact file(s) and 2 other file(s)" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "<IPython.core.display.HTML object>" + ], + "text/html": [ + "Find logs at: <code>./wandb/run-20240903_034455-suem3kz6/logs</code>" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "<IPython.core.display.HTML object>" + ], + "text/html": [ + "The new W&B backend becomes opt-out in version 0.18.0; try it out with `wandb.require(\"core\")`! See https://wandb.me/wandb-core for more information." + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Test set metrics:\n", + "Overall Metrics:\n", + "Loss: 0.0046\n", + "Accuracy: 0.6435\n", + "F1 Score: 0.4505\n", + "AUC: 0.6890\n", + "\n", + "Motion_blur Metrics:\n", + " Accuracy: 0.8305\n", + " F1 Score: 0.5751\n", + " AUC: 0.8094\n", + "\n", + "Dark_scenes Metrics:\n", + " Accuracy: 0.7329\n", + " F1 Score: 0.5517\n", + " AUC: 0.7300\n", + "\n", + "Graininess Metrics:\n", + " Accuracy: 0.7192\n", + " F1 Score: 0.5148\n", + " AUC: 0.6924\n", + "\n", + "Aliasing Metrics:\n", + " Accuracy: 0.7243\n", + " F1 Score: 0.4721\n", + " AUC: 0.6643\n", + "\n", + "Banding Metrics:\n", + " Accuracy: 0.7346\n", + " F1 Score: 0.4746\n", + " AUC: 0.6361\n", + "\n", + "Additional Information:\n", + "Runtime: 6.1607 seconds\n", + "Samples per second: 94.80\n", + "Steps per second: 3.08\n", + "Epoch: 10.71\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "run_name" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "id": "FrFfu6P6vT5g", + "outputId": "38dc4f60-8cf9-4203-ed84-a87aa18d88e3" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'run-2024-09-03--01-02-23'" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } + }, + "metadata": {}, + "execution_count": 168 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# @title Save or Not\n", + "# @markdown **Save the model or not**\n", + "\n", + "# form input for saving or not\n", + "save_logs_or_not = True # @param {type: \"boolean\"}\n", + "save_model_or_not = False # @param {type: \"boolean\"}\n", + "\n", + "if save_logs_or_not:\n", + " # Move logs to drive with today's data\n", + " !scp -r logs/{run_name}/ /content/drive/MyDrive/MSProjectMisc\n", + "\n", + "if save_model_or_not:\n", + " # Save the final model\n", + " final_model_path = os.path.join(\"./results\", run_name)\n", + "\n", + " trainer.save_model(final_model_path)\n", + " print(f\"Final model saved to {final_model_path}\")\n", + "\n", + " # Move the final model to drive\n", + " !scp -r {final_model_path} /content/drive/MyDrive/MSProjectMisc/{run_name}/" + ], + "metadata": { + "id": "3haYagRvmtMj", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# !rm -rf logs results\n", + "\n", + "# !scp -r logs_successful_run_2/ /content/drive/MyDrive/MSProjectMisc" + ], + "metadata": { + "id": "U-e7LLQeJOFC" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# from safetensors import safe_open\n", + "# model = MultiTaskTimeSformer(num_labels, id2label, label2id, train_dataset)\n", + "\n", + "# # Load the saved weights\n", + "# checkpoint_path = \"/content/results/run-2024-08-30--00-49-32/checkpoint-400/model.safetensors\"\n", + "# with safe_open(checkpoint_path, framework=\"pt\", device=\"cpu\") as f:\n", + "# state_dict = {key: f.get_tensor(key) for key in f.keys()}\n", + "\n", + "# model.load_state_dict(state_dict)\n", + "\n", + "# # Set the model to evaluation mode\n", + "# model.eval()\n", + "\n", + "# print(\"Model loaded successfully!\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uuLHkPhNzp4R", + "outputId": "c2702569-133b-4b3c-b0bb-ae730ea0e9c8" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model loaded successfully!\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "_I9XUa4nz-yl" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file