diff --git a/discarded_src/balance_dataset.py b/discarded_src/balance_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb2bc028cefe28bf9174664db9ba7cf9f52b8713
--- /dev/null
+++ b/discarded_src/balance_dataset.py
@@ -0,0 +1,101 @@
+import argparse
+import json
+import os
+import random
+import shutil
+
+from tqdm import tqdm
+
+
+def load_labels(labels_path):
+    with open(labels_path, 'r') as f:
+        return json.load(f)
+
+
+def get_video_paths(input_dir):
+    video_paths = {}
+    for part in ['part1', 'part2']:
+        part_dir = os.path.join(input_dir, part)
+        for video in os.listdir(part_dir):
+            video_paths[video] = os.path.join(part_dir, video)
+    return video_paths
+
+
+def get_maximum_balanced_subset(labels, video_paths):
+    artefacts = set()
+    for video_labels in labels.values():
+        artefacts.update(video_labels.keys())
+
+    balanced_subset = {}
+
+    for artefact in artefacts:
+        positive_videos = [video for video, video_labels in labels.items()
+                           if video in video_paths and video_labels.get(artefact, 0) == 1]
+        negative_videos = [video for video, video_labels in labels.items()
+                           if video in video_paths and video_labels.get(artefact, 0) == 0]
+
+        count_per_label = min(len(positive_videos), len(negative_videos))
+
+        selected_positive = set(random.sample(positive_videos, count_per_label))
+        selected_negative = set(random.sample(negative_videos, count_per_label))
+
+        for video in selected_positive.union(selected_negative):
+            if video not in balanced_subset:
+                balanced_subset[video] = labels[video]
+            balanced_subset[video][artefact] = 1 if video in selected_positive else 0
+
+    return balanced_subset
+
+
+def copy_videos(videos, video_paths, dst_dir):
+    os.makedirs(dst_dir, exist_ok=True)
+    for video in tqdm(videos, desc=f"Copying to {os.path.basename(dst_dir)}"):
+        src_path = video_paths[video]
+        dst_path = os.path.join(dst_dir, video)
+        shutil.copy2(src_path, dst_path)
+
+
+def create_subset_labels(balanced_subset):
+    return balanced_subset
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        description="Create a maximum balanced subset of videos for all artefacts and relocate them.")
+    parser.add_argument("--input_dir", type=str, required=True, help="Path to processed_BVIArtefact folder")
+    parser.add_argument("--output_dir", type=str, required=True, help="Path to output directory")
+    args = parser.parse_args()
+
+    labels_path = os.path.join(args.input_dir, 'processed_labels.json')
+    labels = load_labels(labels_path)
+
+    video_paths = get_video_paths(args.input_dir)
+
+    balanced_subset = get_maximum_balanced_subset(labels, video_paths)
+
+    copy_videos(balanced_subset.keys(), video_paths, args.output_dir)
+
+    # Create and save the subset labels.json
+    subset_labels = create_subset_labels(balanced_subset)
+    labels_json_path = os.path.join(args.output_dir, 'labels.json')
+    with open(labels_json_path, 'w') as f:
+        json.dump(subset_labels, f, indent=4)
+
+    print(f"Maximum balanced subset created in {args.output_dir}")
+    print(f"Total videos in subset: {len(balanced_subset)}")
+    print(f"Labels.json created at {labels_json_path}")
+
+    artefacts = set()
+    for video_labels in balanced_subset.values():
+        artefacts.update(video_labels.keys())
+
+    for artefact in sorted(artefacts):
+        presence_count = sum(1 for labels in balanced_subset.values() if labels.get(artefact, 0) == 1)
+        absence_count = sum(1 for labels in balanced_subset.values() if labels.get(artefact, 0) == 0)
+        print(f"{artefact}:")
+        print(f"  Presence count: {presence_count}")
+        print(f"  Absence count: {absence_count}")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/discarded_src/sample_code.py b/discarded_src/sample_code.py
new file mode 100644
index 0000000000000000000000000000000000000000..975732c89a404ead168301857c0b435ec0d3e97a
--- /dev/null
+++ b/discarded_src/sample_code.py
@@ -0,0 +1,260 @@
+# dataset structure:
+'''
+data/graininess_100_balanced_subset_split
+├── test
+│   ├── BirdsInCage_1920x1080_30fps_8bit_420_Pristine_QP32_FBT_1.avi
+│   ├── Chimera1_4096x2160_60fps_10bit_420_graininess_QP32_FB_1.avi
+│   ├── Chimera3_4096x2160_24fps_10bit_420_graininess_QP32_FT_1.avi
+│   ├── ...
+│   └── labels.json
+├── train
+│   ├── labels.json
+│   ├── lamppost_1920x1080_120fps_8bit_420_Pristine_QP32_BT_3.avi
+│   ├── lamppost_1920x1080_120fps_8bit_420_Pristine_QP47_SF_3.avi
+│   ├── leaveswall_1920x1080_120fps_8bit_420_Motion_QP32_SB_1.avi
+│   ├── leaveswall_1920x1080_120fps_8bit_420_Motion_QP32_SFB_4.avi
+│   ├── library_1920x1080_120fps_8bit_420_aliasing_QP47_FT_1.avi
+│   ├── ...
+└── val
+    ├── Chimera2_4096x2160_60fps_10bit_420_Dark_QP32_BT_1.avi
+    ├── ...
+    ├── labels.json
+    ├── shields_1280x720_50fps_8bit_420_graininess_QP47_SFB_1.avi
+    ├── station_1920x1080_30fps_8bit_420_graininess_QP32_SB_1.avi
+    ├── svtmidnightsun_3840x2160_50fps_10bit_420_banding_QP47_SBT_3.avi
+    ├── svtmidnightsun_3840x2160_50fps_10bit_420_banding_QP47_SFT_1.avi
+    ├── svtsmokesauna_3840x2160_50fps_10bit_420_banding_QP32_F_4.avi
+    ├── svtwaterflyover_3840x2160_50fps_10bit_420_banding_QP32_T_3.avi
+    └── typing_1920x1080_120fps_8bit_420_aliasing_QP47_BT_4.avi
+
+4 directories, 103 files
+'''
+
+'''
+labels.json in each split is like:
+{
+  "Chimera1_4096x2160_60fps_10bit_420_graininess_QP47_FT_1.avi": {
+    "graininess": 1
+  },
+  "riverbed_1920x1080_25fps_8bit_420_banding_QP47_SBT_1.avi": {
+    "graininess": 0
+  },
+  "Meridian1_3840x2160_60fps_10bit_420_banding_QP47_SFT_1.avi": {
+    "graininess": 0
+  },
+  '''
+
+
+# Import necessary libraries
+import os
+import json
+import torch
+import numpy as np
+from transformers import VivitImageProcessor, VivitForVideoClassification, TrainingArguments, Trainer
+from datasets import Dataset, DatasetDict
+from torchvision.io import read_video
+from sklearn.metrics import accuracy_score, precision_recall_fscore_support
+from multiprocessing import Pool
+import functools
+
+
+def load_video(video_path):
+    # Read the video file
+    video, _, info = read_video(video_path, pts_unit='sec')
+
+    # Set the number of frames we want to sample
+    num_frames_to_sample = 32
+
+    # Get the total number of frames in the video
+    total_frames = video.shape[0]
+
+    # Calculate the sampling rate to evenly distribute frames
+    sampling_rate = max(total_frames // num_frames_to_sample, 1)
+
+    # Sample frames at the calculated rate
+    sampled_frames = video[::sampling_rate][:num_frames_to_sample]
+
+    # If we don't have enough frames, pad with zeros
+    if sampled_frames.shape[0] < num_frames_to_sample:
+        padding = torch.zeros(
+            (num_frames_to_sample - sampled_frames.shape[0], *sampled_frames.shape[1:]), dtype=sampled_frames.dtype)
+        sampled_frames = torch.cat([sampled_frames, padding], dim=0)
+
+    # Ensure we have exactly the number of frames we want
+    sampled_frames = sampled_frames[:num_frames_to_sample]
+
+    # Convert to numpy array and change to channel-first format (C, H, W)
+    return sampled_frames.permute(0, 3, 1, 2).numpy()
+
+
+def create_dataset(data_dir, split):
+    # Construct the path to the video directory and labels file
+    video_dir = os.path.join(data_dir, split)
+    json_path = os.path.join(video_dir, 'labels.json')
+
+    # Load the labels from the JSON file
+    with open(json_path, 'r') as f:
+        labels = json.load(f)
+
+    # Get all video files in the directory
+    video_files = [f for f in os.listdir(video_dir) if f.endswith('.avi')]
+
+    # Create a dataset with video paths and their corresponding labels
+    dataset = Dataset.from_dict({
+        'video_path': [os.path.join(video_dir, f) for f in video_files],
+        'label': [labels[f]['graininess'] for f in video_files]
+    })
+
+    return dataset
+
+
+# Load the ViViT image processor
+image_processor = VivitImageProcessor.from_pretrained(
+    "google/vivit-b-16x2-kinetics400")
+
+
+def preprocess_video(example, image_processor):
+    # Load the video
+    video = load_video(example['video_path'])
+
+    # Process the video frames using the ViViT image processor
+    inputs = image_processor(list(video), return_tensors="np")
+
+    # Add the processed inputs to the example dictionary
+    for k, v in inputs.items():
+        example[k] = v.squeeze()  # Remove batch dimension
+
+    return example
+
+
+def preprocess_dataset(dataset, num_proc=4):
+    # Use multiprocessing to preprocess the dataset in parallel
+    return dataset.map(
+        functools.partial(preprocess_video, image_processor=image_processor),
+        remove_columns=['video_path'],
+        num_proc=num_proc
+    )
+
+
+# Define the path to the dataset
+data_dir = 'graininess_100_balanced_subset_split'
+
+# Load the datasets for each split
+dataset = DatasetDict({
+    'train': create_dataset(data_dir, 'train'),
+    'validation': create_dataset(data_dir, 'val'),
+    'test': create_dataset(data_dir, 'test')
+})
+
+# Define the path where the preprocessed dataset will be saved
+preprocessed_path = './preprocessed_dataset'
+
+# Check if preprocessed dataset already exists
+if os.path.exists(preprocessed_path):
+    print("Loading preprocessed dataset...")
+    # Load the preprocessed dataset from disk
+    preprocessed_dataset = DatasetDict.load_from_disk(preprocessed_path)
+else:
+    print("Preprocessing dataset...")
+    # Preprocess each split of the dataset
+    preprocessed_dataset = DatasetDict({
+        split: preprocess_dataset(dataset[split])
+        for split in dataset.keys()
+    })
+    # Save the preprocessed dataset to disk
+    preprocessed_dataset.save_to_disk(preprocessed_path)
+    print("Preprocessed dataset saved to disk.")
+
+# Load the ViViT model
+model = VivitForVideoClassification.from_pretrained(
+    "google/vivit-b-16x2-kinetics400")
+
+# Modify the model for binary classification
+model.classifier = torch.nn.Linear(model.config.hidden_size, 2)
+model.num_labels = 2
+
+# Set up training arguments
+training_args = TrainingArguments(
+    output_dir="./results",  # Directory to save the model checkpoints
+    num_train_epochs=3,  # Number of training epochs
+    per_device_train_batch_size=2,  # Batch size for training
+    per_device_eval_batch_size=2,  # Batch size for evaluation
+    warmup_steps=500,  # Number of warmup steps for learning rate scheduler
+    weight_decay=0.01,  # Strength of weight decay
+    logging_dir='./logs',  # Directory for storing logs
+    logging_steps=10,  # Log every X updates steps
+    evaluation_strategy="steps",  # Evaluate during training
+    eval_steps=100,  # Evaluate every X steps
+    save_steps=1000,  # Save checkpoint every X steps
+    # Load the best model when finished training (default metric is loss)
+    load_best_model_at_end=True,
+)
+
+# Define function to compute evaluation metrics
+
+
+def compute_metrics(eval_pred):
+    # Get the predictions and true labels
+    predictions = np.argmax(eval_pred.predictions, axis=1)
+    labels = eval_pred.label_ids
+
+    # Compute precision, recall, and F1 score
+    precision, recall, f1, _ = precision_recall_fscore_support(
+        labels, predictions, average='binary')
+
+    # Compute accuracy
+    accuracy = accuracy_score(labels, predictions)
+
+    # Return all metrics
+    return {
+        'accuracy': accuracy,
+        'f1': f1,
+        'precision': precision,
+        'recall': recall
+    }
+
+
+# Initialize the Trainer
+trainer = Trainer(
+    model=model,  # The instantiated model to be trained
+    args=training_args,  # Training arguments, defined above
+    train_dataset=preprocessed_dataset['train'],  # Training dataset
+    eval_dataset=preprocessed_dataset['validation'],  # Evaluation dataset
+    compute_metrics=compute_metrics,  # The function that computes metrics
+)
+
+# Train the model
+trainer.train()
+
+# Evaluate the model on the test set
+evaluation_results = trainer.evaluate(preprocessed_dataset['test'])
+print(evaluation_results)
+
+# Save the final model
+trainer.save_model("./vivit_binary_classifier")
+
+# Function to predict on new videos
+
+
+def predict_video(video_path):
+    # Load and preprocess the video
+    video = load_video(video_path)
+    inputs = image_processor(list(video), return_tensors="pt")
+
+    # Make prediction
+    with torch.no_grad():
+        outputs = model(**inputs)
+
+    # Get probabilities and predicted class
+    probabilities = torch.softmax(outputs.logits, dim=1)
+    predicted_class = torch.argmax(probabilities, dim=1).item()
+
+    return predicted_class, probabilities[0][predicted_class].item()
+
+
+
+
+# Example usage of prediction function
+# video_path = "path/to/your/video.avi"
+# predicted_class, confidence = predict_video(video_path)
+# print(f"Predicted class: {predicted_class}, Confidence: {confidence:.2f}")
diff --git a/discarded_src/sample_code_try_2.py b/discarded_src/sample_code_try_2.py
new file mode 100644
index 0000000000000000000000000000000000000000..2923397483e1e2860998c76c0cb7fb4cffa8a041
--- /dev/null
+++ b/discarded_src/sample_code_try_2.py
@@ -0,0 +1,231 @@
+'''
+# dataset structure:
+data/graininess_100_balanced_subset_split
+├── test
+│   ├── BirdsInCage_1920x1080_30fps_8bit_420_Pristine_QP32_FBT_1.avi
+│   ├── Chimera1_4096x2160_60fps_10bit_420_graininess_QP32_FB_1.avi
+│   ├── Chimera3_4096x2160_24fps_10bit_420_graininess_QP32_FT_1.avi
+│   ├── ...
+│   └── labels.json
+├── train
+│   ├── labels.json
+│   ├── lamppost_1920x1080_120fps_8bit_420_Pristine_QP32_BT_3.avi
+│   ├── lamppost_1920x1080_120fps_8bit_420_Pristine_QP47_SF_3.avi
+│   ├── leaveswall_1920x1080_120fps_8bit_420_Motion_QP32_SB_1.avi
+│   ├── leaveswall_1920x1080_120fps_8bit_420_Motion_QP32_SFB_4.avi
+│   ├── library_1920x1080_120fps_8bit_420_aliasing_QP47_FT_1.avi
+│   ├── ...
+└── val
+    ├── Chimera2_4096x2160_60fps_10bit_420_Dark_QP32_BT_1.avi
+    ├── ...
+    ├── labels.json
+    ├── shields_1280x720_50fps_8bit_420_graininess_QP47_SFB_1.avi
+    ├── station_1920x1080_30fps_8bit_420_graininess_QP32_SB_1.avi
+    ├── svtmidnightsun_3840x2160_50fps_10bit_420_banding_QP47_SBT_3.avi
+    ├── svtmidnightsun_3840x2160_50fps_10bit_420_banding_QP47_SFT_1.avi
+    ├── svtsmokesauna_3840x2160_50fps_10bit_420_banding_QP32_F_4.avi
+    ├── svtwaterflyover_3840x2160_50fps_10bit_420_banding_QP32_T_3.avi
+    └── typing_1920x1080_120fps_8bit_420_aliasing_QP47_BT_4.avi
+
+4 directories, 103 files
+'''
+
+'''
+labels.json in each split is like:
+{
+  "Chimera1_4096x2160_60fps_10bit_420_graininess_QP47_FT_1.avi": {
+    "graininess": 1
+  },
+  "riverbed_1920x1080_25fps_8bit_420_banding_QP47_SBT_1.avi": {
+    "graininess": 0
+  },
+  "Meridian1_3840x2160_60fps_10bit_420_banding_QP47_SFT_1.avi": {
+    "graininess": 0
+  },
+  '''
+
+import os
+import json
+import torch
+import numpy as np
+from transformers import VivitImageProcessor, VivitForVideoClassification, TrainingArguments, Trainer
+from datasets import Dataset, DatasetDict
+from torchvision.io import read_video
+import torchvision.transforms as T
+from sklearn.metrics import accuracy_score, precision_recall_fscore_support
+import albumentations as A
+from albumentations.pytorch import ToTensorV2
+import cv2
+from functools import partial
+
+
+def get_augmentation():
+    return A.Compose([
+        A.HorizontalFlip(p=0.5),
+        A.VerticalFlip(p=0.5),
+        A.RandomRotate90(p=0.5),
+        A.Transpose(p=0.5),
+        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
+        ToTensorV2(),
+    ])
+
+
+def apply_augmentation(frames, augmentation):
+    aug_frames = []
+    for frame in frames:
+        augmented = augmentation(image=frame)
+        aug_frames.append(augmented['image'])
+    return torch.stack(aug_frames)
+
+
+def uniform_frame_sample(video, num_frames):
+    total_frames = len(video)
+    if total_frames <= num_frames:
+        return video
+
+    indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
+    return video[indices]
+
+
+def load_video(video_path, num_frames=32, augmentation=None):
+    video, _, info = read_video(video_path, pts_unit='sec')
+
+    # Uniform sampling
+    sampled_frames = uniform_frame_sample(video, num_frames)
+
+    if augmentation:
+        sampled_frames = apply_augmentation(sampled_frames, augmentation)
+
+    return sampled_frames.permute(0, 3, 1, 2).float() / 255.0
+
+
+def create_dataset(data_dir, split):
+    video_dir = os.path.join(data_dir, split)
+    json_path = os.path.join(video_dir, 'labels.json')
+    with open(json_path, 'r') as f:
+        labels = json.load(f)
+
+    video_files = [f for f in os.listdir(video_dir) if f.endswith('.avi')]
+
+    dataset = Dataset.from_dict({
+        'video_path': [os.path.join(video_dir, f) for f in video_files],
+        'label': [labels[f]['graininess'] for f in video_files]
+    })
+
+    return dataset
+
+
+# Load the image processor
+image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
+
+
+def preprocess_video(example, image_processor, augmentation=None):
+    video = load_video(example['video_path'], augmentation=augmentation)
+    inputs = image_processor(list(video), return_tensors="pt")
+    for k, v in inputs.items():
+        example[k] = v.squeeze()
+    return example
+
+
+def preprocess_dataset(dataset, augmentation=None):
+    return dataset.map(
+        partial(preprocess_video, image_processor=image_processor, augmentation=augmentation),
+        remove_columns=['video_path'],
+        num_proc=4
+    )
+
+
+# Load and preprocess the datasets
+data_dir = 'graininess_100_balanced_subset_split'
+dataset = DatasetDict({
+    'train': create_dataset(data_dir, 'train'),
+    'validation': create_dataset(data_dir, 'val'),
+    'test': create_dataset(data_dir, 'test')
+})
+
+augmentation = get_augmentation()
+
+preprocessed_path = './preprocessed_dataset_augmented'
+if os.path.exists(preprocessed_path):
+    print("Loading preprocessed dataset...")
+    preprocessed_dataset = DatasetDict.load_from_disk(preprocessed_path)
+else:
+    print("Preprocessing dataset with augmentation...")
+    preprocessed_dataset = DatasetDict({
+        'train': preprocess_dataset(dataset['train'], augmentation),
+        'validation': preprocess_dataset(dataset['validation']),
+        'test': preprocess_dataset(dataset['test'])
+    })
+    preprocessed_dataset.save_to_disk(preprocessed_path)
+    print("Preprocessed dataset saved to disk.")
+
+# Load the model
+model = VivitForVideoClassification.from_pretrained("google/vivit-b-16x2-kinetics400")
+model.classifier = torch.nn.Linear(model.config.hidden_size, 2)
+model.num_labels = 2
+
+# Set up training arguments
+training_args = TrainingArguments(
+    output_dir="./results",
+    num_train_epochs=5,
+    per_device_train_batch_size=4,
+    per_device_eval_batch_size=4,
+    warmup_steps=500,
+    weight_decay=0.01,
+    logging_dir='./logs',
+    logging_steps=10,
+    evaluation_strategy="steps",
+    eval_steps=100,
+    save_steps=1000,
+    load_best_model_at_end=True,
+    fp16=True,  # Enable mixed precision training
+    gradient_accumulation_steps=2,  # Accumulate gradients over 2 steps
+)
+
+
+def compute_metrics(eval_pred):
+    predictions = np.argmax(eval_pred.predictions, axis=1)
+    labels = eval_pred.label_ids
+    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
+    accuracy = accuracy_score(labels, predictions)
+    return {
+        'accuracy': accuracy,
+        'f1': f1,
+        'precision': precision,
+        'recall': recall
+    }
+
+
+# Initialize Trainer
+trainer = Trainer(
+    model=model,
+    args=training_args,
+    train_dataset=preprocessed_dataset['train'],
+    eval_dataset=preprocessed_dataset['validation'],
+    compute_metrics=compute_metrics,
+)
+
+# Train the model
+trainer.train()
+
+# Evaluate the model
+evaluation_results = trainer.evaluate(preprocessed_dataset['test'])
+print(evaluation_results)
+
+# Save the model
+trainer.save_model("./vivit_binary_classifier_augmented")
+
+
+def predict_video(video_path):
+    video = load_video(video_path)
+    inputs = image_processor(list(video), return_tensors="pt")
+    with torch.no_grad():
+        outputs = model(**inputs)
+    probabilities = torch.softmax(outputs.logits, dim=1)
+    predicted_class = torch.argmax(probabilities, dim=1).item()
+    return predicted_class, probabilities[0][predicted_class].item()
+
+# Example usage of prediction function
+# video_path = "path/to/your/video.avi"
+# predicted_class, confidence = predict_video(video_path)
+# print(f"Predicted class: {predicted_class}, Confidence: {confidence:.2f}")
diff --git a/discarded_src/test_run.py b/discarded_src/test_run.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2df458ad95edb73bb1248a3103a816f0f3d1a20
--- /dev/null
+++ b/discarded_src/test_run.py
@@ -0,0 +1,212 @@
+import os
+import json
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import Dataset, DataLoader
+from torchvision import transforms, models
+from torchvision.io import read_video
+from torchvision.models import ResNet50_Weights
+
+# Set device
+if torch.cuda.is_available():
+    device = torch.device("cuda")
+elif torch.backends.mps.is_available():
+    device = torch.device("mps")
+else:
+    device = torch.device("cpu")
+
+print(f"Using device: {device}")
+
+# Define paths
+data_path = "data/graininess_100_balanced_subset_split"
+train_path = os.path.join(data_path, "train")
+val_path = os.path.join(data_path, "val")
+test_path = os.path.join(data_path, "test")
+
+# Define artifact (can be extended for multi-task later)
+artifact = "graininess"
+
+
+# Helper function to load labels
+def load_labels(split_path):
+    with open(os.path.join(split_path, "labels.json"), "r") as f:
+        return json.load(f)
+
+
+# Custom dataset class
+class VideoDataset(Dataset):
+    def __init__(self, root_dir, labels, artifact):
+        self.root_dir = root_dir
+        self.labels = labels
+        self.artifact = artifact
+        self.video_files = [f for f in os.listdir(root_dir) if f.endswith('.avi')]
+        self.transform = transforms.Compose([
+            transforms.ConvertImageDtype(torch.float32),
+            transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])
+        ])
+
+    def __len__(self):
+        return len(self.video_files)
+
+    def __getitem__(self, idx):
+        video_name = self.video_files[idx]
+        video_path = os.path.join(self.root_dir, video_name)
+        label = self.labels[video_name][self.artifact]
+
+        # Load video using torchvision
+        video, _, _ = read_video(video_path, pts_unit='sec')
+
+        # Subsample frames (adjust as needed)
+        video = video[::video.shape[0] // 16][:16]
+
+        # Apply normalization
+        video = self.transform(video)
+
+        # Rearrange dimensions to [C, T, H, W]
+        video = video.permute(3, 0, 1, 2)
+
+        return video, torch.tensor(label, dtype=torch.float32)
+
+
+# Create datasets
+train_labels = load_labels(train_path)
+val_labels = load_labels(val_path)
+test_labels = load_labels(test_path)
+
+train_dataset = VideoDataset(train_path, train_labels, artifact)
+val_dataset = VideoDataset(val_path, val_labels, artifact)
+test_dataset = VideoDataset(test_path, test_labels, artifact)
+
+# Create data loaders
+batch_size = 8
+train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
+val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
+test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
+
+
+# Define model
+class VideoClassifier(nn.Module):
+    def __init__(self, num_classes=1):
+        super(VideoClassifier, self).__init__()
+        self.resnet = models.resnet50(weights=ResNet50_Weights.DEFAULT)
+        self.resnet.conv1 = nn.Conv2d(16, 64, kernel_size=7, stride=2, padding=3, bias=False)
+        self.fc = nn.Linear(2048, num_classes)
+
+    def forward(self, x):
+        b, c, t, h, w = x.shape
+        x = x.transpose(1, 2).reshape(b * t, c, h, w)
+        x = self.resnet.conv1(x)
+        x = self.resnet.bn1(x)
+        x = self.resnet.relu(x)
+        x = self.resnet.maxpool(x)
+        x = self.resnet.layer1(x)
+        x = self.resnet.layer2(x)
+        x = self.resnet.layer3(x)
+        x = self.resnet.layer4(x)
+        x = self.resnet.avgpool(x)
+        x = x.reshape(b, t, -1).mean(1)
+        x = self.fc(x)
+        return torch.sigmoid(x)
+
+
+model = VideoClassifier().to(device)
+
+# Define loss function and optimizer
+criterion = nn.BCELoss()
+optimizer = optim.Adam(model.parameters(), lr=0.001)
+
+
+# Training function
+def train(model, train_loader, criterion, optimizer, device):
+    model.train()
+    running_loss = 0.0
+    correct = 0
+    total = 0
+
+    for videos, labels in train_loader:
+        videos, labels = videos.to(device), labels.to(device)
+
+        optimizer.zero_grad()
+        outputs = model(videos)
+        loss = criterion(outputs.squeeze(), labels)
+        loss.backward()
+        optimizer.step()
+
+        running_loss += loss.item()
+        predicted = (outputs.squeeze() > 0.5).float()
+        total += labels.size(0)
+        correct += (predicted == labels).sum().item()
+
+    epoch_loss = running_loss / len(train_loader)
+    epoch_acc = correct / total
+    return epoch_loss, epoch_acc
+
+
+# Validation function
+def validate(model, val_loader, criterion, device):
+    model.eval()
+    running_loss = 0.0
+    correct = 0
+    total = 0
+
+    with torch.no_grad():
+        for videos, labels in val_loader:
+            videos, labels = videos.to(device), labels.to(device)
+
+            outputs = model(videos)
+            loss = criterion(outputs.squeeze(), labels)
+
+            running_loss += loss.item()
+            predicted = (outputs.squeeze() > 0.5).float()
+            total += labels.size(0)
+            correct += (predicted == labels).sum().item()
+
+    epoch_loss = running_loss / len(val_loader)
+    epoch_acc = correct / total
+    return epoch_loss, epoch_acc
+
+
+# Training loop
+num_epochs = 10
+for epoch in range(num_epochs):
+    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
+    val_loss, val_acc = validate(model, val_loader, criterion, device)
+
+    print(f"Epoch {epoch + 1}/{num_epochs}")
+    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
+    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
+    print()
+
+
+# Test function
+def test(model, test_loader, criterion, device):
+    model.eval()
+    running_loss = 0.0
+    correct = 0
+    total = 0
+
+    with torch.no_grad():
+        for videos, labels in test_loader:
+            videos, labels = videos.to(device), labels.to(device)
+
+            outputs = model(videos)
+            loss = criterion(outputs.squeeze(), labels)
+
+            running_loss += loss.item()
+            predicted = (outputs.squeeze() > 0.5).float()
+            total += labels.size(0)
+            correct += (predicted == labels).sum().item()
+
+    test_loss = running_loss / len(test_loader)
+    test_acc = correct / total
+    return test_loss, test_acc
+
+
+# Evaluate on test set
+test_loss, test_acc = test(model, test_loader, criterion, device)
+print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
+
+# Save the model
+torch.save(model.state_dict(), f"video_classifier_{artifact}.pth")
+print(f"Model saved as video_classifier_{artifact}.pth")
diff --git a/logs/first_run_logs/run-2024-08-25--18-57-46/events.out.tfevents.1724612332.26610f3d33fa.1894.0 b/logs/first_run_logs/run-2024-08-25--18-57-46/events.out.tfevents.1724612332.26610f3d33fa.1894.0
new file mode 100644
index 0000000000000000000000000000000000000000..cd6b3acc3dd59de5924f5c744c6dfed965de0bc9
Binary files /dev/null and b/logs/first_run_logs/run-2024-08-25--18-57-46/events.out.tfevents.1724612332.26610f3d33fa.1894.0 differ
diff --git a/logs/first_run_logs/run-2024-08-25--18-57-46/events.out.tfevents.1724612413.26610f3d33fa.1894.1 b/logs/first_run_logs/run-2024-08-25--18-57-46/events.out.tfevents.1724612413.26610f3d33fa.1894.1
new file mode 100644
index 0000000000000000000000000000000000000000..1dcb305968855178b24dd6f7a555aee19ce2cfc5
Binary files /dev/null and b/logs/first_run_logs/run-2024-08-25--18-57-46/events.out.tfevents.1724612413.26610f3d33fa.1894.1 differ
diff --git a/logs/first_run_logs/run-2024-08-25--19-00-48/events.out.tfevents.1724612452.26610f3d33fa.1894.2 b/logs/first_run_logs/run-2024-08-25--19-00-48/events.out.tfevents.1724612452.26610f3d33fa.1894.2
new file mode 100644
index 0000000000000000000000000000000000000000..3619b50087b5c9d5154ef9d48164fbb6b557b856
Binary files /dev/null and b/logs/first_run_logs/run-2024-08-25--19-00-48/events.out.tfevents.1724612452.26610f3d33fa.1894.2 differ
diff --git a/logs/first_run_logs/run-2024-08-25--19-00-48/events.out.tfevents.1724612515.26610f3d33fa.1894.3 b/logs/first_run_logs/run-2024-08-25--19-00-48/events.out.tfevents.1724612515.26610f3d33fa.1894.3
new file mode 100644
index 0000000000000000000000000000000000000000..bb878f89e4e3b7341f99e0c1014a936814064682
Binary files /dev/null and b/logs/first_run_logs/run-2024-08-25--19-00-48/events.out.tfevents.1724612515.26610f3d33fa.1894.3 differ
diff --git a/logs/first_run_logs/run-2024-08-25--19-02-49/events.out.tfevents.1724612582.26610f3d33fa.5666.0 b/logs/first_run_logs/run-2024-08-25--19-02-49/events.out.tfevents.1724612582.26610f3d33fa.5666.0
new file mode 100644
index 0000000000000000000000000000000000000000..7588ca876fce4b56f15ca5f21534edb4b4f6fcc6
Binary files /dev/null and b/logs/first_run_logs/run-2024-08-25--19-02-49/events.out.tfevents.1724612582.26610f3d33fa.5666.0 differ
diff --git a/logs/first_run_logs/run-2024-08-25--19-02-49/events.out.tfevents.1724612683.26610f3d33fa.5666.1 b/logs/first_run_logs/run-2024-08-25--19-02-49/events.out.tfevents.1724612683.26610f3d33fa.5666.1
new file mode 100644
index 0000000000000000000000000000000000000000..ef95393aca2d3ecbd709755be79cb946c7b69f49
Binary files /dev/null and b/logs/first_run_logs/run-2024-08-25--19-02-49/events.out.tfevents.1724612683.26610f3d33fa.5666.1 differ
diff --git a/logs/first_run_logs/run-2024-08-25--19-05-27/events.out.tfevents.1724612737.26610f3d33fa.5666.2 b/logs/first_run_logs/run-2024-08-25--19-05-27/events.out.tfevents.1724612737.26610f3d33fa.5666.2
new file mode 100644
index 0000000000000000000000000000000000000000..6ea8859409230e96f44bb013b856c7203d32cb1b
Binary files /dev/null and b/logs/first_run_logs/run-2024-08-25--19-05-27/events.out.tfevents.1724612737.26610f3d33fa.5666.2 differ
diff --git a/logs/first_run_logs/run-2024-08-25--19-05-27/run_info.txt b/logs/first_run_logs/run-2024-08-25--19-05-27/run_info.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a22fb47aedf86c9898107e31b3e6e02f3e0ee33d
--- /dev/null
+++ b/logs/first_run_logs/run-2024-08-25--19-05-27/run_info.txt
@@ -0,0 +1,138 @@
+Run Name: run-2024-08-25--19-05-27
+Model: MultiTaskTimeSformer
+Training Arguments:
+  output_dir: ./results/run-2024-08-25--19-05-27
+  overwrite_output_dir: False
+  do_train: False
+  do_eval: True
+  do_predict: False
+  eval_strategy: steps
+  prediction_loss_only: False
+  per_device_train_batch_size: 16
+  per_device_eval_batch_size: 16
+  per_gpu_train_batch_size: None
+  per_gpu_eval_batch_size: None
+  gradient_accumulation_steps: 2
+  eval_accumulation_steps: None
+  eval_delay: 0
+  learning_rate: 5e-05
+  weight_decay: 0.01
+  adam_beta1: 0.9
+  adam_beta2: 0.999
+  adam_epsilon: 1e-08
+  max_grad_norm: 1.0
+  num_train_epochs: 3.0
+  max_steps: 420
+  lr_scheduler_type: linear
+  lr_scheduler_kwargs: {}
+  warmup_ratio: 0.1
+  warmup_steps: 0
+  log_level: passive
+  log_level_replica: warning
+  log_on_each_node: True
+  logging_dir: ./logs/run-2024-08-25--19-05-27
+  logging_strategy: steps
+  logging_first_step: False
+  logging_steps: 20
+  logging_nan_inf_filter: True
+  save_strategy: steps
+  save_steps: 100
+  save_total_limit: 2
+  save_safetensors: True
+  save_on_each_node: False
+  save_only_model: False
+  restore_callback_states_from_checkpoint: False
+  no_cuda: False
+  use_cpu: False
+  use_mps_device: False
+  seed: 42
+  data_seed: None
+  jit_mode_eval: False
+  use_ipex: False
+  bf16: False
+  fp16: True
+  fp16_opt_level: O1
+  half_precision_backend: auto
+  bf16_full_eval: False
+  fp16_full_eval: False
+  tf32: None
+  local_rank: 0
+  ddp_backend: None
+  tpu_num_cores: None
+  tpu_metrics_debug: False
+  debug: []
+  dataloader_drop_last: False
+  eval_steps: 50
+  dataloader_num_workers: 12
+  dataloader_prefetch_factor: None
+  past_index: -1
+  run_name: run-2024-08-25--19-05-27
+  disable_tqdm: False
+  remove_unused_columns: True
+  label_names: None
+  load_best_model_at_end: True
+  metric_for_best_model: f1
+  greater_is_better: True
+  ignore_data_skip: False
+  fsdp: []
+  fsdp_min_num_params: 0
+  fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
+  fsdp_transformer_layer_cls_to_wrap: None
+  accelerator_config: AcceleratorConfig(split_batches=False, dispatch_batches=None, even_batches=True, use_seedable_sampler=True, non_blocking=False, gradient_accumulation_kwargs=None, use_configured_state=False)
+  deepspeed: None
+  label_smoothing_factor: 0.0
+  optim: adamw_torch
+  optim_args: None
+  adafactor: False
+  group_by_length: False
+  length_column_name: length
+  report_to: ['tensorboard']
+  ddp_find_unused_parameters: None
+  ddp_bucket_cap_mb: None
+  ddp_broadcast_buffers: None
+  dataloader_pin_memory: True
+  dataloader_persistent_workers: False
+  skip_memory_metrics: True
+  use_legacy_prediction_loop: False
+  push_to_hub: False
+  resume_from_checkpoint: None
+  hub_model_id: None
+  hub_strategy: every_save
+  hub_token: None
+  hub_private_repo: False
+  hub_always_push: False
+  gradient_checkpointing: False
+  gradient_checkpointing_kwargs: None
+  include_inputs_for_metrics: False
+  eval_do_concat_batches: True
+  fp16_backend: auto
+  evaluation_strategy: None
+  push_to_hub_model_id: None
+  push_to_hub_organization: None
+  push_to_hub_token: None
+  mp_parameters: 
+  auto_find_batch_size: False
+  full_determinism: False
+  torchdynamo: None
+  ray_scope: last
+  ddp_timeout: 1800
+  torch_compile: False
+  torch_compile_backend: None
+  torch_compile_mode: None
+  dispatch_batches: None
+  split_batches: None
+  include_tokens_per_second: False
+  include_num_input_tokens_seen: False
+  neftune_noise_alpha: None
+  optim_target_modules: None
+  batch_eval_metrics: False
+  eval_on_start: False
+  distributed_state: Distributed environment: NO
+Num processes: 1
+Process index: 0
+Local process index: 0
+Device: cuda
+
+  _n_gpu: 1
+  __cached__setup_devices: cuda:0
+  deepspeed_plugin: None
diff --git a/logs/logs_successful_run_2/run-2024-08-25--23-53-08/events.out.tfevents.1724630048.0826fd70f652.869.0 b/logs/logs_successful_run_2/run-2024-08-25--23-53-08/events.out.tfevents.1724630048.0826fd70f652.869.0
new file mode 100644
index 0000000000000000000000000000000000000000..6011474c6cca20208705553f872d20e212a29edd
Binary files /dev/null and b/logs/logs_successful_run_2/run-2024-08-25--23-53-08/events.out.tfevents.1724630048.0826fd70f652.869.0 differ
diff --git a/logs/logs_successful_run_2/run-2024-08-25--23-53-08/events.out.tfevents.1724630112.0826fd70f652.869.1 b/logs/logs_successful_run_2/run-2024-08-25--23-53-08/events.out.tfevents.1724630112.0826fd70f652.869.1
new file mode 100644
index 0000000000000000000000000000000000000000..552bc0441b298fc0199417237135023b8a741954
Binary files /dev/null and b/logs/logs_successful_run_2/run-2024-08-25--23-53-08/events.out.tfevents.1724630112.0826fd70f652.869.1 differ
diff --git a/logs/logs_successful_run_2/run-2024-08-25--23-53-08/events.out.tfevents.1724630132.0826fd70f652.869.2 b/logs/logs_successful_run_2/run-2024-08-25--23-53-08/events.out.tfevents.1724630132.0826fd70f652.869.2
new file mode 100644
index 0000000000000000000000000000000000000000..6b9adbf3fc6f38a81768f98e594fa1eae681dc1e
Binary files /dev/null and b/logs/logs_successful_run_2/run-2024-08-25--23-53-08/events.out.tfevents.1724630132.0826fd70f652.869.2 differ
diff --git a/logs/logs_successful_run_2/run-2024-08-26--00-02-24/events.out.tfevents.1724630555.0826fd70f652.869.3 b/logs/logs_successful_run_2/run-2024-08-26--00-02-24/events.out.tfevents.1724630555.0826fd70f652.869.3
new file mode 100644
index 0000000000000000000000000000000000000000..8ca1d6ad846d489b7b578fb916e39aeabbb94a96
Binary files /dev/null and b/logs/logs_successful_run_2/run-2024-08-26--00-02-24/events.out.tfevents.1724630555.0826fd70f652.869.3 differ
diff --git a/logs/logs_successful_run_2/run-2024-08-26--00-03-58/events.out.tfevents.1724630643.0826fd70f652.6853.0 b/logs/logs_successful_run_2/run-2024-08-26--00-03-58/events.out.tfevents.1724630643.0826fd70f652.6853.0
new file mode 100644
index 0000000000000000000000000000000000000000..52264f3345f36505026169c61baa8d7eee1185fe
Binary files /dev/null and b/logs/logs_successful_run_2/run-2024-08-26--00-03-58/events.out.tfevents.1724630643.0826fd70f652.6853.0 differ
diff --git a/logs/second_run_logs/run-2024-08-25--18-57-46/events.out.tfevents.1724612332.26610f3d33fa.1894.0 b/logs/second_run_logs/run-2024-08-25--18-57-46/events.out.tfevents.1724612332.26610f3d33fa.1894.0
new file mode 100644
index 0000000000000000000000000000000000000000..cd6b3acc3dd59de5924f5c744c6dfed965de0bc9
Binary files /dev/null and b/logs/second_run_logs/run-2024-08-25--18-57-46/events.out.tfevents.1724612332.26610f3d33fa.1894.0 differ
diff --git a/logs/second_run_logs/run-2024-08-25--18-57-46/events.out.tfevents.1724612413.26610f3d33fa.1894.1 b/logs/second_run_logs/run-2024-08-25--18-57-46/events.out.tfevents.1724612413.26610f3d33fa.1894.1
new file mode 100644
index 0000000000000000000000000000000000000000..1dcb305968855178b24dd6f7a555aee19ce2cfc5
Binary files /dev/null and b/logs/second_run_logs/run-2024-08-25--18-57-46/events.out.tfevents.1724612413.26610f3d33fa.1894.1 differ
diff --git a/logs/second_run_logs/run-2024-08-25--19-00-48/events.out.tfevents.1724612452.26610f3d33fa.1894.2 b/logs/second_run_logs/run-2024-08-25--19-00-48/events.out.tfevents.1724612452.26610f3d33fa.1894.2
new file mode 100644
index 0000000000000000000000000000000000000000..3619b50087b5c9d5154ef9d48164fbb6b557b856
Binary files /dev/null and b/logs/second_run_logs/run-2024-08-25--19-00-48/events.out.tfevents.1724612452.26610f3d33fa.1894.2 differ
diff --git a/logs/second_run_logs/run-2024-08-25--19-00-48/events.out.tfevents.1724612515.26610f3d33fa.1894.3 b/logs/second_run_logs/run-2024-08-25--19-00-48/events.out.tfevents.1724612515.26610f3d33fa.1894.3
new file mode 100644
index 0000000000000000000000000000000000000000..bb878f89e4e3b7341f99e0c1014a936814064682
Binary files /dev/null and b/logs/second_run_logs/run-2024-08-25--19-00-48/events.out.tfevents.1724612515.26610f3d33fa.1894.3 differ
diff --git a/logs/second_run_logs/run-2024-08-25--19-02-49/events.out.tfevents.1724612582.26610f3d33fa.5666.0 b/logs/second_run_logs/run-2024-08-25--19-02-49/events.out.tfevents.1724612582.26610f3d33fa.5666.0
new file mode 100644
index 0000000000000000000000000000000000000000..7588ca876fce4b56f15ca5f21534edb4b4f6fcc6
Binary files /dev/null and b/logs/second_run_logs/run-2024-08-25--19-02-49/events.out.tfevents.1724612582.26610f3d33fa.5666.0 differ
diff --git a/logs/second_run_logs/run-2024-08-25--19-02-49/events.out.tfevents.1724612683.26610f3d33fa.5666.1 b/logs/second_run_logs/run-2024-08-25--19-02-49/events.out.tfevents.1724612683.26610f3d33fa.5666.1
new file mode 100644
index 0000000000000000000000000000000000000000..ef95393aca2d3ecbd709755be79cb946c7b69f49
Binary files /dev/null and b/logs/second_run_logs/run-2024-08-25--19-02-49/events.out.tfevents.1724612683.26610f3d33fa.5666.1 differ
diff --git a/logs/second_run_logs/run-2024-08-25--19-05-27/events.out.tfevents.1724612737.26610f3d33fa.5666.2 b/logs/second_run_logs/run-2024-08-25--19-05-27/events.out.tfevents.1724612737.26610f3d33fa.5666.2
new file mode 100644
index 0000000000000000000000000000000000000000..6ea8859409230e96f44bb013b856c7203d32cb1b
Binary files /dev/null and b/logs/second_run_logs/run-2024-08-25--19-05-27/events.out.tfevents.1724612737.26610f3d33fa.5666.2 differ
diff --git a/logs/second_run_logs/run-2024-08-25--19-05-27/run_info.txt b/logs/second_run_logs/run-2024-08-25--19-05-27/run_info.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a22fb47aedf86c9898107e31b3e6e02f3e0ee33d
--- /dev/null
+++ b/logs/second_run_logs/run-2024-08-25--19-05-27/run_info.txt
@@ -0,0 +1,138 @@
+Run Name: run-2024-08-25--19-05-27
+Model: MultiTaskTimeSformer
+Training Arguments:
+  output_dir: ./results/run-2024-08-25--19-05-27
+  overwrite_output_dir: False
+  do_train: False
+  do_eval: True
+  do_predict: False
+  eval_strategy: steps
+  prediction_loss_only: False
+  per_device_train_batch_size: 16
+  per_device_eval_batch_size: 16
+  per_gpu_train_batch_size: None
+  per_gpu_eval_batch_size: None
+  gradient_accumulation_steps: 2
+  eval_accumulation_steps: None
+  eval_delay: 0
+  learning_rate: 5e-05
+  weight_decay: 0.01
+  adam_beta1: 0.9
+  adam_beta2: 0.999
+  adam_epsilon: 1e-08
+  max_grad_norm: 1.0
+  num_train_epochs: 3.0
+  max_steps: 420
+  lr_scheduler_type: linear
+  lr_scheduler_kwargs: {}
+  warmup_ratio: 0.1
+  warmup_steps: 0
+  log_level: passive
+  log_level_replica: warning
+  log_on_each_node: True
+  logging_dir: ./logs/run-2024-08-25--19-05-27
+  logging_strategy: steps
+  logging_first_step: False
+  logging_steps: 20
+  logging_nan_inf_filter: True
+  save_strategy: steps
+  save_steps: 100
+  save_total_limit: 2
+  save_safetensors: True
+  save_on_each_node: False
+  save_only_model: False
+  restore_callback_states_from_checkpoint: False
+  no_cuda: False
+  use_cpu: False
+  use_mps_device: False
+  seed: 42
+  data_seed: None
+  jit_mode_eval: False
+  use_ipex: False
+  bf16: False
+  fp16: True
+  fp16_opt_level: O1
+  half_precision_backend: auto
+  bf16_full_eval: False
+  fp16_full_eval: False
+  tf32: None
+  local_rank: 0
+  ddp_backend: None
+  tpu_num_cores: None
+  tpu_metrics_debug: False
+  debug: []
+  dataloader_drop_last: False
+  eval_steps: 50
+  dataloader_num_workers: 12
+  dataloader_prefetch_factor: None
+  past_index: -1
+  run_name: run-2024-08-25--19-05-27
+  disable_tqdm: False
+  remove_unused_columns: True
+  label_names: None
+  load_best_model_at_end: True
+  metric_for_best_model: f1
+  greater_is_better: True
+  ignore_data_skip: False
+  fsdp: []
+  fsdp_min_num_params: 0
+  fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
+  fsdp_transformer_layer_cls_to_wrap: None
+  accelerator_config: AcceleratorConfig(split_batches=False, dispatch_batches=None, even_batches=True, use_seedable_sampler=True, non_blocking=False, gradient_accumulation_kwargs=None, use_configured_state=False)
+  deepspeed: None
+  label_smoothing_factor: 0.0
+  optim: adamw_torch
+  optim_args: None
+  adafactor: False
+  group_by_length: False
+  length_column_name: length
+  report_to: ['tensorboard']
+  ddp_find_unused_parameters: None
+  ddp_bucket_cap_mb: None
+  ddp_broadcast_buffers: None
+  dataloader_pin_memory: True
+  dataloader_persistent_workers: False
+  skip_memory_metrics: True
+  use_legacy_prediction_loop: False
+  push_to_hub: False
+  resume_from_checkpoint: None
+  hub_model_id: None
+  hub_strategy: every_save
+  hub_token: None
+  hub_private_repo: False
+  hub_always_push: False
+  gradient_checkpointing: False
+  gradient_checkpointing_kwargs: None
+  include_inputs_for_metrics: False
+  eval_do_concat_batches: True
+  fp16_backend: auto
+  evaluation_strategy: None
+  push_to_hub_model_id: None
+  push_to_hub_organization: None
+  push_to_hub_token: None
+  mp_parameters: 
+  auto_find_batch_size: False
+  full_determinism: False
+  torchdynamo: None
+  ray_scope: last
+  ddp_timeout: 1800
+  torch_compile: False
+  torch_compile_backend: None
+  torch_compile_mode: None
+  dispatch_batches: None
+  split_batches: None
+  include_tokens_per_second: False
+  include_num_input_tokens_seen: False
+  neftune_noise_alpha: None
+  optim_target_modules: None
+  batch_eval_metrics: False
+  eval_on_start: False
+  distributed_state: Distributed environment: NO
+Num processes: 1
+Process index: 0
+Local process index: 0
+Device: cuda
+
+  _n_gpu: 1
+  __cached__setup_devices: cuda:0
+  deepspeed_plugin: None
diff --git a/notebooks/data_prep.ipynb b/notebooks/data_prep.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..110b2e532ff30b031b073405edfd8a02e3459f30
--- /dev/null
+++ b/notebooks/data_prep.ipynb
@@ -0,0 +1,473 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e006d00d0980cdb6",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import json\n",
+    "from pathlib import Path\n",
+    "import torch\n",
+    "from torch.utils.data import Dataset\n",
+    "import torchvision.transforms as transforms\n",
+    "from torchvision.io import read_video\n",
+    "\n",
+    "\n",
+    "class VideoNormalize(torch.nn.Module):\n",
+    "    def __init__(self, mean, std):\n",
+    "        super().__init__()\n",
+    "        self.mean = torch.tensor(mean).view(3, 1, 1, 1)\n",
+    "        self.std = torch.tensor(std).view(3, 1, 1, 1)\n",
+    "\n",
+    "    def forward(self, video):\n",
+    "        return (video - self.mean) / self.std\n",
+    "\n",
+    "\n",
+    "class VideoDataset(Dataset):\n",
+    "    def __init__(self, root_dir, split, transform=None, clip_duration=5.0, target_fps=30):\n",
+    "        self.root_dir = Path(root_dir) / split\n",
+    "        self.transform = transform\n",
+    "        self.clip_duration = clip_duration\n",
+    "        self.target_fps = target_fps\n",
+    "        self.target_frames = int(clip_duration * target_fps)\n",
+    "        self.video_files = []\n",
+    "        self.labels = {}\n",
+    "\n",
+    "        # Load labels from labels.json\n",
+    "        labels_path = self.root_dir / 'labels.json'\n",
+    "        with open(labels_path, 'r') as f:\n",
+    "            self.labels = json.load(f)\n",
+    "\n",
+    "        # Collect video file paths\n",
+    "        self.video_files = list(self.root_dir.glob('*.avi'))\n",
+    "\n",
+    "    def __len__(self):\n",
+    "        return len(self.video_files)\n",
+    "\n",
+    "    def __getitem__(self, idx):\n",
+    "        video_path = str(self.video_files[idx])\n",
+    "        video_name = self.video_files[idx].name\n",
+    "        label = self.labels[video_name]['graininess']\n",
+    "\n",
+    "        # Read video using torchvision\n",
+    "        video, audio, meta = read_video(video_path, pts_unit='sec')\n",
+    "\n",
+    "        # Extract frame rate from metadata\n",
+    "        fps = meta['video_fps']\n",
+    "\n",
+    "        # Calculate the number of frames to sample based on the clip duration and video's fps\n",
+    "        num_frames_to_sample = min(int(self.clip_duration * fps), video.shape[0])\n",
+    "\n",
+    "        # Sample frames\n",
+    "        if num_frames_to_sample < video.shape[0]:\n",
+    "            start_idx = torch.randint(0, video.shape[0] - num_frames_to_sample + 1, (1,)).item()\n",
+    "            video = video[start_idx:start_idx + num_frames_to_sample]\n",
+    "\n",
+    "        # Resample to target FPS\n",
+    "        if fps != self.target_fps:\n",
+    "            indices = torch.linspace(0, video.shape[0] - 1, self.target_frames).long()\n",
+    "            video = video[indices]\n",
+    "\n",
+    "        # Ensure we have exactly target_frames\n",
+    "        if video.shape[0] < self.target_frames:\n",
+    "            video = torch.cat([video, video[-1].unsqueeze(0).repeat(self.target_frames - video.shape[0], 1, 1, 1)])\n",
+    "        elif video.shape[0] > self.target_frames:\n",
+    "            video = video[:self.target_frames]\n",
+    "\n",
+    "        # Change from (T, H, W, C) to (C, T, H, W)\n",
+    "        video = video.permute(3, 0, 1, 2)\n",
+    "\n",
+    "        if self.transform:\n",
+    "            video = self.transform(video)\n",
+    "\n",
+    "        return video, torch.tensor(label, dtype=torch.long)\n",
+    "\n",
+    "\n",
+    "# Example usage\n",
+    "transform = transforms.Compose([\n",
+    "    transforms.Lambda(lambda x: x.float() / 255.0),  # Normalize to [0, 1]\n",
+    "    VideoNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
+    "])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "id": "a21a7b0a8e86913c",
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2024-08-19T16:31:42.133858Z",
+     "start_time": "2024-08-19T16:31:42.128809Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "# Path to the dataset\n",
+    "data_root = Path('/Users/sv7/Projects/mtl-video-classification/data/graininess_100_balanced_subset_split')\n",
+    "\n",
+    "train_dataset = VideoDataset(data_root,\n",
+    "                             split='train',\n",
+    "                             transform=transform)\n",
+    "\n",
+    "test_dataset = VideoDataset(data_root,\n",
+    "                            split='test',\n",
+    "                            transform=transform)\n",
+    "\n",
+    "val_dataset = VideoDataset(data_root,\n",
+    "                           split='val',\n",
+    "                           transform=transform)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "a9092ed9c5027597",
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2024-08-19T16:31:42.761488Z",
+     "start_time": "2024-08-19T16:31:42.759166Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "# DataLoader example\n",
+    "from torch.utils.data import DataLoader\n",
+    "import os\n",
+    "\n",
+    "batch_size = 4\n",
+    "num_workers = os.cpu_count()\n",
+    "\n",
+    "train_loader = DataLoader(train_dataset,\n",
+    "                          batch_size=batch_size,\n",
+    "                          shuffle=True,\n",
+    "                          num_workers=num_workers)\n",
+    "\n",
+    "test_loader = DataLoader(test_dataset,\n",
+    "                         batch_size=batch_size,\n",
+    "                         shuffle=False,\n",
+    "                         num_workers=num_workers)\n",
+    "\n",
+    "val_loader = DataLoader(val_dataset,\n",
+    "                        batch_size=batch_size,\n",
+    "                        shuffle=False,\n",
+    "                        num_workers=num_workers)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "id": "77d2d43a9fe4c2c2",
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2024-08-19T16:55:37.595972Z",
+     "start_time": "2024-08-19T16:55:36.873079Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "import json\n",
+    "from pathlib import Path\n",
+    "\n",
+    "train_data_path = Path('/Users/sv7/Projects/mtl-video-classification/data/graininess_100_balanced_subset_split/train')\n",
+    "labels_path = train_data_path / 'labels.json'\n",
+    "\n",
+    "# /Users/sv7/Projects/mtl-video-classification/data/graininess_100_balanced_subset_split/train/labels.json\n",
+    "video_files = list(train_data_path.glob('*.avi'))\n",
+    "with open(labels_path) as f:\n",
+    "    labels = json.load(f)\n",
+    "\n",
+    "video_path = str(video_files[5])\n",
+    "video, audio, meta = read_video(video_path, pts_unit='sec')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "id": "f7d927a0c9c73948",
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2024-08-19T16:57:53.317039Z",
+     "start_time": "2024-08-19T16:57:53.314764Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "clip_duration = 5.0\n",
+    "\n",
+    "# Extract frame rate from metadata\n",
+    "fps = meta['video_fps']\n",
+    "\n",
+    "# Calculate the number of frames to sample based on the clip duration and video's fps\n",
+    "num_frames_to_sample = min(int(clip_duration * fps), video.shape[0])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 37,
+   "id": "b2c6a74027e9f3",
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2024-08-19T17:13:49.049139Z",
+     "start_time": "2024-08-19T17:13:49.046501Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "300"
+      ]
+     },
+     "execution_count": 37,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "num_frames_to_sample"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 41,
+   "id": "4d960113bee6e247",
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2024-08-19T18:51:49.590079Z",
+     "start_time": "2024-08-19T18:51:19.547632Z"
+    }
+   },
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Some weights of VivitForVideoClassification were not initialized from the model checkpoint at google/vivit-b-16x2-kinetics400 and are newly initialized because the shapes did not match:\n",
+      "- classifier.weight: found shape torch.Size([400, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated\n",
+      "- classifier.bias: found shape torch.Size([400]) in the checkpoint and torch.Size([2]) in the model instantiated\n",
+      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
+      "/Users/sv7/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/transformers/feature_extraction_utils.py:142: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:281.)\n",
+      "  return torch.tensor(value)\n"
+     ]
+    },
+    {
+     "ename": "RuntimeError",
+     "evalue": "MPS backend out of memory (MPS allocated: 17.77 GB, other allocations: 40.66 MB, max allowed: 18.13 GB). Tried to allocate 1.76 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
+      "Cell \u001b[0;32mIn[41], line 124\u001b[0m\n\u001b[1;32m    116\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Trainer(\n\u001b[1;32m    117\u001b[0m     model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[1;32m    118\u001b[0m     args\u001b[38;5;241m=\u001b[39mtraining_args,\n\u001b[1;32m    119\u001b[0m     train_dataset\u001b[38;5;241m=\u001b[39mtrain_dataset,\n\u001b[1;32m    120\u001b[0m     eval_dataset\u001b[38;5;241m=\u001b[39mval_dataset,\n\u001b[1;32m    121\u001b[0m )\n\u001b[1;32m    123\u001b[0m \u001b[38;5;66;03m# Cell 8: Train the model\u001b[39;00m\n\u001b[0;32m--> 124\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    126\u001b[0m \u001b[38;5;66;03m# Cell 9: Evaluate on test set\u001b[39;00m\n\u001b[1;32m    127\u001b[0m test_results \u001b[38;5;241m=\u001b[39m trainer\u001b[38;5;241m.\u001b[39mevaluate(test_dataset)\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/transformers/trainer.py:1948\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m   1946\u001b[0m         hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m   1947\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1948\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1949\u001b[0m \u001b[43m        \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1950\u001b[0m \u001b[43m        \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1951\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1952\u001b[0m \u001b[43m        \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1953\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/transformers/trainer.py:2289\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   2286\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m   2288\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 2289\u001b[0m     tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2291\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m   2292\u001b[0m     args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m   2293\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m   2294\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m   2295\u001b[0m ):\n\u001b[1;32m   2296\u001b[0m     \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m   2297\u001b[0m     tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/transformers/trainer.py:3328\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m   3325\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m loss_mb\u001b[38;5;241m.\u001b[39mreduce_mean()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m   3327\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 3328\u001b[0m     loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3330\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m inputs\n\u001b[1;32m   3331\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m   3332\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mtorch_empty_cache_steps \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   3333\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m%\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mtorch_empty_cache_steps \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m   3334\u001b[0m ):\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/transformers/trainer.py:3373\u001b[0m, in \u001b[0;36mTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m   3371\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   3372\u001b[0m     labels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 3373\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3374\u001b[0m \u001b[38;5;66;03m# Save past state if it exists\u001b[39;00m\n\u001b[1;32m   3375\u001b[0m \u001b[38;5;66;03m# TODO: this needs to be fixed and made cleaner later.\u001b[39;00m\n\u001b[1;32m   3376\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mpast_index \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1551\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1560\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1561\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1565\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/transformers/models/vivit/modeling_vivit.py:759\u001b[0m, in \u001b[0;36mVivitForVideoClassification.forward\u001b[0;34m(self, pixel_values, head_mask, labels, output_attentions, output_hidden_states, interpolate_pos_encoding, return_dict)\u001b[0m\n\u001b[1;32m    673\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    674\u001b[0m \u001b[38;5;124;03mlabels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\u001b[39;00m\n\u001b[1;32m    675\u001b[0m \u001b[38;5;124;03m    Labels for computing the image classification/regression loss. Indices should be in `[0, ...,\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    755\u001b[0m \u001b[38;5;124;03mLABEL_116\u001b[39;00m\n\u001b[1;32m    756\u001b[0m \u001b[38;5;124;03m```\"\"\"\u001b[39;00m\n\u001b[1;32m    757\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[0;32m--> 759\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvivit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    760\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpixel_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    761\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    762\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    763\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    764\u001b[0m \u001b[43m    \u001b[49m\u001b[43minterpolate_pos_encoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minterpolate_pos_encoding\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    765\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    766\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    768\u001b[0m sequence_output \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    770\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclassifier(sequence_output[:, \u001b[38;5;241m0\u001b[39m, :])\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1551\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1560\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1561\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1565\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/transformers/models/vivit/modeling_vivit.py:611\u001b[0m, in \u001b[0;36mVivitModel.forward\u001b[0;34m(self, pixel_values, head_mask, output_attentions, output_hidden_states, interpolate_pos_encoding, return_dict)\u001b[0m\n\u001b[1;32m    607\u001b[0m head_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_head_mask(head_mask, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mnum_hidden_layers)\n\u001b[1;32m    609\u001b[0m embedding_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membeddings(pixel_values, interpolate_pos_encoding\u001b[38;5;241m=\u001b[39minterpolate_pos_encoding)\n\u001b[0;32m--> 611\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    612\u001b[0m \u001b[43m    \u001b[49m\u001b[43membedding_output\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    613\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    614\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    615\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    616\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    617\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    618\u001b[0m sequence_output \u001b[38;5;241m=\u001b[39m encoder_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    619\u001b[0m sequence_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayernorm(sequence_output)\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1551\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1560\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1561\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1565\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/transformers/models/vivit/modeling_vivit.py:378\u001b[0m, in \u001b[0;36mVivitEncoder.forward\u001b[0;34m(self, hidden_states, head_mask, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    371\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m    372\u001b[0m         layer_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m    373\u001b[0m         hidden_states,\n\u001b[1;32m    374\u001b[0m         layer_head_mask,\n\u001b[1;32m    375\u001b[0m         output_attentions,\n\u001b[1;32m    376\u001b[0m     )\n\u001b[1;32m    377\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 378\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mlayer_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlayer_head_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    380\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    382\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_attentions:\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1551\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1560\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1561\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1565\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/transformers/models/vivit/modeling_vivit.py:321\u001b[0m, in \u001b[0;36mVivitLayer.forward\u001b[0;34m(self, hidden_states, head_mask, output_attentions)\u001b[0m\n\u001b[1;32m    320\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, hidden_states, head_mask\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, output_attentions\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[0;32m--> 321\u001b[0m     self_attention_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    322\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;66;43;03m# in Vivit, layernorm is applied before self-attention\u001b[39;49;00m\n\u001b[1;32m    323\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayernorm_before\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    324\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    325\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    326\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    327\u001b[0m     attention_output \u001b[38;5;241m=\u001b[39m self_attention_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    328\u001b[0m     \u001b[38;5;66;03m# add self attentions if we output attention weights\u001b[39;00m\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1551\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1560\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1561\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1565\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/transformers/models/vivit/modeling_vivit.py:265\u001b[0m, in \u001b[0;36mVivitAttention.forward\u001b[0;34m(self, hidden_states, head_mask, output_attentions)\u001b[0m\n\u001b[1;32m    259\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m    260\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m    261\u001b[0m     hidden_states: torch\u001b[38;5;241m.\u001b[39mTensor,\n\u001b[1;32m    262\u001b[0m     head_mask: Optional[torch\u001b[38;5;241m.\u001b[39mTensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m    263\u001b[0m     output_attentions: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m    264\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Tuple[torch\u001b[38;5;241m.\u001b[39mTensor, torch\u001b[38;5;241m.\u001b[39mTensor], Tuple[torch\u001b[38;5;241m.\u001b[39mTensor]]:\n\u001b[0;32m--> 265\u001b[0m     self_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    267\u001b[0m     attention_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput(self_outputs[\u001b[38;5;241m0\u001b[39m], hidden_states)\n\u001b[1;32m    269\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m (attention_output,) \u001b[38;5;241m+\u001b[39m self_outputs[\u001b[38;5;241m1\u001b[39m:]  \u001b[38;5;66;03m# add attentions if we output them\u001b[39;00m\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1551\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1560\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1561\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1565\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+      "File \u001b[0;32m~/opt/anaconda3/envs/video_classification/lib/python3.12/site-packages/transformers/models/vivit/modeling_vivit.py:188\u001b[0m, in \u001b[0;36mVivitSelfAttention.forward\u001b[0;34m(self, hidden_states, head_mask, output_attentions)\u001b[0m\n\u001b[1;32m    185\u001b[0m query_layer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtranspose_for_scores(mixed_query_layer)\n\u001b[1;32m    187\u001b[0m \u001b[38;5;66;03m# Take the dot product between \"query\" and \"key\" to get the raw attention scores.\u001b[39;00m\n\u001b[0;32m--> 188\u001b[0m attention_scores \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmatmul\u001b[49m\u001b[43m(\u001b[49m\u001b[43mquery_layer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey_layer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtranspose\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    190\u001b[0m attention_scores \u001b[38;5;241m=\u001b[39m attention_scores \u001b[38;5;241m/\u001b[39m math\u001b[38;5;241m.\u001b[39msqrt(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattention_head_size)\n\u001b[1;32m    192\u001b[0m \u001b[38;5;66;03m# Normalize the attention scores to probabilities.\u001b[39;00m\n",
+      "\u001b[0;31mRuntimeError\u001b[0m: MPS backend out of memory (MPS allocated: 17.77 GB, other allocations: 40.66 MB, max allowed: 18.13 GB). Tried to allocate 1.76 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure)."
+     ]
+    }
+   ],
+   "source": [
+    "# Cell 1: Import necessary libraries\n",
+    "import os\n",
+    "import json\n",
+    "import random\n",
+    "import numpy as np\n",
+    "import torch\n",
+    "from torch.utils.data import Dataset, DataLoader\n",
+    "from torchvision.io import read_video\n",
+    "from transformers import VivitImageProcessor, VivitForVideoClassification, TrainingArguments, Trainer\n",
+    "\n",
+    "\n",
+    "# Cell 2: Set random seed for reproducibility\n",
+    "def set_seed(seed):\n",
+    "    random.seed(seed)\n",
+    "    np.random.seed(seed)\n",
+    "    torch.manual_seed(seed)\n",
+    "    torch.cuda.manual_seed_all(seed)\n",
+    "\n",
+    "\n",
+    "set_seed(42)\n",
+    "\n",
+    "\n",
+    "# Cell 3: Define custom dataset class\n",
+    "# Cell 3: Define custom dataset class\n",
+    "class VideoDataset(Dataset):\n",
+    "    def __init__(self, data_dir, split, processor, max_frames=32):\n",
+    "        self.data_dir = os.path.join(data_dir, split)\n",
+    "        self.processor = processor\n",
+    "        self.max_frames = max_frames\n",
+    "        \n",
+    "        with open(os.path.join(self.data_dir, 'labels.json'), 'r') as f:\n",
+    "            self.labels = json.load(f)\n",
+    "        \n",
+    "        self.video_files = list(self.labels.keys())\n",
+    "    \n",
+    "    def __len__(self):\n",
+    "        return len(self.video_files)\n",
+    "    \n",
+    "    def __getitem__(self, idx):\n",
+    "        video_file = self.video_files[idx]\n",
+    "        video_path = os.path.join(self.data_dir, video_file)\n",
+    "        \n",
+    "        # Read video\n",
+    "        video, _, _ = read_video(video_path, pts_unit='sec')\n",
+    "        \n",
+    "        # Sample frames\n",
+    "        num_frames = video.shape[0]\n",
+    "        if num_frames > self.max_frames:\n",
+    "            start = random.randint(0, num_frames - self.max_frames)\n",
+    "            video = video[start:start+self.max_frames]\n",
+    "        else:\n",
+    "            video = video[:self.max_frames]\n",
+    "        \n",
+    "        # Ensure we have 3 channels (RGB)\n",
+    "        if video.shape[-1] != 3:\n",
+    "            video = video.expand(-1, -1, -1, 3)\n",
+    "        \n",
+    "        # Convert to numpy array and ensure correct shape\n",
+    "        video = video.numpy()\n",
+    "        \n",
+    "        # Ensure the video has the correct shape (num_frames, height, width, channels)\n",
+    "        if video.shape[1] == 3:  # If channels are in the second dimension\n",
+    "            video = np.transpose(video, (0, 2, 3, 1))\n",
+    "        \n",
+    "        # Process frames\n",
+    "        pixel_values = self.processor(\n",
+    "            list(video),\n",
+    "            return_tensors=\"pt\",\n",
+    "            do_resize=True,\n",
+    "            size={\"shortest_edge\": 224},  # Adjust this size as needed\n",
+    "            do_center_crop=True,\n",
+    "            crop_size={\"height\": 224, \"width\": 224},  # Adjust this size as needed\n",
+    "        ).pixel_values\n",
+    "        \n",
+    "        # Get label\n",
+    "        label = self.labels[video_file]['graininess']\n",
+    "        \n",
+    "        return {'pixel_values': pixel_values.squeeze(), 'label': torch.tensor(label)}\n",
+    "\n",
+    "\n",
+    "# Cell 4: Initialize ViViT model and processor\n",
+    "model_name = \"google/vivit-b-16x2-kinetics400\"\n",
+    "processor = VivitImageProcessor.from_pretrained(model_name,\n",
+    "                                                ignore_mismatched_sizes=True)\n",
+    "model = VivitForVideoClassification.from_pretrained(model_name, num_labels=2,\n",
+    "                                                    ignore_mismatched_sizes=True)\n",
+    "\n",
+    "# Cell 5: Prepare datasets and dataloaders\n",
+    "data_dir = \"/Users/sv7/Projects/mtl-video-classification/data/graininess_100_balanced_subset_split\"\n",
+    "batch_size = 4\n",
+    "\n",
+    "train_dataset = VideoDataset(data_dir, 'train', processor)\n",
+    "val_dataset = VideoDataset(data_dir, 'val', processor)\n",
+    "test_dataset = VideoDataset(data_dir, 'test', processor)\n",
+    "\n",
+    "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
+    "val_dataloader = DataLoader(val_dataset, batch_size=batch_size)\n",
+    "test_dataloader = DataLoader(test_dataset, batch_size=batch_size)\n",
+    "\n",
+    "# Cell 6: Define training arguments\n",
+    "training_args = TrainingArguments(\n",
+    "    output_dir=\"./results\",\n",
+    "    num_train_epochs=3,\n",
+    "    per_device_train_batch_size=batch_size,\n",
+    "    per_device_eval_batch_size=batch_size,\n",
+    "    warmup_steps=500,\n",
+    "    weight_decay=0.01,\n",
+    "    logging_dir='./logs',\n",
+    "    logging_steps=10,\n",
+    "    eval_strategy=\"epoch\",\n",
+    "    save_strategy=\"epoch\",\n",
+    "    load_best_model_at_end=True,\n",
+    ")\n",
+    "\n",
+    "# Cell 7: Define Trainer\n",
+    "trainer = Trainer(\n",
+    "    model=model,\n",
+    "    args=training_args,\n",
+    "    train_dataset=train_dataset,\n",
+    "    eval_dataset=val_dataset,\n",
+    ")\n",
+    "\n",
+    "# Cell 8: Train the model\n",
+    "trainer.train()\n",
+    "\n",
+    "# Cell 9: Evaluate on test set\n",
+    "test_results = trainer.evaluate(test_dataset)\n",
+    "print(test_results)\n",
+    "\n",
+    "# Cell 10: Save the model\n",
+    "model.save_pretrained(\"./vivit_graininess_classifier\")\n",
+    "processor.save_pretrained(\"./vivit_graininess_classifier\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c239dc3cc6e29490",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "# Cell 11: Inference example\n",
+    "def predict_video(video_path):\n",
+    "    video, _, _ = read_video(video_path, pts_unit='sec')\n",
+    "    inputs = processor(list(video.permute(0, 2, 3, 1).numpy()), return_tensors=\"pt\")\n",
+    "\n",
+    "    with torch.no_grad():\n",
+    "        outputs = model(**inputs)\n",
+    "        logits = outputs.logits\n",
+    "        predicted_class = logits.argmax(-1).item()\n",
+    "\n",
+    "    return \"Grainy\" if predicted_class == 1 else \"Not Grainy\"\n",
+    "\n",
+    "\n",
+    "# Example usage\n",
+    "example_video_path = \"path/to/example/video.avi\"\n",
+    "prediction = predict_video(example_video_path)\n",
+    "print(f\"The video is predicted to be: {prediction}\")"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 2
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython2",
+   "version": "2.7.6"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/notebooks/subset_for_patching.ipynb b/notebooks/subset_for_patching.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..b5852cfb8d24e8db6e164e8488ea7a8dc8f15c74
--- /dev/null
+++ b/notebooks/subset_for_patching.ipynb
@@ -0,0 +1,293 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "id": "initial_id",
+   "metadata": {
+    "collapsed": true,
+    "ExecuteTime": {
+     "end_time": "2024-08-23T23:46:55.159025Z",
+     "start_time": "2024-08-23T23:46:55.155910Z"
+    }
+   },
+   "source": [
+    "files_to_use = ['Tennis_1920x1080_24fps_8bit_420_Motion_QP47_SFB_1.avi',\n",
+    "                'Tennis_1920x1080_24fps_8bit_420_Motion_QP32_BT_1.avi',\n",
+    "                'DanceKiss_1920x1080_25fps_8bit_420_Dark_QP47_FB_4.avi',\n",
+    "                'DanceKiss_1920x1080_25fps_8bit_420_Dark_QP32_SB_4.avi',\n",
+    "                'Kimono1_1920x1080_24fps_8bit_420_graininess_QP47_B_4.avi',\n",
+    "                'Kimono1_1920x1080_24fps_8bit_420_graininess_QP32_FB_1.avi',\n",
+    "                'OldTownCross_1920x1080_25fps_8bit_420_graininess_QP47_SB_4.avi',\n",
+    "                'OldTownCross_1920x1080_25fps_8bit_420_graininess_QP32_SBT_2.avi',\n",
+    "                'BirdsInCage_1920x1080_30fps_8bit_420_Pristine_QP47_SFB_3.avi',\n",
+    "                'BirdsInCage_1920x1080_30fps_8bit_420_Pristine_QP32_FBT_1.avi',\n",
+    "                'ElFuente1_1920x1080_30fps_8bit_420_aliasing_QP47_SFB_1.avi',\n",
+    "                'ElFuente1_1920x1080_30fps_8bit_420_aliasing_QP32_FB_4.avi',\n",
+    "                'ElFuente2_1920x1080_30fps_8bit_420_graininess_QP47_SFB_3.avi',\n",
+    "                'ElFuente2_1920x1080_30fps_8bit_420_graininess_QP32_S_2.avi',\n",
+    "                'BQTerrace_1920x1080_30fps_8bit_420_aliasing_QP47_FB_3.avi',\n",
+    "                'BQTerrace_1920x1080_30fps_8bit_420_aliasing_QP32_SF_4.avi',\n",
+    "                'CrowdRun_1920x1080_25fps_8bit_420_aliasing_QP47_SFT_4.avi',\n",
+    "                'CrowdRun_1920x1080_25fps_8bit_420_aliasing_QP32_SF_1.avi',\n",
+    "                'Seeking_1920x1080_25fps_8bit_420_graininess_QP47_SF_2.avi',\n",
+    "                'Seeking_1920x1080_25fps_8bit_420_graininess_QP32_SFT_1.avi',\n",
+    "                'riverbed_1920x1080_25fps_8bit_420_banding_QP47_SFBT_2.avi',\n",
+    "                'riverbed_1920x1080_25fps_8bit_420_banding_QP32_S_3.avi',\n",
+    "                'station_1920x1080_30fps_8bit_420_graininess_QP47_SBT_2.avi',\n",
+    "                'station_1920x1080_30fps_8bit_420_graininess_QP32_SB_1.avi',\n",
+    "                'shields_1280x720_50fps_8bit_420_graininess_QP47_SBT_3.avi',\n",
+    "                'shields_1280x720_50fps_8bit_420_graininess_QP32_SFBT_2.avi']"
+   ],
+   "outputs": [],
+   "execution_count": 1
+  },
+  {
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2024-08-23T23:47:04.814760Z",
+     "start_time": "2024-08-23T23:47:04.812533Z"
+    }
+   },
+   "cell_type": "code",
+   "source": "from pathlib import Path",
+   "id": "f68ef83150ac3734",
+   "outputs": [],
+   "execution_count": 2
+  },
+  {
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2024-08-23T23:50:25.116050Z",
+     "start_time": "2024-08-23T23:50:25.090048Z"
+    }
+   },
+   "cell_type": "code",
+   "source": [
+    "dataset_path = Path('/Volumes/SSD/BVIArtefact')\n",
+    "\n",
+    "parts = ['part1', 'part2']\n",
+    "\n",
+    "# file paths of all files in files_to_use in part1 and part2\n",
+    "file_paths = []\n",
+    "for part in parts:\n",
+    "    file_path = dataset_path / part\n",
+    "    all_files = list(file_path.glob('*.avi'))\n",
+    "    for file in all_files:\n",
+    "        if file.name in files_to_use:\n",
+    "            file_paths.append(file)    "
+   ],
+   "id": "fdfacf937f9f286e",
+   "outputs": [],
+   "execution_count": 3
+  },
+  {
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2024-08-23T23:50:36.713565Z",
+     "start_time": "2024-08-23T23:50:36.711235Z"
+    }
+   },
+   "cell_type": "code",
+   "source": "len(file_paths)",
+   "id": "b4c910a7e71b9503",
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "26"
+      ]
+     },
+     "execution_count": 5,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "execution_count": 5
+  },
+  {
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2024-08-23T23:51:31.282402Z",
+     "start_time": "2024-08-23T23:51:05.913927Z"
+    }
+   },
+   "cell_type": "code",
+   "source": [
+    "# copy files to a new folder\n",
+    "import shutil\n",
+    "\n",
+    "new_folder = Path('/Volumes/SSD/BVIArtefact/subset_for_patching')\n",
+    "new_folder.mkdir(exist_ok=True)\n",
+    "for file in file_paths:\n",
+    "    shutil.copy(file, new_folder)"
+   ],
+   "id": "fa2b07cf8f56b3c6",
+   "outputs": [],
+   "execution_count": 6
+  },
+  {
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2024-08-23T23:53:20.804168Z",
+     "start_time": "2024-08-23T23:53:20.793023Z"
+    }
+   },
+   "cell_type": "code",
+   "source": [
+    "# copy labels of files in file from /Volumes/SSD/BVIArtefact/processed_labels.json to /Volumes/SSD/BVIArtefact/subset_for_patching\n",
+    "import json\n",
+    "\n",
+    "with open(dataset_path / 'processed_labels.json', 'r') as f:\n",
+    "    labels = json.load(f)\n",
+    "    \n",
+    "new_labels = {}\n",
+    "for file in file_paths:\n",
+    "    new_labels[file.name] = labels[file.name]\n",
+    "    \n",
+    "with open(new_folder / 'labels.json', 'w') as f:\n",
+    "    json.dump(new_labels, f)"
+   ],
+   "id": "3ab6eaf72d2ebf1c",
+   "outputs": [],
+   "execution_count": 7
+  },
+  {
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2024-08-24T00:02:44.629506Z",
+     "start_time": "2024-08-24T00:02:44.547315Z"
+    }
+   },
+   "cell_type": "code",
+   "source": [
+    "import os\n",
+    "import random\n",
+    "\n",
+    "# Paths (Assuming the script is in the same directory as the dataset)\n",
+    "dataset_dir = '/Volumes/SSD/subsets/subset_for_patching'\n",
+    "labels_file = os.path.join(dataset_dir, 'labels.json')\n",
+    "\n",
+    "# Load the labels\n",
+    "with open(labels_file, 'r') as f:\n",
+    "    labels = json.load(f)\n",
+    "\n",
+    "# Split ratios\n",
+    "train_ratio = 0.7\n",
+    "val_ratio = 0.15\n",
+    "test_ratio = 0.15\n",
+    "\n",
+    "# Ensure the output directories exist\n",
+    "train_dir = os.path.join(dataset_dir, 'train')\n",
+    "val_dir = os.path.join(dataset_dir, 'val')\n",
+    "test_dir = os.path.join(dataset_dir, 'test')\n",
+    "\n",
+    "os.makedirs(train_dir, exist_ok=True)\n",
+    "os.makedirs(val_dir, exist_ok=True)\n",
+    "os.makedirs(test_dir, exist_ok=True)\n",
+    "\n",
+    "# Get list of all video files\n",
+    "video_files = [f for f in os.listdir(dataset_dir) if f.endswith('.avi')]\n",
+    "\n",
+    "# Shuffle the dataset\n",
+    "random.shuffle(video_files)\n",
+    "\n",
+    "# Calculate the split indices\n",
+    "train_idx = int(len(video_files) * train_ratio)\n",
+    "val_idx = train_idx + int(len(video_files) * val_ratio)\n",
+    "\n",
+    "# Split the files\n",
+    "train_files = video_files[:train_idx]\n",
+    "val_files = video_files[train_idx:val_idx]\n",
+    "test_files = video_files[val_idx:]\n",
+    "\n",
+    "# Helper function to move files and save labels\n",
+    "def move_files_and_save_labels(files, destination_dir, label_dict):\n",
+    "    dest_labels = {}\n",
+    "    for file in files:\n",
+    "        # Skip hidden files or files not present in the label_dict\n",
+    "        if file not in label_dict:\n",
+    "            print(f\"Skipping {file} as it is not found in labels.json\")\n",
+    "            continue\n",
+    "        src_path = os.path.join(dataset_dir, file)\n",
+    "        dest_path = os.path.join(destination_dir, file)\n",
+    "        shutil.move(src_path, dest_path)\n",
+    "        dest_labels[file] = label_dict[file]\n",
+    "    \n",
+    "    # Save the labels file\n",
+    "    labels_file_path = os.path.join(destination_dir, 'labels.json')\n",
+    "    with open(labels_file_path, 'w') as f:\n",
+    "        json.dump(dest_labels, f, indent=4)\n",
+    "\n",
+    "# Move the files and save the corresponding labels\n",
+    "move_files_and_save_labels(train_files, train_dir, labels)\n",
+    "move_files_and_save_labels(val_files, val_dir, labels)\n",
+    "move_files_and_save_labels(test_files, test_dir, labels)\n",
+    "\n",
+    "print(\"Dataset has been reorganized successfully!\")"
+   ],
+   "id": "9b909bde7c2e0915",
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Skipping ._Kimono1_1920x1080_24fps_8bit_420_graininess_QP32_FB_1.avi as it is not found in labels.json\n",
+      "Skipping ._ElFuente1_1920x1080_30fps_8bit_420_aliasing_QP32_FB_4.avi as it is not found in labels.json\n",
+      "Skipping ._BQTerrace_1920x1080_30fps_8bit_420_aliasing_QP32_SF_4.avi as it is not found in labels.json\n",
+      "Skipping ._Seeking_1920x1080_25fps_8bit_420_graininess_QP47_SF_2.avi as it is not found in labels.json\n",
+      "Skipping ._BirdsInCage_1920x1080_30fps_8bit_420_Pristine_QP32_FBT_1.avi as it is not found in labels.json\n",
+      "Skipping ._riverbed_1920x1080_25fps_8bit_420_banding_QP32_S_3.avi as it is not found in labels.json\n",
+      "Skipping ._station_1920x1080_30fps_8bit_420_graininess_QP32_SB_1.avi as it is not found in labels.json\n",
+      "Skipping ._shields_1280x720_50fps_8bit_420_graininess_QP32_SFBT_2.avi as it is not found in labels.json\n",
+      "Skipping ._DanceKiss_1920x1080_25fps_8bit_420_Dark_QP32_SB_4.avi as it is not found in labels.json\n",
+      "Skipping ._DanceKiss_1920x1080_25fps_8bit_420_Dark_QP47_FB_4.avi as it is not found in labels.json\n",
+      "Skipping ._riverbed_1920x1080_25fps_8bit_420_banding_QP47_SFBT_2.avi as it is not found in labels.json\n",
+      "Skipping ._Seeking_1920x1080_25fps_8bit_420_graininess_QP32_SFT_1.avi as it is not found in labels.json\n",
+      "Skipping ._BQTerrace_1920x1080_30fps_8bit_420_aliasing_QP47_FB_3.avi as it is not found in labels.json\n",
+      "Skipping ._shields_1280x720_50fps_8bit_420_graininess_QP47_SBT_3.avi as it is not found in labels.json\n",
+      "Skipping ._BirdsInCage_1920x1080_30fps_8bit_420_Pristine_QP47_SFB_3.avi as it is not found in labels.json\n",
+      "Skipping ._Tennis_1920x1080_24fps_8bit_420_Motion_QP32_BT_1.avi as it is not found in labels.json\n",
+      "Skipping ._ElFuente1_1920x1080_30fps_8bit_420_aliasing_QP47_SFB_1.avi as it is not found in labels.json\n",
+      "Skipping ._OldTownCross_1920x1080_25fps_8bit_420_graininess_QP47_SB_4.avi as it is not found in labels.json\n",
+      "Skipping ._ElFuente2_1920x1080_30fps_8bit_420_graininess_QP32_S_2.avi as it is not found in labels.json\n",
+      "Skipping ._CrowdRun_1920x1080_25fps_8bit_420_aliasing_QP32_SF_1.avi as it is not found in labels.json\n",
+      "Skipping ._ElFuente2_1920x1080_30fps_8bit_420_graininess_QP47_SFB_3.avi as it is not found in labels.json\n",
+      "Skipping ._Kimono1_1920x1080_24fps_8bit_420_graininess_QP47_B_4.avi as it is not found in labels.json\n",
+      "Skipping ._Tennis_1920x1080_24fps_8bit_420_Motion_QP47_SFB_1.avi as it is not found in labels.json\n",
+      "Dataset has been reorganized successfully!\n"
+     ]
+    }
+   ],
+   "execution_count": 10
+  },
+  {
+   "metadata": {},
+   "cell_type": "code",
+   "outputs": [],
+   "execution_count": null,
+   "source": "",
+   "id": "e52181730c5b3138"
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 2
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython2",
+   "version": "2.7.6"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..74a4f29ab4f1aec7e302f2968d5aae5c25c62db0
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,12 @@
+albumentations==1.4.14
+av==12.3.0
+datasets==2.20.0
+numpy==2.1.0
+opencv_python==4.10.0.84
+paramiko==3.4.1
+scikit_learn==1.5.1
+torch==2.4.0
+torchmetrics==1.4.1
+torchvision==0.19.0
+tqdm==4.66.5
+transformers==4.44.0
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/data_prep_utils/__init__.py b/src/data_prep_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/data_prep_utils/data_setup.py b/src/data_prep_utils/data_setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..2122d13d8963579ddc0ff83b612e406ba79d060e
--- /dev/null
+++ b/src/data_prep_utils/data_setup.py
@@ -0,0 +1,31 @@
+"""
+Contains functionality for creating PyTorch DataLoaders
+"""
+
+import os
+
+from torchvision import datasets, transforms
+from torch.utils.data import DataLoader
+
+NUM_WORKERS = os.cpu_count()
+
+
+def create_dataloaders(
+        train_dir: str,
+        test_dir: str,
+        transform: transforms.Compose,
+        batch_size: int,
+        num_workers: int = NUM_WORKERS
+):
+    # todo: implement
+    # return train_dataloader, test_dataloader, class_names
+    pass
+
+
+"""
+Usage:
+from going_modular import data_setup
+
+# Create train/test dataloader and get class names as a list
+train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(...
+"""
diff --git a/src/data_prep_utils/preprocess.py b/src/data_prep_utils/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdee0ea6e9a37e0e8c753f93fe443db80ff63a87
--- /dev/null
+++ b/src/data_prep_utils/preprocess.py
@@ -0,0 +1,92 @@
+import os
+import torch
+import random
+import json
+from torchvision.io import read_video
+from transformers import VideoMAEImageProcessor
+from pathlib import Path
+
+# Load the VideoMAE image processor
+model_ckpt = "MCG-NJU/videomae-base"
+image_processor = VideoMAEImageProcessor.from_pretrained(model_ckpt,
+                                                         do_rescale=False)
+
+
+def random_spatio_temporal_crop(video, num_frames=16, height=224, width=224):
+    T, H, W, C = video.shape
+
+    # Random temporal crop
+    start_frame = random.randint(0, T - num_frames)
+    video = video[start_frame:start_frame + num_frames]
+
+    # Random spatial crop
+    if H > height and W > width:
+        top = random.randint(0, H - height)
+        left = random.randint(0, W - width)
+        video = video[:, top:top + height, left:left + width, :]
+    else:
+        video = torch.nn.functional.interpolate(video.permute(0, 3, 1, 2), size=(height, width)).permute(0, 2, 3, 1)
+
+    return video
+
+
+def preprocess_video(video_path, num_crops=6, num_frames=16, height=224, width=224):
+    video, _, _ = read_video(video_path, pts_unit="sec")
+    video = video.float() / 255.0  # Normalize to [0, 1]
+
+    crops = []
+    for _ in range(num_crops):
+        crop = random_spatio_temporal_crop(video, num_frames, height, width)
+        # Apply VideoMAE preprocessing
+        crop = image_processor(list(crop.permute(0, 3, 1, 2)), return_tensors="pt")["pixel_values"]
+        crops.append(crop.squeeze(0))  # Remove batch dimension
+
+    return torch.stack(crops)  # Stack all crops
+
+
+def main():
+    dataset_root_path = "/Volumes/SSD/BVIArtefact"
+    output_root_path = "/Volumes/SSD/BVIArtefact_preprocessed"
+    os.makedirs(output_root_path, exist_ok=True)
+
+    # Load original labels
+    with open(os.path.join(dataset_root_path, "processed_labels.json"), "r") as f:
+        original_labels = json.load(f)
+
+    # New labels dictionary
+    new_labels = {}
+
+    # Process videos
+    for part in ["part1", "part2"]:
+        part_dir = os.path.join(dataset_root_path, part)
+        for video_name in os.listdir(part_dir):
+            if video_name.endswith('.avi'):
+                video_path = os.path.join(part_dir, video_name)
+
+                if video_name in original_labels:
+                    try:
+                        preprocessed_crops = preprocess_video(video_path)
+
+                        # Save preprocessed video crops
+                        output_name = f"{Path(video_name).stem}_crops.pt"
+                        output_path = os.path.join(output_root_path, output_name)
+                        torch.save(preprocessed_crops, output_path)
+
+                        # Add to new labels dictionary
+                        new_labels[output_name] = original_labels[video_name]
+
+                        print(f"Processed {video_name}")
+                    except Exception as e:
+                        print(f"Error processing {video_name}: {str(e)}")
+                else:
+                    print(f"Skipping {video_name} - not found in labels")
+
+    # Save the new labels
+    with open(os.path.join(output_root_path, "preprocessed_labels.json"), "w") as f:
+        json.dump(new_labels, f)
+
+    print("Preprocessing complete.")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/src/data_prep_utils/resize_bvi_artefact.py b/src/data_prep_utils/resize_bvi_artefact.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca8d064efd2aa3e27881e4c672fbf2dd99048921
--- /dev/null
+++ b/src/data_prep_utils/resize_bvi_artefact.py
@@ -0,0 +1,108 @@
+# resize_bvi_artefact.py
+
+import multiprocessing
+import os
+import re
+import shutil
+
+import ffmpeg
+from tqdm import tqdm
+
+
+def resize_video(input_path, output_path, width=224, height=224):
+    try:
+        (
+            ffmpeg
+            .input(input_path)
+            .filter('scale', width, height)
+            .output(output_path)
+            .overwrite_output()
+            .run(capture_stdout=True, capture_stderr=True)
+        )
+        return None  # Success
+    except ffmpeg.Error as e:
+        return f"Error processing {input_path}: {e.stderr.decode()}"
+
+
+def get_new_filename(old_filename, width, height):
+    pattern = r'(.+)_(\d+x\d+)_(\d+fps)_(.+)\.avi'
+    match = re.match(pattern, old_filename)
+
+    if match:
+        video_name, old_resolution, fps, rest = match.groups()
+        return f"{video_name}_{old_resolution}_to_{width}x{height}_{fps}_{rest}.avi"
+    else:
+        name, ext = os.path.splitext(old_filename)
+        return f"{name}_to_{width}x{height}{ext}"
+
+
+def process_video(args):
+    input_path, output_dir, relative_path, width, height = args
+    file = os.path.basename(input_path)
+    new_filename = get_new_filename(file, width, height)
+    output_path = os.path.join(output_dir, relative_path, new_filename)
+
+    os.makedirs(os.path.dirname(output_path), exist_ok=True)
+    return resize_video(input_path, output_path, width, height)
+
+
+def preprocess_dataset(input_dir, output_dir, width=560, height=560, num_processes=None):
+    if num_processes is None:
+        num_processes = multiprocessing.cpu_count()
+
+    video_files = []
+    for part in ['part1', 'part2']:
+        part_dir = os.path.join(input_dir, part)
+        print(f"Searching for videos in: {part_dir}")
+        if not os.path.exists(part_dir):
+            print(f"Directory not found: {part_dir}")
+            continue
+        for root, _, files in os.walk(part_dir):
+            for file in files:
+                if file.endswith('.avi'):
+                    relative_path = os.path.relpath(root, input_dir)
+                    input_path = os.path.join(root, file)
+                    video_files.append((input_path, output_dir, relative_path, width, height))
+
+    print(f"Found {len(video_files)} video files to process.")
+
+    if not video_files:
+        print("No video files found. Please check the input directory.")
+        return
+
+    with multiprocessing.Pool(processes=num_processes) as pool:
+        results = list(tqdm(pool.imap(process_video, video_files), total=len(video_files), desc="Processing videos"))
+
+    # Print any errors that occurred
+    errors = [error for error in results if error is not None]
+    for error in errors:
+        print(error)
+
+    # Copy json files to the output directory
+    json_files = ['labels.json', 'processed_labels.json', 'subsets.json']
+    for json_file in json_files:
+        src = os.path.join(input_dir, json_file)
+        dst = os.path.join(output_dir, json_file)
+        if os.path.exists(src):
+            shutil.copy2(src, dst)
+        else:
+            print(f"Warning: {json_file} not found in {input_dir}")
+
+    print(f"Preprocessing completed! Processed {len(video_files)} videos with {len(errors)} errors.")
+
+
+if __name__ == "__main__":
+    input_dir = "/Volumes/SSD/BVIArtefact"
+    output_dir = "/Volumes/SSD/preprocessed_BVIArtefact"
+
+    # Get the full path of the current script
+    script_dir = os.path.dirname(os.path.abspath(__file__))
+
+    # Construct full paths for input and output directories
+    input_dir = os.path.join(script_dir, input_dir)
+    output_dir = os.path.join(script_dir, output_dir)
+
+    print(f"Input directory: {input_dir}")
+    print(f"Output directory: {output_dir}")
+
+    preprocess_dataset(input_dir, output_dir)
diff --git a/src/data_prep_utils/split_dataset.py b/src/data_prep_utils/split_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bb73114ceed7e0004e0e99f4e4db55e5541aea0
--- /dev/null
+++ b/src/data_prep_utils/split_dataset.py
@@ -0,0 +1,92 @@
+import random
+import os
+import json
+import shutil
+from collections import defaultdict
+from pathlib import Path
+
+
+def split_dataset(preprocessed_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
+    # Load labels
+    with open(os.path.join(preprocessed_dir, 'preprocessed_labels.json'), 'r') as f:
+        labels = json.load(f)
+
+    # Group crops by artifacts
+    artifact_crops = defaultdict(lambda: {'positive': set(), 'negative': set()})
+    for crop, artifacts in labels.items():
+        for artifact, value in artifacts.items():
+            if value == 1:
+                artifact_crops[artifact]['positive'].add(crop)
+            else:
+                artifact_crops[artifact]['negative'].add(crop)
+
+    # Find the minimum number of crops for any artifact
+    min_pos = min(len(crops['positive']) for crops in artifact_crops.values())
+    min_neg = min(len(crops['negative']) for crops in artifact_crops.values())
+    min_crops = min(min_pos, min_neg) * 2  # Ensure balance between positive and negative
+
+    # Calculate the number of crops for each split
+    train_size = int(min_crops * train_ratio)
+    val_size = int(min_crops * val_ratio)
+    test_size = min_crops - train_size - val_size
+
+    splits = {'train': set(), 'val': set(), 'test': set()}
+    split_artifacts = {split: defaultdict(lambda: {'positive': set(), 'negative': set()}) for split in splits}
+
+    # Distribute crops ensuring balance for each artifact in each split
+    for split, size in [('train', train_size), ('val', val_size), ('test', test_size)]:
+        pos_count = size // 2
+        neg_count = size - pos_count
+
+        for artifact, crops in artifact_crops.items():
+            pos_crops = list(crops['positive'])
+            neg_crops = list(crops['negative'])
+            random.shuffle(pos_crops)
+            random.shuffle(neg_crops)
+
+            for _ in range(pos_count):
+                if pos_crops:
+                    crop = pos_crops.pop()
+                    if crop not in splits['train'] and crop not in splits['val'] and crop not in splits['test']:
+                        splits[split].add(crop)
+                        split_artifacts[split][artifact]['positive'].add(crop)
+
+            for _ in range(neg_count):
+                if neg_crops:
+                    crop = neg_crops.pop()
+                    if crop not in splits['train'] and crop not in splits['val'] and crop not in splits['test']:
+                        splits[split].add(crop)
+                        split_artifacts[split][artifact]['negative'].add(crop)
+
+    # Create directories and move crops
+    preprocessed_dir_path = Path(preprocessed_dir)
+    data_split_path = preprocessed_dir_path.parent / str(preprocessed_dir_path.name + "_split")
+
+    for split, crops in splits.items():
+        os.makedirs(data_split_path / split, exist_ok=True)
+        split_labels = {}
+        for crop in crops:
+            src = os.path.join(preprocessed_dir, crop)
+            dst = os.path.join(data_split_path, split, crop)
+            shutil.copy(src, dst)  # Use copy instead of move to preserve original data
+            split_labels[crop] = labels[crop]
+        with open(os.path.join(data_split_path, split, 'labels.json'), 'w') as f:
+            json.dump(split_labels, f, indent=2)
+
+    print("Dataset split complete")
+    print(f"Train set: {len(splits['train'])} crops")
+    print(f"Validation set: {len(splits['val'])} crops")
+    print(f"Test set: {len(splits['test'])} crops")
+
+    # Print balance information for each artifact in each split
+    for split in splits:
+        print(f"\n{split.capitalize()} set balance:")
+        for artifact in artifact_crops:
+            pos = len(split_artifacts[split][artifact]['positive'])
+            neg = len(split_artifacts[split][artifact]['negative'])
+            print(f"  {artifact}: Positive: {pos}, Negative: {neg}")
+
+
+if __name__ == "__main__":
+    preprocessed_dir = "/Volumes/SSD/BVIArtefact_crops"  # Update this to your preprocessed dataset path
+    split_dataset(preprocessed_dir)
diff --git a/src/data_prep_utils/subset_and_process.py b/src/data_prep_utils/subset_and_process.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7332750ceb7a305efac46ef6b2ea9d487b79ae3
--- /dev/null
+++ b/src/data_prep_utils/subset_and_process.py
@@ -0,0 +1,274 @@
+import os
+import json
+import random
+from collections import Counter
+
+import torch
+import cv2
+import numpy as np
+from pathlib import Path
+from tqdm import tqdm
+import argparse
+from sklearn.model_selection import train_test_split
+
+# Argument parser
+parser = argparse.ArgumentParser(description='Preprocess BVIArtefact dataset')
+parser.add_argument('--input_dir', type=str, default="/Volumes/SSD/BVIArtefact",
+                    help='Input directory containing BVIArtefact dataset')
+parser.add_argument('--output_dir', type=str, default="/Volumes/SSD/BVIArtefact_8_crops_all_videos",
+                    help='Output directory for preprocessed data')
+parser.add_argument('--num_samples', type=int, default=None, help='Number of videos to sample (None for all)')
+parser.add_argument('--crop_size', type=int, default=224, help='Size of spatial crop')
+parser.add_argument('--num_frames', type=int, default=8, help='Number of frames to extract')
+parser.add_argument('--crops_per_video', type=int, default=4, help='Number of crops to extract per video')
+parser.add_argument('--train_ratio', type=float, default=0.7, help='Ratio of videos for training set')
+parser.add_argument('--val_ratio', type=float, default=0.15, help='Ratio of videos for validation set')
+args = parser.parse_args()
+
+# Configuration
+INPUT_DIR = args.input_dir
+OUTPUT_DIR = args.output_dir
+LABELS_FILE = os.path.join(INPUT_DIR, "labels.json")
+CROP_SIZE = (args.crop_size, args.crop_size)
+NUM_FRAMES = args.num_frames
+NUM_CROPS_PER_VIDEO = args.crops_per_video
+
+random.seed(42)
+
+# Create output directories
+for split in ['train', 'val', 'test']:
+    os.makedirs(os.path.join(OUTPUT_DIR, split), exist_ok=True)
+
+# Load labels
+with open(LABELS_FILE, 'r') as f:
+    labels = json.load(f)
+
+
+def parse_size(size_str):
+    """Convert size string to bytes"""
+    size = float(size_str[:-1])
+    unit = size_str[-1]
+    if unit == 'G':
+        return int(size * 1e9)
+    elif unit == 'M':
+        return int(size * 1e6)
+    else:
+        return int(size)
+
+
+def read_file_sizes(filename):
+    """Read file sizes from text file"""
+    sizes = {}
+    with open(filename, 'r') as f:
+        for line in f:
+            parts = line.strip().split()
+            if len(parts) == 2:
+                sizes[parts[0]] = parse_size(parts[1])
+    return sizes
+
+
+def extract_random_crop(frames, num_frames, crop_size):
+    """Extract a random spatio-temporal crop from the frames."""
+    t, h, w, _ = frames.shape
+
+    if t < num_frames:
+        raise ValueError(f"Video has fewer frames ({t}) than required ({num_frames})")
+
+    start_frame = random.randint(0, t - num_frames)
+    top = random.randint(0, h - crop_size[0])
+    left = random.randint(0, w - crop_size[1])
+
+    crop = frames[start_frame:start_frame + num_frames,
+           top:top + crop_size[0],
+           left:left + crop_size[1]]
+
+    return crop
+
+
+def normalize(video, mean, std):
+    """Normalize the video tensor"""
+    mean = torch.tensor(mean).view(1, 3, 1, 1)
+    std = torch.tensor(std).view(1, 3, 1, 1)
+    return (video - mean) / std
+
+
+def process_videos(video_list, split):
+    """Process videos and save crops for a specific split"""
+    preprocessed_labels = {}
+    label_counts = Counter()
+    total_crops = 0
+
+    for video_file, video_name in tqdm(video_list, desc=f"Processing {split} set"):
+        video_path = os.path.join(INPUT_DIR, video_file)
+
+        # Skip if video is not in labels
+        if video_name not in labels:
+            print(f"Skipping {video_file}: No labels found")
+            continue
+
+        video_labels = labels[video_name]
+
+        try:
+            # Read video
+            cap = cv2.VideoCapture(video_path)
+            frames = []
+            while len(frames) < NUM_FRAMES * 2:  # Read more frames than needed
+                ret, frame = cap.read()
+                if not ret:
+                    break
+                frames.append(frame)
+            cap.release()
+
+            if len(frames) < NUM_FRAMES:
+                print(f"Warning: {video_file} has fewer than {NUM_FRAMES} frames. Skipping.")
+                continue
+
+            frames = np.array(frames)
+
+            for i in range(NUM_CROPS_PER_VIDEO):
+                # Extract random crop
+                crop = extract_random_crop(frames, NUM_FRAMES, CROP_SIZE)
+
+                # Convert to torch tensor and normalize
+                crop = torch.from_numpy(crop).permute(0, 3, 1, 2).float() / 255.0
+
+                # Normalize using ImageNet stats
+                crop = normalize(crop, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+                # Generate unique filename for the crop
+                crop_filename = f"{Path(video_name).stem}_crop_{i}.pt"
+                crop_path = os.path.join(OUTPUT_DIR, split, crop_filename)
+
+                # Save crop as .pt file
+                torch.save(crop, crop_path)
+
+                # Store labels for the crop
+                preprocessed_labels[crop_filename] = video_labels
+
+                total_crops += 1
+
+            # Update label counts
+            for artifact, present in video_labels.items():
+                if present == 1:
+                    label_counts[f"{artifact}_Positive"] += NUM_CROPS_PER_VIDEO
+                else:
+                    label_counts[f"{artifact}_Negative"] += NUM_CROPS_PER_VIDEO
+
+        except Exception as e:
+            print(f"Error processing {video_file}: {str(e)}")
+
+    # Save preprocessed labels
+    labels_path = os.path.join(OUTPUT_DIR, split, "labels.json")
+    with open(labels_path, 'w') as f:
+        json.dump(preprocessed_labels, f, indent=4)
+
+    print(f"\n{split} set statistics:")
+    print(f"Total crops generated: {total_crops}")
+    print(f"Number of entries in labels JSON: {len(preprocessed_labels)}")
+
+    # Check if numbers match
+    if total_crops == len(preprocessed_labels):
+        print("✅ Numbers match!")
+    else:
+        print("❌ Numbers don't match. There might be an issue.")
+
+    return label_counts, total_crops
+
+
+def check_split_overlap(output_dir):
+    splits = ['train', 'val', 'test']
+    parent_videos = {split: set() for split in splits}
+
+    for split in splits:
+        labels_path = Path(output_dir) / split / "labels.json"
+        with open(labels_path, 'r') as f:
+            labels = json.load(f)
+
+        for crop_filename in labels.keys():
+            # Extract parent video name by removing the "_crop_{i}.pt" suffix
+            parent_video = crop_filename.rsplit('_crop_', 1)[0]
+            parent_videos[split].add(parent_video)
+
+    # Check for overlap between splits
+    for i, split1 in enumerate(splits):
+        for split2 in splits[i + 1:]:
+            overlap = parent_videos[split1].intersection(parent_videos[split2])
+            if overlap:
+                print(f"❌ Overlap found between {split1} and {split2} splits:")
+                print(f"   Common parent videos: {overlap}")
+            else:
+                print(f"✅ No overlap found between {split1} and {split2} splits")
+
+    # Print summary
+    print("\nSummary:")
+    for split in splits:
+        print(f"{split} split: {len(parent_videos[split])} unique parent videos")
+
+
+def print_label_balance(label_counts, split_name):
+    print(f"\n{split_name} set balance:")
+    artifacts = ['black_screen', 'frame_drop', 'spatial_blur', 'transmission_error', 'aliasing', 'banding',
+                 'dark_scenes', 'graininess', 'motion_blur']
+    for artifact in artifacts:
+        positive = label_counts[f"{artifact}_Positive"]
+        negative = label_counts[f"{artifact}_Negative"]
+        print(f"    {artifact}: Positive: {positive}, Negative: {negative}")
+
+
+# Read file sizes
+part1_sizes = read_file_sizes(os.path.join(INPUT_DIR, "part1_files_sizes.txt"))
+part2_sizes = read_file_sizes(os.path.join(INPUT_DIR, "part2_files_sizes.txt"))
+
+all_sizes = {**part1_sizes, **part2_sizes}
+
+# Sort videos by size
+sorted_videos = sorted(all_sizes.items(), key=lambda x: x[1])
+
+# Sample videos if num_samples is specified
+if args.num_samples is not None:
+    sampled_videos = sorted_videos[:args.num_samples]
+else:
+    sampled_videos = sorted_videos
+
+# Extract video files and their corresponding folders
+video_files = [(os.path.join('part1' if f in part1_sizes else 'part2', f), f) for f, _ in sampled_videos]
+
+# Split videos into train, validation, and test sets
+train_videos, temp_videos = train_test_split(video_files, train_size=args.train_ratio, random_state=42)
+val_ratio = args.val_ratio / (1 - args.train_ratio)
+val_videos, test_videos = train_test_split(temp_videos, train_size=val_ratio, random_state=42)
+
+# Modify the main part of the script to use the updated function
+train_label_counts, train_crops = process_videos(train_videos, 'train')
+val_label_counts, val_crops = process_videos(val_videos, 'val')
+test_label_counts, test_crops = process_videos(test_videos, 'test')
+
+# Add a final summary
+print("\nFinal Summary:")
+print(f"Total crops - Train: {train_crops}, Val: {val_crops}, Test: {test_crops}")
+total_crops = train_crops + val_crops + test_crops
+print(f"Total crops across all splits: {total_crops}")
+
+# Check total number of label entries
+train_labels = json.load(open(os.path.join(OUTPUT_DIR, 'train', 'labels.json')))
+val_labels = json.load(open(os.path.join(OUTPUT_DIR, 'val', 'labels.json')))
+test_labels = json.load(open(os.path.join(OUTPUT_DIR, 'test', 'labels.json')))
+
+total_label_entries = len(train_labels) + len(val_labels) + len(test_labels)
+print(f"Total label entries across all splits: {total_label_entries}")
+
+if total_crops == total_label_entries:
+    print("✅ Total crops match total label entries!")
+else:
+    print("❌ Total crops and total label entries don't match. There might be an issue.")
+
+print_label_balance(train_label_counts, "Train")
+print_label_balance(val_label_counts, "Val")
+print_label_balance(test_label_counts, "Test")
+
+check_split_overlap(OUTPUT_DIR)
+
+print("Preprocessing completed.")
+
+# sample usage of this script:
+# python src/subset_and_process.py --input_dir /Volumes/SSD/BVIArtefact --output_dir /Volumes/SSD/BVIArtefact_crops --num_samples 100 --crop_size 224 --num_frames 8 --crops_per_video 2 --train_ratio 0.7 --val_ratio 0.15
diff --git a/src/data_prep_utils/subset_data.py b/src/data_prep_utils/subset_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..f30aa02daf86311a08b964cd1ac31e960888bbc8
--- /dev/null
+++ b/src/data_prep_utils/subset_data.py
@@ -0,0 +1,158 @@
+import argparse
+import json
+import os
+import shutil
+from collections import defaultdict
+from pathlib import Path
+from tqdm import tqdm
+from src.data_prep_utils.split_dataset import split_dataset
+
+# Configuration
+local_labels_path = 'data/bviArtefactMetaInfo/processed_labels.json'
+artefacts_to_choose = ['graininess', 'aliasing', 'banding', 'motion_blur']  # Add more labels as needed
+size_limit_gb = 4  # Size limit in GB
+
+part1_sizes_path = 'data/bviArtefactMetaInfo/part1_files_sizes.txt'
+part2_sizes_path = 'data/bviArtefactMetaInfo/part2_files_sizes.txt'
+
+
+def convert_to_bytes(size_str):
+    size_unit = size_str[-1]
+    size_value = float(size_str[:-1])
+    if size_unit == 'G':
+        return int(size_value * 1e9)
+    elif size_unit == 'M':
+        return int(size_value * 1e6)
+    elif size_unit == 'K':
+        return int(size_value * 1e3)
+    else:
+        return int(size_value)
+
+
+def load_file_sizes(file_path):
+    file_sizes = {}
+    with open(file_path, 'r') as f:
+        for line in f:
+            parts = line.strip().split()
+            file_name = parts[0]
+            file_size = convert_to_bytes(parts[1])
+            file_sizes[file_name] = file_size
+    return file_sizes
+
+
+def get_balanced_videos(labels, artefacts, size_limit):
+    video_labels = defaultdict(dict)
+    for video, details in labels.items():
+        for artefact in artefacts:
+            video_labels[video][artefact] = details.get(artefact, 0)
+
+    # Separate positive and negative videos
+    positive_videos = [v for v, l in video_labels.items() if all(l[a] == 1 for a in artefacts)]
+    negative_videos = [v for v, l in video_labels.items() if all(l[a] == 0 for a in artefacts)]
+
+    # Sort videos by size (smallest to largest)
+    positive_videos.sort(key=lambda x: file_sizes.get(x, 0))
+    negative_videos.sort(key=lambda x: file_sizes.get(x, 0))
+
+    balanced_videos = []
+    total_size = 0
+
+    print(f"Size limit: {size_limit / 1e9:.2f} GB")
+    print(f"Total positive videos available: {len(positive_videos)}")
+    print(f"Total negative videos available: {len(negative_videos)}")
+
+    # Select videos while maintaining balance and respecting size limit
+    for pos, neg in zip(positive_videos, negative_videos):
+        pos_size = file_sizes.get(pos, 0)
+        neg_size = file_sizes.get(neg, 0)
+
+        if total_size + pos_size + neg_size <= size_limit:
+            balanced_videos.extend([pos, neg])
+            total_size += pos_size + neg_size
+        else:
+            break
+
+    final_subset = {video: video_labels[video] for video in balanced_videos}
+
+    final_size = sum(file_sizes.get(video, 0) for video in final_subset)
+    print(f"\nFinal balanced dataset:")
+    print(f"Size: {final_size / 1e9:.2f} GB")
+    print(f"Total videos: {len(final_subset)}")
+    print(f"Positive videos: {len(final_subset) // 2}")
+    print(f"Negative videos: {len(final_subset) // 2}")
+
+    return final_subset
+
+
+def copy_videos_local(subset_videos, source_base_path, destination_base_path):
+    progress_bar = tqdm(total=len(subset_videos), desc="Copying videos", unit="file", dynamic_ncols=True)
+
+    for video in subset_videos:
+        found = False
+        for part in ['part1', 'part2']:
+            source_path = os.path.join(source_base_path, part, video)
+            destination_path = os.path.join(destination_base_path, video)
+            if os.path.exists(source_path):
+                progress_bar.set_postfix(file=video)
+                shutil.copy2(source_path, destination_path)
+                found = True
+                break
+        if not found:
+            print(f"Video {video} not found in either part1 or part2.")
+        progress_bar.update(1)
+
+    progress_bar.close()
+
+
+def main():
+    parser = argparse.ArgumentParser(description="Create a balanced subset of videos for multi-label classification.")
+    parser.add_argument("--local", help="Path to local bviDataset folder", type=str, required=True)
+    parser.add_argument("--size_limit", help="Size limit in GB", type=float, default=2.0)
+    args = parser.parse_args()
+
+    global size_limit_gb
+    size_limit_gb = args.size_limit
+
+    # Load file sizes
+    part1_file_sizes = load_file_sizes(part1_sizes_path)
+    part2_file_sizes = load_file_sizes(part2_sizes_path)
+    global file_sizes
+    file_sizes = {**part1_file_sizes, **part2_file_sizes}
+
+    # Load labels
+    with open(local_labels_path, 'r') as f:
+        labels = json.load(f)
+
+    size_limit_bytes = size_limit_gb * 1e9
+    balanced_subset = get_balanced_videos(labels, artefacts_to_choose, size_limit_bytes)
+
+    # Create the local download directory
+    local_download_dir = f'/Volumes/SSD/subsets/{"_".join([art for art in artefacts_to_choose])}_subset_{int(size_limit_gb)}_GB'
+    os.makedirs(local_download_dir, exist_ok=True)
+
+    # Save the subset list locally
+    subset_file_path = f'{local_download_dir}/labels.json'
+    with open(subset_file_path, 'w') as f:
+        json.dump(balanced_subset, f, indent=4)
+
+    print(f"Balanced subset saved to {subset_file_path}")
+
+    # Verify the balance of the subset labels
+    for artefact in artefacts_to_choose:
+        presence_count = sum(1 for labels in balanced_subset.values() if labels[artefact] == 1)
+        absence_count = sum(1 for labels in balanced_subset.values() if labels[artefact] == 0)
+        print(f"{artefact}:")
+        print(f"  Presence count: {presence_count}")
+        print(f"  Absence count: {absence_count}")
+
+    # Use local dataset
+    print(f"Using local dataset at: {args.local}")
+    copy_videos_local(balanced_subset.keys(), args.local, local_download_dir)
+
+    print(f"All raw videos copied to {local_download_dir}")
+
+    split_dataset(local_download_dir)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/src/data_prep_utils/subset_processed_dataset.py b/src/data_prep_utils/subset_processed_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a28d06d4fbd9999ab6a3e1e4280582f071c3448
--- /dev/null
+++ b/src/data_prep_utils/subset_processed_dataset.py
@@ -0,0 +1,113 @@
+import argparse
+import json
+import os
+import random
+import shutil
+from collections import defaultdict
+
+from tqdm import tqdm
+
+
+def load_labels(labels_path):
+    with open(labels_path, 'r') as f:
+        return json.load(f)
+
+
+def get_balanced_subset(labels, artefacts, count_per_label):
+    video_labels = defaultdict(dict)
+    for video, details in labels.items():
+        for artefact in artefacts:
+            video_labels[video][artefact] = details.get(artefact, 0)
+
+    final_subset = {}
+    artefact_counts = {artefact: {'positive': 0, 'negative': 0} for artefact in artefacts}
+
+    # Shuffle videos to ensure random selection
+    shuffled_videos = list(video_labels.keys())
+    random.shuffle(shuffled_videos)
+
+    for video in shuffled_videos:
+        include_video = True
+        for artefact in artefacts:
+            label = video_labels[video][artefact]
+            if label == 1 and artefact_counts[artefact]['positive'] >= count_per_label:
+                include_video = False
+                break
+            elif label == 0 and artefact_counts[artefact]['negative'] >= count_per_label:
+                include_video = False
+                break
+
+        if include_video:
+            final_subset[video] = video_labels[video]
+            for artefact in artefacts:
+                if video_labels[video][artefact] == 1:
+                    artefact_counts[artefact]['positive'] += 1
+                else:
+                    artefact_counts[artefact]['negative'] += 1
+
+        # Check if we have reached the target count for all artefacts
+        if all(counts['positive'] >= count_per_label and counts['negative'] >= count_per_label
+               for counts in artefact_counts.values()):
+            break
+
+    return final_subset
+
+
+def copy_videos(videos, src_dir, dst_dir):
+    os.makedirs(dst_dir, exist_ok=True)
+    for video in tqdm(videos, desc=f"Copying to {os.path.basename(dst_dir)}"):
+        src_path_part1 = os.path.join(src_dir, 'part1', video)
+        src_path_part2 = os.path.join(src_dir, 'part2', video)
+        dst_path = os.path.join(dst_dir, video)
+
+        if os.path.exists(src_path_part1):
+            shutil.copy2(src_path_part1, dst_path)
+        elif os.path.exists(src_path_part2):
+            shutil.copy2(src_path_part2, dst_path)
+        else:
+            print(f"Warning: Video {video} not found in either part1 or part2.")
+
+
+def main():
+    parser = argparse.ArgumentParser(description="Create a balanced subset of videos and relocate them.")
+    parser.add_argument("--input_dir", type=str, required=True, help="Path to processed_BVIArtefact folder")
+    parser.add_argument("--output_dir", type=str, required=True, help="Path to output directory")
+    parser.add_argument("--count_per_label", type=int, default=500,
+                        help="Number of videos per label (positive/negative)")
+    args = parser.parse_args()
+
+    # Load labels
+    labels_path = os.path.join(args.input_dir, 'processed_labels.json')
+    labels = load_labels(labels_path)
+
+    # Define artefacts
+    artefacts = ['']  # Add more labels as needed
+
+    # Get balanced subset
+    balanced_subset = get_balanced_subset(labels, artefacts, args.count_per_label)
+
+    # Copy videos to output directory
+    copy_videos(balanced_subset.keys(), args.input_dir, args.output_dir)
+
+    # Save the subset labels
+    subset_labels_path = os.path.join(args.output_dir, 'labels.json')
+    with open(subset_labels_path, 'w') as f:
+        json.dump(balanced_subset, f, indent=4)
+
+    print(f"Balanced subset created in {args.output_dir}")
+    print(f"Total videos in subset: {len(balanced_subset)}")
+
+    # Verify the balance of the subset labels
+    for artefact in artefacts:
+        presence_count = sum(1 for labels in balanced_subset.values() if labels[artefact] == 1)
+        absence_count = sum(1 for labels in balanced_subset.values() if labels[artefact] == 0)
+        print(f"{artefact}:")
+        print(f"  Presence count: {presence_count}")
+        print(f"  Absence count: {absence_count}")
+
+
+if __name__ == "__main__":
+    main()
+
+    # sample usage of the script
+    # python subset_processed_dataset.py --input_dir /Volumes/SSD/preprocessed_BVIArtefact --output_dir /Volumes/SSD/balanced_subset --count_per_label 500
diff --git a/src/data_prep_utils/subset_random.py b/src/data_prep_utils/subset_random.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2b94521837756d5b937b420067fc08592cc1e28
--- /dev/null
+++ b/src/data_prep_utils/subset_random.py
@@ -0,0 +1,82 @@
+import os
+import json
+import random
+
+
+def get_file_sizes(file_path):
+    sizes = {}
+    with open(file_path, 'r') as f:
+        for line in f:
+            parts = line.strip().split()
+            if len(parts) == 2:
+                filename, size = parts
+                sizes[filename] = int(size[:-1])  # Remove 'M' and convert to int
+    return sizes
+
+
+def create_dataset(labels_file, part1_sizes, part2_sizes, target_size_gb):
+    # Load labels
+    with open(labels_file, 'r') as f:
+        labels = json.load(f)
+
+    # Combine file sizes
+    all_sizes = {**part1_sizes, **part2_sizes}
+
+    # Create a list of (filename, size) tuples, sorted by size
+    sorted_files = sorted(all_sizes.items(), key=lambda x: x[1])
+
+    target_size_mb = target_size_gb * 1024
+    selected_files = []
+    current_size = 0
+
+    # Randomly select files, prioritizing smaller ones
+    while current_size < target_size_mb and sorted_files:
+        # Randomly choose from the smallest 10% of remaining files
+        chunk_size = max(1, len(sorted_files) // 10)
+        chosen_file, file_size = random.choice(sorted_files[:chunk_size])
+
+        if chosen_file in labels and (current_size + file_size) <= target_size_mb:
+            selected_files.append(chosen_file)
+            current_size += file_size
+
+        sorted_files.remove((chosen_file, file_size))
+
+    # Create a new labels dictionary with only the selected files
+    selected_labels = {file: labels[file] for file in selected_files if file in labels}
+
+    return selected_files, selected_labels, current_size / 1024  # Convert back to GB
+
+
+# File paths
+labels_file = '/Volumes/SSD/BVIArtefact/processed_labels.json'
+part1_sizes_file = '/Volumes/SSD/BVIArtefact/part1_files_sizes.txt'
+part2_sizes_file = '/Volumes/SSD/BVIArtefact/part1_files_sizes.txt'
+
+# Target dataset size in GB
+target_size_gb = 2  # Change this to your desired size
+
+# Get file sizes
+part1_sizes = get_file_sizes(part1_sizes_file)
+part2_sizes = get_file_sizes(part2_sizes_file)
+
+# Create the dataset
+selected_files, selected_labels, actual_size_gb = create_dataset(
+    labels_file, part1_sizes, part2_sizes, target_size_gb
+)
+
+# Print results
+print(f"Selected {len(selected_files)} files")
+print(f"Total size: {actual_size_gb:.2f} GB")
+
+# Save the new labels to a file
+output_dir = '/Volumes/SSD/BVIArtefact'
+with open(os.path.join(output_dir, 'selected_labels.json'), 'w') as f:
+    json.dump(selected_labels, f, indent=2)
+
+# Save the list of selected files
+with open(os.path.join(output_dir, 'selected_files.txt'), 'w') as f:
+    for file in selected_files:
+        f.write(f"{file}\n")
+
+print("Selected labels saved to 'selected_labels.json'")
+print("Selected files list saved to 'selected_files.txt'")
diff --git a/src/plots.py b/src/plots.py
new file mode 100644
index 0000000000000000000000000000000000000000..baec0570115a0767f7a24068c58f5cf6c9447e1d
--- /dev/null
+++ b/src/plots.py
@@ -0,0 +1,123 @@
+import json
+import matplotlib.pyplot as plt
+import seaborn as sns
+import pandas as pd
+import numpy as np
+import os
+
+
+def load_json_labels(file_path):
+    with open(file_path, 'r') as f:
+        return json.load(f)
+
+
+def create_label_df(json_data):
+    return pd.DataFrame.from_dict(json_data, orient='index')
+
+
+def plot_label_balance_stacked(df, title, save_path):
+    """
+    Plot the positive/negative balance for each label using stacked bars and save as PNG.
+    """
+    label_balance = df.mean()
+    label_balance_negative = 1 - label_balance
+
+    plt.figure(figsize=(14, 6))
+    bar_width = 0.8
+
+    labels = label_balance.index
+    pos_bars = plt.bar(labels, label_balance, bar_width, label='Positive', color='#2ecc71')
+    neg_bars = plt.bar(labels, label_balance_negative, bar_width, bottom=label_balance, label='Negative',
+                       color='#e74c3c')
+
+    plt.title(f'Label Balance - {title}')
+    plt.xlabel('Labels')
+    plt.ylabel('Proportion')
+    plt.legend(title='Class')
+    plt.xticks(rotation=45, ha='right')
+
+    # Add percentage labels on the bars
+    for i, (pos, neg) in enumerate(zip(label_balance, label_balance_negative)):
+        plt.text(i, pos / 2, f'{pos:.1%}', ha='center', va='center', color='white', fontweight='bold')
+        plt.text(i, pos + neg / 2, f'{neg:.1%}', ha='center', va='center', color='white', fontweight='bold')
+
+    plt.tight_layout()
+    plt.savefig(save_path, dpi=300, bbox_inches='tight')
+    plt.close()
+
+
+def plot_label_distribution_across_splits_stacked(train_df, val_df, test_df, save_path):
+    """
+    Plot the distribution of positive and negative labels across train, validation, and test splits and save as PNG.
+    """
+    train_dist = train_df.mean()
+    val_dist = val_df.mean()
+    test_dist = test_df.mean()
+
+    df = pd.DataFrame({
+        'Train Positive': train_dist,
+        'Train Negative': 1 - train_dist,
+        'Validation Positive': val_dist,
+        'Validation Negative': 1 - val_dist,
+        'Test Positive': test_dist,
+        'Test Negative': 1 - test_dist
+    })
+
+    plt.figure(figsize=(16, 6))
+    df.plot(kind='bar', stacked=True, width=0.8)
+    plt.title('Label Distribution Across Splits')
+    plt.xlabel('Labels')
+    plt.ylabel('Proportion')
+    plt.xticks(rotation=45, ha='right')
+    plt.legend(title='Split and Class', bbox_to_anchor=(1.05, 1), loc='upper left')
+    plt.tight_layout()
+    plt.savefig(save_path, dpi=300, bbox_inches='tight')
+    plt.close()
+
+
+def plot_sample_counts(train_df, val_df, test_df, save_path):
+    """
+    Plot the number of samples in each split and save as PNG.
+    """
+    counts = [len(train_df), len(val_df), len(test_df)]
+    splits = ['Train', 'Validation', 'Test']
+
+    plt.figure(figsize=(4, 6))
+    bars = plt.bar(splits, counts)
+    plt.title('Number of Samples in Each Split')
+    plt.ylabel('Number of Samples')
+
+    # Add value labels on the bars
+    for bar in bars:
+        height = bar.get_height()
+        plt.text(bar.get_x() + bar.get_width() / 2., height,
+                 f'{height:,}',
+                 ha='center', va='bottom')
+
+    plt.tight_layout()
+    plt.savefig(save_path, dpi=300, bbox_inches='tight')
+    plt.close()
+
+
+# Get the directory of the JSON files
+json_dir = os.path.dirname('/Volumes/SSD/BVIArtefact_8_crops_all_videos/train_labels.json')
+
+# Load the data
+train_data = load_json_labels(os.path.join(json_dir, 'train_labels.json'))
+val_data = load_json_labels(os.path.join(json_dir, 'val_labels.json'))
+test_data = load_json_labels(os.path.join(json_dir, 'test_labels.json'))
+
+# Create DataFrames
+train_df = create_label_df(train_data)
+val_df = create_label_df(val_data)
+test_df = create_label_df(test_data)
+
+# Generate and save plots
+plot_label_balance_stacked(train_df, 'Train Set', os.path.join(json_dir, 'label_balance_train.png'))
+plot_label_balance_stacked(val_df, 'Validation Set', os.path.join(json_dir, 'label_balance_val.png'))
+plot_label_balance_stacked(test_df, 'Test Set', os.path.join(json_dir, 'label_balance_test.png'))
+plot_label_distribution_across_splits_stacked(train_df, val_df, test_df,
+                                              os.path.join(json_dir, 'label_distribution_across_splits.png'))
+plot_sample_counts(train_df, val_df, test_df, os.path.join(json_dir, 'sample_counts.png'))
+
+print(f"Plots have been saved in the directory: {json_dir}")