diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..14d0651
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 Princeton Vision & Learning Lab
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/RAFTStereo.png b/RAFTStereo.png
new file mode 100644
index 0000000..c382617
Binary files /dev/null and b/RAFTStereo.png differ
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..9a5c864
--- /dev/null
+++ b/README.md
@@ -0,0 +1,113 @@
+# RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching
+This repository contains the source code for our paper:
+
+[RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching](https://www.google.com)
+Lahav Lipson, Zachary Teed and Jia Deng
+
+
+
+## Requirements
+The code has been tested with PyTorch 1.7 and Cuda 10.2.
+```Shell
+conda env create -f environment.yaml
+conda activate raftstereo
+```
+
+
+
+
+## Required Data
+To evaluate/train RAFT, you will need to download the required datasets.
+* [Sceneflow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html#:~:text=on%20Academic%20Torrents-,FlyingThings3D,-Driving) (Includes FlyingThings3D, Driving & Monkaa
+* [Middlebury](https://vision.middlebury.edu/stereo/data/)
+* [ETH3D](https://www.eth3d.net/datasets#low-res-two-view-test-data)
+* [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo)
+
+To download the ETH3D and Middlebury test datasets for the [demos](#demos), run
+```Shell
+chmod ug+x download_datasets.sh && ./download_datasets.sh
+```
+
+By default `stereo_datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder
+
+```Shell
+├── datasets
+ ├── FlyingThings3D
+ ├── frames_cleanpass
+ ├── frames_finalpass
+ ├── disparity
+ ├── Monkaa
+ ├── frames_cleanpass
+ ├── frames_finalpass
+ ├── disparity
+ ├── Driving
+ ├── frames_cleanpass
+ ├── frames_finalpass
+ ├── disparity
+ ├── KITTI
+ ├── testing
+ ├── training
+ ├── devkit
+ ├── Middlebury
+ ├── MiddEval3
+ ├── ETH3D
+ ├── lakeside_1l
+ ├── ...
+ ├── tunnel_3s
+```
+
+## Demos
+Pretrained models can be downloaded by running
+```Shell
+chmod ug+x download_models.sh && ./download_models.sh
+```
+or downloaded from [google drive](https://drive.google.com/drive/folders/1booUFYEXmsdombVuglatP0nZXb5qI89J)
+
+You can demo a trained model on pairs of images. To predict stereo for Middlebury, run
+```Shell
+python demo.py --restore_ckpt models/raftstereo-sceneflow.pth
+```
+Or for ETH3D:
+```Shell
+python demo.py --restore_ckpt models/raftstereo-eth3d.pth -l=datasets/ETH3D/*/im0.png -r=datasets/ETH3D/*/im1.png
+```
+Using our fastest model:
+```Shell
+python demo.py --restore_ckpt models/raftstereo-realtime.pth --shared_backbone --n_downsample 3 --n_gru_layers 2 --slow_fast_gru
+```
+
+To save the disparity values as `.npy` files, run any of the demos with the `--save_numpy` flag.
+
+## Converting Disparity to Depth
+
+If the camera focal length and camera baseline are known, disparity predictions can be converted to depth values using
+
+
+
+Note that the units of the focal length are _pixels_ not millimeters.
+
+## Evaluation
+
+To evaluate a trained model on a validation set (e.g. Middlebury), run
+```Shell
+python evaluate_stereo.py --restore_ckpt models/raftstereo-middlebury.pth --dataset middlebury_H
+```
+
+## Training
+
+Our model is trained on two RTX-6000 GPUs using the following command. Training logs will be written to `runs/` which can be visualized using tensorboard.
+
+```Shell
+python train_stereo.py --batch_size 8 --train_iters 22 --valid_iters 32 --spatial_scale -0.2 0.4 --saturation_range 0 1.4 --n_downsample 2 --num_steps 200000 --mixed_precision
+```
+To train using significantly less memory, change `--n_downsample 2` to `--n_downsample 3`. This will slightly reduce accuracy.
+
+## (Optional) Faster Implementation
+
+We provide a faster CUDA implementation of the correlation volume which works with mixed precision feature maps.
+```Shell
+cd sampler && python setup.py install && cd ..
+```
+Running demo.py, train_stereo.py or evaluate.py with `--corr_implementation reg_cuda` together with `--mixed_precision` will speed up the model without impacting performance.
+
+To significantly decrease memory consumption on high resolution images, use `--corr_implementation alt`. This implementation is slower than the default, however.
diff --git a/core/__init__.py b/core/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/core/corr.py b/core/corr.py
new file mode 100644
index 0000000..4ac5d2e
--- /dev/null
+++ b/core/corr.py
@@ -0,0 +1,188 @@
+import torch
+import torch.nn.functional as F
+from core.utils.utils import bilinear_sampler
+
+try:
+ import corr_sampler
+except:
+ pass
+
+try:
+ import alt_cuda_corr
+except:
+ # alt_cuda_corr is not compiled
+ pass
+
+
+class CorrSampler(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, volume, coords, radius):
+ ctx.save_for_backward(volume,coords)
+ ctx.radius = radius
+ corr, = corr_sampler.forward(volume, coords, radius)
+ return corr
+ @staticmethod
+ def backward(ctx, grad_output):
+ volume, coords = ctx.saved_tensors
+ grad_output = grad_output.contiguous()
+ grad_volume, = corr_sampler.backward(volume, coords, grad_output, ctx.radius)
+ return grad_volume, None, None
+
+class CorrBlockFast1D:
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
+ self.num_levels = num_levels
+ self.radius = radius
+ self.corr_pyramid = []
+ # all pairs correlation
+ corr = CorrBlockFast1D.corr(fmap1, fmap2)
+ batch, h1, w1, dim, w2 = corr.shape
+ corr = corr.reshape(batch*h1*w1, dim, 1, w2)
+ for i in range(self.num_levels):
+ self.corr_pyramid.append(corr.view(batch, h1, w1, -1, w2//2**i))
+ corr = F.avg_pool2d(corr, [1,2], stride=[1,2])
+
+ def __call__(self, coords):
+ out_pyramid = []
+ bz, _, ht, wd = coords.shape
+ coords = coords[:, [0]]
+ for i in range(self.num_levels):
+ corr = CorrSampler.apply(self.corr_pyramid[i].squeeze(3), coords/2**i, self.radius)
+ out_pyramid.append(corr.view(bz, -1, ht, wd))
+ return torch.cat(out_pyramid, dim=1)
+
+ @staticmethod
+ def corr(fmap1, fmap2):
+ B, D, H, W1 = fmap1.shape
+ _, _, _, W2 = fmap2.shape
+ fmap1 = fmap1.view(B, D, H, W1)
+ fmap2 = fmap2.view(B, D, H, W2)
+ corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2)
+ corr = corr.reshape(B, H, W1, 1, W2).contiguous()
+ return corr / torch.sqrt(torch.tensor(D).float())
+
+
+class PytorchAlternateCorrBlock1D:
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
+ self.num_levels = num_levels
+ self.radius = radius
+ self.corr_pyramid = []
+ self.fmap1 = fmap1
+ self.fmap2 = fmap2
+
+ def corr(self, fmap1, fmap2, coords):
+ B, D, H, W = fmap2.shape
+ # map grid coordinates to [-1,1]
+ xgrid, ygrid = coords.split([1,1], dim=-1)
+ xgrid = 2*xgrid/(W-1) - 1
+ ygrid = 2*ygrid/(H-1) - 1
+
+ grid = torch.cat([xgrid, ygrid], dim=-1)
+ output_corr = []
+ for grid_slice in grid.unbind(3):
+ fmapw_mini = F.grid_sample(fmap2, grid_slice, align_corners=True)
+ corr = torch.sum(fmapw_mini * fmap1, dim=1)
+ output_corr.append(corr)
+ corr = torch.stack(output_corr, dim=1).permute(0,2,3,1)
+
+ return corr / torch.sqrt(torch.tensor(D).float())
+
+ def __call__(self, coords):
+ r = self.radius
+ coords = coords.permute(0, 2, 3, 1)
+ batch, h1, w1, _ = coords.shape
+ fmap1 = self.fmap1
+ fmap2 = self.fmap2
+ out_pyramid = []
+ for i in range(self.num_levels):
+ dx = torch.zeros(1)
+ dy = torch.linspace(-r, r, 2*r+1)
+ delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
+ centroid_lvl = coords.reshape(batch, h1, w1, 1, 2).clone()
+ centroid_lvl[...,0] = centroid_lvl[...,0] / 2**i
+ coords_lvl = centroid_lvl + delta.view(-1, 2)
+ corr = self.corr(fmap1, fmap2, coords_lvl)
+ fmap2 = F.avg_pool2d(fmap2, [1, 2], stride=[1, 2])
+ out_pyramid.append(corr)
+ out = torch.cat(out_pyramid, dim=-1)
+ return out.permute(0, 3, 1, 2).contiguous().float()
+
+
+class CorrBlock1D:
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
+ self.num_levels = num_levels
+ self.radius = radius
+ self.corr_pyramid = []
+
+ # all pairs correlation
+ corr = CorrBlock1D.corr(fmap1, fmap2)
+
+ batch, h1, w1, dim, w2 = corr.shape
+ corr = corr.reshape(batch*h1*w1, dim, 1, w2)
+
+ self.corr_pyramid.append(corr)
+ for i in range(self.num_levels):
+ corr = F.avg_pool2d(corr, [1,2], stride=[1,2])
+ self.corr_pyramid.append(corr)
+
+ def __call__(self, coords):
+ r = self.radius
+ coords = coords[:, :1].permute(0, 2, 3, 1)
+ batch, h1, w1, _ = coords.shape
+
+ out_pyramid = []
+ for i in range(self.num_levels):
+ corr = self.corr_pyramid[i]
+ dx = torch.linspace(-r, r, 2*r+1)
+ dx = dx.view(1, 1, 2*r+1, 1).to(coords.device)
+ x0 = dx + coords.reshape(batch*h1*w1, 1, 1, 1) / 2**i
+ y0 = torch.zeros_like(x0)
+
+ coords_lvl = torch.cat([x0,y0], dim=-1)
+ corr = bilinear_sampler(corr, coords_lvl)
+ corr = corr.view(batch, h1, w1, -1)
+ out_pyramid.append(corr)
+
+ out = torch.cat(out_pyramid, dim=-1)
+ return out.permute(0, 3, 1, 2).contiguous().float()
+
+ @staticmethod
+ def corr(fmap1, fmap2):
+ B, D, H, W1 = fmap1.shape
+ _, _, _, W2 = fmap2.shape
+ fmap1 = fmap1.view(B, D, H, W1)
+ fmap2 = fmap2.view(B, D, H, W2)
+ corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2)
+ corr = corr.reshape(B, H, W1, 1, W2).contiguous()
+ return corr / torch.sqrt(torch.tensor(D).float())
+
+
+class AlternateCorrBlock:
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
+ raise NotImplementedError
+ self.num_levels = num_levels
+ self.radius = radius
+
+ self.pyramid = [(fmap1, fmap2)]
+ for i in range(self.num_levels):
+ fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
+ fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
+ self.pyramid.append((fmap1, fmap2))
+
+ def __call__(self, coords):
+ coords = coords.permute(0, 2, 3, 1)
+ B, H, W, _ = coords.shape
+ dim = self.pyramid[0][0].shape[1]
+
+ corr_list = []
+ for i in range(self.num_levels):
+ r = self.radius
+ fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
+ fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
+
+ coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
+ corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
+ corr_list.append(corr.squeeze(1))
+
+ corr = torch.stack(corr_list, dim=1)
+ corr = corr.reshape(B, -1, H, W)
+ return corr / torch.sqrt(torch.tensor(dim).float())
diff --git a/core/extractor.py b/core/extractor.py
new file mode 100644
index 0000000..edd71e2
--- /dev/null
+++ b/core/extractor.py
@@ -0,0 +1,300 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.Sequential()
+
+ if stride == 1 and in_planes == planes:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
+
+
+ def forward(self, x):
+ y = x
+ y = self.conv1(y)
+ y = self.norm1(y)
+ y = self.relu(y)
+ y = self.conv2(y)
+ y = self.norm2(y)
+ y = self.relu(y)
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x+y)
+
+
+
+class BottleneckBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
+ super(BottleneckBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
+ self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
+ self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(planes//4)
+ self.norm2 = nn.BatchNorm2d(planes//4)
+ self.norm3 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm4 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(planes//4)
+ self.norm2 = nn.InstanceNorm2d(planes//4)
+ self.norm3 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm4 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ self.norm3 = nn.Sequential()
+ if not stride == 1:
+ self.norm4 = nn.Sequential()
+
+ if stride == 1:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
+
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+ y = self.relu(self.norm3(self.conv3(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x+y)
+
+class BasicEncoder(nn.Module):
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, downsample=3):
+ super(BasicEncoder, self).__init__()
+ self.norm_fn = norm_fn
+ self.downsample = downsample
+
+ if self.norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
+
+ elif self.norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(64)
+
+ elif self.norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(64)
+
+ elif self.norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.in_planes = 64
+ self.layer1 = self._make_layer(64, stride=1)
+ self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))
+ self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))
+
+ # output convolution
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
+
+ self.dropout = None
+ if dropout > 0:
+ self.dropout = nn.Dropout2d(p=dropout)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1):
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+
+ def forward(self, x, dual_inp=False):
+
+ # if input is list, combine batch dimension
+ is_list = isinstance(x, tuple) or isinstance(x, list)
+ if is_list:
+ batch_dim = x[0].shape[0]
+ x = torch.cat(x, dim=0)
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+
+ x = self.conv2(x)
+
+ if self.training and self.dropout is not None:
+ x = self.dropout(x)
+
+ if is_list:
+ x = x.split(split_size=batch_dim, dim=0)
+
+ return x
+
+class MultiBasicEncoder(nn.Module):
+ def __init__(self, output_dim=[128], norm_fn='batch', dropout=0.0, downsample=3):
+ super(MultiBasicEncoder, self).__init__()
+ self.norm_fn = norm_fn
+ self.downsample = downsample
+
+ if self.norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
+
+ elif self.norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(64)
+
+ elif self.norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(64)
+
+ elif self.norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3)
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.in_planes = 64
+ self.layer1 = self._make_layer(64, stride=1)
+ self.layer2 = self._make_layer(96, stride=1 + (downsample > 1))
+ self.layer3 = self._make_layer(128, stride=1 + (downsample > 0))
+ self.layer4 = self._make_layer(128, stride=2)
+ self.layer5 = self._make_layer(128, stride=2)
+
+ output_list = []
+ for dim in output_dim:
+ conv_out = nn.Sequential(
+ ResidualBlock(128, 128, self.norm_fn, stride=1),
+ nn.Conv2d(128, dim[2], 3, padding=1))
+ output_list.append(conv_out)
+
+ self.outputs08 = nn.ModuleList(output_list)
+
+ output_list = []
+ for dim in output_dim:
+ conv_out = nn.Sequential(
+ ResidualBlock(128, 128, self.norm_fn, stride=1),
+ nn.Conv2d(128, dim[1], 3, padding=1))
+ output_list.append(conv_out)
+
+ self.outputs16 = nn.ModuleList(output_list)
+
+ output_list = []
+ for dim in output_dim:
+ conv_out = nn.Conv2d(128, dim[0], 3, padding=1)
+ output_list.append(conv_out)
+
+ self.outputs32 = nn.ModuleList(output_list)
+
+ if dropout > 0:
+ self.dropout = nn.Dropout2d(p=dropout)
+ else:
+ self.dropout = None
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1):
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x, dual_inp=False, num_layers=3):
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ if dual_inp:
+ v = x
+ x = x[:(x.shape[0]//2)]
+
+ outputs08 = [f(x) for f in self.outputs08]
+ if num_layers == 1:
+ return (outputs08, v) if dual_inp else (outputs08,)
+
+ y = self.layer4(x)
+ outputs16 = [f(y) for f in self.outputs16]
+
+ if num_layers == 2:
+ return (outputs08, outputs16, v) if dual_inp else (outputs08, outputs16)
+
+ z = self.layer5(y)
+ outputs32 = [f(z) for f in self.outputs32]
+
+ return (outputs08, outputs16, outputs32, v) if dual_inp else (outputs08, outputs16, outputs32)
diff --git a/core/raft_stereo.py b/core/raft_stereo.py
new file mode 100644
index 0000000..5f796cc
--- /dev/null
+++ b/core/raft_stereo.py
@@ -0,0 +1,141 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from core.update import BasicMultiUpdateBlock
+from core.extractor import BasicEncoder, MultiBasicEncoder, ResidualBlock
+from core.corr import CorrBlock1D, PytorchAlternateCorrBlock1D, CorrBlockFast1D, AlternateCorrBlock
+from core.utils.utils import coords_grid, upflow8
+
+
+try:
+ autocast = torch.cuda.amp.autocast
+except:
+ # dummy autocast for PyTorch < 1.6
+ class autocast:
+ def __init__(self, enabled):
+ pass
+ def __enter__(self):
+ pass
+ def __exit__(self, *args):
+ pass
+
+class RAFTStereo(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+ self.args = args
+
+ context_dims = args.hidden_dims
+
+ self.cnet = MultiBasicEncoder(output_dim=[args.hidden_dims, context_dims], norm_fn="batch", downsample=args.n_downsample)
+ self.update_block = BasicMultiUpdateBlock(self.args, hidden_dims=args.hidden_dims)
+
+ self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])
+
+ if args.shared_backbone:
+ self.conv2 = nn.Sequential(
+ ResidualBlock(128, 128, 'instance', stride=1),
+ nn.Conv2d(128, 256, 3, padding=1))
+ else:
+ self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', downsample=args.n_downsample)
+
+ def freeze_bn(self):
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+
+ def initialize_flow(self, img):
+ """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
+ N, _, H, W = img.shape
+
+ coords0 = coords_grid(N, H, W).to(img.device)
+ coords1 = coords_grid(N, H, W).to(img.device)
+
+ return coords0, coords1
+
+ def upsample_flow(self, flow, mask):
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
+ N, D, H, W = flow.shape
+ factor = 2 ** self.args.n_downsample
+ mask = mask.view(N, 1, 9, factor, factor, H, W)
+ mask = torch.softmax(mask, dim=2)
+
+ up_flow = F.unfold(factor * flow, [3,3], padding=1)
+ up_flow = up_flow.view(N, D, 9, 1, 1, H, W)
+
+ up_flow = torch.sum(mask * up_flow, dim=2)
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
+ return up_flow.reshape(N, D, factor*H, factor*W)
+
+
+ def forward(self, image1, image2, iters=12, flow_init=None, test_mode=False):
+ """ Estimate optical flow between pair of frames """
+
+ image1 = (2 * (image1 / 255.0) - 1.0).contiguous()
+ image2 = (2 * (image2 / 255.0) - 1.0).contiguous()
+
+ # run the context network
+ with autocast(enabled=self.args.mixed_precision):
+ if self.args.shared_backbone:
+ *cnet_list, x = self.cnet(torch.cat((image1, image2), dim=0), dual_inp=True, num_layers=self.args.n_gru_layers)
+ fmap1, fmap2 = self.conv2(x).split(dim=0, split_size=x.shape[0]//2)
+ else:
+ cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
+ fmap1, fmap2 = self.fnet([image1, image2])
+ net_list = [torch.tanh(x[0]) for x in cnet_list]
+ inp_list = [torch.relu(x[1]) for x in cnet_list]
+
+ # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning
+ inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)]
+
+ if self.args.corr_implementation == "reg": # Default
+ corr_block = CorrBlock1D
+ fmap1, fmap2 = fmap1.float(), fmap2.float()
+ elif self.args.corr_implementation == "alt": # More memory efficient than reg
+ corr_block = PytorchAlternateCorrBlock1D
+ fmap1, fmap2 = fmap1.float(), fmap2.float()
+ elif self.args.corr_implementation == "reg_cuda": # Faster version of reg
+ corr_block = CorrBlockFast1D
+ elif self.args.corr_implementation == "alt_cuda": # Faster version of alt
+ corr_block = AlternateCorrBlock
+ corr_fn = corr_block(fmap1, fmap2, radius=self.args.corr_radius, num_levels=self.args.corr_levels)
+
+ coords0, coords1 = self.initialize_flow(net_list[0])
+
+ if flow_init is not None:
+ coords1 = coords1 + flow_init
+
+ flow_predictions = []
+ for itr in range(iters):
+ coords1 = coords1.detach()
+ corr = corr_fn(coords1) # index correlation volume
+ flow = coords1 - coords0
+ with autocast(enabled=self.args.mixed_precision):
+ if self.args.n_gru_layers == 3 and self.args.slow_fast_gru: # Update low-res GRU
+ net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False)
+ if self.args.n_gru_layers >= 2 and self.args.slow_fast_gru:# Update low-res GRU and mid-res GRU
+ net_list = self.update_block(net_list, inp_list, iter32=self.args.n_gru_layers==3, iter16=True, iter08=False, update=False)
+ net_list, up_mask, delta_flow = self.update_block(net_list, inp_list, corr, flow, iter32=self.args.n_gru_layers==3, iter16=self.args.n_gru_layers>=2)
+
+ # in stereo mode, project flow onto epipolar
+ delta_flow[:,1] = 0.0
+
+ # F(t+1) = F(t) + \Delta(t)
+ coords1 = coords1 + delta_flow
+
+ # We do not need to upsample or output intermediate results in test_mode
+ if test_mode and itr < iters-1:
+ continue
+
+ # upsample predictions
+ if up_mask is None:
+ flow_up = upflow8(coords1 - coords0)
+ else:
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
+ flow_up = flow_up[:,:1]
+
+ flow_predictions.append(flow_up)
+
+ if test_mode:
+ return coords1 - coords0, flow_up
+
+ return flow_predictions
diff --git a/core/stereo_datasets.py b/core/stereo_datasets.py
new file mode 100644
index 0000000..4475e4c
--- /dev/null
+++ b/core/stereo_datasets.py
@@ -0,0 +1,314 @@
+# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
+
+import numpy as np
+import torch
+import torch.utils.data as data
+import torch.nn.functional as F
+import logging
+import os
+import re
+import copy
+import math
+import random
+from pathlib import Path
+from glob import glob
+import os.path as osp
+
+from core.utils import frame_utils
+from core.utils.augmentor import FlowAugmentor, SparseFlowAugmentor
+
+
+class StereoDataset(data.Dataset):
+ def __init__(self, aug_params=None, sparse=False, reader=None):
+ self.augmentor = None
+ self.sparse = sparse
+ self.img_pad = aug_params.pop("img_pad", None) if aug_params is not None else None
+ if aug_params is not None and "crop_size" in aug_params:
+ if sparse:
+ self.augmentor = SparseFlowAugmentor(**aug_params)
+ else:
+ self.augmentor = FlowAugmentor(**aug_params)
+
+ if reader is None:
+ self.disparity_reader = frame_utils.read_gen
+ else:
+ self.disparity_reader = reader
+
+ self.is_test = False
+ self.init_seed = False
+ self.flow_list = []
+ self.disparity_list = []
+ self.image_list = []
+ self.extra_info = []
+
+ def __getitem__(self, index):
+
+ if self.is_test:
+ img1 = frame_utils.read_gen(self.image_list[index][0])
+ img2 = frame_utils.read_gen(self.image_list[index][1])
+ img1 = np.array(img1).astype(np.uint8)[..., :3]
+ img2 = np.array(img2).astype(np.uint8)[..., :3]
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
+ return img1, img2, self.extra_info[index]
+
+ if not self.init_seed:
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info is not None:
+ torch.manual_seed(worker_info.id)
+ np.random.seed(worker_info.id)
+ random.seed(worker_info.id)
+ self.init_seed = True
+
+ index = index % len(self.image_list)
+ disp = self.disparity_reader(self.disparity_list[index])
+ if isinstance(disp, tuple):
+ disp, valid = disp
+ else:
+ valid = disp < 512
+
+ img1 = frame_utils.read_gen(self.image_list[index][0])
+ img2 = frame_utils.read_gen(self.image_list[index][1])
+
+ img1 = np.array(img1).astype(np.uint8)
+ img2 = np.array(img2).astype(np.uint8)
+
+ disp = np.array(disp).astype(np.float32)
+ flow = np.stack([-disp, np.zeros_like(disp)], axis=-1)
+
+ # grayscale images
+ if len(img1.shape) == 2:
+ img1 = np.tile(img1[...,None], (1, 1, 3))
+ img2 = np.tile(img2[...,None], (1, 1, 3))
+ else:
+ img1 = img1[..., :3]
+ img2 = img2[..., :3]
+
+ if self.augmentor is not None:
+ if self.sparse:
+ img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
+ else:
+ img1, img2, flow = self.augmentor(img1, img2, flow)
+
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
+ flow = torch.from_numpy(flow).permute(2, 0, 1).float()
+
+ if self.sparse:
+ valid = torch.from_numpy(valid)
+ else:
+ valid = (flow[0].abs() < 512) & (flow[1].abs() < 512)
+
+ if self.img_pad is not None:
+ padH, padW = self.img_pad
+ img1 = F.pad(img1, [padW]*2 + [padH]*2)
+ img2 = F.pad(img2, [padW]*2 + [padH]*2)
+
+ flow = flow[:1]
+ return self.image_list[index] + [self.disparity_list[index]], img1, img2, flow, valid.float()
+
+
+ def __mul__(self, v):
+ copy_of_self = copy.deepcopy(self)
+ copy_of_self.flow_list = v * copy_of_self.flow_list
+ copy_of_self.image_list = v * copy_of_self.image_list
+ copy_of_self.disparity_list = v * copy_of_self.disparity_list
+ copy_of_self.extra_info = v * copy_of_self.extra_info
+ return copy_of_self
+
+ def __len__(self):
+ return len(self.image_list)
+
+
+class SceneFlowDatasets(StereoDataset):
+ def __init__(self, aug_params=None, root='datasets', dstype='frames_cleanpass', things_test=False):
+ super(SceneFlowDatasets, self).__init__(aug_params)
+ self.root = root
+ self.dstype = dstype
+
+ if things_test:
+ self._add_things("TEST")
+ else:
+ self._add_things("TRAIN")
+ self._add_monkaa()
+ self._add_driving()
+
+ def _add_things(self, split='TRAIN'):
+ """ Add FlyingThings3D data """
+
+ original_length = len(self.disparity_list)
+ root = osp.join(self.root, 'FlyingThings3D')
+ left_images = sorted( glob(osp.join(root, self.dstype, split, '*/*/left/*.png')) )
+ right_images = [ im.replace('left', 'right') for im in left_images ]
+ disparity_images = [ im.replace(self.dstype, 'disparity').replace('.png', '.pfm') for im in left_images ]
+
+ with open(osp.join('datasets', 'flyingthings_validation.txt')) as f:
+ validation_files = set(f.read().splitlines())
+
+ for img1, img2, disp in zip(left_images, right_images, disparity_images):
+ if split == 'TEST' and disp not in validation_files:
+ continue
+ self.image_list += [ [img1, img2] ]
+ self.disparity_list += [ disp ]
+ logging.info(f"Added {len(self.disparity_list) - original_length} from FlyingThings {self.dstype}")
+
+ def _add_monkaa(self):
+ """ Add FlyingThings3D data """
+
+ original_length = len(self.disparity_list)
+ root = osp.join(self.root, 'Monkaa')
+ left_images = sorted( glob(osp.join(root, self.dstype, '*/left/*.png')) )
+ right_images = [ image_file.replace('left', 'right') for image_file in left_images ]
+ disparity_images = [ im.replace(self.dstype, 'disparity').replace('.png', '.pfm') for im in left_images ]
+
+ for img1, img2, disp in zip(left_images, right_images, disparity_images):
+ self.image_list += [ [img1, img2] ]
+ self.disparity_list += [ disp ]
+ logging.info(f"Added {len(self.disparity_list) - original_length} from Monkaa {self.dstype}")
+
+
+ def _add_driving(self):
+ """ Add FlyingThings3D data """
+
+ original_length = len(self.disparity_list)
+ root = osp.join(self.root, 'Driving')
+ left_images = sorted( glob(osp.join(root, self.dstype, '*/*/*/left/*.png')) )
+ right_images = [ image_file.replace('left', 'right') for image_file in left_images ]
+ disparity_images = [ im.replace(self.dstype, 'disparity').replace('.png', '.pfm') for im in left_images ]
+
+ for img1, img2, disp in zip(left_images, right_images, disparity_images):
+ self.image_list += [ [img1, img2] ]
+ self.disparity_list += [ disp ]
+ logging.info(f"Added {len(self.disparity_list) - original_length} from Driving {self.dstype}")
+
+
+class ETH3D(StereoDataset):
+ def __init__(self, aug_params=None, root='datasets/ETH3D', split='training'):
+ super(ETH3D, self).__init__(aug_params, sparse=True)
+
+ image1_list = sorted( glob(osp.join(root, f'two_view_{split}/*/im0.png')) )
+ image2_list = sorted( glob(osp.join(root, f'two_view_{split}/*/im1.png')) )
+ disp_list = sorted( glob(osp.join(root, 'two_view_training_gt/*/disp0GT.pfm')) ) if split == 'training' else [osp.join(root, 'two_view_training_gt/playground_1l/disp0GT.pfm')]*len(image1_list)
+
+ for img1, img2, disp in zip(image1_list, image2_list, disp_list):
+ self.image_list += [ [img1, img2] ]
+ self.disparity_list += [ disp ]
+
+class SintelStereo(StereoDataset):
+ def __init__(self, aug_params=None, root='datasets/SintelStereo'):
+ super().__init__(aug_params, sparse=True, reader=frame_utils.readDispSintelStereo)
+
+ image1_list = sorted( glob(osp.join(root, 'training/*_left/*/frame_*.png')) )
+ image2_list = sorted( glob(osp.join(root, 'training/*_right/*/frame_*.png')) )
+ disp_list = sorted( glob(osp.join(root, 'training/disparities/*/frame_*.png')) ) * 2
+
+ for img1, img2, disp in zip(image1_list, image2_list, disp_list):
+ assert img1.split('/')[-2:] == disp.split('/')[-2:]
+ self.image_list += [ [img1, img2] ]
+ self.disparity_list += [ disp ]
+
+class FallingThings(StereoDataset):
+ def __init__(self, aug_params=None, root='datasets/FallingThings'):
+ super().__init__(aug_params, reader=frame_utils.readDispFallingThings)
+ assert os.path.exists(root)
+
+ with open(os.path.join(root, 'filenames.txt'), 'r') as f:
+ filenames = sorted(f.read().splitlines())
+
+ image1_list = [osp.join(root, e) for e in filenames]
+ image2_list = [osp.join(root, e.replace('left.jpg', 'right.jpg')) for e in filenames]
+ disp_list = [osp.join(root, e.replace('left.jpg', 'left.depth.png')) for e in filenames]
+
+ for img1, img2, disp in zip(image1_list, image2_list, disp_list):
+ self.image_list += [ [img1, img2] ]
+ self.disparity_list += [ disp ]
+
+class TartanAir(StereoDataset):
+ def __init__(self, aug_params=None, root='datasets', keywords=[]):
+ super().__init__(aug_params, reader=frame_utils.readDispTartanAir)
+ assert os.path.exists(root)
+
+ with open(os.path.join(root, 'tartanair_filenames.txt'), 'r') as f:
+ filenames = sorted(list(filter(lambda s: 'seasonsforest_winter/Easy' not in s, f.read().splitlines())))
+ for kw in keywords:
+ filenames = sorted(list(filter(lambda s: kw in s.lower(), filenames)))
+
+ image1_list = [osp.join(root, e) for e in filenames]
+ image2_list = [osp.join(root, e.replace('_left', '_right')) for e in filenames]
+ disp_list = [osp.join(root, e.replace('image_left', 'depth_left').replace('left.png', 'left_depth.npy')) for e in filenames]
+
+ for img1, img2, disp in zip(image1_list, image2_list, disp_list):
+ self.image_list += [ [img1, img2] ]
+ self.disparity_list += [ disp ]
+
+class KITTI(StereoDataset):
+ def __init__(self, aug_params=None, root='datasets/KITTI', image_set='training'):
+ super(KITTI, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispKITTI)
+ assert os.path.exists(root)
+
+ image1_list = sorted(glob(os.path.join(root, image_set, 'image_2/*_10.png')))
+ image2_list = sorted(glob(os.path.join(root, image_set, 'image_3/*_10.png')))
+ disp_list = sorted(glob(os.path.join(root, 'training', 'disp_occ_0/*_10.png'))) if image_set == 'training' else [osp.join(root, 'training/disp_occ_0/000085_10.png')]*len(image1_list)
+
+ for idx, (img1, img2, disp) in enumerate(zip(image1_list, image2_list, disp_list)):
+ self.image_list += [ [img1, img2] ]
+ self.disparity_list += [ disp ]
+
+
+class Middlebury(StereoDataset):
+ def __init__(self, aug_params=None, root='datasets/Middlebury', split='F'):
+ super(Middlebury, self).__init__(aug_params, sparse=True, reader=frame_utils.readDispMiddlebury)
+ assert os.path.exists(root)
+ assert split in "FHQ"
+ lines = list(map(osp.basename, glob(os.path.join(root, "MiddEval3/trainingF/*"))))
+ lines = list(filter(lambda p: any(s in p.split('/') for s in Path(os.path.join(root, "MiddEval3/official_train.txt")).read_text().splitlines()), lines))
+ image1_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/im0.png') for name in lines])
+ image2_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/im1.png') for name in lines])
+ disp_list = sorted([os.path.join(root, "MiddEval3", f'training{split}', f'{name}/disp0GT.pfm') for name in lines])
+
+ assert len(image1_list) == len(image2_list) == len(disp_list) > 0, [image1_list, split]
+ for img1, img2, disp in zip(image1_list, image2_list, disp_list):
+ self.image_list += [ [img1, img2] ]
+ self.disparity_list += [ disp ]
+
+
+def fetch_dataloader(args):
+ """ Create the data loader for the corresponding trainign set """
+
+ aug_params = {'crop_size': args.image_size, 'min_scale': args.spatial_scale[0], 'max_scale': args.spatial_scale[1], 'do_flip': False, 'yjitter': not args.noyjitter}
+ if hasattr(args, "saturation_range") and args.saturation_range is not None:
+ aug_params["saturation_range"] = args.saturation_range
+ if hasattr(args, "img_gamma") and args.img_gamma is not None:
+ aug_params["gamma"] = args.img_gamma
+ if hasattr(args, "do_flip") and args.do_flip is not None:
+ aug_params["do_flip"] = args.do_flip
+
+ train_dataset = None
+ for dataset_name in args.train_datasets:
+ if re.compile("middlebury_.*").fullmatch(dataset_name):
+ new_dataset = Middlebury(aug_params, split=dataset_name.replace('middlebury_',''))
+ elif dataset_name == 'sceneflow':
+ clean_dataset = SceneFlowDatasets(aug_params, dstype='frames_cleanpass')
+ final_dataset = SceneFlowDatasets(aug_params, dstype='frames_finalpass')
+ new_dataset = (clean_dataset*4) + (final_dataset*4)
+ logging.info(f"Adding {len(new_dataset)} samples from SceneFlow")
+ elif 'kitti' in dataset_name:
+ new_dataset = KITTI(aug_params, split=dataset_name)
+ logging.info(f"Adding {len(new_dataset)} samples from KITTI")
+ elif dataset_name == 'sintel_stereo':
+ new_dataset = SintelStereo(aug_params)*140
+ logging.info(f"Adding {len(new_dataset)} samples from Sintel Stereo")
+ elif dataset_name == 'falling_things':
+ new_dataset = FallingThings(aug_params)*5
+ logging.info(f"Adding {len(new_dataset)} samples from FallingThings")
+ elif dataset_name.startswith('tartan_air'):
+ new_dataset = TartanAir(aug_params, keywords=dataset_name.split('_')[2:])
+ logging.info(f"Adding {len(new_dataset)} samples from Tartain Air")
+ train_dataset = new_dataset if train_dataset is None else train_dataset + new_dataset
+
+ train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
+ pin_memory=True, shuffle=True, num_workers=int(os.environ.get('SLURM_CPUS_PER_TASK', 6))-2, drop_last=True)
+
+ logging.info('Training with %d image pairs' % len(train_dataset))
+ return train_loader
+
diff --git a/core/update.py b/core/update.py
new file mode 100644
index 0000000..ae3fac5
--- /dev/null
+++ b/core/update.py
@@ -0,0 +1,138 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from opt_einsum import contract
+
+class FlowHead(nn.Module):
+ def __init__(self, input_dim=128, hidden_dim=256, output_dim=2):
+ super(FlowHead, self).__init__()
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
+ self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ return self.conv2(self.relu(self.conv1(x)))
+
+class ConvGRU(nn.Module):
+ def __init__(self, hidden_dim, input_dim, kernel_size=3):
+ super(ConvGRU, self).__init__()
+ self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
+ self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
+ self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
+
+ def forward(self, h, cz, cr, cq, *x_list):
+ x = torch.cat(x_list, dim=1)
+ hx = torch.cat([h, x], dim=1)
+
+ z = torch.sigmoid(self.convz(hx) + cz)
+ r = torch.sigmoid(self.convr(hx) + cr)
+ q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)) + cq)
+
+ h = (1-z) * h + z * q
+ return h
+
+class SepConvGRU(nn.Module):
+ def __init__(self, hidden_dim=128, input_dim=192+128):
+ super(SepConvGRU, self).__init__()
+ self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
+ self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
+ self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
+
+ self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
+ self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
+ self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
+
+
+ def forward(self, h, *x):
+ # horizontal
+ x = torch.cat(x, dim=1)
+ hx = torch.cat([h, x], dim=1)
+ z = torch.sigmoid(self.convz1(hx))
+ r = torch.sigmoid(self.convr1(hx))
+ q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
+ h = (1-z) * h + z * q
+
+ # vertical
+ hx = torch.cat([h, x], dim=1)
+ z = torch.sigmoid(self.convz2(hx))
+ r = torch.sigmoid(self.convr2(hx))
+ q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
+ h = (1-z) * h + z * q
+
+ return h
+
+class BasicMotionEncoder(nn.Module):
+ def __init__(self, args):
+ super(BasicMotionEncoder, self).__init__()
+ self.args = args
+
+ cor_planes = args.corr_levels * (2*args.corr_radius + 1)
+
+ self.convc1 = nn.Conv2d(cor_planes, 64, 1, padding=0)
+ self.convc2 = nn.Conv2d(64, 64, 3, padding=1)
+ self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
+ self.convf2 = nn.Conv2d(64, 64, 3, padding=1)
+ self.conv = nn.Conv2d(64+64, 128-2, 3, padding=1)
+
+ def forward(self, flow, corr):
+ cor = F.relu(self.convc1(corr))
+ cor = F.relu(self.convc2(cor))
+ flo = F.relu(self.convf1(flow))
+ flo = F.relu(self.convf2(flo))
+
+ cor_flo = torch.cat([cor, flo], dim=1)
+ out = F.relu(self.conv(cor_flo))
+ return torch.cat([out, flow], dim=1)
+
+def pool2x(x):
+ return F.avg_pool2d(x, 3, stride=2, padding=1)
+
+def pool4x(x):
+ return F.avg_pool2d(x, 5, stride=4, padding=1)
+
+def interp(x, dest):
+ interp_args = {'mode': 'bilinear', 'align_corners': True}
+ return F.interpolate(x, dest.shape[2:], **interp_args)
+
+class BasicMultiUpdateBlock(nn.Module):
+ def __init__(self, args, hidden_dims=[]):
+ super().__init__()
+ self.args = args
+ self.encoder = BasicMotionEncoder(args)
+ encoder_output_dim = 128
+
+ self.gru08 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (args.n_gru_layers > 1))
+ self.gru16 = ConvGRU(hidden_dims[1], hidden_dims[0] * (args.n_gru_layers == 3) + hidden_dims[2])
+ self.gru32 = ConvGRU(hidden_dims[0], hidden_dims[1])
+ self.flow_head = FlowHead(hidden_dims[2], hidden_dim=256, output_dim=2)
+ factor = 2**self.args.n_downsample
+
+ self.mask = nn.Sequential(
+ nn.Conv2d(hidden_dims[2], 256, 3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, (factor**2)*9, 1, padding=0))
+
+ def forward(self, net, inp, corr=None, flow=None, iter08=True, iter16=True, iter32=True, update=True):
+
+ if iter32:
+ net[2] = self.gru32(net[2], *(inp[2]), pool2x(net[1]))
+ if iter16:
+ if self.args.n_gru_layers > 2:
+ net[1] = self.gru16(net[1], *(inp[1]), pool2x(net[0]), interp(net[2], net[1]))
+ else:
+ net[1] = self.gru16(net[1], *(inp[1]), pool2x(net[0]))
+ if iter08:
+ motion_features = self.encoder(flow, corr)
+ if self.args.n_gru_layers > 1:
+ net[0] = self.gru08(net[0], *(inp[0]), motion_features, interp(net[1], net[0]))
+ else:
+ net[0] = self.gru08(net[0], *(inp[0]), motion_features)
+
+ if not update:
+ return net
+
+ delta_flow = self.flow_head(net[0])
+
+ # scale mask to balence gradients
+ mask = .25 * self.mask(net[0])
+ return net, mask, delta_flow
diff --git a/core/utils/__init__.py b/core/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/core/utils/augmentor.py b/core/utils/augmentor.py
new file mode 100644
index 0000000..b407f98
--- /dev/null
+++ b/core/utils/augmentor.py
@@ -0,0 +1,317 @@
+import numpy as np
+import random
+import warnings
+import os
+import time
+from glob import glob
+from skimage import color, io
+from PIL import Image
+
+import cv2
+cv2.setNumThreads(0)
+cv2.ocl.setUseOpenCL(False)
+
+import torch
+from torchvision.transforms import ColorJitter, functional, Compose
+import torch.nn.functional as F
+
+def get_middlebury_images():
+ root = "datasets/Middlebury/MiddEval3"
+ with open(os.path.join(root, "official_train.txt"), 'r') as f:
+ lines = f.read().splitlines()
+ return sorted([os.path.join(root, 'trainingQ', f'{name}/im0.png') for name in lines])
+
+def get_eth3d_images():
+ return sorted(glob('datasets/ETH3D/two_view_training/*/im0.png'))
+
+def get_kitti_images():
+ return sorted(glob('datasets/KITTI/training/image_2/*_10.png'))
+
+def transfer_color(image, style_mean, style_stddev):
+ reference_image_lab = color.rgb2lab(image)
+ reference_stddev = np.std(reference_image_lab, axis=(0,1), keepdims=True)# + 1
+ reference_mean = np.mean(reference_image_lab, axis=(0,1), keepdims=True)
+
+ reference_image_lab = reference_image_lab - reference_mean
+ lamb = style_stddev/reference_stddev
+ style_image_lab = lamb * reference_image_lab
+ output_image_lab = style_image_lab + style_mean
+ l, a, b = np.split(output_image_lab, 3, axis=2)
+ l = l.clip(0, 100)
+ output_image_lab = np.concatenate((l,a,b), axis=2)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=UserWarning)
+ output_image_rgb = color.lab2rgb(output_image_lab) * 255
+ return output_image_rgb
+
+class AdjustGamma(object):
+
+ def __init__(self, gamma_min, gamma_max, gain_min=1.0, gain_max=1.0):
+ self.gamma_min, self.gamma_max, self.gain_min, self.gain_max = gamma_min, gamma_max, gain_min, gain_max
+
+ def __call__(self, sample):
+ gain = random.uniform(self.gain_min, self.gain_max)
+ gamma = random.uniform(self.gamma_min, self.gamma_max)
+ return functional.adjust_gamma(sample, gamma, gain)
+
+ def __repr__(self):
+ return f"Adjust Gamma {self.gamma_min}, ({self.gamma_max}) and Gain ({self.gain_min}, {self.gain_max})"
+
+class FlowAugmentor:
+ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True, yjitter=False, saturation_range=[0.6,1.4], gamma=[1,1,1,1]):
+
+ # spatial augmentation params
+ self.crop_size = crop_size
+ self.min_scale = min_scale
+ self.max_scale = max_scale
+ self.spatial_aug_prob = 1.0
+ self.stretch_prob = 0.8
+ self.max_stretch = 0.2
+
+ # flip augmentation params
+ self.yjitter = yjitter
+ self.do_flip = do_flip
+ self.h_flip_prob = 0.5
+ self.v_flip_prob = 0.1
+
+ # photometric augmentation params
+ self.photo_aug = Compose([ColorJitter(brightness=0.4, contrast=0.4, saturation=saturation_range, hue=0.5/3.14), AdjustGamma(*gamma)])
+ self.asymmetric_color_aug_prob = 0.2
+ self.eraser_aug_prob = 0.5
+
+ def color_transform(self, img1, img2):
+ """ Photometric augmentation """
+
+ # asymmetric
+ if np.random.rand() < self.asymmetric_color_aug_prob:
+ img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
+ img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
+
+ # symmetric
+ else:
+ image_stack = np.concatenate([img1, img2], axis=0)
+ image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
+ img1, img2 = np.split(image_stack, 2, axis=0)
+
+ return img1, img2
+
+ def eraser_transform(self, img1, img2, bounds=[50, 100]):
+ """ Occlusion augmentation """
+
+ ht, wd = img1.shape[:2]
+ if np.random.rand() < self.eraser_aug_prob:
+ mean_color = np.mean(img2.reshape(-1, 3), axis=0)
+ for _ in range(np.random.randint(1, 3)):
+ x0 = np.random.randint(0, wd)
+ y0 = np.random.randint(0, ht)
+ dx = np.random.randint(bounds[0], bounds[1])
+ dy = np.random.randint(bounds[0], bounds[1])
+ img2[y0:y0+dy, x0:x0+dx, :] = mean_color
+
+ return img1, img2
+
+ def spatial_transform(self, img1, img2, flow):
+ # randomly sample scale
+ ht, wd = img1.shape[:2]
+ min_scale = np.maximum(
+ (self.crop_size[0] + 8) / float(ht),
+ (self.crop_size[1] + 8) / float(wd))
+
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
+ scale_x = scale
+ scale_y = scale
+ if np.random.rand() < self.stretch_prob:
+ scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
+ scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
+
+ scale_x = np.clip(scale_x, min_scale, None)
+ scale_y = np.clip(scale_y, min_scale, None)
+
+ if np.random.rand() < self.spatial_aug_prob:
+ # rescale the images
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ flow = flow * [scale_x, scale_y]
+
+ if self.do_flip:
+ if np.random.rand() < self.h_flip_prob and self.do_flip == 'hf': # h-flip
+ img1 = img1[:, ::-1]
+ img2 = img2[:, ::-1]
+ flow = flow[:, ::-1] * [-1.0, 1.0]
+
+ if np.random.rand() < self.h_flip_prob and self.do_flip == 'h': # h-flip for stereo
+ tmp = img1[:, ::-1]
+ img1 = img2[:, ::-1]
+ img2 = tmp
+
+ if np.random.rand() < self.v_flip_prob and self.do_flip == 'v': # v-flip
+ img1 = img1[::-1, :]
+ img2 = img2[::-1, :]
+ flow = flow[::-1, :] * [1.0, -1.0]
+
+ if self.yjitter:
+ y0 = np.random.randint(2, img1.shape[0] - self.crop_size[0] - 2)
+ x0 = np.random.randint(2, img1.shape[1] - self.crop_size[1] - 2)
+
+ y1 = y0 + np.random.randint(-2, 2 + 1)
+ img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ img2 = img2[y1:y1+self.crop_size[0], x0:x0+self.crop_size[1]]
+ flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+
+ else:
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
+ x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
+
+ img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+
+ return img1, img2, flow
+
+
+ def __call__(self, img1, img2, flow):
+ img1, img2 = self.color_transform(img1, img2)
+ img1, img2 = self.eraser_transform(img1, img2)
+ img1, img2, flow = self.spatial_transform(img1, img2, flow)
+
+ img1 = np.ascontiguousarray(img1)
+ img2 = np.ascontiguousarray(img2)
+ flow = np.ascontiguousarray(flow)
+
+ return img1, img2, flow
+
+class SparseFlowAugmentor:
+ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False, yjitter=False, saturation_range=[0.7,1.3], gamma=[1,1,1,1]):
+ # spatial augmentation params
+ self.crop_size = crop_size
+ self.min_scale = min_scale
+ self.max_scale = max_scale
+ self.spatial_aug_prob = 0.8
+ self.stretch_prob = 0.8
+ self.max_stretch = 0.2
+
+ # flip augmentation params
+ self.do_flip = do_flip
+ self.h_flip_prob = 0.5
+ self.v_flip_prob = 0.1
+
+ # photometric augmentation params
+ self.photo_aug = Compose([ColorJitter(brightness=0.3, contrast=0.3, saturation=saturation_range, hue=0.3/3.14), AdjustGamma(*gamma)])
+ self.asymmetric_color_aug_prob = 0.2
+ self.eraser_aug_prob = 0.5
+
+ def color_transform(self, img1, img2):
+ image_stack = np.concatenate([img1, img2], axis=0)
+ image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
+ img1, img2 = np.split(image_stack, 2, axis=0)
+ return img1, img2
+
+ def eraser_transform(self, img1, img2):
+ ht, wd = img1.shape[:2]
+ if np.random.rand() < self.eraser_aug_prob:
+ mean_color = np.mean(img2.reshape(-1, 3), axis=0)
+ for _ in range(np.random.randint(1, 3)):
+ x0 = np.random.randint(0, wd)
+ y0 = np.random.randint(0, ht)
+ dx = np.random.randint(50, 100)
+ dy = np.random.randint(50, 100)
+ img2[y0:y0+dy, x0:x0+dx, :] = mean_color
+
+ return img1, img2
+
+ def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
+ ht, wd = flow.shape[:2]
+ coords = np.meshgrid(np.arange(wd), np.arange(ht))
+ coords = np.stack(coords, axis=-1)
+
+ coords = coords.reshape(-1, 2).astype(np.float32)
+ flow = flow.reshape(-1, 2).astype(np.float32)
+ valid = valid.reshape(-1).astype(np.float32)
+
+ coords0 = coords[valid>=1]
+ flow0 = flow[valid>=1]
+
+ ht1 = int(round(ht * fy))
+ wd1 = int(round(wd * fx))
+
+ coords1 = coords0 * [fx, fy]
+ flow1 = flow0 * [fx, fy]
+
+ xx = np.round(coords1[:,0]).astype(np.int32)
+ yy = np.round(coords1[:,1]).astype(np.int32)
+
+ v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
+ xx = xx[v]
+ yy = yy[v]
+ flow1 = flow1[v]
+
+ flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
+ valid_img = np.zeros([ht1, wd1], dtype=np.int32)
+
+ flow_img[yy, xx] = flow1
+ valid_img[yy, xx] = 1
+
+ return flow_img, valid_img
+
+ def spatial_transform(self, img1, img2, flow, valid):
+ # randomly sample scale
+
+ ht, wd = img1.shape[:2]
+ min_scale = np.maximum(
+ (self.crop_size[0] + 1) / float(ht),
+ (self.crop_size[1] + 1) / float(wd))
+
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
+ scale_x = np.clip(scale, min_scale, None)
+ scale_y = np.clip(scale, min_scale, None)
+
+ if np.random.rand() < self.spatial_aug_prob:
+ # rescale the images
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
+
+ if self.do_flip:
+ if np.random.rand() < self.h_flip_prob and self.do_flip == 'hf': # h-flip
+ img1 = img1[:, ::-1]
+ img2 = img2[:, ::-1]
+ flow = flow[:, ::-1] * [-1.0, 1.0]
+
+ if np.random.rand() < self.h_flip_prob and self.do_flip == 'h': # h-flip for stereo
+ tmp = img1[:, ::-1]
+ img1 = img2[:, ::-1]
+ img2 = tmp
+
+ if np.random.rand() < self.v_flip_prob and self.do_flip == 'v': # v-flip
+ img1 = img1[::-1, :]
+ img2 = img2[::-1, :]
+ flow = flow[::-1, :] * [1.0, -1.0]
+
+ margin_y = 20
+ margin_x = 50
+
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
+ x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
+
+ y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
+ x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
+
+ img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ return img1, img2, flow, valid
+
+
+ def __call__(self, img1, img2, flow, valid):
+ img1, img2 = self.color_transform(img1, img2)
+ img1, img2 = self.eraser_transform(img1, img2)
+ img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
+
+ img1 = np.ascontiguousarray(img1)
+ img2 = np.ascontiguousarray(img2)
+ flow = np.ascontiguousarray(flow)
+ valid = np.ascontiguousarray(valid)
+
+ return img1, img2, flow, valid
diff --git a/core/utils/frame_utils.py b/core/utils/frame_utils.py
new file mode 100644
index 0000000..10d3d85
--- /dev/null
+++ b/core/utils/frame_utils.py
@@ -0,0 +1,187 @@
+import numpy as np
+from PIL import Image
+from os.path import *
+import re
+import json
+import imageio
+import cv2
+cv2.setNumThreads(0)
+cv2.ocl.setUseOpenCL(False)
+
+TAG_CHAR = np.array([202021.25], np.float32)
+
+def readFlow(fn):
+ """ Read .flo file in Middlebury format"""
+ # Code adapted from:
+ # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
+
+ # WARNING: this will work on little-endian architectures (eg Intel x86) only!
+ # print 'fn = %s'%(fn)
+ with open(fn, 'rb') as f:
+ magic = np.fromfile(f, np.float32, count=1)
+ if 202021.25 != magic:
+ print('Magic number incorrect. Invalid .flo file')
+ return None
+ else:
+ w = np.fromfile(f, np.int32, count=1)
+ h = np.fromfile(f, np.int32, count=1)
+ # print 'Reading %d x %d flo file\n' % (w, h)
+ data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
+ # Reshape data into 3D array (columns, rows, bands)
+ # The reshape here is for visualization, the original code is (w,h,2)
+ return np.resize(data, (int(h), int(w), 2))
+
+def readPFM(file):
+ file = open(file, 'rb')
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header == b'PF':
+ color = True
+ elif header == b'Pf':
+ color = False
+ else:
+ raise Exception('Not a PFM file.')
+
+ dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
+ if dim_match:
+ width, height = map(int, dim_match.groups())
+ else:
+ raise Exception('Malformed PFM header.')
+
+ scale = float(file.readline().rstrip())
+ if scale < 0: # little-endian
+ endian = '<'
+ scale = -scale
+ else:
+ endian = '>' # big-endian
+
+ data = np.fromfile(file, endian + 'f')
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+ return data
+
+def writePFM(file, array):
+ import os
+ assert type(file) is str and type(array) is np.ndarray and \
+ os.path.splitext(file)[1] == ".pfm"
+ with open(file, 'wb') as f:
+ H, W = array.shape
+ headers = ["Pf\n", f"{W} {H}\n", "-1\n"]
+ for header in headers:
+ f.write(str.encode(header))
+ array = np.flip(array, axis=0).astype(np.float32)
+ f.write(array.tobytes())
+
+
+
+def writeFlow(filename,uv,v=None):
+ """ Write optical flow to file.
+
+ If v is None, uv is assumed to contain both u and v channels,
+ stacked in depth.
+ Original code by Deqing Sun, adapted from Daniel Scharstein.
+ """
+ nBands = 2
+
+ if v is None:
+ assert(uv.ndim == 3)
+ assert(uv.shape[2] == 2)
+ u = uv[:,:,0]
+ v = uv[:,:,1]
+ else:
+ u = uv
+
+ assert(u.shape == v.shape)
+ height,width = u.shape
+ f = open(filename,'wb')
+ # write the header
+ f.write(TAG_CHAR)
+ np.array(width).astype(np.int32).tofile(f)
+ np.array(height).astype(np.int32).tofile(f)
+ # arrange into matrix form
+ tmp = np.zeros((height, width*nBands))
+ tmp[:,np.arange(width)*2] = u
+ tmp[:,np.arange(width)*2 + 1] = v
+ tmp.astype(np.float32).tofile(f)
+ f.close()
+
+
+def readFlowKITTI(filename):
+ flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
+ flow = flow[:,:,::-1].astype(np.float32)
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
+ flow = (flow - 2**15) / 64.0
+ return flow, valid
+
+def readDispKITTI(filename):
+ disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
+ valid = disp > 0.0
+ return disp, valid
+
+# Method taken from /n/fs/raft-depth/RAFT-Stereo/datasets/SintelStereo/sdk/python/sintel_io.py
+def readDispSintelStereo(file_name):
+ a = np.array(Image.open(file_name))
+ d_r, d_g, d_b = np.split(a, axis=2, indices_or_sections=3)
+ disp = (d_r * 4 + d_g / (2**6) + d_b / (2**14))[..., 0]
+ mask = np.array(Image.open(file_name.replace('disparities', 'occlusions')))
+ valid = ((mask == 0) & (disp > 0))
+ return disp, valid
+
+# Method taken from https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
+def readDispFallingThings(file_name):
+ a = np.array(Image.open(file_name))
+ with open('/'.join(file_name.split('/')[:-1] + ['_camera_settings.json']), 'r') as f:
+ intrinsics = json.load(f)
+ fx = intrinsics['camera_settings'][0]['intrinsic_settings']['fx']
+ disp = (fx * 6.0 * 100) / a.astype(np.float32)
+ valid = disp > 0
+ return disp, valid
+
+# Method taken from https://github.com/castacks/tartanair_tools/blob/master/data_type.md
+def readDispTartanAir(file_name):
+ depth = np.load(file_name)
+ disp = 80.0 / depth
+ valid = disp > 0
+ return disp, valid
+
+
+def readDispMiddlebury(file_name):
+ assert basename(file_name) == 'disp0GT.pfm'
+ disp = readPFM(file_name).astype(np.float32)
+ assert len(disp.shape) == 2
+ nocc_pix = file_name.replace('disp0GT.pfm', 'mask0nocc.png')
+ assert exists(nocc_pix)
+ nocc_pix = imageio.imread(nocc_pix) == 255
+ assert np.any(nocc_pix)
+ return disp, nocc_pix
+
+def writeFlowKITTI(filename, uv):
+ uv = 64.0 * uv + 2**15
+ valid = np.ones([uv.shape[0], uv.shape[1], 1])
+ uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
+ cv2.imwrite(filename, uv[..., ::-1])
+
+
+def read_gen(file_name, pil=False):
+ ext = splitext(file_name)[-1]
+ if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
+ return Image.open(file_name)
+ elif ext == '.bin' or ext == '.raw':
+ return np.load(file_name)
+ elif ext == '.flo':
+ return readFlow(file_name).astype(np.float32)
+ elif ext == '.pfm':
+ flow = readPFM(file_name).astype(np.float32)
+ if len(flow.shape) == 2:
+ return flow
+ else:
+ return flow[:, :, :-1]
+ return []
\ No newline at end of file
diff --git a/core/utils/utils.py b/core/utils/utils.py
new file mode 100644
index 0000000..d6d5953
--- /dev/null
+++ b/core/utils/utils.py
@@ -0,0 +1,93 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy import interpolate
+
+
+class InputPadder:
+ """ Pads images such that dimensions are divisible by 8 """
+ def __init__(self, dims, mode='sintel', divis_by=8):
+ self.ht, self.wd = dims[-2:]
+ pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
+ pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
+ if mode == 'sintel':
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
+ else:
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
+
+ def pad(self, *inputs):
+ assert all((x.ndim == 4) for x in inputs)
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
+
+ def unpad(self, x):
+ assert x.ndim == 4
+ ht, wd = x.shape[-2:]
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
+ return x[..., c[0]:c[1], c[2]:c[3]]
+
+def forward_interpolate(flow):
+ flow = flow.detach().cpu().numpy()
+ dx, dy = flow[0], flow[1]
+
+ ht, wd = dx.shape
+ x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
+
+ x1 = x0 + dx
+ y1 = y0 + dy
+
+ x1 = x1.reshape(-1)
+ y1 = y1.reshape(-1)
+ dx = dx.reshape(-1)
+ dy = dy.reshape(-1)
+
+ valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
+ x1 = x1[valid]
+ y1 = y1[valid]
+ dx = dx[valid]
+ dy = dy[valid]
+
+ flow_x = interpolate.griddata(
+ (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
+
+ flow_y = interpolate.griddata(
+ (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
+
+ flow = np.stack([flow_x, flow_y], axis=0)
+ return torch.from_numpy(flow).float()
+
+
+def bilinear_sampler(img, coords, mode='bilinear', mask=False):
+ """ Wrapper for grid_sample, uses pixel coordinates """
+ H, W = img.shape[-2:]
+ xgrid, ygrid = coords.split([1,1], dim=-1)
+ xgrid = 2*xgrid/(W-1) - 1
+ assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem
+
+ grid = torch.cat([xgrid, ygrid], dim=-1)
+ img = F.grid_sample(img, grid, align_corners=True)
+
+ if mask:
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
+ return img, mask.float()
+
+ return img
+
+
+def coords_grid(batch, ht, wd):
+ coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
+ coords = torch.stack(coords[::-1], dim=0).float()
+ return coords[None].repeat(batch, 1, 1, 1)
+
+
+def upflow8(flow, mode='bilinear'):
+ new_size = (8 * flow.shape[2], 8 * flow.shape[3])
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
+
+def gauss_blur(input, N=5, std=1):
+ B, D, H, W = input.shape
+ x, y = torch.meshgrid(torch.arange(N).float() - N//2, torch.arange(N).float() - N//2)
+ unnormalized_gaussian = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * std ** 2))
+ weights = unnormalized_gaussian / unnormalized_gaussian.sum().clamp(min=1e-4)
+ weights = weights.view(1,1,N,N).to(input)
+ output = F.conv2d(input.reshape(B*D,1,H,W), weights, padding=N//2)
+ return output.view(B, D, H, W)
\ No newline at end of file
diff --git a/demo.py b/demo.py
new file mode 100644
index 0000000..f0a029a
--- /dev/null
+++ b/demo.py
@@ -0,0 +1,75 @@
+import sys
+sys.path.append('core')
+
+import argparse
+import glob
+import numpy as np
+import torch
+from tqdm import tqdm
+from pathlib import Path
+from raft_stereo import RAFTStereo
+from utils.utils import InputPadder
+from PIL import Image
+from matplotlib import pyplot as plt
+
+
+DEVICE = 'cuda'
+
+def load_image(imfile):
+ img = np.array(Image.open(imfile)).astype(np.uint8)
+ img = torch.from_numpy(img).permute(2, 0, 1).float()
+ return img[None].to(DEVICE)
+
+def demo(args):
+ model = torch.nn.DataParallel(RAFTStereo(args), device_ids=[0])
+ model.load_state_dict(torch.load(args.restore_ckpt))
+
+ model = model.module
+ model.to(DEVICE)
+ model.eval()
+
+ output_directory = Path(args.output_directory)
+ output_directory.mkdir(exist_ok=True)
+
+ with torch.no_grad():
+ left_images = sorted(glob.glob(args.left_imgs, recursive=True))
+ right_images = sorted(glob.glob(args.right_imgs, recursive=True))
+ print(f"Found {len(left_images)} images. Saving files to {output_directory}/")
+
+ for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))):
+ image1 = load_image(imfile1)
+ image2 = load_image(imfile2)
+
+ padder = InputPadder(image1.shape, divis_by=32)
+ image1, image2 = padder.pad(image1, image2)
+
+ _, flow_up = model(image1, image2, iters=args.valid_iters, test_mode=True)
+ file_stem = imfile1.split('/')[-2]
+ if args.save_numpy:
+ np.save(output_directory / f"{file_stem}.npy", flow_up.cpu().numpy().squeeze())
+ plt.imsave(output_directory / f"{file_stem}.png", -flow_up.cpu().numpy().squeeze(), cmap='jet')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--restore_ckpt', help="restore checkpoint", required=True)
+ parser.add_argument('--save_numpy', action='store_true', help='save output as numpy arrays')
+ parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames", default="datasets/Middlebury/MiddEval3/testH/*/im0.png")
+ parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames", default="datasets/Middlebury/MiddEval3/testH/*/im1.png")
+ parser.add_argument('--output_directory', help="directory to save output", default="demo_output")
+ parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
+ parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during forward pass')
+
+ # Architecture choices
+ parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
+ parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
+ parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
+ parser.add_argument('--corr_levels', type=int, default=4, help="number of levels in the correlation pyramid")
+ parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
+ parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
+ parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
+ parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
+
+ args = parser.parse_args()
+
+ demo(args)
diff --git a/depth_eq.png b/depth_eq.png
new file mode 100644
index 0000000..d731f06
Binary files /dev/null and b/depth_eq.png differ
diff --git a/download_datasets.sh b/download_datasets.sh
new file mode 100755
index 0000000..ee48fd7
--- /dev/null
+++ b/download_datasets.sh
@@ -0,0 +1,24 @@
+mkdir datasets/Middlebury -p
+cd datasets/Middlebury/
+wget https://www.dropbox.com/s/fn8siy5muak3of3/official_train.txt -P MiddEval3/
+wget https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-Q.zip
+unzip MiddEval3-data-Q.zip
+wget https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-GT0-Q.zip
+unzip MiddEval3-GT0-Q.zip
+wget https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-H.zip
+unzip MiddEval3-data-H.zip
+wget https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-GT0-H.zip
+unzip MiddEval3-GT0-H.zip
+wget https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-F.zip
+unzip MiddEval3-data-F.zip
+wget https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-GT0-F.zip
+unzip MiddEval3-GT0-F.zip
+rm *.zip
+cd ../..
+
+mkdir datasets/ETH3D -p
+cd datasets/ETH3D/
+wget https://www.eth3d.net/data/two_view_test.7z
+echo "Unzipping two_view_test.7z using p7zip (installed from environment.yaml)"
+7za x two_view_test.7z
+cd ../..
\ No newline at end of file
diff --git a/download_models.sh b/download_models.sh
new file mode 100755
index 0000000..fa84e29
--- /dev/null
+++ b/download_models.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+mkdir models -p
+cd models
+wget https://www.dropbox.com/s/q4312z8g5znhhkp/models.zip
+unzip models.zip
+rm models.zip -f
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000..6532755
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,19 @@
+name: raftstereo
+channels:
+ - pytorch
+ - bioconda
+ - defaults
+dependencies:
+ - python=3.7.6
+ - pytorch=1.7.0
+ - torchvision=0.8.1
+ - cudatoolkit=10.2.89
+ - matplotlib
+ - tensorboard
+ - scipy
+ - opencv
+ - tqdm
+ - opt_einsum
+ - imageio
+ - scikit-image
+ - p7zip
diff --git a/evaluate_stereo.py b/evaluate_stereo.py
new file mode 100644
index 0000000..746c843
--- /dev/null
+++ b/evaluate_stereo.py
@@ -0,0 +1,242 @@
+from __future__ import print_function, division
+import sys
+sys.path.append('core')
+
+import argparse
+import time
+import logging
+import numpy as np
+import torch
+from tqdm import tqdm
+from raft_stereo import RAFTStereo, autocast
+import stereo_datasets as datasets
+from utils.utils import InputPadder
+
+def count_parameters(model):
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+@torch.no_grad()
+def validate_eth3d(model, iters=32, mixed_prec=False):
+ """ Peform validation using the ETH3D (train) split """
+ model.eval()
+ aug_params = {}
+ val_dataset = datasets.ETH3D(aug_params)
+
+ out_list, epe_list = [], []
+ for val_id in range(len(val_dataset)):
+ _, image1, image2, flow_gt, valid_gt = val_dataset[val_id]
+ image1 = image1[None].cuda()
+ image2 = image2[None].cuda()
+
+ padder = InputPadder(image1.shape, divis_by=32)
+ image1, image2 = padder.pad(image1, image2)
+
+ with autocast(enabled=mixed_prec):
+ _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
+ flow_pr = padder.unpad(flow_pr.float()).cpu().squeeze(0)
+ assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
+ epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
+
+ epe_flattened = epe.flatten()
+ val = valid_gt.flatten() >= 0.5
+ out = (epe_flattened > 1.0)
+ image_out = out[val].float().mean().item()
+ image_epe = epe_flattened[val].mean().item()
+ logging.info(f"ETH3D {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}")
+ epe_list.append(image_epe)
+ out_list.append(image_out)
+
+ epe_list = np.array(epe_list)
+ out_list = np.array(out_list)
+
+ epe = np.mean(epe_list)
+ d1 = 100 * np.mean(out_list)
+
+ print("Validation ETH3D: EPE %f, D1 %f" % (epe, d1))
+ return {'eth3d-epe': epe, 'eth3d-d1': d1}
+
+
+@torch.no_grad()
+def validate_kitti(model, iters=32, mixed_prec=False):
+ """ Peform validation using the KITTI-2015 (train) split """
+ model.eval()
+ aug_params = {}
+ val_dataset = datasets.KITTI(aug_params, image_set='training')
+ torch.backends.cudnn.benchmark = True
+
+ out_list, epe_list, elapsed_list = [], [], []
+ for val_id in range(len(val_dataset)):
+ _, image1, image2, flow_gt, valid_gt = val_dataset[val_id]
+ image1 = image1[None].cuda()
+ image2 = image2[None].cuda()
+
+ padder = InputPadder(image1.shape, divis_by=32)
+ image1, image2 = padder.pad(image1, image2)
+
+ with autocast(enabled=mixed_prec):
+ start = time.time()
+ _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
+ end = time.time()
+
+ if val_id > 50:
+ elapsed_list.append(end-start)
+ flow_pr = padder.unpad(flow_pr).cpu().squeeze(0)
+
+ assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
+ epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
+
+ epe_flattened = epe.flatten()
+ val = valid_gt.flatten() >= 0.5
+
+ out = (epe_flattened > 3.0)
+ image_out = out[val].float().mean().item()
+ image_epe = epe_flattened[val].mean().item()
+ if val_id < 9 or (val_id+1)%10 == 0:
+ logging.info(f"KITTI Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}. Runtime: {format(end-start, '.3f')}s ({format(1/(end-start), '.2f')}-FPS)")
+ epe_list.append(epe_flattened[val].mean().item())
+ out_list.append(out[val].cpu().numpy())
+
+ epe_list = np.array(epe_list)
+ out_list = np.concatenate(out_list)
+
+ epe = np.mean(epe_list)
+ d1 = 100 * np.mean(out_list)
+
+ avg_runtime = np.mean(elapsed_list)
+
+ print(f"Validation KITTI: EPE {epe}, D1 {d1}, {format(1/avg_runtime, '.2f')}-FPS ({format(avg_runtime, '.3f')}s)")
+ return {'kitti-epe': epe, 'kitti-d1': d1}
+
+
+@torch.no_grad()
+def validate_things(model, iters=32, mixed_prec=False):
+ """ Peform validation using the FlyingThings3D (TEST) split """
+ model.eval()
+ val_dataset = datasets.SceneFlowDatasets(dstype='frames_finalpass', things_test=True)
+
+ out_list, epe_list = [], []
+ for val_id in tqdm(range(len(val_dataset))):
+ _, image1, image2, flow_gt, valid_gt = val_dataset[val_id]
+ image1 = image1[None].cuda()
+ image2 = image2[None].cuda()
+
+ padder = InputPadder(image1.shape, divis_by=32)
+ image1, image2 = padder.pad(image1, image2)
+
+ with autocast(enabled=mixed_prec):
+ _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
+ flow_pr = padder.unpad(flow_pr).cpu().squeeze(0)
+ assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
+ epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
+
+ epe = epe.flatten()
+ val = (valid_gt.flatten() >= 0.5) & (flow_gt.abs().flatten() < 192)
+
+ out = (epe > 1.0)
+ epe_list.append(epe[val].mean().item())
+ out_list.append(out[val].cpu().numpy())
+
+ epe_list = np.array(epe_list)
+ out_list = np.concatenate(out_list)
+
+ epe = np.mean(epe_list)
+ d1 = 100 * np.mean(out_list)
+
+ print("Validation FlyingThings: %f, %f" % (epe, d1))
+ return {'things-epe': epe, 'things-d1': d1}
+
+
+@torch.no_grad()
+def validate_middlebury(model, iters=32, split='F', mixed_prec=False):
+ """ Peform validation using the Middlebury-V3 dataset """
+ model.eval()
+ aug_params = {}
+ val_dataset = datasets.Middlebury(aug_params, split=split)
+
+ out_list, epe_list = [], []
+ for val_id in range(len(val_dataset)):
+ (imageL_file, _, _), image1, image2, flow_gt, valid_gt = val_dataset[val_id]
+ image1 = image1[None].cuda()
+ image2 = image2[None].cuda()
+
+ padder = InputPadder(image1.shape, divis_by=32)
+ image1, image2 = padder.pad(image1, image2)
+
+ with autocast(enabled=mixed_prec):
+ _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
+ flow_pr = padder.unpad(flow_pr).cpu().squeeze(0)
+
+ assert flow_pr.shape == flow_gt.shape, (flow_pr.shape, flow_gt.shape)
+ epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
+
+ epe_flattened = epe.flatten()
+ val = (valid_gt.reshape(-1) >= -0.5) & (flow_gt[0].reshape(-1) > -1000)
+
+ out = (epe_flattened > 2.0)
+ image_out = out[val].float().mean().item()
+ image_epe = epe_flattened[val].mean().item()
+ logging.info(f"Middlebury Iter {val_id+1} out of {len(val_dataset)}. EPE {round(image_epe,4)} D1 {round(image_out,4)}")
+ epe_list.append(image_epe)
+ out_list.append(image_out)
+
+ epe_list = np.array(epe_list)
+ out_list = np.array(out_list)
+
+ epe = np.mean(epe_list)
+ d1 = 100 * np.mean(out_list)
+
+ print(f"Validation Middlebury{split}: EPE {epe}, D1 {d1}")
+ return {f'middlebury{split}-epe': epe, f'middlebury{split}-d1': d1}
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--restore_ckpt', help="restore checkpoint", default=None)
+ parser.add_argument('--dataset', help="dataset for evaluation", required=True, choices=["eth3d", "kitti", "things"] + [f"middlebury_{s}" for s in 'FHQ'])
+ parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
+ parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during forward pass')
+
+ # Architecure choices
+ parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
+ parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
+ parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
+ parser.add_argument('--corr_levels', type=int, default=4, help="number of levels in the correlation pyramid")
+ parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
+ parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
+ parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
+ parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
+ args = parser.parse_args()
+
+ model = torch.nn.DataParallel(RAFTStereo(args), device_ids=[0])
+
+ logging.basicConfig(level=logging.INFO,
+ format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')
+
+ if args.restore_ckpt is not None:
+ assert args.restore_ckpt.endswith(".pth")
+ logging.info("Loading checkpoint...")
+ checkpoint = torch.load(args.restore_ckpt)
+ model.load_state_dict(checkpoint, strict=True)
+ logging.info(f"Done loading checkpoint")
+
+ model.cuda()
+ model.eval()
+
+ print(f"The model has {format(count_parameters(model)/1e6, '.2f')}M learnable parameters.")
+
+ # The CUDA implementations of the correlation volume prevent half-precision
+ # rounding errors in the correlation lookup. This allows us to use mixed precision
+ # in the entire forward pass, not just in the GRUs & feature extractors.
+ use_mixed_precision = args.corr_implementation.endswith("_cuda")
+
+ if args.dataset == 'eth3d':
+ validate_eth3d(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
+
+ elif args.dataset == 'kitti':
+ validate_kitti(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
+
+ elif args.dataset in [f"middlebury_{s}" for s in 'FHQ']:
+ validate_middlebury(model, iters=args.valid_iters, split=args.dataset[-1], mixed_prec=use_mixed_precision)
+
+ elif args.dataset == 'things':
+ validate_things(model, iters=args.valid_iters, mixed_prec=use_mixed_precision)
diff --git a/sampler/__init__.py b/sampler/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/sampler/sampler.cpp b/sampler/sampler.cpp
new file mode 100644
index 0000000..b1ed09b
--- /dev/null
+++ b/sampler/sampler.cpp
@@ -0,0 +1,51 @@
+#include
+
+#include
+
+// CUDA forward declarations
+
+
+std::vector sampler_cuda_forward(
+ torch::Tensor volume,
+ torch::Tensor coords,
+ int radius);
+
+std::vector sampler_cuda_backward(
+ torch::Tensor volume,
+ torch::Tensor coords,
+ torch::Tensor corr_grad,
+ int radius);
+
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+std::vector sampler_forward(
+ torch::Tensor volume,
+ torch::Tensor coords,
+ int radius) {
+ CHECK_INPUT(volume);
+ CHECK_INPUT(coords);
+
+ return sampler_cuda_forward(volume, coords, radius);
+}
+
+std::vector sampler_backward(
+ torch::Tensor volume,
+ torch::Tensor coords,
+ torch::Tensor corr_grad,
+ int radius) {
+ CHECK_INPUT(volume);
+ CHECK_INPUT(coords);
+ CHECK_INPUT(corr_grad);
+
+ auto volume_grad = sampler_cuda_backward(volume, coords, corr_grad, radius);
+ return {volume_grad};
+}
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("forward", &sampler_forward, "SAMPLER forward");
+ m.def("backward", &sampler_backward, "SAMPLER backward");
+}
\ No newline at end of file
diff --git a/sampler/sampler_kernel.cu b/sampler/sampler_kernel.cu
new file mode 100644
index 0000000..7dee2a8
--- /dev/null
+++ b/sampler/sampler_kernel.cu
@@ -0,0 +1,167 @@
+#include
+#include
+#include
+#include
+#include
+#include
+
+
+#include
+#include
+#include
+
+#define BLOCK 16
+
+__forceinline__ __device__ bool within_bounds(int h, int w, int H, int W) {
+ return h >= 0 && h < H && w >= 0 && w < W;
+}
+
+template
+__global__ void sampler_forward_kernel(
+ const torch::PackedTensorAccessor32 volume,
+ const torch::PackedTensorAccessor32 coords,
+ torch::PackedTensorAccessor32 corr,
+ int r)
+{
+ // batch index
+ const int x = blockIdx.x * blockDim.x + threadIdx.x;
+ const int y = blockIdx.y * blockDim.y + threadIdx.y;
+ const int n = blockIdx.z;
+
+ const int h1 = volume.size(1);
+ const int w1 = volume.size(2);
+ const int w2 = volume.size(3);
+
+ if (!within_bounds(y, x, h1, w1)) {
+ return;
+ }
+
+ float x0 = coords[n][0][y][x];
+ float y0 = coords[n][1][y][x];
+
+ float dx = x0 - floor(x0);
+ float dy = y0 - floor(y0);
+
+ int rd = 2*r + 1;
+ for (int i=0; i(floor(x0)) - r + i;
+
+ if (within_bounds(0, x1, 1, w2)) {
+ scalar_t s = volume[n][y][x][x1];
+
+ if (i > 0)
+ corr[n][i-1][y][x] += s * scalar_t(dx);
+
+ if (i < rd)
+ corr[n][i][y][x] += s * scalar_t((1.0f-dx));
+
+ }
+ }
+}
+
+
+template
+__global__ void sampler_backward_kernel(
+ const torch::PackedTensorAccessor32 coords,
+ const torch::PackedTensorAccessor32 corr_grad,
+ torch::PackedTensorAccessor32 volume_grad,
+ int r)
+{
+ // batch index
+ const int x = blockIdx.x * blockDim.x + threadIdx.x;
+ const int y = blockIdx.y * blockDim.y + threadIdx.y;
+ const int n = blockIdx.z;
+
+ const int h1 = volume_grad.size(1);
+ const int w1 = volume_grad.size(2);
+ const int w2 = volume_grad.size(3);
+
+ if (!within_bounds(y, x, h1, w1)) {
+ return;
+ }
+
+ float x0 = coords[n][0][y][x];
+ float y0 = coords[n][1][y][x];
+
+ float dx = x0 - floor(x0);
+ float dy = y0 - floor(y0);
+
+ int rd = 2*r + 1;
+ for (int i=0; i(floor(x0)) - r + i;
+
+ if (within_bounds(0, x1, 1, w2)) {
+ scalar_t g = 0.0;
+
+ if (i > 0)
+ g += corr_grad[n][i-1][y][x] * scalar_t(dx);
+
+ if (i < rd)
+ g += corr_grad[n][i][y][x] * scalar_t((1.0f-dx));
+
+ volume_grad[n][y][x][x1] += g;
+ }
+ }
+}
+
+std::vector sampler_cuda_forward(
+ torch::Tensor volume,
+ torch::Tensor coords,
+ int radius)
+{
+ const auto batch_size = volume.size(0);
+ const auto ht = volume.size(1);
+ const auto wd = volume.size(2);
+
+ const dim3 blocks((wd + BLOCK - 1) / BLOCK,
+ (ht + BLOCK - 1) / BLOCK,
+ batch_size);
+
+ const dim3 threads(BLOCK, BLOCK);
+
+ auto opts = volume.options();
+ torch::Tensor corr = torch::zeros(
+ {batch_size, 2*radius+1, ht, wd}, opts);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_forward_kernel", ([&] {
+ sampler_forward_kernel<<>>(
+ volume.packed_accessor32(),
+ coords.packed_accessor32(),
+ corr.packed_accessor32(),
+ radius);
+ }));
+
+ return {corr};
+
+}
+
+std::vector sampler_cuda_backward(
+ torch::Tensor volume,
+ torch::Tensor coords,
+ torch::Tensor corr_grad,
+ int radius)
+{
+ const auto batch_size = volume.size(0);
+ const auto ht = volume.size(1);
+ const auto wd = volume.size(2);
+
+ auto volume_grad = torch::zeros_like(volume);
+
+ const dim3 blocks((wd + BLOCK - 1) / BLOCK,
+ (ht + BLOCK - 1) / BLOCK,
+ batch_size);
+
+ const dim3 threads(BLOCK, BLOCK);
+
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_backward_kernel", ([&] {
+ sampler_backward_kernel<<>>(
+ coords.packed_accessor32(),
+ corr_grad.packed_accessor32(),
+ volume_grad.packed_accessor32(),
+ radius);
+ }));
+
+ return {volume_grad};
+}
+
diff --git a/sampler/setup.py b/sampler/setup.py
new file mode 100644
index 0000000..51843ea
--- /dev/null
+++ b/sampler/setup.py
@@ -0,0 +1,28 @@
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+extra_compile_flags = {}
+gencodes = ['-arch=sm_50',
+ '-gencode', 'arch=compute_50,code=sm_50',
+ '-gencode', 'arch=compute_52,code=sm_52',
+ '-gencode', 'arch=compute_60,code=sm_60',
+ '-gencode', 'arch=compute_61,code=sm_61',
+ '-gencode', 'arch=compute_70,code=sm_70',
+ '-gencode', 'arch=compute_75,code=sm_75',
+ '-gencode', 'arch=compute_75,code=compute_75',]
+
+# extra_compile_flags['nvcc'] = gencodes
+
+setup(
+ name='corr_sampler',
+ ext_modules=[
+ CUDAExtension('corr_sampler', [
+ 'sampler.cpp', 'sampler_kernel.cu',
+ ],
+ extra_compile_args=extra_compile_flags)
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension
+ })
+
+
diff --git a/train_stereo.py b/train_stereo.py
new file mode 100644
index 0000000..7c3897f
--- /dev/null
+++ b/train_stereo.py
@@ -0,0 +1,258 @@
+from __future__ import print_function, division
+
+import argparse
+import logging
+import numpy as np
+from pathlib import Path
+from tqdm import tqdm
+
+from torch.utils.tensorboard import SummaryWriter
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from core.raft_stereo import RAFTStereo
+
+from evaluate_stereo import *
+import core.stereo_datasets as datasets
+
+try:
+ from torch.cuda.amp import GradScaler
+except:
+ # dummy GradScaler for PyTorch < 1.6
+ class GradScaler:
+ def __init__(self):
+ pass
+ def scale(self, loss):
+ return loss
+ def unscale_(self, optimizer):
+ pass
+ def step(self, optimizer):
+ optimizer.step()
+ def update(self):
+ pass
+
+
+def sequence_loss(flow_preds, flow_gt, valid, loss_gamma=0.9, max_flow=700):
+ """ Loss function defined over sequence of flow predictions """
+
+ n_predictions = len(flow_preds)
+ assert n_predictions >= 1
+ flow_loss = 0.0
+
+ # exlude invalid pixels and extremely large diplacements
+ mag = torch.sum(flow_gt**2, dim=1).sqrt()
+
+ # exclude extremly large displacements
+ valid = ((valid >= 0.5) & (mag < max_flow)).unsqueeze(1)
+ assert valid.shape == flow_gt.shape, [valid.shape, flow_gt.shape]
+ assert not torch.isinf(flow_gt[valid.bool()]).any()
+
+ for i in range(n_predictions):
+ assert not torch.isnan(flow_preds[i]).any() and not torch.isinf(flow_preds[i]).any()
+ # We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations
+ adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1))
+ i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
+ i_loss = (flow_preds[i] - flow_gt).abs()
+ assert i_loss.shape == valid.shape, [i_loss.shape, valid.shape, flow_gt.shape, flow_preds[i].shape]
+ flow_loss += i_weight * i_loss[valid.bool()].mean()
+
+ epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
+ epe = epe.view(-1)[valid.view(-1)]
+
+ metrics = {
+ 'epe': epe.mean().item(),
+ '1px': (epe < 1).float().mean().item(),
+ '3px': (epe < 3).float().mean().item(),
+ '5px': (epe < 5).float().mean().item(),
+ }
+
+ return flow_loss, metrics
+
+
+def fetch_optimizer(args, model):
+ """ Create the optimizer and learning rate scheduler """
+ optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
+
+ scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
+ pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
+
+ return optimizer, scheduler
+
+
+class Logger:
+
+ SUM_FREQ = 100
+
+ def __init__(self, model, scheduler):
+ self.model = model
+ self.scheduler = scheduler
+ self.total_steps = 0
+ self.running_loss = {}
+ self.writer = SummaryWriter(log_dir='runs')
+
+ def _print_training_status(self):
+ metrics_data = [self.running_loss[k]/Logger.SUM_FREQ for k in sorted(self.running_loss.keys())]
+ training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
+ metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
+
+ # print the training status
+ logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}")
+
+ if self.writer is None:
+ self.writer = SummaryWriter(log_dir='runs')
+
+ for k in self.running_loss:
+ self.writer.add_scalar(k, self.running_loss[k]/Logger.SUM_FREQ, self.total_steps)
+ self.running_loss[k] = 0.0
+
+ def push(self, metrics):
+ self.total_steps += 1
+
+ for key in metrics:
+ if key not in self.running_loss:
+ self.running_loss[key] = 0.0
+
+ self.running_loss[key] += metrics[key]
+
+ if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ-1:
+ self._print_training_status()
+ self.running_loss = {}
+
+ def write_dict(self, results):
+ if self.writer is None:
+ self.writer = SummaryWriter(log_dir='runs')
+
+ for key in results:
+ self.writer.add_scalar(key, results[key], self.total_steps)
+
+ def close(self):
+ self.writer.close()
+
+
+def train(args):
+
+ model = nn.DataParallel(RAFTStereo(args))
+ print("Parameter Count: %d" % count_parameters(model))
+
+ train_loader = datasets.fetch_dataloader(args)
+ optimizer, scheduler = fetch_optimizer(args, model)
+ total_steps = 0
+ logger = Logger(model, scheduler)
+
+ if args.restore_ckpt is not None:
+ assert args.restore_ckpt.endswith(".pth")
+ logging.info("Loading checkpoint...")
+ checkpoint = torch.load(args.restore_ckpt)
+ model.load_state_dict(checkpoint, strict=True)
+ logging.info(f"Done loading checkpoint")
+
+ model.cuda()
+ model.train()
+ model.module.freeze_bn() # We keep BatchNorm frozen
+
+ validation_frequency = 10000
+
+ scaler = GradScaler(enabled=args.mixed_precision)
+
+ should_keep_training = True
+ global_batch_num = 0
+ while should_keep_training:
+
+ for i_batch, (_, *data_blob) in enumerate(tqdm(train_loader)):
+ optimizer.zero_grad()
+ image1, image2, flow, valid = [x.cuda() for x in data_blob]
+
+ assert model.training
+ flow_predictions = model(image1, image2, iters=args.train_iters)
+ assert model.training
+
+ loss, metrics = sequence_loss(flow_predictions, flow, valid)
+ logger.writer.add_scalar("live_loss", loss.item(), global_batch_num)
+ logger.writer.add_scalar(f'learning_rate', optimizer.param_groups[0]['lr'], global_batch_num)
+ global_batch_num += 1
+ scaler.scale(loss).backward()
+ scaler.unscale_(optimizer)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+
+ scaler.step(optimizer)
+ scheduler.step()
+ scaler.update()
+
+ logger.push(metrics)
+
+ if total_steps % validation_frequency == validation_frequency - 1:
+ save_path = Path('checkpoints/%d_%s.pth' % (total_steps + 1, args.name))
+ logging.info(f"Saving file {save_path.absolute()}")
+ torch.save(model.state_dict(), save_path)
+
+ results = validate_things(model.module, iters=args.valid_iters)
+
+ logger.write_dict(results)
+
+ model.train()
+ model.module.freeze_bn()
+
+ total_steps += 1
+
+ if total_steps > args.num_steps:
+ should_keep_training = False
+ break
+
+ if len(train_loader) >= 10000:
+ save_path = Path('checkpoints/%d_epoch_%s.pth.gz' % (total_steps + 1, args.name))
+ logging.info(f"Saving file {save_path}")
+ torch.save(model.state_dict(), save_path)
+
+ print("FINISHED TRAINING")
+ logger.close()
+ PATH = 'checkpoints/%s.pth' % args.name
+ torch.save(model.state_dict(), PATH)
+
+ return PATH
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--name', default='raft-stereo', help="name your experiment")
+ parser.add_argument('--restore_ckpt', help="restore checkpoint")
+ parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
+
+ # Training parameters
+ parser.add_argument('--batch_size', type=int, default=6, help="batch size used during training.")
+ parser.add_argument('--train_datasets', nargs='+', default=['sceneflow'], help="training datasets.")
+ parser.add_argument('--lr', type=float, default=0.0002, help="max learning rate.")
+ parser.add_argument('--num_steps', type=int, default=100000, help="length of training schedule.")
+ parser.add_argument('--image_size', type=int, nargs='+', default=[320, 720], help="size of the random image crops used during training.")
+ parser.add_argument('--train_iters', type=int, default=16, help="number of updates to the disparity field in each forward pass.")
+ parser.add_argument('--wdecay', type=float, default=.00001, help="Weight decay in optimizer.")
+
+ # Validation parameters
+ parser.add_argument('--valid_iters', type=int, default=32, help='number of flow-field updates during validation forward pass')
+
+ # Architecure choices
+ parser.add_argument('--corr_implementation', choices=["reg", "alt", "reg_cuda", "alt_cuda"], default="reg", help="correlation volume implementation")
+ parser.add_argument('--shared_backbone', action='store_true', help="use a single backbone for the context and feature encoders")
+ parser.add_argument('--corr_levels', type=int, default=4, help="number of levels in the correlation pyramid")
+ parser.add_argument('--corr_radius', type=int, default=4, help="width of the correlation pyramid")
+ parser.add_argument('--n_downsample', type=int, default=2, help="resolution of the disparity field (1/2^K)")
+ parser.add_argument('--slow_fast_gru', action='store_true', help="iterate the low-res GRUs more frequently")
+ parser.add_argument('--n_gru_layers', type=int, default=3, help="number of hidden GRU levels")
+ parser.add_argument('--hidden_dims', nargs='+', type=int, default=[128]*3, help="hidden state and context dimensions")
+
+ # Data augmentation
+ parser.add_argument('--img_gamma', type=float, nargs='+', default=None, help="gamma range")
+ parser.add_argument('--saturation_range', type=float, nargs='+', default=None, help='color saturation')
+ parser.add_argument('--do_flip', default=False, choices=['h', 'v'], help='flip the images horizontally or vertically')
+ parser.add_argument('--spatial_scale', type=float, nargs='+', default=[0, 0], help='re-scale the images randomly')
+ parser.add_argument('--noyjitter', action='store_true', help='don\'t simulate imperfect rectification')
+ args = parser.parse_args()
+
+ torch.manual_seed(1234)
+ np.random.seed(1234)
+
+ logging.basicConfig(level=logging.INFO,
+ format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')
+
+ Path("checkpoints").mkdir(exist_ok=True, parents=True)
+
+ train(args)
\ No newline at end of file