diff --git a/predict_multi_video.py b/predict_multi_video.py new file mode 100644 index 0000000000000000000000000000000000000000..561ad64f40c8614cae2920f73b94370be7ce765e --- /dev/null +++ b/predict_multi_video.py @@ -0,0 +1,359 @@ +import os +import sys +import math +import pickle as pkl +import argparse +from pathlib import Path + +import cv2 +import numpy as np +import torch +import scipy.special + +#sys.path.append("..") +import models.i3d as models +from utils.misc import to_torch +from utils.imutils import im_to_numpy, im_to_torch, resize_generic +from utils.transforms import color_normalize + +from numpy import asarray +from numpy import savetxt + +import torchvision.transforms as transforms +import torchvision.transforms.functional as F +import json +import lzma +import pickle +import pathlib + +frame_pred_score = [] # information for start frame, which class is predicted, how much confidence. for ex: [1 exchange 0.14] + +def _parser(): + _parser = argparse.ArgumentParser(description="Helper script to run demo.") + _parser.add_argument("--checkpoint_path",type=Path,default= "checkpoint.pth.tar",help="Path to checkpoint.") + _parser.add_argument("--word_data_json",type=Path,default="vocab.json",help="Path to word data") + _parser.add_argument("--video_path",type=Path,default= "/vol/vssp/LF_datasets/mixedmode/EASIER-DSGS/SRF_daily_news/MASKED_VIDEOS/000090.2021-01-04.masked.mp4",help="Path to test video.") + _parser.add_argument("--output_dir",type=str,default= "Preds",help="Path to test video.") + _parser.add_argument("--video_start_frame",type=int,default=0,help="Start frame of test video.") + _parser.add_argument("--video_end_frame",type=int,default=0,help="End frame of test video. Default=0, in this case whole video is predicted.") + _parser.add_argument("--num_in_frames",type=int,default=16,help="Number of frames processed at a time by the model") + _parser.add_argument("--stride",type=int,default=1,help="Number of frames to stride when sliding window.") + _parser.add_argument("--batch_size",type=int,default=4,help="Maximum number of clips to put in each batch") + _parser.add_argument("--num_classes",type=int,default=2281,help="The number of classes predicted by the model") + _parser.add_argument("--topk",type=int,default=1,help="Top-k results to show.") + _parser.add_argument("--confidence",type=float,default=0,help="Only show predictions above certain confidence threshold [0, 1]") + _parser.add_argument("--resize_res",type=int,default=224,help="Spatial resolution of the network input.") + _parser.add_argument("--datasetname",type=str,default= "focusnews",help="focusnews | srf | bobsl") + _parser.add_argument("--include_embds",type=int,default=0,help="Whether to return the I3D embeddings.") + _parser.add_argument("--include_probs",type=int,default=0,help="Whether to save I3D class probs.") + return _parser.parse_args() + +def load_rgb_video(video_path: Path, video_start_frame: int, resize_res: int, batch_size: int, num_in_frames) -> torch.Tensor: + + videofile = os.path.join(video_path) + cap = cv2.VideoCapture(videofile) + cap.set(cv2.CAP_PROP_POS_FRAMES, video_start_frame) # cv2.cv2.CAP_PROP_POS_FRAMES + cap_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + video_end_frame = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + if cap_width > cap_height: # crop from the sides + center_crop = transforms.CenterCrop((cap_height, cap_height)) + elif cap_height > cap_width: + center_crop = transforms.CenterCrop((cap_width, cap_width)) + + f = 0 + rgb = [] + bound = min(num_in_frames + batch_size - 1, video_end_frame - video_start_frame) + for i, f in enumerate(range(0, bound)): # since stride = 1, I added (batch_size -1) + # frame: BGR, (h, w, 3), dtype=uint8 0..255 + ret, frame = cap.read() + if not ret: + print("Could not read: ", videofile) + print("i:", i) + break + # BGR (OpenCV) to RGB (Torch) + frame = frame[:, :, [2, 1, 0]] + rgb_t = im_to_torch(frame) + if cap_width != cap_height: + rgb_t = center_crop(rgb_t) + + rgb_t = F.resize(rgb_t, resize_res) + rgb.append(rgb_t) + + cap.release() + + # (nframes, 3, cap_height, cap_width) => (3, nframes, cap_height, cap_width) + rgb = torch.stack(rgb).permute(1, 0, 2, 3) + return rgb + + +def prepare_input( + rgb: torch.Tensor, + resize_res: int = 256, + inp_res: int = 224, + mean: torch.Tensor = 0.5 * torch.ones(3), std=1.0 * torch.ones(3), +): + """ + Process the video: + 1) Resize to [resize_res x resize_res] + 2) Center crop with [inp_res x inp_res] + 3) Color normalize using mean/std + """ + iC, iF, iH, iW = rgb.shape + rgb_resized = np.zeros((iF, resize_res, resize_res, iC)) + for t in range(iF): + tmp = rgb[:, t, :, :] + tmp = resize_generic( + im_to_numpy(tmp), resize_res, resize_res, interp="bilinear", is_flow=False + ) + rgb_resized[t] = tmp + rgb = np.transpose(rgb_resized, (3, 0, 1, 2)) + # Center crop coords + ulx = int((resize_res - inp_res) / 2) + uly = int((resize_res - inp_res) / 2) + # Crop 256x256 + rgb = rgb[:, :, uly : uly + inp_res, ulx : ulx + inp_res] + rgb = to_torch(rgb).float() + try: + assert rgb.max() <= 1 + except: + print(f"Max value of the tensor is {rgb.max()}") + rgb = color_normalize(rgb, mean, std) + return rgb + + +def load_model( + checkpoint_path: Path, + num_classes: int, + num_in_frames: int, + include_embds: int +) -> torch.nn.Module: + + model = models.InceptionI3d( + num_classes=num_classes, + spatiotemporal_squeeze=True, + final_endpoint="Logits", + name="inception_i3d", + in_channels=3, + dropout_keep_prob=0.5, + num_in_frames=num_in_frames, + activation_func="swish", + include_embds=include_embds + ) + model = torch.nn.DataParallel(model).cuda() + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint["state_dict"], strict=True) + model.eval() + return model + + +def load_vocab(word_data_json): + with open(word_data_json, "r") as read_file: + bobsl_word_data = json.load(read_file) + return bobsl_word_data + + +def sliding_windows( + rgb: torch.Tensor, + num_in_frames: int, + stride: int, +) -> tuple: + """ + Return sliding windows and corresponding (middle) timestamp + """ + C, nFrames, H, W = rgb.shape + # If needed, pad to the minimum clip length + if nFrames < num_in_frames: + rgb_ = torch.zeros(C, num_in_frames, H, W) + rgb_[:, :nFrames] = rgb + rgb_[:, nFrames:] = rgb[:, -1].unsqueeze(1) + rgb = rgb_ + nFrames = rgb.shape[1] + + num_clips = math.ceil((nFrames - num_in_frames) / stride) + 1 + plural = "" + if num_clips > 1: + plural = "s" + # print(f"{num_clips} clip{plural} resulted from sliding window processing.") + + rgb_slided = torch.zeros(num_clips, 3, num_in_frames, H, W) + t_mid = [] + # For each clip + for j in range(num_clips): + # Check if num_clips becomes 0 + actual_clip_length = min(num_in_frames, nFrames - j * stride) + if actual_clip_length == num_in_frames: + t_beg = j * stride + else: + t_beg = nFrames - num_in_frames + t_mid.append(t_beg + num_in_frames / 2) + rgb_slided[j] = rgb[:, t_beg : t_beg + num_in_frames, :, :] + return rgb_slided, np.array(t_mid) + + +def main( + checkpoint_path: Path, + video_path: Path, + output_dir: str, + video_start_frame: int, + video_end_frame:int, + word_data_json, + num_classes: int, + num_in_frames: int, + confidence: int, + batch_size: int, + stride: int, + topk: int, + resize_res: int, + datasetname: str, + include_embds: int, + include_probs: int +): + import time + start_time = time.time() + + model = load_model( + checkpoint_path=checkpoint_path, + num_classes=num_classes, + num_in_frames=num_in_frames, + include_embds=include_embds + ) + word_data = load_vocab(word_data_json=word_data_json) + + if video_end_frame == 0: + videofile = os.path.join(video_path) + cap = cv2.VideoCapture(videofile) + if not cap.isOpened(): + print(video_path) + raise('video does not exist') + else: + print(video_path, video_end_frame) + + video_end_frame = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + print(f"{video_end_frame} frames.\n") + + feature_dim = 1024 + all_preds = np.zeros((video_end_frame-num_in_frames+1, num_classes)) + all_embds = np.zeros((video_end_frame - num_in_frames + 1, feature_dim)) + + video_name = str(video_path).split("/")[-1] + if datasetname == "srf" or datasetname=="srf_dev" or datasetname=="srf_test": # since srf videos contain "." in video name, take the index[1] + video_name = video_name.split(".")[1] + elif datasetname in ["focusnews", "bobsl", "meinedgs", "focusnews_wmt", "easier_dgs"]: + video_name = video_name.split(".")[0] + elif datasetname == "easier_lsf": + video_name = video_name.split(".mp4")[0] + video_name.split(".mp4")[1] + else: + print("UNKNOWN dataset name! Could not process video name.") + return + + for start_frame in range(video_start_frame, video_end_frame-num_in_frames, stride + (batch_size-1)): + # print("START FRAME:", start_frame) + # if start_frame %1600 == 0: + # print("\r{}/{}".format(start_frame, video_end_frame), end='') + + rgb_orig = load_rgb_video( + # dataset_path=dataset_path, + video_path=video_path, + video_start_frame=start_frame, + num_in_frames=num_in_frames, + resize_res=resize_res, + batch_size=batch_size + ) + # Prepare: resize/crop/normalize + rgb_input = prepare_input(rgb_orig, resize_res) + # Sliding window + rgb_slides, t_mid = sliding_windows( + rgb=rgb_input, + stride=stride, + num_in_frames=num_in_frames, + ) + # Number of windows/clips + num_clips = rgb_slides.shape[0] + # Group the clips into batches + num_batches = math.ceil(num_clips / batch_size) + raw_scores = np.empty((0, num_classes), dtype=float) + embds = np.empty((0, feature_dim), dtype=float) + for b in range(num_batches): + inp = rgb_slides[b * batch_size : (b + 1) * batch_size] + # Forward pass + out = model(inp) + raw_scores = np.append(raw_scores, out["logits"].cpu().detach().numpy(), axis=0) + embds = np.append(embds, out["embds"].cpu().detach().numpy().squeeze(4).squeeze(3).squeeze(2), axis=0) + prob_scores = scipy.special.softmax(raw_scores, axis=1) + prob_sorted = np.sort(prob_scores, axis=1)[:, ::-1] + pred_sorted = np.argsort(prob_scores, axis=1)[:, ::-1] + + ### Predicted gloss names are appended to frame_pred_score[] + word_topk = np.empty((topk, num_clips), dtype=object) + for k in range(topk): + for i, p in enumerate(pred_sorted[:, k]): + word_topk[k, i] = word_data["words"][p] + prob_topk = prob_sorted[:, :topk].transpose() + + for b in range(num_clips): + if include_probs: + all_preds[start_frame + b] = prob_scores[b] + if include_embds: + all_embds[start_frame + b] = embds[b] + frame_pred_score.append([start_frame + b, word_topk[0][b], round(prob_topk[0][b], 2)]) + + #end_time = time.time() + + # Save predicted glosses and their confidence + pathlib.Path(output_dir + "/pred_glosses_confidences").mkdir(parents=True, exist_ok=True) + np.savetxt(f"{output_dir}/pred_glosses_confidences/{video_name}.csv", + frame_pred_score, + delimiter=",", + fmt='%s', + header="start_frame,predicted_gloss,confidence", + comments='') + + if include_probs: + pathlib.Path(output_dir + "/probs").mkdir(parents=True, exist_ok=True) + out_file = os.path.join(output_dir + "/probs", video_name + ".lzma.pkl") + all_preds = all_preds.astype(np.float16) + with lzma.open(out_file, "wb") as f: + pickle.dump(all_preds, f) + + if include_embds: + pathlib.Path(output_dir + "/embds").mkdir(parents=True, exist_ok=True) + out_file = os.path.join(output_dir + "/embds", video_name + ".lzma.pkl") + all_embds = all_embds.astype(np.float16) + with lzma.open(out_file, "wb") as f: + pickle.dump(all_embds, f) + + # # To load pkl file: + # with lzma.open(out_file, 'rb') as file: + # raw_data = file.read() + # data = pickle.loads(raw_data) + + return frame_pred_score + + + +if __name__ == "__main__": + print(_parser()) + input_video_path = os.listdir(_parser().video_path) + + for path in input_video_path: + path = os.path.join(_parser().video_path, path) + print("Processing video: ", path) + main(_parser().checkpoint_path, + path, + _parser().output_dir, + _parser().video_start_frame, + _parser().video_end_frame, + _parser().word_data_json, + _parser().num_classes, + _parser().num_in_frames, + _parser().confidence, + _parser().batch_size, + _parser().stride, + _parser().topk, + _parser().resize_res, + _parser().datasetname, + _parser().include_embds, + _parser().include_probs + ) diff --git a/videos/1180490/1180490_1a1.masked.mp4 b/videos/1180490/1180490_1a1.masked.mp4 deleted file mode 100644 index 508863ff0786e23b7f1ec4b7fdda265facb37b17..0000000000000000000000000000000000000000 Binary files a/videos/1180490/1180490_1a1.masked.mp4 and /dev/null differ diff --git a/videos/1180706/1180706_1a1.masked.mp4 b/videos/1180706/1180706_1a1.masked.mp4 deleted file mode 100644 index 1ab5b8f8f4025ec8f388c144dea322ff56c5879f..0000000000000000000000000000000000000000 Binary files a/videos/1180706/1180706_1a1.masked.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/ALRIGHT.mp4 b/videos/Test_Vid_2/ALRIGHT.mp4 deleted file mode 100644 index 52520a026fc5e2297d0c9db0e9b770978e5f0ec7..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/ALRIGHT.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/AND.mp4 b/videos/Test_Vid_2/AND.mp4 deleted file mode 100644 index 04e360700656426a436f033a934e3d86427bac9e..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/AND.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/ANGRY.mp4 b/videos/Test_Vid_2/ANGRY.mp4 deleted file mode 100644 index bca53cc818fae763b78fd28c168052120348c406..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/ANGRY.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/AWAY.mp4 b/videos/Test_Vid_2/AWAY.mp4 deleted file mode 100644 index e05f9636067a720ade725820b86631be211c98f4..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/AWAY.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/CHANGE.mp4 b/videos/Test_Vid_2/CHANGE.mp4 deleted file mode 100644 index f510a1e46e503b66f940f0d499bdb130fa28ff9c..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/CHANGE.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/FEEL.mp4 b/videos/Test_Vid_2/FEEL.mp4 deleted file mode 100644 index 0045db8655e4d52bd0930a8591f47b1cdc4ea962..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/FEEL.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/HAPPY.mp4 b/videos/Test_Vid_2/HAPPY.mp4 deleted file mode 100644 index 479ce6832035751d94d1881e3f462b90a880aa2e..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/HAPPY.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/HOLD.mp4 b/videos/Test_Vid_2/HOLD.mp4 deleted file mode 100644 index cf9b100f478cb853d7cfcccdaf6d5ba19b662645..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/HOLD.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/JUMP.mp4 b/videos/Test_Vid_2/JUMP.mp4 deleted file mode 100644 index 95ad882ce5b63f9adcc2470453fa190098206c69..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/JUMP.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/MAYBE.mp4 b/videos/Test_Vid_2/MAYBE.mp4 deleted file mode 100644 index da1b8b6fae1496c3c71ddb5a1d9a53b47242e611..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/MAYBE.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/MUST.mp4 b/videos/Test_Vid_2/MUST.mp4 deleted file mode 100644 index d25722bd25edbfac92507bf20692d3893c0f489b..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/MUST.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/NEXT.mp4 b/videos/Test_Vid_2/NEXT.mp4 deleted file mode 100644 index 9344b427b6d00a73a73de3a0a995a3dd262c6f36..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/NEXT.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/PERHAPS.mp4 b/videos/Test_Vid_2/PERHAPS.mp4 deleted file mode 100644 index e353fbf6c02175a31c722107cafcedbfeafe5e09..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/PERHAPS.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/PLAN.mp4 b/videos/Test_Vid_2/PLAN.mp4 deleted file mode 100644 index 0fde7515936eeafafd1dbcf0a9dd09903caf4131..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/PLAN.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/PROJECT.mp4 b/videos/Test_Vid_2/PROJECT.mp4 deleted file mode 100644 index 8ada62cef9f8ac99a105293b6de9cded3176cbbc..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/PROJECT.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/REMEMBER.mp4 b/videos/Test_Vid_2/REMEMBER.mp4 deleted file mode 100644 index 9fa3b50f34c19c4ca0f56f083b760d81e4126990..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/REMEMBER.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/RUN.mp4 b/videos/Test_Vid_2/RUN.mp4 deleted file mode 100644 index 8fb84b210a22f063554e097ca2e52b2aedb2a27d..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/RUN.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/SIGN.mp4 b/videos/Test_Vid_2/SIGN.mp4 deleted file mode 100644 index 996b8d71565bd8306b5c9ac6991c17be8dc343b9..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/SIGN.mp4 and /dev/null differ diff --git a/videos/Test_Vid_2/WE.mp4 b/videos/Test_Vid_2/WE.mp4 deleted file mode 100644 index b8c0d1594e0e1a4c183c8f4c18815f99ca7a9e65..0000000000000000000000000000000000000000 Binary files a/videos/Test_Vid_2/WE.mp4 and /dev/null differ