A JAX/Flax implementation of the RAFT optical flow estimator (https://arxiv.org/abs/2003.12039), ported from PyTorch (https://docs.pytorch.org/vision/main/models/raft.html). Checkpoints have been ported, too. The implementation has been tested to reproduce the original results.
With pre-trained checkpoints, jax-raft
achieves the following metrics on Sintel (train), compared to the original PyTorch implementation. This comparison uses the raft_large_C_T_SKHT_V2
and raft_small_C_T_V2
checkpoints, respectively. FPS have been computed on a single RTX 3090 Ti.
Model | EPE (clean) ↓ | EPE (final) ↓ | FPS |
---|---|---|---|
raft_large (jax-raft ) |
0.649 | 1.020 | 11.8 |
raft_large (PyTorch) | 0.649 | 1.020 | 8.1 |
raft_small (jax-raft ) |
1.993 | 3.268 | 36.6 |
raft_small (PyTorch) | 1.998 | 3.279 | 15.0 |
from jax_raft import raft_large # or raft_small
model, variables = raft_large(pretrained=True)
model.apply(variables, image1, image2, train=False)
pip install git+https://github.com/alebeck/jax-raft
In the scripts
directory, we provide scripts for converting official PyTorch RAFT checkpoints to Flax; and for validation on Sintel. The examples
directory contains example usage scripts.