diff --git a/3D_CNN_baseline.ipynb b/3D_CNN_baseline.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5dd434dc473f4704e71c920def4f271a02792f0b --- /dev/null +++ b/3D_CNN_baseline.ipynb @@ -0,0 +1,497 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "048616b7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using GPU.\n" + ] + } + ], + "source": [ + "#Code for running 3D CNN adapted from:https://github.com/latte488/smth-smth-v2\n", + "import os\n", + "import cv2\n", + "import sys\n", + "import importlib\n", + "import torch\n", + "import torchvision\n", + "import numpy as np\n", + "from torch import nn\n", + "import json\n", + "\n", + "# imports for displaying a video an IPython cell\n", + "import io\n", + "import base64\n", + "from IPython.display import HTML\n", + "\n", + "from data_parser import WebmDataset\n", + "from data_loader_av import VideoFolder\n", + "\n", + "from models.multi_column import MultiColumn\n", + "from transforms_video import *\n", + "\n", + "from utils import load_json_config, remove_module_from_checkpoint_state_dict\n", + "from pprint import pprint\n", + "\n", + "from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay\n", + "from tqdm import tqdm\n", + "from matplotlib import pyplot as plt\n", + "\n", + "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "torch.backends.cudnn.deterministic = True\n", + "print(f\"Using {'GPU' if str(DEVICE) == 'cuda' else 'CPU'}.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "75e80d4f", + "metadata": {}, + "outputs": [], + "source": [ + "#helper functions\n", + "def count_parameters(model):\n", + " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "def train(model, dataloader, optimizer, criterion, device):\n", + " model.train()\n", + " running_loss = 0.0\n", + " running_correct_preds = 0\n", + " count = 0\n", + "\n", + " for i, (input, target) in tqdm(enumerate(dataloader), total=len(dataloader)):\n", + " count += 1\n", + " \n", + " optimizer.zero_grad()\n", + " \n", + " if config['nclips_train'] > 1:\n", + " input_var = list(input.split(config['clip_size'], 2))\n", + " for idx, inp in enumerate(input_var):\n", + " input_var[idx] = inp.to(device)\n", + " else:\n", + " input_var = [input.to(device)]\n", + "\n", + " target = target.to(device)\n", + "\n", + " model.zero_grad()\n", + "\n", + " # compute output and loss\n", + " output = model(input_var)\n", + " loss = criterion(output, target)\n", + " running_loss += loss.item()\n", + " \n", + " # compute accuracy\n", + " _, preds = torch.max(output.data, 1)\n", + " running_correct_preds += (preds == target).sum().item()\n", + "\n", + " # backward pass\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # calculate loss and accuracy\n", + " epoch_loss = running_loss / count\n", + " epoch_acc = 100. * (running_correct_preds/ len(dataloader.dataset))\n", + " return epoch_loss, epoch_acc\n", + "\n", + "def eval_model(model, dataloader, device):\n", + " y_pred = []\n", + " y_true = []\n", + " \n", + " running_acc = 0\n", + " count = 0\n", + " \n", + " with torch.no_grad():\n", + " for i, (input, target) in tqdm(enumerate(dataloader), total=len(dataloader)):\n", + " count += 1\n", + "\n", + " if config['nclips_train'] > 1:\n", + " input_var = list(input.split(config['clip_size'], 2))\n", + " for idx, inp in enumerate(input_var):\n", + " input_var[idx] = inp.to(device)\n", + " else:\n", + " input_var = [input.to(device)]\n", + "\n", + " target = target.to(device)\n", + "\n", + " output = model(input_var)\n", + " _, preds = torch.max(output, 1)\n", + " \n", + " count += target.size(0)\n", + " running_acc += (preds == target).sum().item()\n", + " \n", + " y_pred.extend(preds.to('cpu').tolist())\n", + " y_true.extend(target.to('cpu').tolist())\n", + " \n", + " acc = (100 * running_acc / count)\n", + " \n", + " # classification report \n", + " print(classification_report(y_true, y_pred, target_names=['106', '112', '118'], zero_division=0))\n", + " \n", + " # confusion matrix \n", + " cm = confusion_matrix(y_true, y_pred, labels=[0,1,2], normalize='true')\n", + " disp = ConfusionMatrixDisplay(confusion_matrix=cm)\n", + " disp.plot(include_values=False)\n", + " disp.ax_.get_images()[0].set_clim(0, 1.0) # set scale so it does not vary\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "3b79d192", + "metadata": {}, + "source": [ + "Create annotation files for subset of 3 classes that model will be fine-tuned on" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b4cdee0", + "metadata": {}, + "outputs": [], + "source": [ + "label_train_path = '/vol/research/TopDownVideo/labels/something-something-v2-train.json' \n", + "label_val_path = '/vol/research/TopDownVideo/labels/something-something-v2-validation.json'\n", + "\n", + "action_list = [\n", + " 'Putting [something] into [something]',\n", + " 'Putting [something] onto [something]',\n", + " 'Putting [something] underneath [something]'\n", + "]\n", + "\n", + "with open(label_train_path) as json_file:\n", + " train_json = json.load(json_file)\n", + "\n", + "with open(label_val_path) as json_file:\n", + " val_json = json.load(json_file)\n", + " \n", + "train_json_updated = []\n", + "for d in train_json:\n", + " if d['template'] in action_list:\n", + " train_json_updated.append(d) \n", + "print(\"Length of train set:\" + str(len(train_json_updated)))\n", + "\n", + "val_json_updated = []\n", + "for d in val_json:\n", + " if d['template'] in action_list:\n", + " val_json_updated.append(d) \n", + "print(\"Length of validation set:\" + str(len(val_json_updated)))\n", + "\n", + "label_train_target = '/vol/research/TopDownVideo/aa03813/LabelsForBaseline/something-something-v2-train3.json' \n", + "label_val_target = '/vol/research/TopDownVideo/aa03813/LabelsForBaseline/something-something-v2-val3.json'\n", + "\n", + "with open(label_train_target, \"w\") as write_file:\n", + " json.dump(train_json_updated, write_file, indent=1)\n", + " \n", + "with open(label_val_target, \"w\") as write_file:\n", + " json.dump(val_json_updated, write_file, indent=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "07b888a1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=> Name of the model -- model3D_1\n", + "=> Checkpoint path --> ../trained_models/pretrained/model3D_1/model_best.pth.tar\n" + ] + } + ], + "source": [ + "# load config\n", + "config = load_json_config('./configs/pretrained/config_model1_for_finetuning.json')\n", + "\n", + "#setup model from checkpoint\n", + "column_cnn_def = importlib.import_module(\"{}\".format(config['conv_model']))\n", + "model_name = config[\"model_name\"]\n", + "\n", + "print(\"=> Name of the model -- {}\".format(model_name))\n", + "\n", + "# checkpoint path to a trained model\n", + "checkpoint_path = os.path.join(\"../\", config[\"output_dir\"], config[\"model_name\"], \"model_best.pth.tar\")\n", + "print(\"=> Checkpoint path --> {}\".format(checkpoint_path))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "dd213cfb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Num of trainable parameters before freezing: 23384430\n", + "Num of trainable parameters after freezing: 1539\n" + ] + }, + { + "data": { + "text/plain": [ + "MultiColumn(\n", + " (conv_column): Model(\n", + " (block1): Sequential(\n", + " (0): Conv3d(3, 32, kernel_size=(3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2))\n", + " (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Dropout3d(p=0.2, inplace=False)\n", + " )\n", + " (block2): Sequential(\n", + " (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n", + " (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1))\n", + " (4): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " (6): Dropout3d(p=0.2, inplace=False)\n", + " )\n", + " (block3): Sequential(\n", + " (0): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n", + " (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n", + " (4): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " (6): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1))\n", + " (7): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (8): ReLU(inplace=True)\n", + " (9): Dropout3d(p=0.2, inplace=False)\n", + " )\n", + " (block4): Sequential(\n", + " (0): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n", + " (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n", + " (4): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " (6): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1))\n", + " (7): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (8): ReLU(inplace=True)\n", + " (9): Dropout3d(p=0.2, inplace=False)\n", + " )\n", + " (block5): Sequential(\n", + " (0): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n", + " (1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1))\n", + " (4): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (clf_layers): Linear(in_features=512, out_features=3, bias=True)\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#initialize and freeze model\n", + "model = MultiColumn(config['num_classes'], column_cnn_def.Model, int(config[\"column_units\"]))\n", + "\n", + "print(\"Num of trainable parameters before freezing: \" + str(count_parameters(model)))\n", + "for param in model.parameters():\n", + " param.requires_grad = False\n", + "#replace last layer so it's output is only 3\n", + "model.clf_layers = nn.Linear(512, 3)\n", + "print(\"Num of trainable parameters after freezing: \" + str(count_parameters(model)))\n", + "model.to(DEVICE)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "9fae288b", + "metadata": {}, + "outputs": [], + "source": [ + "# define augmentation pipeline\n", + "upscale_size_train = int(config['input_spatial_size'] * config[\"upscale_factor_train\"])\n", + "upscale_size_eval = int(config['input_spatial_size'] * config[\"upscale_factor_eval\"])\n", + "\n", + "# Random crop videos during training\n", + "transform_train_pre = ComposeMix([\n", + " [RandomRotationVideo(15), \"vid\"],\n", + " [Scale(upscale_size_train), \"img\"],\n", + " [RandomCropVideo(config['input_spatial_size']), \"vid\"],\n", + " ])\n", + "\n", + "# Center crop videos during evaluation\n", + "transform_eval_pre = ComposeMix([\n", + " [Scale(upscale_size_eval), \"img\"],\n", + " [torchvision.transforms.ToPILImage(), \"img\"],\n", + " [torchvision.transforms.CenterCrop(config['input_spatial_size']), \"img\"],\n", + " ])\n", + "\n", + "# Transforms common to train and eval sets and applied after \"pre\" transforms\n", + "transform_post = ComposeMix([\n", + " [torchvision.transforms.ToTensor(), \"img\"],\n", + " [torchvision.transforms.Normalize(\n", + " mean=[0.485, 0.456, 0.406], # default values for imagenet\n", + " std=[0.229, 0.224, 0.225]), \"img\"]\n", + " ])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3b5c018d", + "metadata": {}, + "outputs": [], + "source": [ + "train_data = VideoFolder(root=config['data_folder'],\n", + " json_file_input=config['json_data_train'],\n", + " json_file_labels=config['json_file_labels'],\n", + " clip_size=config['clip_size'],\n", + " nclips=config['nclips_train'],\n", + " step_size=config['step_size_train'],\n", + " is_val=False,\n", + " transform_pre=transform_train_pre,\n", + " transform_post=transform_post,\n", + " augmentation_mappings_json=config['augmentation_mappings_json'],\n", + " augmentation_types_todo=config['augmentation_types_todo'],\n", + " get_item_id=False,\n", + " )\n", + "\n", + "train_loader = torch.utils.data.DataLoader(\n", + " train_data,\n", + " batch_size=config['batch_size'], shuffle=True,\n", + " num_workers=config['num_workers'], pin_memory=True,\n", + " drop_last=True)\n", + "\n", + "val_data = VideoFolder(root=config['data_folder'],\n", + " json_file_input=config['json_data_val'],\n", + " json_file_labels=config['json_file_labels'],\n", + " clip_size=config['clip_size'],\n", + " nclips=config['nclips_val'],\n", + " step_size=config['step_size_val'],\n", + " is_val=True,\n", + " transform_pre=transform_eval_pre,\n", + " transform_post=transform_post,\n", + " get_item_id=False,\n", + " )\n", + "\n", + "val_loader = torch.utils.data.DataLoader(\n", + " val_data,\n", + " batch_size=config['batch_size'], shuffle=False,\n", + " num_workers=config['num_workers'], pin_memory=True,\n", + " drop_last=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "fafe4c40", + "metadata": {}, + "outputs": [], + "source": [ + "LR = 5e-2\n", + "OPTIMIZER = torch.optim.SGD(model.parameters(), LR)\n", + "CRITERION = nn.CrossEntropyLoss().to(DEVICE)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "bdbab144", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████| 146/146 [04:35<00:00, 1.89s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training loss: 1.7852653726323011\n", + "Training accuracy:39.977298524404084\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "loss, acc = train(model, train_loader, OPTIMIZER, CRITERION, DEVICE)\n", + "print(\"Training loss: \" + str(loss))\n", + "print(\"Training accuracy:\" + str(acc))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2dd0a6ae", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<All keys matched successfully>" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.save(model.state_dict(), '/vol/research/TopDownVideo/aa03813/LabelsForBaseline/model.pkl')\n", + "model.load_state_dict(torch.load('/vol/research/TopDownVideo/aa03813/LabelsForBaseline/model.pkl'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5990018", + "metadata": {}, + "outputs": [], + "source": [ + "eval_model(model, val_loader, DEVICE)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}