Video Recognition using SlowFast and OpenVINO™¶
This Jupyter notebook can be launched on-line, opening an interactive environment in a browser window. You can also make a local installation. Choose one of the following options:
Teaching machines to detect, understand and analyze the contents of images has been one of the more well-known and well-studied problems in computer vision. However, analyzing videos to understand what is happening in them and detecting objects of interest are equally important and challenging tasks that have widespread applications in several areas including autonomous driving, healthcare, security, and many more.
The SlowFast model puts forth an interesting approach to analyzing videos based on the intuition that videos typically contain static as well as dynamic elements- use a slow pathway operating at a low frame rate to analyze the static content and a fast pathway operating at a high frame rate to capture dynamic content. Its strength lies in its ability to effectively capture both fast and slow-motion information in video sequences, making it particularly well-suited to tasks that require a temporal and spatial understanding of the data.
More details about the network can be found in the original paper and repository.
In this notebook, we will see how to convert and run a ResNet-50 based SlowFast model using OpenVINO.
This tutorial consists of the following steps
Preparing the PyTorch model
Download and prepare data
Check inference with the PyTorch model
Convert Model to OpenVINO Intermediate Representation
Verify inference with the converted model
Table of contents:¶
Prepare PyTorch Model¶
Install necessary packages¶
%pip install -q "openvino>=2023.3.0" fvcore --extra-index-url https://download.pytorch.org/whl/cpu
Note: you may need to restart the kernel to use updated packages.
Imports and Settings¶
import json
import math
import sys
import cv2
import torch
import numpy as np
from pathlib import Path
from typing import Any, List, Dict
from IPython.display import Video
import openvino as ov
sys.path.append("../utils")
from notebook_utils import download_file
DATA_DIR = Path("data/")
MODEL_DIR = Path("model/")
MODEL_DIR.mkdir(exist_ok=True)
DATA_DIR.mkdir(exist_ok=True)
To begin, we download the PyTorch model from the PyTorchVideo repository. In this notebook, we will be using a SlowFast Network based on the ResNet-50 architecture trained on the Kinetics 400 dataset. Kinetics 400 is a large-scale dataset for action recognition in videos, containing 400 human action classes, with at least 400 video clips for each action. Read more about the dataset and the paper here.
MODEL_NAME = "slowfast_r50"
MODEL_REPOSITORY = "facebookresearch/pytorchvideo"
DEVICE = "cpu"
# load the pretrained model from the repository
model = torch.hub.load(
repo_or_dir=MODEL_REPOSITORY, model=MODEL_NAME, pretrained=True, skip_validation=True
)
# set the device to allocate tensors to. for example, "cpu" or "cuda"
model.to(DEVICE)
# set the model to eval mode
model.eval()
Using cache found in /opt/home/k8sworker/.cache/torch/hub/facebookresearch_pytorchvideo_main
Net(
(blocks): ModuleList(
(0): MultiPathWayWithFuse(
(multipathway_blocks): ModuleList(
(0): ResNetBasicStem(
(conv): Conv3d(3, 64, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False)
(norm): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
(pool): MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=[0, 1, 1], dilation=1, ceil_mode=False)
)
(1): ResNetBasicStem(
(conv): Conv3d(3, 8, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False)
(norm): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
(pool): MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=[0, 1, 1], dilation=1, ceil_mode=False)
)
)
(multipathway_fusion): FuseFastToSlow(
(conv_fast_to_slow): Conv3d(8, 16, kernel_size=(7, 1, 1), stride=(4, 1, 1), padding=(3, 0, 0), bias=False)
(norm): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
)
(1): MultiPathWayWithFuse(
(multipathway_blocks): ModuleList(
(0): ResStage(
(res_blocks): ModuleList(
(0): ResBlock(
(branch1_conv): Conv3d(80, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(branch1_norm): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(branch2): BottleneckBlock(
(conv_a): Conv3d(80, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_a): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
(1-2): 2 x ResBlock(
(branch2): BottleneckBlock(
(conv_a): Conv3d(256, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_a): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
)
)
(1): ResStage(
(res_blocks): ModuleList(
(0): ResBlock(
(branch1_conv): Conv3d(8, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(branch1_norm): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(branch2): BottleneckBlock(
(conv_a): Conv3d(8, 8, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
(norm_a): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(8, 8, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(8, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
(1-2): 2 x ResBlock(
(branch2): BottleneckBlock(
(conv_a): Conv3d(32, 8, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
(norm_a): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(8, 8, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(8, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
)
)
)
(multipathway_fusion): FuseFastToSlow(
(conv_fast_to_slow): Conv3d(32, 64, kernel_size=(7, 1, 1), stride=(4, 1, 1), padding=(3, 0, 0), bias=False)
(norm): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
)
(2): MultiPathWayWithFuse(
(multipathway_blocks): ModuleList(
(0): ResStage(
(res_blocks): ModuleList(
(0): ResBlock(
(branch1_conv): Conv3d(320, 512, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False)
(branch1_norm): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(branch2): BottleneckBlock(
(conv_a): Conv3d(320, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_a): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(128, 128, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(128, 512, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
(1-3): 3 x ResBlock(
(branch2): BottleneckBlock(
(conv_a): Conv3d(512, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_a): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(128, 128, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(128, 512, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
)
)
(1): ResStage(
(res_blocks): ModuleList(
(0): ResBlock(
(branch1_conv): Conv3d(32, 64, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False)
(branch1_norm): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(branch2): BottleneckBlock(
(conv_a): Conv3d(32, 16, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
(norm_a): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(16, 16, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(16, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
(1-3): 3 x ResBlock(
(branch2): BottleneckBlock(
(conv_a): Conv3d(64, 16, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
(norm_a): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(16, 16, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(16, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
)
)
)
(multipathway_fusion): FuseFastToSlow(
(conv_fast_to_slow): Conv3d(64, 128, kernel_size=(7, 1, 1), stride=(4, 1, 1), padding=(3, 0, 0), bias=False)
(norm): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
)
(3): MultiPathWayWithFuse(
(multipathway_blocks): ModuleList(
(0): ResStage(
(res_blocks): ModuleList(
(0): ResBlock(
(branch1_conv): Conv3d(640, 1024, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False)
(branch1_norm): BatchNorm3d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(branch2): BottleneckBlock(
(conv_a): Conv3d(640, 256, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
(norm_a): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(256, 256, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(256, 1024, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
(1-5): 5 x ResBlock(
(branch2): BottleneckBlock(
(conv_a): Conv3d(1024, 256, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
(norm_a): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(256, 256, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(256, 1024, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
)
)
(1): ResStage(
(res_blocks): ModuleList(
(0): ResBlock(
(branch1_conv): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False)
(branch1_norm): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(branch2): BottleneckBlock(
(conv_a): Conv3d(64, 32, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
(norm_a): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(32, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
(1-5): 5 x ResBlock(
(branch2): BottleneckBlock(
(conv_a): Conv3d(128, 32, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
(norm_a): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(32, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(32, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
)
)
)
(multipathway_fusion): FuseFastToSlow(
(conv_fast_to_slow): Conv3d(128, 256, kernel_size=(7, 1, 1), stride=(4, 1, 1), padding=(3, 0, 0), bias=False)
(norm): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
)
(4): MultiPathWayWithFuse(
(multipathway_blocks): ModuleList(
(0): ResStage(
(res_blocks): ModuleList(
(0): ResBlock(
(branch1_conv): Conv3d(1280, 2048, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False)
(branch1_norm): BatchNorm3d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(branch2): BottleneckBlock(
(conv_a): Conv3d(1280, 512, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
(norm_a): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(512, 512, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(512, 2048, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
(1-2): 2 x ResBlock(
(branch2): BottleneckBlock(
(conv_a): Conv3d(2048, 512, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
(norm_a): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(512, 512, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(512, 2048, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
)
)
(1): ResStage(
(res_blocks): ModuleList(
(0): ResBlock(
(branch1_conv): Conv3d(128, 256, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False)
(branch1_norm): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(branch2): BottleneckBlock(
(conv_a): Conv3d(128, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
(norm_a): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
(1-2): 2 x ResBlock(
(branch2): BottleneckBlock(
(conv_a): Conv3d(256, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
(norm_a): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_a): ReLU()
(conv_b): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
(norm_b): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_b): ReLU()
(conv_c): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
(norm_c): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(activation): ReLU()
)
)
)
)
(multipathway_fusion): Identity()
)
(5): PoolConcatPathway(
(pool): ModuleList(
(0): AvgPool3d(kernel_size=(8, 7, 7), stride=(1, 1, 1), padding=(0, 0, 0))
(1): AvgPool3d(kernel_size=(32, 7, 7), stride=(1, 1, 1), padding=(0, 0, 0))
)
)
(6): ResNetBasicHead(
(dropout): Dropout(p=0.5, inplace=False)
(proj): Linear(in_features=2304, out_features=400, bias=True)
(output_pool): AdaptiveAvgPool3d(output_size=1)
)
)
)
Now that we have loaded our pre-trained model, we will check inference with it. Since the model returns the detected class IDs, we download the ID to class label mapping for the Kinetics 400 dataset and load the mapping to a dict for later use.
CLASSNAMES_SOURCE = (
"https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json"
)
CLASSNAMES_FILE = "kinetics_classnames.json"
download_file(url=CLASSNAMES_SOURCE, directory=DATA_DIR, show_progress=True)
# load from json
with open(DATA_DIR / CLASSNAMES_FILE, "r") as f:
kinetics_classnames = json.load(f)
# load dict of id to class label mapping
kinetics_id_to_classname = {}
for k, v in kinetics_classnames.items():
kinetics_id_to_classname[v] = str(k).replace('"', "")
data/kinetics_classnames.json: 0.00B [00:00, ?B/s]
Let us download a sample video to run inference on, and take a look at the downloaded video.
VIDEO_SOURCE = "https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4"
VIDEO_NAME = "archery.mp4"
VIDEO_PATH = DATA_DIR / VIDEO_NAME
download_file(url=VIDEO_SOURCE, directory=DATA_DIR, show_progress=True)
Video(VIDEO_PATH, embed=True)
data/archery.mp4: 0%| | 0.00/536k [00:00<?, ?B/s]
The sample video requires some preprocessing before we can run inference on it. During preprocessing, the video is normalized and scaled to size. Additionally, this preprocessing pipeline also involves sampling frames from the video to pass through the two pathways. The slow pathway can be any convolutional network that uses a large temporal stride on the input frames. The fast pathway is another convolutional network that uses a temporal stride smaller by a factor alpha(\(\alpha\)). In our model, both pathways use a 3D ResNet model. We define the following helper functions to implement the preprocessing steps.
def scale_short_side(size: int, frame: np.ndarray) -> np.ndarray:
"""
Scale the short side of the frame to size and return a float
array.
"""
height = frame.shape[0]
width = frame.shape[1]
# return unchanged if short side already scaled
if (width <= height and width == size) or (height <= width and height == size):
return frame
new_width = size
new_height = size
if width < height:
new_height = int(math.floor((float(height) / width) * size))
else:
new_width = int(math.floor((float(width) / height) * size))
scaled = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
return scaled.astype(np.float32)
def center_crop(size: int, frame: np.ndarray) -> np.ndarray:
"""
Center crop the input frame to size.
"""
height = frame.shape[0]
width = frame.shape[1]
y_offset = int(math.ceil((height - size) / 2))
x_offset = int(math.ceil((width - size) / 2))
cropped = frame[y_offset:y_offset + size, x_offset:x_offset + size, :]
assert cropped.shape[0] == size, "Image height not cropped properly"
assert cropped.shape[1] == size, "Image width not cropped properly"
return cropped
def normalize(array: np.ndarray, mean: List[float], std: List[float]) -> np.ndarray:
"""
Normalize a given array by subtracting the mean and dividing the std.
"""
if array.dtype == np.uint8:
array = array.astype(np.float32)
array = array / 255.0
mean = np.array(mean, dtype=np.float32)
std = np.array(std, dtype=np.float32)
array = array - mean
array = array / std
return array
def pack_pathway_output(frames: np.ndarray, alpha: int = 4) -> List[np.ndarray]:
"""
Prepare output as a list of arrays, each corresponding
to a unique pathway.
"""
fast_pathway = frames
# Perform temporal sampling from the fast pathway.
slow_pathway = np.take(
frames,
indices=np.linspace(0, frames.shape[1] - 1, frames.shape[1] // alpha).astype(np.int_),
axis=1
)
frame_list = [slow_pathway, fast_pathway]
return frame_list
def process_inputs(
frames: List[np.ndarray],
num_frames: int,
crop_size: int,
mean: List[float],
std: List[float],
) -> List[np.ndarray]:
"""
Performs normalization and applies required transforms
to prepare the input frames and returns a list of arrays.
Specifically the following actions are performed
1. scale the short side of the frames
2. center crop the frames to crop_size
3. perform normalization by subtracting mean and dividing std
4. sample frames for specified num_frames
5. sample frames for slow and fast pathways
"""
inputs = [scale_short_side(size=crop_size, frame=frame) for frame in frames]
inputs = [center_crop(size=crop_size, frame=frame) for frame in inputs]
inputs = np.array(inputs).astype(np.float32) / 255
inputs = normalize(array=inputs, mean=mean, std=std)
# T H W C -> C T H W
inputs = inputs.transpose([3, 0, 1, 2])
# Sample frames for num_frames specified
indices = np.linspace(0, inputs.shape[1] - 1, num_frames).astype(np.int_)
inputs = np.take(inputs, indices=indices, axis=1)
# prepare pathways for the model
inputs = pack_pathway_output(inputs)
inputs = [np.expand_dims(inp, 0) for inp in inputs]
return inputs
Another helper method to run inference on a custom video using the given model.
def run_inference(
model: Any,
video_path: str,
top_k: int,
id_to_label_mapping: Dict[str, str],
num_frames: int,
sampling_rate: int,
crop_size: int,
mean: List[float],
std: List[float],
) -> List[str]:
"""
Run inference on the video given by video_path using the given model.
First, the video is loaded from source. Frames are collected, processed
and fed to the model. The top top_k predicted class IDs are converted to class
labels and returned as a list of strings.
"""
video_cap = cv2.VideoCapture(video_path)
frames = []
seq_length = num_frames * sampling_rate
# get the list of frames from the video
ret = True
while ret and len(frames) < seq_length:
ret, frame = video_cap.read()
frames.append(frame)
# prepare the inputs
inputs = process_inputs(
frames=frames, num_frames=num_frames, crop_size=crop_size, mean=mean, std=std
)
if isinstance(model, ov.CompiledModel):
# openvino compiled model
output_blob = model.output(0)
predictions = model(inputs)[output_blob]
else:
# pytorch model
predictions = model([torch.from_numpy(inp) for inp in inputs])
predictions = predictions.detach().cpu().numpy()
def softmax(x):
return (np.exp(x) / np.exp(x).sum(axis=None))
# apply activation
predictions = softmax(predictions)
# top k predicted class IDs
topk = 5
pred_classes = np.argsort(-1 * predictions, axis=1)[:, :topk]
# Map the predicted classes to the label names
pred_class_names = [id_to_label_mapping[int(i)] for i in pred_classes[0]]
return pred_class_names
We define model-specific parameters for processing the input and run inference using the same. The top 5 predictions can be seen below.
NUM_FRAMES = 32
SAMPLING_RATE = 2
CROP_SIZE = 256
MEAN = [0.45, 0.45, 0.45]
STD = [0.225, 0.225, 0.225]
TOP_K = 5
predictions = run_inference(
model=model,
video_path=str(VIDEO_PATH),
top_k=TOP_K,
id_to_label_mapping=kinetics_id_to_classname,
num_frames=NUM_FRAMES,
sampling_rate=SAMPLING_RATE,
crop_size=CROP_SIZE,
mean=MEAN,
std=STD,
)
print(f"Predicted labels: {', '.join(predictions)}")
Predicted labels: archery, throwing axe, playing paintball, golf driving, riding or walking with horse
Convert model to OpenVINO Intermediate Representation¶
Now that we have obtained our trained model and checked inference with
it, we export the PyTorch model to OpenVINO IR format. In this format,
the network is represented using two files: an xml
file describing
the network architecture and an accompanying binary file that stores
constant values such as convolution weights in a binary format. We can
use model conversion API for converting into IR format as follows. The
ov.convert_model
method returns an ov.Model
object that can
either be compiled and inferred or serialized.
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input):
return model(list(input))
dummy_input = [torch.randn((1, 3, 8, 256, 256)), torch.randn([1, 3, 32, 256, 256])]
model = ov.convert_model(ModelWrapper(model), example_input=(dummy_input,))
IR_PATH = MODEL_DIR / "slowfast-r50.xml"
# serialize model for saving IR
ov.save_model(model=model, output_model=str(IR_PATH), compress_to_fp16=False)
Next, we read and compile the serialized model using OpenVINO runtime.
The read_model
function expects the .bin
weights file to have
the same filename and be located in the same directory as the .xml
file. If the weights file has a different filename, it can be specified
using the weights
parameter.
core = ov.Core()
# read converted model
conv_model = core.read_model(str(IR_PATH))
Select inference device¶
select device from dropdown list for running inference using OpenVINO
import ipywidgets as widgets
device = widgets.Dropdown(
options=core.available_devices + ["AUTO"],
value='AUTO',
description='Device:',
disabled=False,
)
device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
# load model on device
compiled_model = core.compile_model(model=conv_model, device_name=device.value)
Verify Model Inference¶
Using the compiled model, we run inference on the same sample video and print the top 5 predictions again.
pred_class_names = run_inference(
model=compiled_model,
video_path=str(VIDEO_PATH),
top_k=TOP_K,
id_to_label_mapping=kinetics_id_to_classname,
num_frames=NUM_FRAMES,
sampling_rate=SAMPLING_RATE,
crop_size=CROP_SIZE,
mean=MEAN,
std=STD,
)
print(f"Predicted labels: {', '.join(pred_class_names)}")
Predicted labels: archery, throwing axe, playing paintball, golf driving, riding or walking with horse