+
+ 

+
+
+
+
diff --git a/frame-interpolation-pytorch/export.py b/frame-interpolation-pytorch/export.py
new file mode 100644
index 0000000000000000000000000000000000000000..1414f768348f0412afa0f94a57ca5e68693c89e3
--- /dev/null
+++ b/frame-interpolation-pytorch/export.py
@@ -0,0 +1,155 @@
+import warnings
+
+import numpy as np
+import tensorflow as tf
+import torch
+
+from interpolator import Interpolator
+
+
+def translate_state_dict(var_dict, state_dict):
+ for name, (prev_name, weight) in zip(state_dict, var_dict.items()):
+ print('Mapping', prev_name, '->', name)
+ weight = torch.from_numpy(weight)
+ if 'kernel' in prev_name:
+ # Transpose the conv2d kernel weights, since TF uses (H, W, C, K) and PyTorch uses (K, C, H, W)
+ weight = weight.permute(3, 2, 0, 1)
+
+ assert state_dict[name].shape == weight.shape, f'Shape mismatch {state_dict[name].shape} != {weight.shape}'
+
+ state_dict[name] = weight
+
+
+def import_state_dict(interpolator: Interpolator, saved_model):
+ variables = saved_model.keras_api.variables
+
+ extract_dict = interpolator.extract.state_dict()
+ flow_dict = interpolator.predict_flow.state_dict()
+ fuse_dict = interpolator.fuse.state_dict()
+
+ extract_vars = {}
+ _flow_vars = {}
+ _fuse_vars = {}
+
+ for var in variables:
+ name = var.name
+ if name.startswith('feat_net'):
+ extract_vars[name[9:]] = var.numpy()
+ elif name.startswith('predict_flow'):
+ _flow_vars[name[13:]] = var.numpy()
+ elif name.startswith('fusion'):
+ _fuse_vars[name[7:]] = var.numpy()
+
+ # reverse order of modules to allow jit export
+ # TODO: improve this hack
+ flow_vars = dict(sorted(_flow_vars.items(), key=lambda x: x[0].split('/')[0], reverse=True))
+ fuse_vars = dict(sorted(_fuse_vars.items(), key=lambda x: int((x[0].split('/')[0].split('_')[1:] or [0])[0]) // 3, reverse=True))
+
+ assert len(extract_vars) == len(extract_dict), f'{len(extract_vars)} != {len(extract_dict)}'
+ assert len(flow_vars) == len(flow_dict), f'{len(flow_vars)} != {len(flow_dict)}'
+ assert len(fuse_vars) == len(fuse_dict), f'{len(fuse_vars)} != {len(fuse_dict)}'
+
+ for state_dict, var_dict in ((extract_dict, extract_vars), (flow_dict, flow_vars), (fuse_dict, fuse_vars)):
+ translate_state_dict(var_dict, state_dict)
+
+ interpolator.extract.load_state_dict(extract_dict)
+ interpolator.predict_flow.load_state_dict(flow_dict)
+ interpolator.fuse.load_state_dict(fuse_dict)
+
+
+def verify_debug_outputs(pt_outputs, tf_outputs):
+ max_error = 0
+ for name, predicted in pt_outputs.items():
+ if name == 'image':
+ continue
+ pred_frfp = [f.permute(0, 2, 3, 1).detach().cpu().numpy() for f in predicted]
+ true_frfp = [f.numpy() for f in tf_outputs[name]]
+
+ for i, (pred, true) in enumerate(zip(pred_frfp, true_frfp)):
+ assert pred.shape == true.shape, f'{name} {i} shape mismatch {pred.shape} != {true.shape}'
+ error = np.max(np.abs(pred - true))
+ max_error = max(max_error, error)
+ assert error < 1, f'{name} {i} max error: {error}'
+ print('Max intermediate error:', max_error)
+
+
+def test_model(interpolator, model, half=False, gpu=False):
+ torch.manual_seed(0)
+ time = torch.full((1, 1), .5)
+ x0 = torch.rand(1, 3, 256, 256)
+ x1 = torch.rand(1, 3, 256, 256)
+
+ x0_ = tf.convert_to_tensor(x0.permute(0, 2, 3, 1).numpy(), dtype=tf.float32)
+ x1_ = tf.convert_to_tensor(x1.permute(0, 2, 3, 1).numpy(), dtype=tf.float32)
+ time_ = tf.convert_to_tensor(time.numpy(), dtype=tf.float32)
+ tf_outputs = model({'x0': x0_, 'x1': x1_, 'time': time_}, training=False)
+
+ if half:
+ x0 = x0.half()
+ x1 = x1.half()
+ time = time.half()
+
+ if gpu and torch.cuda.is_available():
+ x0 = x0.cuda()
+ x1 = x1.cuda()
+ time = time.cuda()
+
+ with torch.no_grad():
+ pt_outputs = interpolator.debug_forward(x0, x1, time)
+
+ verify_debug_outputs(pt_outputs, tf_outputs)
+
+ with torch.no_grad():
+ prediction = interpolator(x0, x1, time)
+ output_color = prediction.permute(0, 2, 3, 1).detach().cpu().numpy()
+ true_color = tf_outputs['image'].numpy()
+ error = np.abs(output_color - true_color).max()
+
+ print('Color max error:', error)
+
+
+def main(model_path, save_path, export_to_torchscript=True, use_gpu=False, fp16=True, skiptest=False):
+ print(f'Exporting model to FP{["32", "16"][fp16]} {["state_dict", "torchscript"][export_to_torchscript]} '
+ f'using {"CG"[use_gpu]}PU')
+ model = tf.compat.v2.saved_model.load(model_path)
+ interpolator = Interpolator()
+ interpolator.eval()
+ import_state_dict(interpolator, model)
+
+ if use_gpu and torch.cuda.is_available():
+ interpolator = interpolator.cuda()
+ else:
+ use_gpu = False
+
+ if fp16:
+ interpolator = interpolator.half()
+ if export_to_torchscript:
+ interpolator = torch.jit.script(interpolator)
+ if export_to_torchscript:
+ interpolator.save(save_path)
+ else:
+ torch.save(interpolator.state_dict(), save_path)
+
+ if not skiptest:
+ if not use_gpu and fp16:
+ warnings.warn('Testing FP16 model on CPU is impossible, casting it back')
+ interpolator = interpolator.float()
+ fp16 = False
+ test_model(interpolator, model, fp16, use_gpu)
+
+
+if __name__ == '__main__':
+ import argparse
+
+ parser = argparse.ArgumentParser(description='Export frame-interpolator model to PyTorch state dict')
+
+ parser.add_argument('model_path', type=str, help='Path to the TF SavedModel')
+ parser.add_argument('save_path', type=str, help='Path to save the PyTorch state dict')
+ parser.add_argument('--statedict', action='store_true', help='Export to state dict instead of TorchScript')
+ parser.add_argument('--fp32', action='store_true', help='Save at full precision')
+ parser.add_argument('--skiptest', action='store_true', help='Skip testing and save model immediately instead')
+ parser.add_argument('--gpu', action='store_true', help='Use GPU')
+
+ args = parser.parse_args()
+
+ main(args.model_path, args.save_path, not args.statedict, args.gpu, not args.fp32, args.skiptest)
diff --git a/frame-interpolation-pytorch/feature_extractor.py b/frame-interpolation-pytorch/feature_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b41975291c64173e4d98619d0ea5f2ca67f3240
--- /dev/null
+++ b/frame-interpolation-pytorch/feature_extractor.py
@@ -0,0 +1,156 @@
+"""PyTorch layer for extracting image features for the film_net interpolator.
+
+The feature extractor implemented here converts an image pyramid into a pyramid
+of deep features. The feature pyramid serves a similar purpose as U-Net
+architecture's encoder, but we use a special cascaded architecture described in
+Multi-view Image Fusion [1].
+
+For comprehensiveness, below is a short description of the idea. While the
+description is a bit involved, the cascaded feature pyramid can be used just
+like any image feature pyramid.
+
+Why cascaded architeture?
+=========================
+To understand the concept it is worth reviewing a traditional feature pyramid
+first: *A traditional feature pyramid* as in U-net or in many optical flow
+networks is built by alternating between convolutions and pooling, starting
+from the input image.
+
+It is well known that early features of such architecture correspond to low
+level concepts such as edges in the image whereas later layers extract
+semantically higher level concepts such as object classes etc. In other words,
+the meaning of the filters in each resolution level is different. For problems
+such as semantic segmentation and many others this is a desirable property.
+
+However, the asymmetric features preclude sharing weights across resolution
+levels in the feature extractor itself and in any subsequent neural networks
+that follow. This can be a downside, since optical flow prediction, for
+instance is symmetric across resolution levels. The cascaded feature
+architecture addresses this shortcoming.
+
+How is it built?
+================
+The *cascaded* feature pyramid contains feature vectors that have constant
+length and meaning on each resolution level, except few of the finest ones. The
+advantage of this is that the subsequent optical flow layer can learn
+synergically from many resolutions. This means that coarse level prediction can
+benefit from finer resolution training examples, which can be useful with
+moderately sized datasets to avoid overfitting.
+
+The cascaded feature pyramid is built by extracting shallower subtree pyramids,
+each one of them similar to the traditional architecture. Each subtree
+pyramid S_i is extracted starting from each resolution level:
+
+image resolution 0 -> S_0
+image resolution 1 -> S_1
+image resolution 2 -> S_2
+...
+
+If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid
+is constructed by concatenating features as follows (assuming subtree depth=3):
+
+lvl
+feat_0 = concat( S_0_0 )
+feat_1 = concat( S_1_0 S_0_1 )
+feat_2 = concat( S_2_0 S_1_1 S_0_2 )
+feat_3 = concat( S_3_0 S_2_1 S_1_2 )
+feat_4 = concat( S_4_0 S_3_1 S_2_2 )
+feat_5 = concat( S_5_0 S_4_1 S_3_2 )
+ ....
+
+In above, all levels except feat_0 and feat_1 have the same number of features
+with similar semantic meaning. This enables training a single optical flow
+predictor module shared by levels 2,3,4,5... . For more details and evaluation
+see [1].
+
+[1] Multi-view Image Fusion, Trinidad et al. 2019
+"""
+from typing import List
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from util import Conv2d
+
+
+class SubTreeExtractor(nn.Module):
+ """Extracts a hierarchical set of features from an image.
+
+ This is a conventional, hierarchical image feature extractor, that extracts
+ [k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels.
+ Each level is followed by average pooling.
+ """
+
+ def __init__(self, in_channels=3, channels=64, n_layers=4):
+ super().__init__()
+ convs = []
+ for i in range(n_layers):
+ convs.append(nn.Sequential(
+ Conv2d(in_channels, (channels << i), 3),
+ Conv2d((channels << i), (channels << i), 3)
+ ))
+ in_channels = channels << i
+ self.convs = nn.ModuleList(convs)
+
+ def forward(self, image: torch.Tensor, n: int) -> List[torch.Tensor]:
+ """Extracts a pyramid of features from the image.
+
+ Args:
+ image: TORCH.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS.
+ n: number of pyramid levels to extract. This can be less or equal to
+ options.sub_levels given in the __init__.
+ Returns:
+ The pyramid of features, starting from the finest level. Each element
+ contains the output after the last convolution on the corresponding
+ pyramid level.
+ """
+ head = image
+ pyramid = []
+ for i, layer in enumerate(self.convs):
+ head = layer(head)
+ pyramid.append(head)
+ if i < n - 1:
+ head = F.avg_pool2d(head, kernel_size=2, stride=2)
+ return pyramid
+
+
+class FeatureExtractor(nn.Module):
+ """Extracts features from an image pyramid using a cascaded architecture.
+ """
+
+ def __init__(self, in_channels=3, channels=64, sub_levels=4):
+ super().__init__()
+ self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels)
+ self.sub_levels = sub_levels
+
+ def forward(self, image_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
+ """Extracts a cascaded feature pyramid.
+
+ Args:
+ image_pyramid: Image pyramid as a list, starting from the finest level.
+ Returns:
+ A pyramid of cascaded features.
+ """
+ sub_pyramids: List[List[torch.Tensor]] = []
+ for i in range(len(image_pyramid)):
+ # At each level of the image pyramid, creates a sub_pyramid of features
+ # with 'sub_levels' pyramid levels, re-using the same SubTreeExtractor.
+ # We use the same instance since we want to share the weights.
+ #
+ # However, we cap the depth of the sub_pyramid so we don't create features
+ # that are beyond the coarsest level of the cascaded feature pyramid we
+ # want to generate.
+ capped_sub_levels = min(len(image_pyramid) - i, self.sub_levels)
+ sub_pyramids.append(self.extract_sublevels(image_pyramid[i], capped_sub_levels))
+ # Below we generate the cascades of features on each level of the feature
+ # pyramid. Assuming sub_levels=3, The layout of the features will be
+ # as shown in the example on file documentation above.
+ feature_pyramid: List[torch.Tensor] = []
+ for i in range(len(image_pyramid)):
+ features = sub_pyramids[i][0]
+ for j in range(1, self.sub_levels):
+ if j <= i:
+ features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
+ feature_pyramid.append(features)
+ return feature_pyramid
diff --git a/frame-interpolation-pytorch/film_net_fp16.pt b/frame-interpolation-pytorch/film_net_fp16.pt
new file mode 100644
index 0000000000000000000000000000000000000000..e2695211566846c6137de304743e5e4b5dd56739
--- /dev/null
+++ b/frame-interpolation-pytorch/film_net_fp16.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5d48a9c8f1032f046d7dfcbed40299d51e615b4bd8bbfbb36a83c9a49c76aca9
+size 69048401
diff --git a/frame-interpolation-pytorch/film_net_fp32.pt b/frame-interpolation-pytorch/film_net_fp32.pt
new file mode 100644
index 0000000000000000000000000000000000000000..2691162477f27fe5e3cd4c69890fa2c28be27713
--- /dev/null
+++ b/frame-interpolation-pytorch/film_net_fp32.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f810cada26d0c288e50a27eac43af74446eb84b857ccbc77a22bb006f4d27240
+size 137922129
diff --git a/frame-interpolation-pytorch/fusion.py b/frame-interpolation-pytorch/fusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca79661fdea7435a118e783cb74436f477faab2c
--- /dev/null
+++ b/frame-interpolation-pytorch/fusion.py
@@ -0,0 +1,120 @@
+"""The final fusion stage for the film_net frame interpolator.
+
+The inputs to this module are the warped input images, image features and
+flow fields, all aligned to the target frame (often midway point between the
+two original inputs). The output is the final image. FILM has no explicit
+occlusion handling -- instead using the abovementioned information this module
+automatically decides how to best blend the inputs together to produce content
+in areas where the pixels can only be borrowed from one of the inputs.
+
+Similarly, this module also decides on how much to blend in each input in case
+of fractional timestep that is not at the halfway point. For example, if the two
+inputs images are at t=0 and t=1, and we were to synthesize a frame at t=0.1,
+it often makes most sense to favor the first input. However, this is not
+always the case -- in particular in occluded pixels.
+
+The architecture of the Fusion module follows U-net [1] architecture's decoder
+side, e.g. each pyramid level consists of concatenation with upsampled coarser
+level output, and two 3x3 convolutions.
+
+The upsampling is implemented as 'resize convolution', e.g. nearest neighbor
+upsampling followed by 2x2 convolution as explained in [2]. The classic U-net
+uses max-pooling which has a tendency to create checkerboard artifacts.
+
+[1] Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image
+ Segmentation, 2015, https://arxiv.org/pdf/1505.04597.pdf
+[2] https://distill.pub/2016/deconv-checkerboard/
+"""
+from typing import List
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from util import Conv2d
+
+_NUMBER_OF_COLOR_CHANNELS = 3
+
+
+def get_channels_at_level(level, filters):
+ n_images = 2
+ channels = _NUMBER_OF_COLOR_CHANNELS
+ flows = 2
+
+ return (sum(filters << i for i in range(level)) + channels + flows) * n_images
+
+
+class Fusion(nn.Module):
+ """The decoder."""
+
+ def __init__(self, n_layers=4, specialized_layers=3, filters=64):
+ """
+ Args:
+ m: specialized levels
+ """
+ super().__init__()
+
+ # The final convolution that outputs RGB:
+ self.output_conv = nn.Conv2d(filters, 3, kernel_size=1)
+
+ # Each item 'convs[i]' will contain the list of convolutions to be applied
+ # for pyramid level 'i'.
+ self.convs = nn.ModuleList()
+
+ # Create the convolutions. Roughly following the feature extractor, we
+ # double the number of filters when the resolution halves, but only up to
+ # the specialized_levels, after which we use the same number of filters on
+ # all levels.
+ #
+ # We create the convs in fine-to-coarse order, so that the array index
+ # for the convs will correspond to our normal indexing (0=finest level).
+ # in_channels: tuple = (128, 202, 256, 522, 512, 1162, 1930, 2442)
+
+ in_channels = get_channels_at_level(n_layers, filters)
+ increase = 0
+ for i in range(n_layers)[::-1]:
+ num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers)
+ convs = nn.ModuleList([
+ Conv2d(in_channels, num_filters, size=2, activation=None),
+ Conv2d(in_channels + (increase or num_filters), num_filters, size=3),
+ Conv2d(num_filters, num_filters, size=3)]
+ )
+ self.convs.append(convs)
+ in_channels = num_filters
+ increase = get_channels_at_level(i, filters) - num_filters // 2
+
+ def forward(self, pyramid: List[torch.Tensor]) -> torch.Tensor:
+ """Runs the fusion module.
+
+ Args:
+ pyramid: The input feature pyramid as list of tensors. Each tensor being
+ in (B x H x W x C) format, with finest level tensor first.
+
+ Returns:
+ A batch of RGB images.
+ Raises:
+ ValueError, if len(pyramid) != config.fusion_pyramid_levels as provided in
+ the constructor.
+ """
+
+ # As a slight difference to a conventional decoder (e.g. U-net), we don't
+ # apply any extra convolutions to the coarsest level, but just pass it
+ # to finer levels for concatenation. This choice has not been thoroughly
+ # evaluated, but is motivated by the educated guess that the fusion part
+ # probably does not need large spatial context, because at this point the
+ # features are spatially aligned by the preceding warp.
+ net = pyramid[-1]
+
+ # Loop starting from the 2nd coarsest level:
+ # for i in reversed(range(0, len(pyramid) - 1)):
+ for k, layers in enumerate(self.convs):
+ i = len(self.convs) - 1 - k
+ # Resize the tensor from coarser level to match for concatenation.
+ level_size = pyramid[i].shape[2:4]
+ net = F.interpolate(net, size=level_size, mode='nearest')
+ net = layers[0](net)
+ net = torch.cat([pyramid[i], net], dim=1)
+ net = layers[1](net)
+ net = layers[2](net)
+ net = self.output_conv(net)
+ return net
diff --git a/frame-interpolation-pytorch/inference.py b/frame-interpolation-pytorch/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..344d8bb62d81f02b60846cbf0865b28744047f33
--- /dev/null
+++ b/frame-interpolation-pytorch/inference.py
@@ -0,0 +1,105 @@
+import bisect
+import os
+from tqdm import tqdm
+import torch
+import numpy as np
+import cv2
+
+from util import load_image
+
+
+def inference(model_path, img1, img2, save_path, gpu, inter_frames, fps, half):
+ model = torch.jit.load(model_path, map_location='cpu')
+ model.eval()
+ img_batch_1, crop_region_1 = load_image(img1)
+ img_batch_2, crop_region_2 = load_image(img2)
+
+ img_batch_1 = torch.from_numpy(img_batch_1).permute(0, 3, 1, 2)
+ img_batch_2 = torch.from_numpy(img_batch_2).permute(0, 3, 1, 2)
+
+ if not half:
+ model.float()
+
+ if gpu and torch.cuda.is_available():
+ if half:
+ model = model.half()
+ else:
+ model.float()
+ model = model.cuda()
+
+ if save_path == 'img1 folder':
+ save_path = os.path.join(os.path.split(img1)[0], 'output.mp4')
+
+ results = [
+ img_batch_1,
+ img_batch_2
+ ]
+
+ idxes = [0, inter_frames + 1]
+ remains = list(range(1, inter_frames + 1))
+
+ splits = torch.linspace(0, 1, inter_frames + 2)
+
+ for _ in tqdm(range(len(remains)), 'Generating in-between frames'):
+ starts = splits[idxes[:-1]]
+ ends = splits[idxes[1:]]
+ distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
+ matrix = torch.argmin(distances).item()
+ start_i, step = np.unravel_index(matrix, distances.shape)
+ end_i = start_i + 1
+
+ x0 = results[start_i]
+ x1 = results[end_i]
+
+ if gpu and torch.cuda.is_available():
+ if half:
+ x0 = x0.half()
+ x1 = x1.half()
+ x0 = x0.cuda()
+ x1 = x1.cuda()
+
+ dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
+
+ with torch.no_grad():
+ prediction = model(x0, x1, dt)
+ insert_position = bisect.bisect_left(idxes, remains[step])
+ idxes.insert(insert_position, remains[step])
+ results.insert(insert_position, prediction.clamp(0, 1).cpu().float())
+ del remains[step]
+
+ video_folder = os.path.split(save_path)[0]
+ os.makedirs(video_folder, exist_ok=True)
+
+ y1, x1, y2, x2 = crop_region_1
+ frames = [(tensor[0] * 255).byte().flip(0).permute(1, 2, 0).numpy()[y1:y2, x1:x2].copy() for tensor in results]
+
+ w, h = frames[0].shape[1::-1]
+ fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
+ writer = cv2.VideoWriter(save_path, fourcc, fps, (w, h))
+ for frame in frames:
+ writer.write(frame)
+
+ for frame in frames[1:][::-1]:
+ writer.write(frame)
+
+ writer.release()
+
+
+if __name__ == '__main__':
+ import argparse
+
+ parser = argparse.ArgumentParser(description='Test frame interpolator model')
+
+ parser.add_argument('model_path', type=str, help='Path to the TorchScript model')
+ parser.add_argument('img1', type=str, help='Path to the first image')
+ parser.add_argument('img2', type=str, help='Path to the second image')
+
+ parser.add_argument('--save_path', type=str, default='img1 folder', help='Path to save the interpolated frames')
+ parser.add_argument('--gpu', action='store_true', help='Use GPU')
+ parser.add_argument('--fp16', action='store_true', help='Use FP16')
+ parser.add_argument('--frames', type=int, default=18, help='Number of frames to interpolate')
+ parser.add_argument('--fps', type=int, default=10, help='FPS of the output video')
+
+ args = parser.parse_args()
+
+ inference(args.model_path, args.img1, args.img2, args.save_path, args.gpu, args.frames, args.fps, args.fp16)
diff --git a/frame-interpolation-pytorch/interpolator.py b/frame-interpolation-pytorch/interpolator.py
new file mode 100644
index 0000000000000000000000000000000000000000..707f8a69af0c6783a75766fed38c1353e96d1c16
--- /dev/null
+++ b/frame-interpolation-pytorch/interpolator.py
@@ -0,0 +1,158 @@
+"""The film_net frame interpolator main model code.
+
+Basics
+======
+The film_net is an end-to-end learned neural frame interpolator implemented as
+a PyTorch model. It has the following inputs and outputs:
+
+Inputs:
+ x0: image A.
+ x1: image B.
+ time: desired sub-frame time.
+
+Outputs:
+ image: the predicted in-between image at the chosen time in range [0, 1].
+
+Additional outputs include forward and backward warped image pyramids, flow
+pyramids, etc., that can be visualized for debugging and analysis.
+
+Note that many training sets only contain triplets with ground truth at
+time=0.5. If a model has been trained with such training set, it will only work
+well for synthesizing frames at time=0.5. Such models can only generate more
+in-between frames using recursion.
+
+Architecture
+============
+The inference consists of three main stages: 1) feature extraction 2) warping
+3) fusion. On high-level, the architecture has similarities to Context-aware
+Synthesis for Video Frame Interpolation [1], but the exact architecture is
+closer to Multi-view Image Fusion [2] with some modifications for the frame
+interpolation use-case.
+
+Feature extraction stage employs the cascaded multi-scale architecture described
+in [2]. The advantage of this architecture is that coarse level flow prediction
+can be learned from finer resolution image samples. This is especially useful
+to avoid overfitting with moderately sized datasets.
+
+The warping stage uses a residual flow prediction idea that is similar to
+PWC-Net [3], Multi-view Image Fusion [2] and many others.
+
+The fusion stage is similar to U-Net's decoder where the skip connections are
+connected to warped image and feature pyramids. This is described in [2].
+
+Implementation Conventions
+====================
+Pyramids
+--------
+Throughtout the model, all image and feature pyramids are stored as python lists
+with finest level first followed by downscaled versions obtained by successively
+halving the resolution. The depths of all pyramids are determined by
+options.pyramid_levels. The only exception to this is internal to the feature
+extractor, where smaller feature pyramids are temporarily constructed with depth
+options.sub_levels.
+
+Color ranges & gamma
+--------------------
+The model code makes no assumptions on whether the images are in gamma or
+linearized space or what is the range of RGB color values. So a model can be
+trained with different choices. This does not mean that all the choices lead to
+similar results. In practice the model has been proven to work well with RGB
+scale = [0,1] with gamma-space images (i.e. not linearized).
+
+[1] Context-aware Synthesis for Video Frame Interpolation, Niklaus and Liu, 2018
+[2] Multi-view Image Fusion, Trinidad et al, 2019
+[3] PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume
+"""
+from typing import Dict, List
+
+import torch
+from torch import nn
+
+import util
+from feature_extractor import FeatureExtractor
+from fusion import Fusion
+from pyramid_flow_estimator import PyramidFlowEstimator
+
+
+class Interpolator(nn.Module):
+ def __init__(
+ self,
+ pyramid_levels=7,
+ fusion_pyramid_levels=5,
+ specialized_levels=3,
+ sub_levels=4,
+ filters=64,
+ flow_convs=(3, 3, 3, 3),
+ flow_filters=(32, 64, 128, 256),
+ ):
+ super().__init__()
+ self.pyramid_levels = pyramid_levels
+ self.fusion_pyramid_levels = fusion_pyramid_levels
+
+ self.extract = FeatureExtractor(3, filters, sub_levels)
+ self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters)
+ self.fuse = Fusion(sub_levels, specialized_levels, filters)
+
+ def shuffle_images(self, x0, x1):
+ return [
+ util.build_image_pyramid(x0, self.pyramid_levels),
+ util.build_image_pyramid(x1, self.pyramid_levels)
+ ]
+
+ def debug_forward(self, x0, x1, batch_dt) -> Dict[str, List[torch.Tensor]]:
+ image_pyramids = self.shuffle_images(x0, x1)
+
+ # Siamese feature pyramids:
+ feature_pyramids = [self.extract(image_pyramids[0]), self.extract(image_pyramids[1])]
+
+ # Predict forward flow.
+ forward_residual_flow_pyramid = self.predict_flow(feature_pyramids[0], feature_pyramids[1])
+
+ # Predict backward flow.
+ backward_residual_flow_pyramid = self.predict_flow(feature_pyramids[1], feature_pyramids[0])
+
+ # Concatenate features and images:
+
+ # Note that we keep up to 'fusion_pyramid_levels' levels as only those
+ # are used by the fusion module.
+
+ forward_flow_pyramid = util.flow_pyramid_synthesis(forward_residual_flow_pyramid)[:self.fusion_pyramid_levels]
+
+ backward_flow_pyramid = util.flow_pyramid_synthesis(backward_residual_flow_pyramid)[:self.fusion_pyramid_levels]
+
+ # We multiply the flows with t and 1-t to warp to the desired fractional time.
+ #
+ # Note: In film_net we fix time to be 0.5, and recursively invoke the interpo-
+ # lator for multi-frame interpolation. Below, we create a constant tensor of
+ # shape [B]. We use the `time` tensor to infer the batch size.
+ backward_flow = util.multiply_pyramid(backward_flow_pyramid, batch_dt)
+ forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - batch_dt)
+
+ pyramids_to_warp = [
+ util.concatenate_pyramids(image_pyramids[0][:self.fusion_pyramid_levels],
+ feature_pyramids[0][:self.fusion_pyramid_levels]),
+ util.concatenate_pyramids(image_pyramids[1][:self.fusion_pyramid_levels],
+ feature_pyramids[1][:self.fusion_pyramid_levels])
+ ]
+
+ # Warp features and images using the flow. Note that we use backward warping
+ # and backward flow is used to read from image 0 and forward flow from
+ # image 1.
+ forward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[0], backward_flow)
+ backward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[1], forward_flow)
+
+ aligned_pyramid = util.concatenate_pyramids(forward_warped_pyramid,
+ backward_warped_pyramid)
+ aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, backward_flow)
+ aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, forward_flow)
+
+ return {
+ 'image': [self.fuse(aligned_pyramid)],
+ 'forward_residual_flow_pyramid': forward_residual_flow_pyramid,
+ 'backward_residual_flow_pyramid': backward_residual_flow_pyramid,
+ 'forward_flow_pyramid': forward_flow_pyramid,
+ 'backward_flow_pyramid': backward_flow_pyramid,
+ }
+
+ def forward(self, x0, x1, batch_dt) -> torch.Tensor:
+ return self.debug_forward(x0, x1, batch_dt)['image'][0]
diff --git a/frame-interpolation-pytorch/photos/one.png b/frame-interpolation-pytorch/photos/one.png
new file mode 100644
index 0000000000000000000000000000000000000000..044b61a95f23ff2b140c4deaf94230e10db2f7e2
--- /dev/null
+++ b/frame-interpolation-pytorch/photos/one.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8bad1c97feb31a4bec60a809f808e1b0a26f55219fa991c4caa2e696bce8e81f
+size 3442971
diff --git a/frame-interpolation-pytorch/photos/output.gif b/frame-interpolation-pytorch/photos/output.gif
new file mode 100644
index 0000000000000000000000000000000000000000..423413a343e899ff721db372c58d6c3452eba47d
--- /dev/null
+++ b/frame-interpolation-pytorch/photos/output.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:81ff68882dfca2c22343d1a435de6815b7d1c9747899febf9bb429ec8746cc35
+size 2829322
diff --git a/frame-interpolation-pytorch/photos/two.png b/frame-interpolation-pytorch/photos/two.png
new file mode 100644
index 0000000000000000000000000000000000000000..c6aac8b76c7d8170987b380424facd2c3f30527f
--- /dev/null
+++ b/frame-interpolation-pytorch/photos/two.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d80058cede12e10b9d7fe49ea022d1cc4f9c28bd2a00a1c3d4830d048c55f3fa
+size 3392356
diff --git a/frame-interpolation-pytorch/pyramid_flow_estimator.py b/frame-interpolation-pytorch/pyramid_flow_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..3083690a0737a19479a199ed633155cd6ff30163
--- /dev/null
+++ b/frame-interpolation-pytorch/pyramid_flow_estimator.py
@@ -0,0 +1,149 @@
+"""PyTorch layer for estimating optical flow by a residual flow pyramid.
+
+This approach of estimating optical flow between two images can be traced back
+to [1], but is also used by later neural optical flow computation methods such
+as SpyNet [2] and PWC-Net [3].
+
+The basic idea is that the optical flow is first estimated in a coarse
+resolution, then the flow is upsampled to warp the higher resolution image and
+then a residual correction is computed and added to the estimated flow. This
+process is repeated in a pyramid on coarse to fine order to successively
+increase the resolution of both optical flow and the warped image.
+
+In here, the optical flow predictor is used as an internal component for the
+film_net frame interpolator, to warp the two input images into the inbetween,
+target frame.
+
+[1] F. Glazer, Hierarchical motion detection. PhD thesis, 1987.
+[2] A. Ranjan and M. J. Black, Optical Flow Estimation using a Spatial Pyramid
+ Network. 2016
+[3] D. Sun X. Yang, M-Y. Liu and J. Kautz, PWC-Net: CNNs for Optical Flow Using
+ Pyramid, Warping, and Cost Volume, 2017
+"""
+from typing import List
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+import util
+
+
+class FlowEstimator(nn.Module):
+ """Small-receptive field predictor for computing the flow between two images.
+
+ This is used to compute the residual flow fields in PyramidFlowEstimator.
+
+ Note that while the number of 3x3 convolutions & filters to apply is
+ configurable, two extra 1x1 convolutions are appended to extract the flow in
+ the end.
+
+ Attributes:
+ name: The name of the layer
+ num_convs: Number of 3x3 convolutions to apply
+ num_filters: Number of filters in each 3x3 convolution
+ """
+
+ def __init__(self, in_channels: int, num_convs: int, num_filters: int):
+ super(FlowEstimator, self).__init__()
+
+ self._convs = nn.ModuleList()
+ for i in range(num_convs):
+ self._convs.append(util.Conv2d(in_channels=in_channels, out_channels=num_filters, size=3))
+ in_channels = num_filters
+ self._convs.append(util.Conv2d(in_channels, num_filters // 2, size=1))
+ in_channels = num_filters // 2
+ # For the final convolution, we want no activation at all to predict the
+ # optical flow vector values. We have done extensive testing on explicitly
+ # bounding these values using sigmoid, but it turned out that having no
+ # activation gives better results.
+ self._convs.append(util.Conv2d(in_channels, 2, size=1, activation=None))
+
+ def forward(self, features_a: torch.Tensor, features_b: torch.Tensor) -> torch.Tensor:
+ """Estimates optical flow between two images.
+
+ Args:
+ features_a: per pixel feature vectors for image A (B x H x W x C)
+ features_b: per pixel feature vectors for image B (B x H x W x C)
+
+ Returns:
+ A tensor with optical flow from A to B
+ """
+ net = torch.cat([features_a, features_b], dim=1)
+ for conv in self._convs:
+ net = conv(net)
+ return net
+
+
+class PyramidFlowEstimator(nn.Module):
+ """Predicts optical flow by coarse-to-fine refinement.
+ """
+
+ def __init__(self, filters: int = 64,
+ flow_convs: tuple = (3, 3, 3, 3),
+ flow_filters: tuple = (32, 64, 128, 256)):
+ super(PyramidFlowEstimator, self).__init__()
+
+ in_channels = filters << 1
+ predictors = []
+ for i in range(len(flow_convs)):
+ predictors.append(
+ FlowEstimator(
+ in_channels=in_channels,
+ num_convs=flow_convs[i],
+ num_filters=flow_filters[i]))
+ in_channels += filters << (i + 2)
+ self._predictor = predictors[-1]
+ self._predictors = nn.ModuleList(predictors[:-1][::-1])
+
+ def forward(self, feature_pyramid_a: List[torch.Tensor],
+ feature_pyramid_b: List[torch.Tensor]) -> List[torch.Tensor]:
+ """Estimates residual flow pyramids between two image pyramids.
+
+ Each image pyramid is represented as a list of tensors in fine-to-coarse
+ order. Each individual image is represented as a tensor where each pixel is
+ a vector of image features.
+
+ util.flow_pyramid_synthesis can be used to convert the residual flow
+ pyramid returned by this method into a flow pyramid, where each level
+ encodes the flow instead of a residual correction.
+
+ Args:
+ feature_pyramid_a: image pyramid as a list in fine-to-coarse order
+ feature_pyramid_b: image pyramid as a list in fine-to-coarse order
+
+ Returns:
+ List of flow tensors, in fine-to-coarse order, each level encoding the
+ difference against the bilinearly upsampled version from the coarser
+ level. The coarsest flow tensor, e.g. the last element in the array is the
+ 'DC-term', e.g. not a residual (alternatively you can think of it being a
+ residual against zero).
+ """
+ levels = len(feature_pyramid_a)
+ v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1])
+ residuals = [v]
+ for i in range(levels - 2, len(self._predictors) - 1, -1):
+ # Upsamples the flow to match the current pyramid level. Also, scales the
+ # magnitude by two to reflect the new size.
+ level_size = feature_pyramid_a[i].shape[2:4]
+ v = F.interpolate(2 * v, size=level_size, mode='bilinear')
+ # Warp feature_pyramid_b[i] image based on the current flow estimate.
+ warped = util.warp(feature_pyramid_b[i], v)
+ # Estimate the residual flow between pyramid_a[i] and warped image:
+ v_residual = self._predictor(feature_pyramid_a[i], warped)
+ residuals.insert(0, v_residual)
+ v = v_residual + v
+
+ for k, predictor in enumerate(self._predictors):
+ i = len(self._predictors) - 1 - k
+ # Upsamples the flow to match the current pyramid level. Also, scales the
+ # magnitude by two to reflect the new size.
+ level_size = feature_pyramid_a[i].shape[2:4]
+ v = F.interpolate(2 * v, size=level_size, mode='bilinear')
+ # Warp feature_pyramid_b[i] image based on the current flow estimate.
+ warped = util.warp(feature_pyramid_b[i], v)
+ # Estimate the residual flow between pyramid_a[i] and warped image:
+ v_residual = predictor(feature_pyramid_a[i], warped)
+ residuals.insert(0, v_residual)
+ v = v_residual + v
+ return residuals
diff --git a/frame-interpolation-pytorch/requirements.txt b/frame-interpolation-pytorch/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8cf85fc758e71c0f018843b4e238b3198d89da30
--- /dev/null
+++ b/frame-interpolation-pytorch/requirements.txt
@@ -0,0 +1,3 @@
+opencv-python
+torch
+tqdm
\ No newline at end of file
diff --git a/frame-interpolation-pytorch/util.py b/frame-interpolation-pytorch/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0eec1a11fa9d885a44081f35917783acda14d626
--- /dev/null
+++ b/frame-interpolation-pytorch/util.py
@@ -0,0 +1,166 @@
+"""Various utilities used in the film_net frame interpolator model."""
+from typing import List, Optional
+
+import cv2
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+def pad_batch(batch, align):
+ height, width = batch.shape[1:3]
+ height_to_pad = (align - height % align) if height % align != 0 else 0
+ width_to_pad = (align - width % align) if width % align != 0 else 0
+
+ crop_region = [height_to_pad >> 1, width_to_pad >> 1, height + (height_to_pad >> 1), width + (width_to_pad >> 1)]
+ batch = np.pad(batch, ((0, 0), (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)),
+ (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), (0, 0)), mode='constant')
+ return batch, crop_region
+
+
+def load_image(path, align=64):
+ image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255)
+ image_batch, crop_region = pad_batch(np.expand_dims(image, axis=0), align)
+ return image_batch, crop_region
+
+
+def build_image_pyramid(image: torch.Tensor, pyramid_levels: int = 3) -> List[torch.Tensor]:
+ """Builds an image pyramid from a given image.
+
+ The original image is included in the pyramid and the rest are generated by
+ successively halving the resolution.
+
+ Args:
+ image: the input image.
+ options: film_net options object
+
+ Returns:
+ A list of images starting from the finest with options.pyramid_levels items
+ """
+
+ pyramid = []
+ for i in range(pyramid_levels):
+ pyramid.append(image)
+ if i < pyramid_levels - 1:
+ image = F.avg_pool2d(image, 2, 2)
+ return pyramid
+
+
+def warp(image: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
+ """Backward warps the image using the given flow.
+
+ Specifically, the output pixel in batch b, at position x, y will be computed
+ as follows:
+ (flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0])
+ output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x)
+
+ Note that the flow vectors are expected as [x, y], e.g. x in position 0 and
+ y in position 1.
+
+ Args:
+ image: An image with shape BxHxWxC.
+ flow: A flow with shape BxHxWx2, with the two channels denoting the relative
+ offset in order: (dx, dy).
+ Returns:
+ A warped image.
+ """
+ flow = -flow.flip(1)
+
+ dtype = flow.dtype
+ device = flow.device
+
+ # warped = tfa_image.dense_image_warp(image, flow)
+ # Same as above but with pytorch
+ ls1 = 1 - 1 / flow.shape[3]
+ ls2 = 1 - 1 / flow.shape[2]
+
+ normalized_flow2 = flow.permute(0, 2, 3, 1) / torch.tensor(
+ [flow.shape[2] * .5, flow.shape[3] * .5], dtype=dtype, device=device)[None, None, None]
+ normalized_flow2 = torch.stack([
+ torch.linspace(-ls1, ls1, flow.shape[3], dtype=dtype, device=device)[None, None, :] - normalized_flow2[..., 1],
+ torch.linspace(-ls2, ls2, flow.shape[2], dtype=dtype, device=device)[None, :, None] - normalized_flow2[..., 0],
+ ], dim=3)
+
+ warped = F.grid_sample(image, normalized_flow2,
+ mode='bilinear', padding_mode='border', align_corners=False)
+ return warped.reshape(image.shape)
+
+
+def multiply_pyramid(pyramid: List[torch.Tensor],
+ scalar: torch.Tensor) -> List[torch.Tensor]:
+ """Multiplies all image batches in the pyramid by a batch of scalars.
+
+ Args:
+ pyramid: Pyramid of image batches.
+ scalar: Batch of scalars.
+
+ Returns:
+ An image pyramid with all images multiplied by the scalar.
+ """
+ # To multiply each image with its corresponding scalar, we first transpose
+ # the batch of images from BxHxWxC-format to CxHxWxB. This can then be
+ # multiplied with a batch of scalars, then we transpose back to the standard
+ # BxHxWxC form.
+ return [image * scalar[..., None, None] for image in pyramid]
+
+
+def flow_pyramid_synthesis(
+ residual_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
+ """Converts a residual flow pyramid into a flow pyramid."""
+ flow = residual_pyramid[-1]
+ flow_pyramid: List[torch.Tensor] = [flow]
+ for residual_flow in residual_pyramid[:-1][::-1]:
+ level_size = residual_flow.shape[2:4]
+ flow = F.interpolate(2 * flow, size=level_size, mode='bilinear')
+ flow = residual_flow + flow
+ flow_pyramid.insert(0, flow)
+ return flow_pyramid
+
+
+def pyramid_warp(feature_pyramid: List[torch.Tensor],
+ flow_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
+ """Warps the feature pyramid using the flow pyramid.
+
+ Args:
+ feature_pyramid: feature pyramid starting from the finest level.
+ flow_pyramid: flow fields, starting from the finest level.
+
+ Returns:
+ Reverse warped feature pyramid.
+ """
+ warped_feature_pyramid = []
+ for features, flow in zip(feature_pyramid, flow_pyramid):
+ warped_feature_pyramid.append(warp(features, flow))
+ return warped_feature_pyramid
+
+
+def concatenate_pyramids(pyramid1: List[torch.Tensor],
+ pyramid2: List[torch.Tensor]) -> List[torch.Tensor]:
+ """Concatenates each pyramid level together in the channel dimension."""
+ result = []
+ for features1, features2 in zip(pyramid1, pyramid2):
+ result.append(torch.cat([features1, features2], dim=1))
+ return result
+
+
+class Conv2d(nn.Sequential):
+ def __init__(self, in_channels, out_channels, size, activation: Optional[str] = 'relu'):
+ assert activation in (None, 'relu')
+ super().__init__(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=size,
+ padding='same' if size % 2 else 0)
+ )
+ self.size = size
+ self.activation = nn.LeakyReLU(.2) if activation == 'relu' else None
+
+ def forward(self, x):
+ if not self.size % 2:
+ x = F.pad(x, (0, 1, 0, 1))
+ y = self[0](x)
+ if self.activation is not None:
+ y = self.activation(y)
+ return y
diff --git a/models/.DS_Store b/models/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6
Binary files /dev/null and b/models/.DS_Store differ
diff --git a/models/jointembedding_high_env0.py b/models/jointembedding_high_env0.py
new file mode 100644
index 0000000000000000000000000000000000000000..044f144fcf1318ffa4eeee1fb3c25dec3f38767d
--- /dev/null
+++ b/models/jointembedding_high_env0.py
@@ -0,0 +1,483 @@
+import copy
+import math
+import pickle
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import difflib
+from typing import Optional, Tuple, Union
+
+from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, BertTokenizer, BertModel, Wav2Vec2Model, Wav2Vec2Config
+from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2FeatureEncoder
+from .motion_encoder import VQEncoderV6
+
+
+def audio_to_time_aligned_text_features(inputs, processor, model, tokenizer, bert_model):
+ with torch.no_grad():
+ logits = model(inputs.input_values).logits # shape: (1, time_steps, vocab_size)
+
+ predicted_ids_per_timestep = torch.argmax(logits, dim=-1) # shape: (1, time_steps)
+ predicted_ids_per_timestep = predicted_ids_per_timestep[0].cpu().numpy()
+ vocab = processor.tokenizer.get_vocab()
+ id_to_token = {v: k for k, v in vocab.items()}
+ tokens_per_timestep = [id_to_token[id] for id in predicted_ids_per_timestep]
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ transcription = processor.decode(predicted_ids[0])
+ inputs_bert = tokenizer(transcription, return_tensors='pt')
+ input_ids = inputs_bert['input_ids'][0]
+ tokens_bert = tokenizer.convert_ids_to_tokens(input_ids)
+
+ with torch.no_grad():
+ outputs_bert = bert_model(**inputs_bert.to(inputs.input_values.device))
+ all_token_embeddings = outputs_bert.last_hidden_state[0]
+ per_timestep_chars = []
+ per_timestep_char_indices = []
+ for idx, t in enumerate(tokens_per_timestep):
+ if t not in ('