diff --git a/i3D/gtransforms.py b/i3D/gtransforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..aef89b83417602e82c197ded0537a9ecaad1ec07
--- /dev/null
+++ b/i3D/gtransforms.py
@@ -0,0 +1,179 @@
+# Borrowed from: https://github.com/yjxiong/tsn-pytorch/blob/master/transforms.py
+
+import torchvision
+import random
+from PIL import Image
+import numbers
+import torch
+import torchvision.transforms.functional as F
+
+
+class GroupResize(object):
+    def __init__(self, size, interpolation=Image.BILINEAR):
+        self.worker = torchvision.transforms.Resize(size, interpolation)
+
+    def __call__(self, img_group):
+        return [self.worker(img) for img in img_group]
+
+
+class GroupRandomCrop(object):
+    def __init__(self, size):
+        if isinstance(size, numbers.Number):
+            self.size = (int(size), int(size))
+        else:
+            self.size = size
+
+    def __call__(self, img_group):
+
+        w, h = img_group[0].size
+        th, tw = self.size
+
+        out_images = list()
+
+        x1 = random.randint(0, w - tw)
+        y1 = random.randint(0, h - th)
+
+        for img in img_group:
+            assert (img.size[0] == w and img.size[1] == h)
+            if w == tw and h == th:
+                out_images.append(img)
+            else:
+                out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
+
+        return out_images
+
+
+class GroupCenterCrop(object):
+    def __init__(self, size):
+        self.worker = torchvision.transforms.CenterCrop(size)
+
+    def __call__(self, img_group):
+        return [self.worker(img) for img in img_group]
+
+
+class GroupRandomHorizontalFlip(object):
+    def __call__(self, img_group):
+        if random.random() < 0.5:
+            img_group = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
+        return img_group
+
+
+class GroupNormalize(object):
+    def __init__(self, mean, std):
+        self.mean = mean
+        self.std = std
+
+    def __call__(self, tensor):  # (T, 3, 224, 224)
+        for b in range(tensor.size(0)):
+            for t, m, s in zip(tensor[b], self.mean, self.std):
+                t.sub_(m).div_(s)
+        return tensor
+
+
+class LoopPad(object):
+
+    def __init__(self, max_len):
+        self.max_len = max_len
+
+    def __call__(self, tensor):
+        length = tensor.size(0)
+
+        if length == self.max_len:
+            return tensor
+
+        # repeat the clip as many times as is necessary
+        n_pad = self.max_len - length
+        pad = [tensor] * (n_pad // length)
+        if n_pad % length > 0:
+            pad += [tensor[0:n_pad % length]]
+
+        tensor = torch.cat([tensor] + pad, 0)
+        return tensor
+
+
+# NOTE: Returns [0-255] rather than torchvision's [0-1]
+class ToTensor(object):
+    def __init__(self):
+        self.worker = lambda x: F.to_tensor(x) * 255
+
+    def __call__(self, img_group):
+        img_group = [self.worker(img) for img in img_group]
+        return torch.stack(img_group, 0)
+
+class GroupMultiScaleCrop(object):
+    def __init__(self, output_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True,
+                 center_crop_only=False):
+        self.scales = scales if scales is not None else [1, .875, .75, .66]
+        self.max_distort = max_distort
+        self.fix_crop = fix_crop
+        self.more_fix_crop = more_fix_crop
+        self.center_crop_only = center_crop_only
+        assert center_crop_only is False or max_distort == 0 and len(self.scales) == 1, \
+            'Center crop should only be performed during testing time.'
+        self.output_size = output_size if not isinstance(output_size, int) else [output_size, output_size]
+        self.interpolation = Image.BILINEAR
+
+    def __call__(self, img_group):
+
+        im_size = img_group[0].size
+
+        crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
+        crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
+        ret_img_group = [img.resize((self.output_size[0], self.output_size[1]), self.interpolation)
+                         for img in crop_img_group]
+        return ret_img_group, (offset_h, offset_w, crop_h, crop_w)
+
+    def _sample_crop_size(self, im_size):
+        image_w, image_h = im_size[0], im_size[1]
+
+        # find a crop size
+        base_size = min(image_w, image_h)
+        crop_sizes = [int(base_size * x) for x in self.scales]
+        crop_h = [self.output_size[1] if abs(x - self.output_size[1]) < 3 else x for x in crop_sizes]
+        crop_w = [self.output_size[0] if abs(x - self.output_size[0]) < 3 else x for x in crop_sizes]
+
+        pairs = []
+        for i, h in enumerate(crop_h):
+            for j, w in enumerate(crop_w):
+                if abs(i - j) <= self.max_distort:
+                    pairs.append((w, h))
+
+        crop_pair = random.choice(pairs)
+        if not self.fix_crop:
+            w_offset = random.randint(0, image_w - crop_pair[0])
+            h_offset = random.randint(0, image_h - crop_pair[1])
+        else:
+            w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
+
+        return crop_pair[0], crop_pair[1], w_offset, h_offset
+
+    def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
+        offsets = self.fill_fix_offset(self.center_crop_only, self.more_fix_crop, image_w, image_h, crop_w, crop_h)
+        return random.choice(offsets)
+
+    @staticmethod
+    def fill_fix_offset(center_crop_only, more_fix_crop, image_w, image_h, crop_w, crop_h):
+        w_step = (image_w - crop_w) // 4
+        h_step = (image_h - crop_h) // 4
+
+        ret = list()
+        ret.append((0, 0))  # upper left
+        ret.append((2 * w_step, 2 * h_step))  # center
+        if center_crop_only:
+            return ret
+        ret.append((4 * w_step, 0))  # upper right
+        ret.append((0, 4 * h_step))  # lower left
+        ret.append((4 * w_step, 4 * h_step))  # lower right
+
+        if more_fix_crop:
+            ret.append((0, 2 * h_step))  # center left
+            ret.append((4 * w_step, 2 * h_step))  # center right
+            ret.append((2 * w_step, 4 * h_step))  # lower center
+            ret.append((2 * w_step, 0 * h_step))  # upper center
+
+            ret.append((1 * w_step, 1 * h_step))  # upper left quarter
+            ret.append((3 * w_step, 1 * h_step))  # upper right quarter
+            ret.append((1 * w_step, 3 * h_step))  # lower left quarter
+            ret.append((3 * w_step, 3 * h_step))  # lower righ quarter
+
+        return ret
diff --git a/i3D/model.py b/i3D/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..55cb83dfbccb3c112337f89ae7270e1e8b0da3d0
--- /dev/null
+++ b/i3D/model.py
@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+from i3D.resnet3d_xl import Net
+import torch.nn.functional as F
+'''
+Video Classification Model library.
+'''
+
+class TrainingScheduleError(Exception):
+    pass
+
+class VideoModel(nn.Module):
+    def __init__(self,
+                 num_classes,
+                 num_boxes,
+                 num_videos=16,
+                 restore_dict=None,
+                 freeze_weights=None,
+                 device=None,
+                 loss_type='softmax'):
+        super(VideoModel, self).__init__()
+        self.device = device
+        self.num_frames = num_videos
+        self.num_classes = num_classes
+        # Network loads kinetic pre-trained weights in initialization
+        self.i3D = Net(num_classes, extract_features=True, loss_type=loss_type)
+
+
+        try:
+            # Restore weights
+            if restore_dict:
+                self.restore(restore_dict)
+            # Freeze weights
+            if freeze_weights:
+                self.freeze_weights(freeze_weights)
+            else:
+                print(" > No weights are freezed")
+        except Exception as e:
+            print(" > Exception {}".format(e))
+
+    def restore(self, restore=None):
+        # Load pre-trained I3D + Graph weights for fine-tune (replace the last FC)
+        restore_finetuned = restore.get("restore_finetuned", None)
+        if restore_finetuned:
+            self._restore_fintuned(restore_finetuned)
+            print(" > Restored I3D + Graph weights")
+            return
+
+        # Load pre-trained I3D weights
+        restore_i3d = restore.get("restore_i3d", None)
+        if restore_i3d:
+            self._restore_i3d(restore_i3d)
+            print(" > Restored only I3D weights")
+            return
+
+        # Load pre-trained I3D + Graph weights without replacing anything
+        restore_predict = restore.get("restore_predict", None)
+        if restore_predict:
+            self._restore_predict(restore_predict)
+            print(" > Restored the model with strict weights")
+            return
+
+    def _restore_predict(self, path):
+        if path is None:
+            raise TrainingScheduleError('You should pre-train the video model on your training data first')
+
+        weights = torch.load(path, map_location=self.device)['state_dict']
+        new_weights = {}
+        for k, v in weights.items():
+            new_weights[k.replace('module.', '')] = v
+
+        self.load_state_dict(new_weights, strict=True)
+        print(" > Weights {} loaded".format(path))
+
+    def _restore_i3d(self, path):
+        if path is None:
+            raise TrainingScheduleError('You should pre-train the video model on your training data first')
+       
+        weights = torch.load(path, map_location=self.device)['state_dict']
+        new_weights = {}
+        for k, v in weights.items():
+            if not k.startswith('module.fc') and not k.startswith('module.i3D.classifier'):
+                new_weights[k.replace('module.', '')] = v
+        self.load_state_dict(new_weights, strict=False)
+
+    def _restore_fintuned(self, path):
+        if path is None:
+            raise TrainingScheduleError('You should pre-train the video model on your training data first')
+
+        weights = torch.load(path, map_location=self.device)['state_dict']
+        new_weights = {}
+        for k, v in weights.items():
+            # Don't load classifiers (different classes 88 vs 86)
+            if not k.startswith('module.fc'):
+                if not k.startswith('module.i3D.classifier'):
+                    new_weights[k.replace('module.', '')] = v
+
+        self.load_state_dict(new_weights, strict=False)
+        print(" > Weights {} loaded".format(path))
+
+    def freeze_weights(self, module):
+        if module == 'i3d':
+            print(" > Freeze I3D module")
+            for param in self.i3D.parameters():
+                param.requires_grad = False
+        elif module == 'fine_tuned':
+            print(" > Freeze Graph + I3D module, only last FC is training")
+            # Fixed the entire params without the last FC
+            for name, param in self.i3D.named_parameters():
+                if not name.startswith('classifier'):
+                    param.requires_grad = False
+            for param in self.graph_embedding.parameters():
+                param.requires_grad = False
+            for param in self.conv.parameters():
+                param.requires_grad = False
+
+        else:
+            raise NotImplementedError('Unrecognized option, you can freeze either graph module or I3D module')
+        pass
+
+    def _get_i3d_features(self, videos, output_video_features=False):
+        # org_features - [V x 2048 x T / 2 x 14 x 14]
+        _, org_features = self.i3D(videos)
+        # Reduce dimension video_features - [V x 512 x T / 2 x 14 x 14]
+        videos_features = self.conv(org_features)
+        bs, d, t, h, w = videos_features.size()
+        # Get global features
+        videos_features_rs = videos_features.permute(0, 2, 1, 3, 4)  # [V x T / 2 x 512 x h x w]
+        videos_features_rs = videos_features_rs.reshape(-1, d, h, w)  # [V * T / 2 x 512 x h x w]
+        global_features = self.avgpool(videos_features_rs)  # [V * T / 2 x 512 x 1 x 1]
+        global_features = self.dropout(global_features)
+        global_features = global_features.reshape(bs, t, d)  # [V x T / 2 x 512]
+        if output_video_features:
+            return global_features, videos_features
+        else:
+            return global_features
+
+    def flatten(self, x):
+        return [item for sublist in x for item in sublist]
+
diff --git a/i3D/model_lib.py b/i3D/model_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..54027926a3e2029b7b41952ade9ec85850eb1f57
--- /dev/null
+++ b/i3D/model_lib.py
@@ -0,0 +1,1050 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from i3D.resnet3d_xl import Net
+from i3D.nonlocal_helper import Nonlocal
+
+
+class VideoModelCoord(nn.Module):
+    def __init__(self, opt):
+        super(VideoModelCoord, self).__init__()
+        self.nr_boxes = opt.num_boxes
+        self.nr_actions = opt.num_classes
+        self.nr_frames = opt.num_frames // 2
+        self.coord_feature_dim = opt.coord_feature_dim
+
+        self.coord_to_feature = nn.Sequential(
+            nn.Linear(4, self.coord_feature_dim//2, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim//2),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim//2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.spatial_node_fusion = nn.Sequential(
+            nn.Linear(self.coord_feature_dim*2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.box_feature_fusion = nn.Sequential(
+            nn.Linear(self.nr_frames*self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.classifier = nn.Sequential(
+            nn.Linear(self.coord_feature_dim, self.coord_feature_dim),
+            # nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, 512), #self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(512, self.nr_actions)
+        )
+
+        if opt.fine_tune:
+            self.fine_tune(opt.fine_tune)
+
+    def fine_tune(self, restore_path, parameters_to_train=['classifier']):
+        weights = torch.load(restore_path)['state_dict']
+        new_weights = {}
+        #import pdb
+        for k, v in weights.items():
+            if not 'classifier.4' in k:
+                new_weights[k.replace('module.', '')] = v
+        #pdb.set_trace()
+        self.load_state_dict(new_weights, strict=False)
+        print('Num of weights in restore dict {}'.format(len(new_weights.keys())))
+
+        frozen_weights = 0
+        for name, param in self.named_parameters():
+            if not 'classifier.4' in name:
+
+                param.requires_grad = False
+                frozen_weights += 1
+
+            else:
+                print('Training : {}'.format(name))
+        print('Number of frozen weights {}'.format(frozen_weights))
+        assert frozen_weights != 0, 'You are trying to fine tune, but no weights are frozen!!! ' \
+                                    'Check the naming convention of the parameters'
+
+    def forward(self, global_img_input, box_categories, box_input, video_label, is_inference=False):
+        # local_img_tensor is (b, nr_frames, nr_boxes, 3, h, w)
+        # global_img_tensor is (b, nr_frames, 3, h, w)
+        # box_input is (b, nr_frames, nr_boxes, 4)
+
+        b, _, _, _h, _w = global_img_input.size()
+        # global_imgs = global_img_input.view(b*self.nr_frames, 3, _h, _w)
+        # local_imgs = local_img_input.view(b*self.nr_frames*self.nr_boxes, 3, _h, _w)
+
+        box_input = box_input.transpose(2, 1).contiguous()
+        box_input = box_input.view(b*self.nr_boxes*self.nr_frames, 4)
+
+        bf = self.coord_to_feature(box_input)
+        bf = bf.view(b, self.nr_boxes, self.nr_frames, self.coord_feature_dim)
+
+        # spatial message passing (graph)
+        spatial_message = bf.sum(dim=1, keepdim=True)  # (b, 1, self.nr_frames, coord_feature_dim)
+        # message passed should substract itself, and normalize to it as a single feature
+        spatial_message = (spatial_message - bf) / (self.nr_boxes - 1)  # message passed should substract itself
+        bf_and_message = torch.cat([bf, spatial_message], dim=3)  # (b, nr_boxes, nr_frames, 2*coord_feature_dim)
+
+        # (b*nr_boxes*nr_frames, coord_feature_dim)
+        bf_spatial = self.spatial_node_fusion(bf_and_message.view(b*self.nr_boxes*self.nr_frames, -1))
+        bf_spatial = bf_spatial.view(b, self.nr_boxes, self.nr_frames, self.coord_feature_dim)
+
+        bf_temporal_input = bf_spatial.view(b, self.nr_boxes, self.nr_frames*self.coord_feature_dim)
+
+        box_features = self.box_feature_fusion(bf_temporal_input.view(b*self.nr_boxes, -1))  # (b*nr_boxes, coord_feature_dim)
+        box_features = torch.mean(box_features.view(b, self.nr_boxes, -1), dim=1)  # (b, coord_feature_dim)
+        # video_features = torch.cat([global_features, local_features, box_features], dim=1)
+        video_features = box_features
+
+        cls_output = self.classifier(video_features)  # (b, num_classes)
+        return cls_output
+
+class VideoModelCoordLatent(nn.Module):
+    def __init__(self, opt):
+        super(VideoModelCoordLatent, self).__init__()
+        self.nr_boxes = opt.num_boxes
+        self.nr_actions = opt.num_classes
+        self.nr_frames = opt.num_frames // 2
+        self.img_feature_dim = opt.img_feature_dim
+        self.coord_feature_dim = opt.coord_feature_dim
+
+        self.category_embed_layer = nn.Embedding(3, opt.coord_feature_dim // 2, padding_idx=0, scale_grad_by_freq=True)
+
+        self.coord_to_feature = nn.Sequential(
+            nn.Linear(4, self.coord_feature_dim//2, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim//2),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim//2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.coord_category_fusion = nn.Sequential(
+            nn.Linear(self.coord_feature_dim+self.coord_feature_dim//2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+        )
+
+        self.spatial_node_fusion = nn.Sequential(
+            nn.Linear(self.coord_feature_dim*2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.box_feature_fusion = nn.Sequential(
+            nn.Linear(self.nr_frames*self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.classifier = nn.Sequential(
+            nn.Linear(self.coord_feature_dim, self.coord_feature_dim),
+            # nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, 512), #self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(512, self.nr_actions)
+        )
+
+        if opt.fine_tune:
+            self.fine_tune(opt.fine_tune)
+
+    def fine_tune(self, restore_path, parameters_to_train=['classifier']):
+        weights = torch.load(restore_path)['state_dict']
+        new_weights = {}
+        for k, v in weights.items():
+            if not 'classifier.4' in k:
+                new_weights[k.replace('module.', '')] = v
+        self.load_state_dict(new_weights, strict=False)
+        print('Num of weights in restore dict {}'.format(len(new_weights.keys())))
+
+        frozen_weights = 0
+        for name, param in self.named_parameters():
+            if not 'classifier.4' in name:
+                param.requires_grad = False
+                frozen_weights += 1
+            else:
+                print('Training : {}'.format(name))
+        print('Number of frozen weights {}'.format(frozen_weights))
+        assert frozen_weights != 0, 'You are trying to fine tune, but no weights are frozen!!! ' \
+                                    'Check the naming convention of the parameters'
+
+    def forward(self, global_img_input, box_categories, box_input, video_label, is_inference=False):
+        # local_img_tensor is (b, nr_frames, nr_boxes, 3, h, w)
+        # global_img_tensor is (b, nr_frames, 3, h, w)
+        # box_input is (b, nr_frames, nr_boxes, 4)
+
+        b, _, _, _h, _w = global_img_input.size()
+
+        box_input = box_input.transpose(2, 1).contiguous()
+        box_input = box_input.view(b*self.nr_boxes*self.nr_frames, 4)
+
+        box_categories = box_categories.long()
+        box_categories = box_categories.transpose(2, 1).contiguous()
+        box_categories = box_categories.view(b*self.nr_boxes*self.nr_frames)
+        box_category_embeddings = self.category_embed_layer(box_categories)  # (b*nr_b*nr_f, coord_feature_dim//2)
+
+        bf = self.coord_to_feature(box_input)
+        bf = torch.cat([bf, box_category_embeddings], dim=1)  # (b*nr_b*nr_f, coord_feature_dim + coord_feature_dim//2)
+        bf = self.coord_category_fusion(bf)  # (b*nr_b*nr_f, coord_feature_dim)
+        bf = bf.view(b, self.nr_boxes, self.nr_frames, self.coord_feature_dim)
+
+        # spatial message passing (graph)
+        spatial_message = bf.sum(dim=1, keepdim=True)  # (b, 1, self.nr_frames, coord_feature_dim)
+        # message passed should substract itself, and normalize to it as a single feature
+        spatial_message = (spatial_message - bf) / (self.nr_boxes - 1)  # message passed should substract itself
+        bf_and_message = torch.cat([bf, spatial_message], dim=3)  # (b, nr_boxes, nr_frames, 2*coord_feature_dim)
+
+        # (b*nr_boxes*nr_frames, coord_feature_dim)
+        bf_spatial = self.spatial_node_fusion(bf_and_message.view(b*self.nr_boxes*self.nr_frames, -1))
+        bf_spatial = bf_spatial.view(b, self.nr_boxes, self.nr_frames, self.coord_feature_dim)
+
+        bf_temporal_input = bf_spatial.view(b, self.nr_boxes, self.nr_frames*self.coord_feature_dim)
+
+        box_features = self.box_feature_fusion(bf_temporal_input.view(b*self.nr_boxes, -1))  # (b*nr_boxes, coord_feature_dim)
+        box_features = torch.mean(box_features.view(b, self.nr_boxes, -1), dim=1)  # (b, coord_feature_dim)
+        # video_features = torch.cat([global_features, local_features, box_features], dim=1)
+        video_features = box_features
+
+        cls_output = self.classifier(video_features)  # (b, num_classes)
+        return cls_output
+
+class VideoModelCoordLatentNL(nn.Module):
+    def __init__(self, opt):
+        super(VideoModelCoordLatentNL, self).__init__()
+        self.nr_boxes = opt.num_boxes
+        self.nr_actions = opt.num_classes
+        self.nr_frames = opt.num_frames // 2
+        self.img_feature_dim = opt.img_feature_dim
+        self.coord_feature_dim = opt.coord_feature_dim
+
+        self.category_embed_layer = nn.Embedding(3, opt.coord_feature_dim // 2, padding_idx=0, scale_grad_by_freq=True)
+
+        self.coord_to_feature = nn.Sequential(
+            nn.Linear(4, self.coord_feature_dim // 2, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim // 2),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim // 2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.coord_category_fusion = nn.Sequential(
+            nn.Linear(self.coord_feature_dim + self.coord_feature_dim // 2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+        )
+
+        self.spatial_node_fusion = nn.Sequential(
+            nn.Linear(self.coord_feature_dim * 2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.nr_nonlocal_layers = 3
+        self.nonlocal_fusion = []
+        for i in range(self.nr_nonlocal_layers):
+            self.nonlocal_fusion.append(nn.Sequential(
+                Nonlocal(dim=self.coord_feature_dim, dim_inner=self.coord_feature_dim // 2),
+                nn.Conv1d(self.coord_feature_dim, self.coord_feature_dim, kernel_size=1, stride=1, padding=0,
+                          bias=False),
+                nn.BatchNorm1d(self.coord_feature_dim),
+                nn.ReLU()
+            ))
+        self.nonlocal_fusion = nn.ModuleList(self.nonlocal_fusion)
+
+        self.box_feature_fusion = nn.Sequential(
+            nn.Linear(self.nr_frames * self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.classifier = nn.Sequential(
+            nn.Linear(self.coord_feature_dim, self.coord_feature_dim),
+            # nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, 512),  # self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(512, self.nr_actions)
+        )
+
+        if opt.fine_tune:
+            self.fine_tune(opt.fine_tune)
+
+    def train(self, mode=True):  # overriding default train function
+        super(VideoModelCoordLatentNL, self).train(mode)
+        for m in self.modules():  # or self.modules(), if freezing all bn layers
+            if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d):
+                m.eval()
+                # shutdown update in frozen mode
+                m.weight.requires_grad = False
+                m.bias.requires_grad = False
+
+    def fine_tune(self, restore_path, parameters_to_train=['classifier']):
+        weights = torch.load(restore_path)['state_dict']
+        new_weights = {}
+        # import pdb
+        for k, v in weights.items():
+            if not 'classifier.4' in k:
+                new_weights[k.replace('module.', '')] = v
+        # pdb.set_trace()
+        self.load_state_dict(new_weights, strict=False)
+        print('Num of weights in restore dict {}'.format(len(new_weights.keys())))
+
+        frozen_weights = 0
+        for name, param in self.named_parameters():
+            if not 'classifier.4' in name:
+
+                param.requires_grad = False
+                frozen_weights += 1
+
+            else:
+                print('Training : {}'.format(name))
+        print('Number of frozen weights {}'.format(frozen_weights))
+        assert frozen_weights != 0, 'You are trying to fine tune, but no weights are frozen!!! ' \
+                                    'Check the naming convention of the parameters'
+
+    def forward(self, global_img_input, box_categories, box_input, video_label, is_inference=False):
+        # local_img_tensor is (b, nr_frames, nr_boxes, 3, h, w)
+        # global_img_tensor is (b, nr_frames, 3, h, w)
+        # box_input is (b, nr_frames, nr_boxes, 4)
+
+        b, _, _, _h, _w = global_img_input.size()
+
+        box_input = box_input.transpose(2, 1).contiguous()
+        box_input = box_input.view(b * self.nr_boxes * self.nr_frames, 4)
+
+        box_categories = box_categories.long()
+        box_categories = box_categories.transpose(2, 1).contiguous()
+        box_categories = box_categories.view(b * self.nr_boxes * self.nr_frames)
+        box_category_embeddings = self.category_embed_layer(box_categories)  # (b*nr_b*nr_f, coord_feature_dim//2)
+
+        bf = self.coord_to_feature(box_input)
+        bf = torch.cat([bf, box_category_embeddings], dim=1)  # (b*nr_b*nr_f, coord_feature_dim + coord_feature_dim//2)
+        bf = self.coord_category_fusion(bf)  # (b*nr_b*nr_f, coord_feature_dim)
+        bf = bf.view(b, self.nr_boxes, self.nr_frames, self.coord_feature_dim)
+
+        # spatial message passing (graph)
+        spatial_message = bf.sum(dim=1, keepdim=True)  # (b, 1, self.nr_frames, coord_feature_dim)
+        # message passed should substract itself, and normalize to it as a single feature
+
+        spatial_message = (spatial_message - bf) / (self.nr_boxes - 1)  # message passed should substract itself
+        bf_and_message = torch.cat([bf, spatial_message], dim=3)  # (b, nr_boxes, nr_frames, 2*coord_feature_dim)
+
+        # (b*nr_boxes*nr_frames, coord_feature_dim)
+        bf_spatial = self.spatial_node_fusion(bf_and_message.view(b * self.nr_boxes * self.nr_frames, -1))
+        bf_spatial = bf_spatial.view(b, self.nr_boxes, self.nr_frames, self.coord_feature_dim)
+
+        bf_temporal_input = bf_spatial.view(b, self.nr_boxes, self.nr_frames * self.coord_feature_dim)
+
+        bf_nonlocal = self.box_feature_fusion(
+            bf_temporal_input.view(b * self.nr_boxes, -1))  # (b*nr_boxes, coord_feature_dim)
+        bf_nonlocal = bf_nonlocal.view(b, self.nr_boxes, self.coord_feature_dim).permute(0, 2,
+                                                                                         1).contiguous()  # (N, C, NB)
+        for i in range(self.nr_nonlocal_layers):
+            bf_nonlocal = self.nonlocal_fusion[i](bf_nonlocal)
+
+        box_features = torch.mean(bf_nonlocal, dim=2)  # (b, coord_feature_dim)
+
+        # video_features = torch.cat([global_features, local_features, box_features], dim=1)
+        video_features = box_features
+
+        cls_output = self.classifier(video_features)  # (b, num_classes)
+        return cls_output
+
+class VideoModelGlobalCoordLatent(nn.Module):
+    """
+    This model contains only global pooling without any graph.
+    """
+
+    def __init__(self, opt,
+                 ):
+        super(VideoModelGlobalCoordLatent, self).__init__()
+
+        self.nr_boxes = opt.num_boxes
+        self.nr_actions = opt.num_classes
+        self.nr_frames = opt.num_frames
+        self.img_feature_dim = opt.img_feature_dim
+        self.coord_feature_dim = opt.coord_feature_dim
+        self.i3D = Net(self.nr_actions, extract_features=True, loss_type='softmax')
+        self.dropout = nn.Dropout(0.3)
+        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
+        self.conv = nn.Conv3d(2048, 512, kernel_size=(1, 1, 1), stride=1)
+
+        self.category_embed_layer = nn.Embedding(3, opt.coord_feature_dim // 2, padding_idx=0, scale_grad_by_freq=True)
+
+        self.c_coord_category_fusion = nn.Sequential(
+            nn.Linear(self.coord_feature_dim+self.coord_feature_dim//2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+        )
+
+        self.c_coord_to_feature = nn.Sequential(
+            nn.Linear(4, self.coord_feature_dim // 2, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim // 2),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim // 2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.c_spatial_node_fusion = nn.Sequential(
+            nn.Linear(self.coord_feature_dim * 2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.c_box_feature_fusion = nn.Sequential(
+            nn.Linear((self.nr_frames // 2) * self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.classifier = nn.Sequential(
+            nn.Linear(self.coord_feature_dim + 2*self.img_feature_dim, self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, 512),
+            nn.ReLU(inplace=True),
+            nn.Linear(512, self.nr_actions)
+        )
+        if opt.fine_tune:
+            self.fine_tune(opt.fine_tune)
+        if opt.restore_i3d:
+            self.restore_i3d(opt.restore_i3d)
+        if opt.restore_custom:
+            self.restore_custom(opt.restore_custom)
+
+    def train(self, mode=True):  # overriding default train function
+        super(VideoModelGlobalCoordLatent, self).train(mode)
+        for m in self.modules():  # or self.modules(), if freezing all bn layers
+            if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d):
+                m.eval()
+                # shutdown update in frozen mode
+                m.weight.requires_grad = False
+                m.bias.requires_grad = False
+
+    def restore_custom(self, restore_path):
+        print("restoring path {}".format(restore_path))
+        weights = torch.load(restore_path)
+
+        ks = list(weights.keys())
+        print('\n\n BEFORE', weights[ks[0]][0,0,0])
+        new_weights = {}
+        # import pdb
+        for k, v in weights.items():
+            new_weights[k.replace('module.', '')] = v
+        self.load_state_dict(new_weights, strict=False)
+        print('\n\n AFTER', self.state_dict()[ks[0]][0,0, 0])
+        print('Num of weights in restore dict {}'.format(len(new_weights.keys())))
+
+        frozen_weights = 0
+        for name, param in self.named_parameters():
+            if not name.startswith('classifier') :
+                param.requires_grad = False
+                frozen_weights += 1
+            else:
+                print('Training : {}'.format(name))
+        print('Number of frozen weights {}'.format(frozen_weights))
+        assert frozen_weights != 0, 'You are trying to fine tune, but no weights are frozen!!! ' \
+                                    'Check the naming convention of the parameters'
+
+
+    def restore_i3d(self, restore_path, parameters_to_train=['classifier']):
+        weights = torch.load(restore_path)['state_dict']
+        new_weights = {}
+        # import pdb
+        for k, v in weights.items():
+            if 'i3D' in k :
+                new_weights[k.replace('module.', '')] = v
+        # pdb.set_trace()
+        self.load_state_dict(new_weights, strict=False)
+        print('Num of weights in restore dict {}'.format(len(new_weights.keys())))
+
+        for m in self.i3D.modules():
+            if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d):
+                m.eval()
+                # shutdown update in frozen mode
+                m.weight.requires_grad = False
+                m.bias.requires_grad = False
+
+        frozen_weights = 0
+        for name, param in self.named_parameters():
+            if 'i3D' in name:
+                param.requires_grad = False
+                frozen_weights += 1
+            else:
+                print('Training : {}'.format(name))
+        print('Number of frozen weights {}'.format(frozen_weights))
+        assert frozen_weights != 0, 'You are trying to fine tune, but no weights are frozen!!! ' \
+                                    'Check the naming convention of the parameters'
+
+    def fine_tune(self, restore_path, parameters_to_train=['classifier']):
+        weights = torch.load(restore_path)['state_dict']
+        new_weights = {}
+        # import pdb
+        for k, v in weights.items():
+            if not 'classifier.4' in k and 'i3D.classifier':
+                new_weights[k.replace('module.', '')] = v
+        # pdb.set_trace()
+        self.load_state_dict(new_weights, strict=False)
+        print('Num of weights in restore dict {}'.format(len(new_weights.keys())))
+
+        frozen_weights = 0
+        for name, param in self.named_parameters():
+            if not 'classifier.4' in name:
+                param.requires_grad = False
+                frozen_weights += 1
+            else:
+                print('Training : {}'.format(name))
+        print('Number of frozen weights {}'.format(frozen_weights))
+        assert frozen_weights != 0, 'You are trying to fine tune, but no weights are frozen!!! ' \
+                                    'Check the naming convention of the parameters'
+
+    def forward(self, global_img_input, box_categories, box_input, video_label, is_inference=False):
+
+        """
+        V: num of videos
+        T: num of frames
+        P: num of proposals
+        :param videos: [V x 3 x T x 224 x 224]
+        :param proposals_t: [V x T] List of BoxList (size of num_boxes each)
+        :return:
+        """
+
+        # org_features - [V x 2048 x T / 2 x 14 x 14]
+        bs, _, _, _, _ = global_img_input.shape
+        y_i3d, org_features = self.i3D(global_img_input)
+        # Reduce dimension video_features - [V x 512 x T / 2 x 14 x 14]
+        videos_features = self.conv(org_features)
+        b = bs
+
+        box_input = box_input.transpose(2, 1).contiguous()
+        box_input = box_input.view(b * self.nr_boxes * (self.nr_frames//2), 4)
+
+        box_categories = box_categories.long()
+        box_categories = box_categories.transpose(2, 1).contiguous()
+        box_categories = box_categories.view(b * self.nr_boxes * (self.nr_frames // 2))
+        box_category_embeddings = self.category_embed_layer(box_categories)  # (b*nr_b*nr_f, coord_feature_dim//2)
+
+        bf = self.c_coord_to_feature(box_input)
+        bf = torch.cat([bf, box_category_embeddings], dim=1)  # (b*nr_b*nr_f, coord_feature_dim + coord_feature_dim//2)
+        bf = self.c_coord_category_fusion(bf)  # (b*nr_b*nr_f, coord_feature_dim)
+
+        bf = bf.view(b, self.nr_boxes, self.nr_frames // 2, self.coord_feature_dim)
+
+        # spatial message passing (graph)
+        spatial_message = bf.sum(dim=1, keepdim=True)  # (b, 1, self.nr_frames, coord_feature_dim)
+        # message passed should substract itself, and normalize to it as a single feature
+        spatial_message = (spatial_message - bf) / (self.nr_boxes - 1)  # message passed should substract itself
+
+        bf_message_gf = torch.cat([bf, spatial_message], dim=3)  # (b, nr_boxes, nr_frames, 2*coord_feature_dim)
+
+        # (b*nr_boxes*nr_frames, coord_feature_dim)
+        bf_spatial = self.c_spatial_node_fusion(bf_message_gf.view(b * self.nr_boxes * (self.nr_frames // 2), -1))
+        bf_spatial = bf_spatial.view(b, self.nr_boxes, self.nr_frames // 2, self.coord_feature_dim)
+
+        bf_temporal_input = bf_spatial.view(b, self.nr_boxes, (self.nr_frames // 2) * self.coord_feature_dim)
+
+        box_features = self.c_box_feature_fusion(
+            bf_temporal_input.view(b * self.nr_boxes, -1))  # (b*nr_boxes, img_feature_dim)
+        coord_ft = torch.mean(box_features.view(b, self.nr_boxes, -1), dim=1)  # (b, coord_feature_dim)
+        # video_features = torch.cat([global_features, local_features, box_features], dim=1)
+        # _gf = self.global_new_fc(_gf)
+        _gf = videos_features.mean(-1).mean(-1).view(b, (self.nr_frames//2), 2*self.img_feature_dim)
+        _gf = _gf.mean(1)
+        video_features = torch.cat([_gf.view(b, -1), coord_ft], dim=-1)
+
+        cls_output = self.classifier(video_features)  # (b, num_classes)
+        return cls_output
+
+class VideoModelGlobalCoordLatentNL(nn.Module):
+    """
+    This model contains only global pooling without any graph.
+    """
+
+    def __init__(self, base_net, opt,
+                 ):
+        super(VideoModelGlobalCoordLatentNL, self).__init__()
+
+        self.nr_boxes = opt.num_boxes
+        self.nr_actions = opt.num_classes
+        self.nr_frames = opt.num_frames
+        self.img_feature_dim = opt.img_feature_dim
+        self.coord_feature_dim = opt.coord_feature_dim
+        self.i3D = Net(self.nr_actions, extract_features=True, loss_type='softmax')
+        self.dropout = nn.Dropout(0.3)
+        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
+        self.conv = nn.Conv3d(2048, 512, kernel_size=(1, 1, 1), stride=1)
+
+
+        self.category_embed_layer = nn.Embedding(3, opt.coord_feature_dim // 2, padding_idx=0, scale_grad_by_freq=True)
+
+        self.c_coord_category_fusion = nn.Sequential(
+            nn.Linear(self.coord_feature_dim+self.coord_feature_dim//2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+        )
+
+        self.c_coord_to_feature = nn.Sequential(
+            nn.Linear(4, self.coord_feature_dim // 2, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim // 2),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim // 2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.c_spatial_node_fusion = nn.Sequential(
+            nn.Linear(self.coord_feature_dim * 2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.nr_nonlocal_layers = 3
+        self.c_nonlocal_fusion = []
+        for i in range(self.nr_nonlocal_layers):
+            self.c_nonlocal_fusion.append(nn.Sequential(
+                    Nonlocal(dim=self.coord_feature_dim, dim_inner=self.coord_feature_dim // 2),
+                    nn.Conv1d(self.coord_feature_dim, self.coord_feature_dim, kernel_size=1, stride=1, padding=0, bias=False),
+                    nn.BatchNorm1d(self.coord_feature_dim),
+                    nn.ReLU()
+            ))
+        self.c_nonlocal_fusion = nn.ModuleList(self.c_nonlocal_fusion)
+
+        self.c_box_feature_fusion = nn.Sequential(
+            nn.Linear((self.nr_frames // 2) * self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.classifier = nn.Sequential(
+            nn.Linear(self.coord_feature_dim + 2*self.img_feature_dim, self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, 512),
+            nn.ReLU(inplace=True),
+            nn.Linear(512, self.nr_actions)
+        )
+        if opt.fine_tune:
+            self.fine_tune(opt.fine_tune)
+        if opt.restore_i3d:
+            self.restore_i3d(opt.restore_i3d)
+
+        if opt.restore_custom:
+            self.restore_custom(opt.restore_custom)
+
+    def restore_custom(self, restore_path):
+        print("restoring path {}".format(restore_path))
+        weights = torch.load(restore_path)
+        ks = list(weights.keys())
+        print('\n\n BEFORE', weights[ks[0]][0,0,0])
+        new_weights = {}
+        # import pdb
+        for k, v in weights.items():
+            new_weights[k.replace('module.', '')] = v
+        self.load_state_dict(new_weights, strict=False)
+        print('\n\n AFTER', self.state_dict()[ks[0]][0,0, 0])
+        print('Num of weights in restore dict {}'.format(len(new_weights.keys())))
+
+        frozen_weights = 0
+        for name, param in self.named_parameters():
+            if not name.startswith('classifier') :
+                param.requires_grad = False
+                frozen_weights += 1
+            else:
+                print('Training : {}'.format(name))
+        print('Number of frozen weights {}'.format(frozen_weights))
+        assert frozen_weights != 0, 'You are trying to fine tune, but no weights are frozen!!! ' \
+                                    'Check the naming convention of the parameters'
+
+
+
+    def restore_i3d(self, restore_path, parameters_to_train=['classifier']):
+        weights = torch.load(restore_path)['state_dict']
+        new_weights = {}
+        # import pdb
+        for k, v in weights.items():
+            if 'i3D' in k  or k.startswith('conv.'):
+                new_weights[k.replace('module.', '')] = v
+        # pdb.set_trace()
+        self.load_state_dict(new_weights, strict=False)
+        print('Num of weights in restore dict {}'.format(len(new_weights.keys())))
+
+        for m in self.i3D.modules():
+            if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d):
+                m.eval()
+                # shutdown update in frozen mode
+                m.weight.requires_grad = False
+                m.bias.requires_grad = False
+
+        frozen_weights = 0
+        for name, param in self.named_parameters():
+            if 'i3D' in name or k.startswith('conv.') :
+                param.requires_grad = False
+                frozen_weights += 1
+            else:
+                print('Training : {}'.format(name))
+        print('Number of frozen weights {}'.format(frozen_weights))
+        assert frozen_weights != 0, 'You are trying to fine tune, but no weights are frozen!!! ' \
+                                    'Check the naming convention of the parameters'
+
+    def train(self, mode=True):  # overriding default train function
+        super(VideoModelGlobalCoordLatentNL, self).train(mode)
+        for m in self.i3D.modules():  # or self.modules(), if freezing all bn layers
+            if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d):
+                m.eval()
+                # shutdown update in frozen mode
+                m.weight.requires_grad = False
+                m.bias.requires_grad = False
+
+    def fine_tune(self, restore_path, parameters_to_train=['classifier']):
+        weights = torch.load(restore_path)['state_dict']
+        new_weights = {}
+        import pdb
+        for k, v in weights.items():
+            if not 'classifier.4' in k and 'i3D.classifier' not in k:
+                new_weights[k.replace('module.', '')] = v
+        pdb.set_trace()
+        self.load_state_dict(new_weights, strict=False)
+        print('Num of weights in restore dict {}'.format(len(new_weights.keys())))
+
+        frozen_weights = 0
+        for name, param in self.named_parameters():
+            if not 'classifier.4' in name:
+                param.requires_grad = False
+                frozen_weights += 1
+            else:
+                print('Training : {}'.format(name))
+        print('Number of frozen weights {}'.format(frozen_weights))
+        assert frozen_weights != 0, 'You are trying to fine tune, but no weights are frozen!!! ' \
+                                    'Check the naming convention of the parameters'
+
+    def forward(self, global_img_input, box_categories, box_input, video_label, is_inference=False):
+
+        """
+        V: num of videos
+        T: num of frames
+        P: num of proposals
+        :param videos: [V x 3 x T x 224 x 224]
+        :param proposals_t: [V x T] List of BoxList (size of num_boxes each)
+        :return:
+        """
+
+        # org_features - [V x 2048 x T / 2 x 14 x 14]
+        bs, _, _, _, _ = global_img_input.shape
+        y_i3d, org_features = self.i3D(global_img_input)
+        # Reduce dimension video_features - [V x 512 x T / 2 x 14 x 14]
+        videos_features = self.conv(org_features)
+        b = bs
+
+        box_input = box_input.transpose(2, 1).contiguous()
+        box_input = box_input.view(b * self.nr_boxes * (self.nr_frames//2), 4)
+
+        box_categories = box_categories.long()
+        box_categories = box_categories.transpose(2, 1).contiguous()
+        box_categories = box_categories.view(b * self.nr_boxes * (self.nr_frames // 2))
+        box_category_embeddings = self.category_embed_layer(box_categories)  # (b*nr_b*nr_f, coord_feature_dim//2)
+
+        bf = self.c_coord_to_feature(box_input)
+        bf = torch.cat([bf, box_category_embeddings], dim=1)  # (b*nr_b*nr_f, coord_feature_dim + coord_feature_dim//2)
+        bf = self.c_coord_category_fusion(bf)  # (b*nr_b*nr_f, coord_feature_dim)
+
+        bf = bf.view(b, self.nr_boxes, self.nr_frames // 2, self.coord_feature_dim)
+
+        # spatial message passing (graph)
+        spatial_message = bf.sum(dim=1, keepdim=True)  # (b, 1, self.nr_frames, coord_feature_dim)
+        # message passed should substract itself, and normalize to it as a single feature
+        spatial_message = (spatial_message - bf) / (self.nr_boxes - 1)  # message passed should substract itself
+
+        bf_message_gf = torch.cat([bf, spatial_message], dim=3)  # (b, nr_boxes, nr_frames, 2*coord_feature_dim)
+
+        # (b*nr_boxes*nr_frames, coord_feature_dim)
+        bf_spatial = self.c_spatial_node_fusion(bf_message_gf.view(b * self.nr_boxes * (self.nr_frames // 2), -1))
+        bf_spatial = bf_spatial.view(b, self.nr_boxes, self.nr_frames // 2, self.coord_feature_dim)
+
+        bf_temporal_input = bf_spatial.view(b, self.nr_boxes, (self.nr_frames // 2) * self.coord_feature_dim)
+
+        bf_nonlocal = self.c_box_feature_fusion(
+            bf_temporal_input.view(b * self.nr_boxes, -1))  # (b*nr_boxes, img_feature_dim)
+
+        bf_nonlocal = bf_nonlocal.view(b, self.nr_boxes, self.coord_feature_dim).permute(0, 2, 1).contiguous()  # (N, C, NB)
+        for i in range(self.nr_nonlocal_layers):
+            bf_nonlocal = self.c_nonlocal_fusion[i](bf_nonlocal)
+
+        coord_ft = torch.mean(bf_nonlocal, dim=2)  # (b, coord_feature_dim)
+
+        # video_features = torch.cat([global_features, local_features, box_features], dim=1)
+        _gf = videos_features.mean(-1).mean(-1).view(b, (self.nr_frames//2), 2*self.img_feature_dim)
+        _gf = _gf.mean(1)
+        video_features = torch.cat([_gf.view(b, -1), coord_ft], dim=-1)
+
+        cls_output = self.classifier(video_features)  # (b, num_classes)
+        return cls_output
+
+class VideoGlobalModel(nn.Module):
+    """
+    This model contains only global pooling without any graph.
+    """
+
+    def __init__(self, opt,
+                 ):
+        super(VideoGlobalModel, self).__init__()
+
+        self.nr_boxes = opt.num_boxes
+        self.nr_actions = opt.num_classes
+        self.nr_frames = opt.num_frames
+        self.img_feature_dim = opt.img_feature_dim
+        self.coord_feature_dim = opt.coord_feature_dim
+        self.i3D = Net(self.nr_actions, extract_features=True, loss_type='softmax')
+        self.dropout = nn.Dropout(0.3)
+        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
+        self.conv = nn.Conv3d(2048, 512, kernel_size=(1, 1, 1), stride=1)
+        self.fc = nn.Linear(512, self.nr_actions)
+        self.crit = nn.CrossEntropyLoss()
+
+        if opt.fine_tune:
+            self.fine_tune(opt.fine_tune)
+
+    def fine_tune(self, restore_path, parameters_to_train=['classifier']):
+        weights = torch.load(restore_path)['state_dict']
+        new_weights = {}
+        for k, v in weights.items():
+            if not 'fc' in k and not 'classifier' in k:
+                new_weights[k.replace('module.', '')] = v
+        self.load_state_dict(new_weights, strict=False)
+        print('Num of weights in restore dict {}'.format(len(new_weights.keys())))
+
+        frozen_weights = 0
+        for name, param in self.named_parameters():
+            if not 'fc' in name:
+                param.requires_grad = False
+                frozen_weights += 1
+            else:
+                print('Training : {}'.format(name))
+        print('Number of frozen weights {}'.format(frozen_weights))
+        assert frozen_weights != 0, 'You are trying to fine tune, but no weights are frozen!!! ' \
+                                    'Check the naming convention of the parameters'
+
+    def forward(self, global_img_input, local_img_input, box_input, video_label, is_inference=False):
+        """
+        V: num of videos
+        T: num of frames
+        P: num of proposals
+        :param videos: [V x 3 x T x 224 x 224]
+        :param proposals_t: [V x T] List of BoxList (size of num_boxes each)
+        :return:
+        """
+
+        # org_features - [V x 2048 x T / 2 x 14 x 14]
+        y_i3d, org_features = self.i3D(global_img_input)
+        # Reduce dimension video_features - [V x 512 x T / 2 x 14 x 14]
+        videos_features = self.conv(org_features)
+
+        # Get global features - [V x 512]
+        global_features = self.avgpool(videos_features).squeeze()
+        global_features = self.dropout(global_features)
+
+        cls_output = self.fc(global_features)
+        return cls_output
+
+class VideoModelGlobalCoord(nn.Module):
+    """
+    This model contains only global pooling without any graph.
+    """
+
+    def __init__(self, opt):
+        super(VideoModelGlobalCoord, self).__init__()
+
+        self.nr_boxes = opt.num_boxes
+        self.nr_actions = opt.num_classes
+        self.nr_frames = opt.num_frames
+        self.img_feature_dim = opt.img_feature_dim
+        self.coord_feature_dim = opt.coord_feature_dim
+        self.i3D = Net(self.nr_actions, extract_features=True, loss_type='softmax')
+        self.dropout = nn.Dropout(0.3)
+        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
+        self.conv = nn.Conv3d(2048, 256, kernel_size=(1, 1, 1), stride=1)
+
+
+        self.global_new_fc = nn.Sequential(
+            nn.Linear(256, self.img_feature_dim, bias=False),
+            nn.BatchNorm1d(self.img_feature_dim),
+            nn.ReLU(inplace=True)
+        )
+
+
+        self.c_coord_to_feature = nn.Sequential(
+            nn.Linear(4, self.coord_feature_dim // 2, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim // 2),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim // 2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.c_spatial_node_fusion = nn.Sequential(
+            nn.Linear(self.coord_feature_dim * 2, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.c_box_feature_fusion = nn.Sequential(
+            nn.Linear((self.nr_frames // 2) * self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, self.coord_feature_dim, bias=False),
+            nn.BatchNorm1d(self.coord_feature_dim),
+            nn.ReLU()
+        )
+
+        self.classifier = nn.Sequential(
+            nn.Linear(self.coord_feature_dim + self.img_feature_dim, self.coord_feature_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.coord_feature_dim, 512),
+            nn.ReLU(inplace=True),
+            nn.Linear(512, self.nr_actions)
+        )
+        if opt.fine_tune:
+            self.fine_tune(opt.fine_tune)
+        if opt.restore_i3d:
+            self.restore_i3d(opt.restore_i3d)
+
+    def train(self, mode=True):  # overriding default train function
+        super(VideoModelGlobalCoord, self).train(mode)
+        for m in self.i3D.modules():  # or self.modules(), if freezing all bn layers
+            if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d):
+                m.eval()
+                # shutdown update in frozen mode
+                m.weight.requires_grad = False
+                m.bias.requires_grad = False
+
+    def restore_i3d(self, restore_path, parameters_to_train=['classifier']):
+        weights = torch.load(restore_path)['state_dict']
+        new_weights = {}
+        # import pdb
+        for k, v in weights.items():
+            if 'i3D' in k :
+                new_weights[k.replace('module.', '')] = v
+        # pdb.set_trace()
+        self.load_state_dict(new_weights, strict=False)
+        print('Num of weights in restore dict {}'.format(len(new_weights.keys())))
+
+        frozen_weights = 0
+        for name, param in self.named_parameters():
+            if 'i3D' in name:
+                param.requires_grad = False
+                frozen_weights += 1
+            else:
+                print('Training : {}'.format(name))
+        print('Number of frozen weights {}'.format(frozen_weights))
+        assert frozen_weights != 0, 'You are trying to fine tune, but no weights are frozen!!! ' \
+                                    'Check the naming convention of the parameters'
+
+    def fine_tune(self, restore_path, parameters_to_train=['classifier']):
+        weights = torch.load(restore_path)['state_dict']
+        new_weights = {}
+        # import pdb
+        for k, v in weights.items():
+            if not 'classifier.4' in k and 'i3D.classifier':
+                new_weights[k.replace('module.', '')] = v
+        # pdb.set_trace()
+        self.load_state_dict(new_weights, strict=False)
+        print('Num of weights in restore dict {}'.format(len(new_weights.keys())))
+
+        frozen_weights = 0
+        for name, param in self.named_parameters():
+            if not 'classifier.4' in name:
+                param.requires_grad = False
+                frozen_weights += 1
+            else:
+                print('Training : {}'.format(name))
+        print('Number of frozen weights {}'.format(frozen_weights))
+        assert frozen_weights != 0, 'You are trying to fine tune, but no weights are frozen!!! ' \
+                                    'Check the naming convention of the parameters'
+
+    def forward(self, global_img_input, box_categories, box_input, video_label, is_inference=False):
+
+        """
+        V: num of videos
+        T: num of frames
+        P: num of proposals
+        :param videos: [V x 3 x T x 224 x 224]
+        :param proposals_t: [V x T] List of BoxList (size of num_boxes each)
+        :return:
+        """
+
+        # org_features - [V x 2048 x T / 2 x 14 x 14]
+        bs, _, _, _, _ = global_img_input.shape
+        y_i3d, org_features = self.i3D(global_img_input)
+        # Reduce dimension video_features - [V x 512 x T / 2 x 14 x 14]
+        videos_features = self.conv(org_features)
+        b = bs
+
+        box_input = box_input.transpose(2, 1).contiguous()
+        box_input = box_input.view(b * self.nr_boxes * (self.nr_frames//2), 4)
+
+        bf = self.c_coord_to_feature(box_input)
+        bf = bf.view(b, self.nr_boxes, self.nr_frames // 2, self.coord_feature_dim)
+
+        # spatial message passing (graph)
+        spatial_message = bf.sum(dim=1, keepdim=True)  # (b, 1, self.nr_frames, coord_feature_dim)
+        # message passed should substract itself, and normalize to it as a single feature
+        spatial_message = (spatial_message - bf) / (self.nr_boxes - 1)  # message passed should substract itself
+
+        bf_message_gf = torch.cat([bf, spatial_message], dim=3)  # (b, nr_boxes, nr_frames, 2*coord_feature_dim)
+
+        # (b*nr_boxes*nr_frames, coord_feature_dim)
+        bf_spatial = self.c_spatial_node_fusion(bf_message_gf.view(b * self.nr_boxes * (self.nr_frames // 2), -1))
+        bf_spatial = bf_spatial.view(b, self.nr_boxes, self.nr_frames // 2, self.coord_feature_dim)
+
+        bf_temporal_input = bf_spatial.view(b, self.nr_boxes, (self.nr_frames // 2) * self.coord_feature_dim)
+
+        box_features = self.c_box_feature_fusion(
+            bf_temporal_input.view(b * self.nr_boxes, -1))  # (b*nr_boxes, img_feature_dim)
+        coord_ft = torch.mean(box_features.view(b, self.nr_boxes, -1), dim=1)  # (b, coord_feature_dim)
+        # video_features = torch.cat([global_features, local_features, box_features], dim=1)
+        _gf = videos_features.mean(-1).mean(-1).view(b*(self.nr_frames//2), self.img_feature_dim)
+        _gf = self.global_new_fc(_gf)
+        _gf = _gf.view(b, self.nr_frames // 2, self.img_feature_dim).mean(1)
+        video_features = torch.cat([_gf.view(b, -1), coord_ft], dim=-1)
+
+        cls_output = self.classifier(video_features)  # (b, num_classes)
+        return cls_output
diff --git a/i3D/nonlocal_helper.py b/i3D/nonlocal_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7afe2438f00be3e211a638a7ed11cd836ae3c8a
--- /dev/null
+++ b/i3D/nonlocal_helper.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+"""Non-local helper"""
+
+import torch
+import torch.nn as nn
+
+
+class Nonlocal(nn.Module):
+    """
+    Builds Non-local Neural Networks as a generic family of building
+    blocks for capturing long-range dependencies. Non-local Network
+    computes the response at a position as a weighted sum of the
+    features at all positions. This building block can be plugged into
+    many computer vision architectures.
+    More details in the paper: https://arxiv.org/pdf/1711.07971.pdf
+    """
+
+    def __init__(
+        self,
+        dim,
+        dim_inner,
+        pool_size=None,
+        instantiation="softmax",
+        norm_type="layernorm",
+        zero_init_final_conv=True,
+        zero_init_final_norm=False,
+        norm_eps=1e-5,
+        norm_momentum=0.1,
+    ):
+        """
+        Args:
+            dim (int): number of dimension for the input.
+            dim_inner (int): number of dimension inside of the Non-local block.
+            pool_size (list): the kernel size of spatial temporal pooling,
+                temporal pool kernel size, spatial pool kernel size, spatial
+                pool kernel size in order. By default pool_size is None,
+                then there would be no pooling used.
+            instantiation (string): supports two different instantiation method:
+                "dot_product": normalizing correlation matrix with L2.
+                "softmax": normalizing correlation matrix with Softmax.
+            norm_type (string): support BatchNorm and LayerNorm for
+                normalization.
+                "batchnorm": using BatchNorm for normalization.
+                "layernorm": using LayerNorm for normalization.
+                "none": not using any normalization.
+            zero_init_final_conv (bool): If true, zero initializing the final
+                convolution of the Non-local block.
+            zero_init_final_norm (bool):
+                If true, zero initializing the final batch norm of the Non-local
+                block.
+        """
+        super(Nonlocal, self).__init__()
+        self.dim = dim
+        self.dim_inner = dim_inner
+        self.pool_size = pool_size
+        self.instantiation = instantiation
+        self.norm_type = norm_type
+        self.use_pool = (
+            False
+            if pool_size is None
+            else any((size > 1 for size in pool_size))
+        )
+        self.norm_eps = norm_eps
+        self.norm_momentum = norm_momentum
+        self._construct_nonlocal(zero_init_final_conv, zero_init_final_norm)
+
+    def _construct_nonlocal(self, zero_init_final_conv, zero_init_final_norm):
+        # Three convolution heads: theta, phi, and g.
+        self.conv_theta = nn.Conv1d(
+            self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0
+        )
+        self.conv_phi = nn.Conv1d(
+            self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0
+        )
+        self.conv_g = nn.Conv1d(
+            self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0
+        )
+
+        # Final convolution output.
+        self.conv_out = nn.Conv1d(
+            self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0
+        )
+        # Zero initializing the final convolution output.
+        self.conv_out.zero_init = zero_init_final_conv
+
+        if self.norm_type == "batchnorm":
+            self.bn = nn.BatchNorm1d(
+                self.dim, eps=self.norm_eps, momentum=self.norm_momentum
+            )
+            # Zero initializing the final bn.
+            self.bn.transform_final_bn = zero_init_final_norm
+        elif self.norm_type == "layernorm":
+            # In Caffe2 the LayerNorm op does not contain the scale an bias
+            # terms described in the paper:
+            # https://caffe2.ai/docs/operators-catalogue.html#layernorm
+            # Builds LayerNorm as GroupNorm with one single group.
+            # Setting Affine to false to align with Caffe2.
+            self.ln = nn.GroupNorm(1, self.dim, eps=self.norm_eps, affine=False)
+        elif self.norm_type == "none":
+            # Does not use any norm.
+            pass
+        else:
+            raise NotImplementedError(
+                "Norm type {} is not supported".format(self.norm_type)
+            )
+
+        # Optional to add the spatial-temporal pooling.
+        if self.use_pool:
+            self.pool = nn.MaxPool1d(
+                kernel_size=self.pool_size,
+                stride=self.pool_size,
+                padding=[0, 0, 0],
+            )
+
+    def forward(self, x):
+        x_identity = x
+        N, C, NB = x.size()
+
+        theta = self.conv_theta(x)
+
+        # Perform temporal-spatial pooling to reduce the computation.
+        if self.use_pool:
+            x = self.pool(x)
+
+        phi = self.conv_phi(x)
+        g = self.conv_g(x)
+
+        theta = theta.view(N, self.dim_inner, -1)
+        phi = phi.view(N, self.dim_inner, -1)
+        g = g.view(N, self.dim_inner, -1)
+
+        # (N, C, NB) * (N, C, NB) => (N, NB, NB).
+        theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi))
+        # For original Non-local paper, there are two main ways to normalize
+        # the affinity tensor:
+        #   1) Softmax normalization (norm on exp).
+        #   2) dot_product normalization.
+        if self.instantiation == "softmax":
+            # Normalizing the affinity tensor theta_phi before softmax.
+            theta_phi = theta_phi * (self.dim_inner ** -0.5)
+            theta_phi = nn.functional.softmax(theta_phi, dim=2)
+        elif self.instantiation == "dot_product":
+            spatial_temporal_dim = theta_phi.shape[2]
+            theta_phi = theta_phi / spatial_temporal_dim
+        else:
+            raise NotImplementedError(
+                "Unknown norm type {}".format(self.instantiation)
+            )
+
+        # (N, NB, NB) * (N, C, NB) => (N, C, NB).
+        theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g))
+
+        # (N, C, NB) => (N, C, NB).
+        theta_phi_g = theta_phi_g.view(N, self.dim_inner, NB)
+
+        p = self.conv_out(theta_phi_g)
+        if self.norm_type == "batchnorm":
+            p = self.bn(p)
+        elif self.norm_type == "layernorm":
+            p = self.ln(p)
+        return x_identity + p
+
+
diff --git a/i3D/resnet3d_xl.py b/i3D/resnet3d_xl.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4d1695507c7a9f2b232bd886ee87c5489ff899d
--- /dev/null
+++ b/i3D/resnet3d_xl.py
@@ -0,0 +1,456 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import math
+import numpy as np
+
+from functools import partial
+
+__all__ = [
+    'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+    'resnet152', 'resnet200',
+]
+
+
+def conv3x3x3(in_planes, out_planes, stride=1):
+    # 3x3x3 convolution with padding
+    return nn.Conv3d(
+        in_planes,
+        out_planes,
+        kernel_size=3,
+        stride=stride,
+        padding=1,
+        bias=False)
+
+
+def downsample_basic_block(x, planes, stride):
+    out = F.avg_pool3d(x, kernel_size=1, stride=stride)
+    zero_pads = torch.Tensor(
+        out.size(0), planes - out.size(1), out.size(2), out.size(3),
+        out.size(4)).zero_()
+    if isinstance(out.data, torch.cuda.FloatTensor):
+        zero_pads = zero_pads.cuda()
+
+    out = Variable(torch.cat([out.data, zero_pads], dim=1))
+
+    return out
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm3d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3x3(planes, planes)
+        self.bn2 = nn.BatchNorm3d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    conv_op = None
+    offset_groups = 1
+
+    def __init__(self, dim_in, dim_out, stride, dim_inner, group=1, use_temp_conv=1, temp_stride=1, dcn=False,
+                 shortcut_type='B'):
+        super(Bottleneck, self).__init__()
+        # 1 x 1 layer
+        self.with_dcn = dcn
+        self.conv1 = self.Conv3dBN(dim_in, dim_inner, (1 + use_temp_conv * 2, 1, 1), (temp_stride, 1, 1),
+                                   (use_temp_conv, 0, 0))
+        self.relu = nn.ReLU(inplace=True)
+        # 3 x 3 layer
+        self.conv2 = self.Conv3dBN(dim_inner, dim_inner, (1, 3, 3), (1, stride, stride), (0, 1, 1))
+        # 1 x 1 layer
+        self.conv3 = self.Conv3dBN(dim_inner, dim_out, (1, 1, 1), (1, 1, 1), (0, 0, 0))
+
+        self.shortcut_type = shortcut_type
+        self.dim_in = dim_in
+        self.dim_out = dim_out
+        self.temp_stride = temp_stride
+        self.stride = stride
+        # nn.Conv3d(dim_in, dim_out, (1,1,1),(temp_stride,stride,stride),(0,0,0))
+        if self.shortcut_type == 'B':
+            if self.dim_in == self.dim_out and self.temp_stride == 1 and self.stride == 1:  # or (self.dim_in == self.dim_out and self.dim_in == 64 and self.stride ==1):
+
+                pass
+            else:
+                # pass
+                self.shortcut = self.Conv3dBN(dim_in, dim_out, (1, 1, 1), (temp_stride, stride, stride), (0, 0, 0))
+
+        # nn.Conv3d(dim_in,dim_inner,kernel_size=(1+use_temp_conv*2,1,1),stride = (temp_stride,1,1),padding = )
+
+    def forward(self, x):
+        residual = x
+        out = self.conv1(x)
+        out = self.relu(out)
+        out = self.conv2(out)
+        out = self.relu(out)
+        out = self.conv3(out)
+        if self.dim_in == self.dim_out and self.temp_stride == 1 and self.stride == 1:
+            pass
+        else:
+            residual = self.shortcut(residual)
+        out += residual
+        out = self.relu(out)
+        return out
+
+    def Conv3dBN(self, dim_in, dim_out, kernels, strides, pads, group=1):
+        if self.with_dcn and kernels[0] > 1:
+            # use deformable conv
+            return nn.Sequential(
+                self.conv_op(dim_in, dim_out, kernel_size=kernels, stride=strides, padding=pads, bias=False,
+                             offset_groups=self.offset_groups),
+                nn.BatchNorm3d(dim_out)
+            )
+        else:
+            return nn.Sequential(
+                nn.Conv3d(dim_in, dim_out, kernel_size=kernels, stride=strides, padding=pads, bias=False),
+                nn.BatchNorm3d(dim_out)
+            )
+
+
+class ResNet(nn.Module):
+
+    def __init__(self,
+                 block,
+                 layers,
+                 use_temp_convs_set,
+                 temp_strides_set,
+                 sample_size,
+                 sample_duration,
+                 shortcut_type='B',
+                 num_classes=400,
+                 stage_with_dcn=(False, False, False, False),
+                 extract_features=False,
+                 loss_type='softmax'):
+        super(ResNet, self).__init__()
+        self.extract_features = extract_features
+        self.stage_with_dcn = stage_with_dcn
+        self.group = 1
+        self.width_per_group = 64
+        self.dim_inner = self.group * self.width_per_group
+        # self.shortcut_type = shortcut_type
+        self.conv1 = nn.Conv3d(
+            3,
+            64,
+            kernel_size=(1 + use_temp_convs_set[0][0] * 2, 7, 7),
+            stride=(temp_strides_set[0][0], 2, 2),
+            padding=(use_temp_convs_set[0][0], 3, 3),
+            bias=False)
+        self.bn1 = nn.BatchNorm3d(64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool1 = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 0, 0))
+        with_dcn = True if self.stage_with_dcn[0] else False
+        self.layer1 = self._make_layer(block, 64, 256, shortcut_type, stride=1, num_blocks=layers[0],
+                                       dim_inner=self.dim_inner, group=self.group, use_temp_convs=use_temp_convs_set[1],
+                                       temp_strides=temp_strides_set[1], dcn=with_dcn)
+        self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
+        with_dcn = True if self.stage_with_dcn[1] else False
+        self.layer2 = self._make_layer(block, 256, 512, shortcut_type, stride=2, num_blocks=layers[1],
+                                       dim_inner=self.dim_inner * 2, group=self.group,
+                                       use_temp_convs=use_temp_convs_set[2], temp_strides=temp_strides_set[2],
+                                       dcn=with_dcn)
+        with_dcn = True if self.stage_with_dcn[2] else False
+        self.layer3 = self._make_layer(block, 512, 1024, shortcut_type, stride=2, num_blocks=layers[2],
+                                       dim_inner=self.dim_inner * 4, group=self.group,
+                                       use_temp_convs=use_temp_convs_set[3], temp_strides=temp_strides_set[3],
+                                       dcn=with_dcn)
+        with_dcn = True if self.stage_with_dcn[3] else False
+        self.layer4 = self._make_layer(block, 1024, 2048, shortcut_type, stride=1, num_blocks=layers[3],
+                                       dim_inner=self.dim_inner * 8, group=self.group,
+                                       use_temp_convs=use_temp_convs_set[4], temp_strides=temp_strides_set[4],
+                                       dcn=with_dcn)
+        last_duration = int(math.ceil(sample_duration / 2))  # int(math.ceil(sample_duration / 8))
+        last_size = int(math.ceil(sample_size / 16))
+        # self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size), stride=1) #nn.AdaptiveAvgPool3d((1, 1, 1)) #
+        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
+        self.dropout = torch.nn.Dropout(p=0.5)
+        self.classifier = nn.Linear(2048, num_classes)
+
+        for m in self.modules():
+            # if isinstance(m, nn.Conv3d):
+            #     m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
+            # elif isinstance(m,nn.Linear):
+            #    m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
+            # elif 
+            if isinstance(m, nn.BatchNorm3d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+    def _make_layer(self, block, dim_in, dim_out, shortcut_type, stride, num_blocks, dim_inner=None, group=None,
+                    use_temp_convs=None, temp_strides=None, dcn=False):
+        if use_temp_convs is None:
+            use_temp_convs = np.zeros(num_blocks).astype(int)
+        if temp_strides is None:
+            temp_strides = np.ones(num_blocks).astype(int)
+        if len(use_temp_convs) < num_blocks:
+            for _ in range(num_blocks - len(use_temp_convs)):
+                use_temp_convs.append(0)
+                temp_strides.append(1)
+        layers = []
+        for idx in range(num_blocks):
+            block_stride = 2 if (idx == 0 and stride == 2) else 1
+
+            layers.append(
+                block(dim_in, dim_out, block_stride, dim_inner, group, use_temp_convs[idx], temp_strides[idx], dcn))
+            dim_in = dim_out
+        return nn.Sequential(*layers)
+
+    def forward_single(self, x):
+        x = self.conv1(x)
+
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool1(x)
+
+        x = self.layer1(x)
+        x = self.maxpool2(x)
+        x = self.layer2(x)
+
+        x = self.layer3(x)
+        features = self.layer4(x)
+
+        x = self.avgpool(features)
+
+        x = x.view(x.size(0), -1)
+        x = self.dropout(x)
+
+        y = self.classifier(x)
+        if self.extract_features:
+            return y, features
+        else:
+            return y
+
+    def forward_multi(self, x):
+        clip_preds = []
+        # import ipdb;ipdb.set_trace()
+        for clip_idx in range(x.shape[1]):  # B, 10, 3, 3, 32, 224, 224
+            spatial_crops = []
+            for crop_idx in range(x.shape[2]):
+                clip = x[:, clip_idx, crop_idx]
+                clip = self.forward_single(clip)
+                spatial_crops.append(clip)
+            spatial_crops = torch.stack(spatial_crops, 1).mean(1)  # (B, 400)
+            clip_preds.append(spatial_crops)
+        clip_preds = torch.stack(clip_preds, 1).mean(1)  # (B, 400)
+        return clip_preds
+
+    def forward(self, x):
+
+        # 5D tensor == single clip
+        if x.dim() == 5:
+            pred = self.forward_single(x)
+
+        # 7D tensor == 3 crops/10 clips
+        elif x.dim() == 7:
+            pred = self.forward_multi(x)
+
+        # loss_dict = {}
+        # if 'label' in batch:
+        #     loss = F.cross_entropy(pred, batch['label'], reduction='none')
+        #     loss_dict = {'clf': loss}
+
+        return pred
+
+
+def get_fine_tuning_parameters(model, ft_begin_index):
+    if ft_begin_index == 0:
+        return model.parameters()
+
+    ft_module_names = []
+    for i in range(ft_begin_index, 5):
+        ft_module_names.append('layer{}'.format(i))
+    ft_module_names.append('fc')
+    # import ipdb;ipdb.set_trace()
+    parameters = []
+    for k, v in model.named_parameters():
+        for ft_module in ft_module_names:
+            if ft_module in k:
+                parameters.append({'params': v})
+                break
+        else:
+            parameters.append({'params': v, 'lr': 0.0})
+
+    return parameters
+
+
+def obtain_arc(arc_type):
+    # c2d, ResNet50
+    if arc_type == 1:
+        use_temp_convs_1 = [0]
+        temp_strides_1 = [2]
+        use_temp_convs_2 = [0, 0, 0]
+        temp_strides_2 = [1, 1, 1]
+        use_temp_convs_3 = [0, 0, 0, 0]
+        temp_strides_3 = [1, 1, 1, 1]
+        use_temp_convs_4 = [0, ] * 6
+        temp_strides_4 = [1, ] * 6
+        use_temp_convs_5 = [0, 0, 0]
+        temp_strides_5 = [1, 1, 1]
+
+    # i3d, ResNet50
+    if arc_type == 2:
+        use_temp_convs_1 = [2]
+        temp_strides_1 = [1]
+        use_temp_convs_2 = [1, 1, 1]
+        temp_strides_2 = [1, 1, 1]
+        use_temp_convs_3 = [1, 0, 1, 0]
+        temp_strides_3 = [1, 1, 1, 1]
+        use_temp_convs_4 = [1, 0, 1, 0, 1, 0]
+        temp_strides_4 = [1, 1, 1, 1, 1, 1]
+        use_temp_convs_5 = [0, 1, 0]
+        temp_strides_5 = [1, 1, 1]
+
+    # c2d, ResNet101
+    if arc_type == 3:
+        use_temp_convs_1 = [0]
+        temp_strides_1 = [2]
+        use_temp_convs_2 = [0, 0, 0]
+        temp_strides_2 = [1, 1, 1]
+        use_temp_convs_3 = [0, 0, 0, 0]
+        temp_strides_3 = [1, 1, 1, 1]
+        use_temp_convs_4 = [0, ] * 23
+        temp_strides_4 = [1, ] * 23
+        use_temp_convs_5 = [0, 0, 0]
+        temp_strides_5 = [1, 1, 1]
+
+    # i3d, ResNet101
+    if arc_type == 4:
+        use_temp_convs_1 = [2]
+        temp_strides_1 = [2]
+        use_temp_convs_2 = [1, 1, 1]
+        temp_strides_2 = [1, 1, 1]
+        use_temp_convs_3 = [1, 0, 1, 0]
+        temp_strides_3 = [1, 1, 1, 1]
+        use_temp_convs_4 = []
+        for i in range(23):
+            if i % 2 == 0:
+                use_temp_convs_4.append(1)
+            else:
+                use_temp_convs_4.append(0)
+
+        temp_strides_4 = [1, ] * 23
+        use_temp_convs_5 = [0, 1, 0]
+        temp_strides_5 = [1, 1, 1]
+
+    use_temp_convs_set = [use_temp_convs_1, use_temp_convs_2, use_temp_convs_3, use_temp_convs_4, use_temp_convs_5]
+    temp_strides_set = [temp_strides_1, temp_strides_2, temp_strides_3, temp_strides_4, temp_strides_5]
+
+    return use_temp_convs_set, temp_strides_set
+
+
+def resnet10(**kwargs):
+    """Constructs a ResNet-18 model.
+    """
+    use_temp_convs_set = []
+    temp_strides_set = []
+    model = ResNet(BasicBlock, [1, 1, 1, 1], use_temp_convs_set, temp_strides_set, **kwargs)
+    return model
+
+
+def resnet18(**kwargs):
+    """Constructs a ResNet-18 model.
+    """
+    use_temp_convs_set = []
+    temp_strides_set = []
+    model = ResNet(BasicBlock, [2, 2, 2, 2], use_temp_convs_set, temp_strides_set, **kwargs)
+    return model
+
+
+def resnet34(**kwargs):
+    """Constructs a ResNet-34 model.
+    """
+    use_temp_convs_set = []
+    temp_strides_set = []
+    model = ResNet(BasicBlock, [3, 4, 6, 3], use_temp_convs_set, temp_strides_set, **kwargs)
+    return model
+
+
+def resnet50(extract_features, **kwargs):
+    """Constructs a ResNet-50 model.
+    """
+    use_temp_convs_set, temp_strides_set = obtain_arc(2)
+    model = ResNet(Bottleneck, [3, 4, 6, 3], use_temp_convs_set, temp_strides_set,
+                   extract_features=extract_features, **kwargs)
+    return model
+
+
+def resnet101(**kwargs):
+    """Constructs a ResNet-101 model.
+    """
+    use_temp_convs_set, temp_strides_set = obtain_arc(4)
+    model = ResNet(Bottleneck, [3, 4, 23, 3], use_temp_convs_set, temp_strides_set, **kwargs)
+    return model
+
+
+def resnet152(**kwargs):
+    """Constructs a ResNet-101 model.
+    """
+    use_temp_convs_set = []
+    temp_strides_set = []
+    model = ResNet(Bottleneck, [3, 8, 36, 3], use_temp_convs_set, temp_strides_set, **kwargs)
+    return model
+
+
+def resnet200(**kwargs):
+    """Constructs a ResNet-101 model.
+    """
+    use_temp_convs_set = []
+    temp_strides_set = []
+    model = ResNet(Bottleneck, [3, 24, 36, 3], use_temp_convs_set, temp_strides_set, **kwargs)
+    return model
+
+
+def Net(num_classes, extract_features=False, loss_type='softmax',
+        weights=None, freeze_all_but_cls=False):
+    net = globals()['resnet' + str(50)](
+        num_classes=num_classes,
+        sample_size=50,
+        sample_duration=32,
+        extract_features=extract_features,
+        loss_type=loss_type,
+    )
+
+    if weights is not None:
+        kinetics_weights = torch.load(weights)['state_dict']
+        print("Found weights in {}.".format(weights))
+        cls_name = 'fc'
+    else:
+        kinetics_weights = torch.load('i3D/kinetics-res50.pth')
+        cls_name = 'fc'
+        print('\n Restoring Kintetics \n')
+
+    new_weights = {}
+    for k, v in kinetics_weights.items():
+        if not k.startswith('module.' + cls_name):
+            new_weights[k.replace('module.', '')] = v
+    net.load_state_dict(new_weights, strict=False)
+
+    if freeze_all_but_cls:
+        for name, par in net.named_parameters():
+            if not name.startswith('classifier'):
+                par.requires_grad = False
+    return net
diff --git a/regen_frame_fv.py b/regen_frame_fv.py
new file mode 100644
index 0000000000000000000000000000000000000000..59e6ab7ec83586e589567607f5241c4956efffe7
--- /dev/null
+++ b/regen_frame_fv.py
@@ -0,0 +1,130 @@
+# -*- coding: utf-8 -*-
+import os
+import sys
+import glob
+import datetime
+import argparse
+import random
+
+import numpy as np
+
+from pathlib import Path
+filepath = Path.cwd()
+sys.path.append(filepath)
+from video_loaders import load_av
+from se_bb_from_np import annot_np
+from SmthSequence import SmthSequence
+from SmthFrameRelations import frame_relations
+
+from PIL import Image
+import cv2
+import torch
+from i3D.model import VideoModel
+from i3D.model_lib import VideoModelGlobalCoordLatent
+import i3D.gtransforms as gtransforms
+
+class FrameFV:
+        def __init__(self,path,args):
+                self.anno = annot_np(path)
+                self.net = VideoModelGlobalCoordLatent(args)
+                self.pre_resize_shape = (224, 224)
+                self.random_crop = gtransforms.GroupMultiScaleCrop(output_size=224,
+                                                                   scales=[1],
+                                                                   max_distort=0,
+                                                                   center_crop_only=True)
+        
+        def process_video(self,finput,verbose=False):
+                # get video id
+                vidnum = int(os.path.splitext(os.path.basename(finput))[0])
+                
+                # load video to ndarray list
+                img_array = load_av(finput)
+                print(img_array[0].shape)
+                #for i in range(len(img_array)):
+                #    img_array[i] = cv2.resize(img_array[i],self.pre_resize_shape)
+                img_array = [cv2.resize(img, (self.pre_resize_shape[1], self.pre_resize_shape[0])) for img in img_array]
+                
+                rs = []
+                gs = []
+                bs = []
+                for i in range(len(img_array)//3):
+                    B, R, G = cv2.split(img_array[i])
+                    rs.append(R)
+                    gs.append(G)
+                    bs.append(B)
+                frames = [rs, gs, bs]
+                    
+                #frames = [Image.fromarray(img.astype('uint8'), 'RGB') for img in img_array]
+                #frames, (offset_h, offset_w, crop_h, crop_w) = self.random_crop(frames)
+                
+                # read frame annotations into Sequence
+                seq = SmthSequence()
+                for framenum in range(0,len(img_array)):
+                    cats, bbs = self.anno.get_vf_bbx(vidnum, framenum+1)
+                    # add detections to Sequence
+                    for i in range(0,len(cats)):
+                        seq.add(framenum, cats[i], bbs[i])
+                
+                # compute object relations per frame
+                relations = []
+                for framenum in range(0,len(img_array)):
+                    fv = frame_relations(seq, 0, 1, framenum)
+                    relations.append(fv)
+                relations  = np.asarray(relations)
+                
+                # TODO bb category embedding per frame
+                
+                # i3D features per frame
+                #clip = torch.from_numpy(np.asarray([[img_array[0],img_array[1],img_array[2]]]))
+                clip = torch.from_numpy(np.asarray([frames]))
+                #clip = img_array
+                print(clip.shape)
+                clip = clip.float()
+                glo, vid = self.net.i3D(clip)
+                
+                videos_features = self.net.conv(vid)
+                
+                print(glo.shape)
+                print(vid.shape)
+                
+                print(videos_features.shape)
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        '--annotations',
+        dest='path_to_annotations',
+        default='../annotations_ground/',
+        help='folder to load annotations from')
+    parser.add_argument(
+        '--video',
+        dest='path_to_video',
+        default='.',
+        help='video to load')
+        
+    # begin import
+    parser.add_argument('--img_feature_dim', default=256, type=int, metavar='N',
+                    help='intermediate feature dimension for image-based features')
+    parser.add_argument('--coord_feature_dim', default=128, type=int, metavar='N',
+                        help='intermediate feature dimension for coord-based features')
+    parser.add_argument('--size', default=224, type=int, metavar='N',
+                        help='primary image input size')
+    parser.add_argument('--batch_size', '-b', default=72, type=int,
+                        metavar='N', help='mini-batch size (default: 72)')
+    parser.add_argument('--num_classes', default=50, type=int,
+                        help='num of class in the model')
+    parser.add_argument('--num_boxes', default=4, type=int,
+                        help='num of boxes for each image')
+    parser.add_argument('--num_frames', default=36, type=int,
+                        help='num of frames for the model')
+    parser.add_argument('--fine_tune', help='path with ckpt to restore')
+    parser.add_argument('--restore_i3d')
+    parser.add_argument('--restore_custom')
+    # end import
+    
+    args = parser.parse_args()
+    
+    compfv = FrameFV(args.path_to_annotations, args)
+    fv = compfv.process_video(args.path_to_video, verbose=True)
+    
+    print("fin")