Skip to content

Commit c8f7b97

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] Doc revamp for optical flow models (#5895)
Summary: * Doc revamp for optical flow models * Some more Reviewed By: NicolasHug Differential Revision: D36760919 fbshipit-source-id: b43d45ec9fe9d3a1663a341ef6dec746c7bb1ace
1 parent 0c6aabb commit c8f7b97

File tree

4 files changed

+132
-26
lines changed

4 files changed

+132
-26
lines changed

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
347347
metrics = meta.pop("_metrics")
348348
for dataset, dataset_metrics in metrics.items():
349349
for metric_name, metric_value in dataset_metrics.items():
350+
metric_name = metric_name.replace("_", "-")
350351
table.append((f"{metric_name} (on {dataset})", str(metric_value)))
351352

352353
for k, v in meta.items():

docs/source/models/raft.rst

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
RAFT
2+
====
3+
4+
.. currentmodule:: torchvision.models.optical_flow
5+
6+
The RAFT model is based on the `RAFT: Recurrent All-Pairs Field Transforms for
7+
Optical Flow <https://arxiv.org/abs/2003.12039>`__ paper.
8+
9+
10+
Model builders
11+
--------------
12+
13+
The following model builders can be used to instantiate a RAFT model, with or
14+
without pre-trained weights. All the model builders internally rely on the
15+
``torchvision.models.optical_flow.RAFT`` base class. Please refer to the `source
16+
code
17+
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_ for
18+
more details about this class.
19+
20+
.. autosummary::
21+
:toctree: generated/
22+
:template: function.rst
23+
24+
raft_large
25+
raft_small

docs/source/models_new.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ Box MAPs are reported on COCO val2017:
376376

377377
.. include:: generated/detection_table.rst
378378

379+
379380
Instance Segmentation
380381
---------------------
381382

@@ -481,6 +482,18 @@ Accuracies are reported on Kinetics-400 using single crops for clip length 16:
481482

482483
.. include:: generated/video_table.rst
483484

485+
Optical Flow
486+
============
487+
488+
.. currentmodule:: torchvision.models.optical_flow
489+
490+
The following Optical Flow models are available, with or without pre-trained
491+
492+
.. toctree::
493+
:maxdepth: 1
494+
495+
models/raft
496+
484497
Using models from Hub
485498
=====================
486499

torchvision/models/optical_flow/raft.py

Lines changed: 93 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,19 @@ def forward(self, image1, image2, num_flow_updates: int = 12):
517517

518518

519519
class Raft_Large_Weights(WeightsEnum):
520+
"""The metrics reported here are as follows.
521+
522+
``epe`` is the "end-point-error" and indicates how far (in pixels) the
523+
predicted flow is from its true value. This is averaged over all pixels
524+
of all images. ``per_image_epe`` is similar, but the average is different:
525+
the epe is first computed on each image independently, and then averaged
526+
over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe")
527+
in the original paper, and it's only used on Kitti. ``fl-all`` is also a
528+
Kitti-specific metric, defined by the author of the dataset and used for the
529+
Kitti leaderboard. It corresponds to the average of pixels whose epe is
530+
either <3px, or <5% of flow's 2-norm.
531+
"""
532+
520533
C_T_V1 = Weights(
521534
# Weights ported from https://github.com/princeton-vl/RAFT
522535
url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth",
@@ -530,7 +543,9 @@ class Raft_Large_Weights(WeightsEnum):
530543
"Sintel-Train-Finalpass": {"epe": 2.7894},
531544
"Kitti-Train": {"per_image_epe": 5.0172, "fl_all": 17.4506},
532545
},
533-
"_docs": """These weights were ported from the original paper. They are trained on Chairs + Things.""",
546+
"_docs": """These weights were ported from the original paper. They
547+
are trained on :class:`~torchvision.datasets.FlyingChairs` +
548+
:class:`~torchvision.datasets.FlyingThings3D`.""",
534549
},
535550
)
536551

@@ -546,7 +561,9 @@ class Raft_Large_Weights(WeightsEnum):
546561
"Sintel-Train-Finalpass": {"epe": 2.7161},
547562
"Kitti-Train": {"per_image_epe": 4.5118, "fl_all": 16.0679},
548563
},
549-
"_docs": """These weights were trained from scratch on Chairs + Things.""",
564+
"_docs": """These weights were trained from scratch on
565+
:class:`~torchvision.datasets.FlyingChairs` +
566+
:class:`~torchvision.datasets.FlyingThings3D`.""",
550567
},
551568
)
552569

@@ -563,8 +580,14 @@ class Raft_Large_Weights(WeightsEnum):
563580
"Sintel-Test-Finalpass": {"epe": 3.18},
564581
},
565582
"_docs": """
566-
These weights were ported from the original paper. They are trained on Chairs + Things and fine-tuned on
567-
Sintel (C+T+S+K+H).
583+
These weights were ported from the original paper. They are
584+
trained on :class:`~torchvision.datasets.FlyingChairs` +
585+
:class:`~torchvision.datasets.FlyingThings3D` and fine-tuned on
586+
Sintel. The Sintel fine-tuning step is a combination of
587+
:class:`~torchvision.datasets.Sintel`,
588+
:class:`~torchvision.datasets.KittiFlow`,
589+
:class:`~torchvision.datasets.HD1K`, and
590+
:class:`~torchvision.datasets.FlyingThings3D` (clean pass).
568591
""",
569592
},
570593
)
@@ -581,7 +604,14 @@ class Raft_Large_Weights(WeightsEnum):
581604
"Sintel-Test-Finalpass": {"epe": 3.067},
582605
},
583606
"_docs": """
584-
These weights were trained from scratch on Chairs + Things and fine-tuned on Sintel (C+T+S+K+H).
607+
These weights were trained from scratch. They are
608+
pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
609+
:class:`~torchvision.datasets.FlyingThings3D` and then
610+
fine-tuned on Sintel. The Sintel fine-tuning step is a
611+
combination of :class:`~torchvision.datasets.Sintel`,
612+
:class:`~torchvision.datasets.KittiFlow`,
613+
:class:`~torchvision.datasets.HD1K`, and
614+
:class:`~torchvision.datasets.FlyingThings3D` (clean pass).
585615
""",
586616
},
587617
)
@@ -598,8 +628,12 @@ class Raft_Large_Weights(WeightsEnum):
598628
"Kitti-Test": {"fl_all": 5.10},
599629
},
600630
"_docs": """
601-
These weights were ported from the original paper. They are trained on Chairs + Things, fine-tuned on
602-
Sintel and then on Kitti.
631+
These weights were ported from the original paper. They are
632+
pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
633+
:class:`~torchvision.datasets.FlyingThings3D`,
634+
fine-tuned on Sintel, and then fine-tuned on
635+
:class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
636+
step was described above.
603637
""",
604638
},
605639
)
@@ -615,7 +649,12 @@ class Raft_Large_Weights(WeightsEnum):
615649
"Kitti-Test": {"fl_all": 5.19},
616650
},
617651
"_docs": """
618-
These weights were trained from scratch on Chairs + Things, fine-tuned on Sintel and then on Kitti.
652+
These weights were trained from scratch. They are
653+
pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
654+
:class:`~torchvision.datasets.FlyingThings3D`,
655+
fine-tuned on Sintel, and then fine-tuned on
656+
:class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
657+
step was described above.
619658
""",
620659
},
621660
)
@@ -624,6 +663,19 @@ class Raft_Large_Weights(WeightsEnum):
624663

625664

626665
class Raft_Small_Weights(WeightsEnum):
666+
"""The metrics reported here are as follows.
667+
668+
``epe`` is the "end-point-error" and indicates how far (in pixels) the
669+
predicted flow is from its true value. This is averaged over all pixels
670+
of all images. ``per_image_epe`` is similar, but the average is different:
671+
the epe is first computed on each image independently, and then averaged
672+
over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe")
673+
in the original paper, and it's only used on Kitti. ``fl-all`` is also a
674+
Kitti-specific metric, defined by the author of the dataset and used for the
675+
Kitti leaderboard. It corresponds to the average of pixels whose epe is
676+
either <3px, or <5% of flow's 2-norm.
677+
"""
678+
627679
C_T_V1 = Weights(
628680
# Weights ported from https://github.com/princeton-vl/RAFT
629681
url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
@@ -637,7 +689,9 @@ class Raft_Small_Weights(WeightsEnum):
637689
"Sintel-Train-Finalpass": {"epe": 3.2790},
638690
"Kitti-Train": {"per_image_epe": 7.6557, "fl_all": 25.2801},
639691
},
640-
"_docs": """These weights were ported from the original paper. They are trained on Chairs + Things.""",
692+
"_docs": """These weights were ported from the original paper. They
693+
are trained on :class:`~torchvision.datasets.FlyingChairs` +
694+
:class:`~torchvision.datasets.FlyingThings3D`.""",
641695
},
642696
)
643697
C_T_V2 = Weights(
@@ -652,7 +706,9 @@ class Raft_Small_Weights(WeightsEnum):
652706
"Sintel-Train-Finalpass": {"epe": 3.2831},
653707
"Kitti-Train": {"per_image_epe": 7.5978, "fl_all": 25.2369},
654708
},
655-
"_docs": """These weights were trained from scratch on Chairs + Things.""",
709+
"_docs": """These weights were trained from scratch on
710+
:class:`~torchvision.datasets.FlyingChairs` +
711+
:class:`~torchvision.datasets.FlyingThings3D`.""",
656712
},
657713
)
658714

@@ -750,13 +806,19 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
750806
Please see the example below for a tutorial on how to use this model.
751807
752808
Args:
753-
weights(Raft_Large_weights, optional): The pretrained weights for the model
754-
progress (bool): If True, displays a progress bar of the download to stderr
755-
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
756-
to override any default.
757-
758-
Returns:
759-
RAFT: The model.
809+
weights(:class:`~torchvision.models.optical_flow.Raft_Large_Weights`, optional): The
810+
pretrained weights to use. See
811+
:class:`~torchvision.models.optical_flow.Raft_Large_Weights`
812+
below for more details, and possible values. By default, no
813+
pre-trained weights are used.
814+
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
815+
**kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
816+
base class. Please refer to the `source code
817+
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
818+
for more details about this class.
819+
820+
.. autoclass:: torchvision.models.optical_flow.Raft_Large_Weights
821+
:members:
760822
"""
761823

762824
weights = Raft_Large_Weights.verify(weights)
@@ -794,19 +856,24 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
794856
@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2))
795857
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT:
796858
"""RAFT "small" model from
797-
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
859+
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`__.
798860
799861
Please see the example below for a tutorial on how to use this model.
800862
801863
Args:
802-
weights(Raft_Small_weights, optional): The pretrained weights for the model
803-
progress (bool): If True, displays a progress bar of the download to stderr
804-
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
805-
to override any default.
806-
807-
Returns:
808-
RAFT: The model.
809-
864+
weights(:class:`~torchvision.models.optical_flow.Raft_Small_Weights`, optional): The
865+
pretrained weights to use. See
866+
:class:`~torchvision.models.optical_flow.Raft_Small_Weights`
867+
below for more details, and possible values. By default, no
868+
pre-trained weights are used.
869+
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
870+
**kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
871+
base class. Please refer to the `source code
872+
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
873+
for more details about this class.
874+
875+
.. autoclass:: torchvision.models.optical_flow.Raft_Small_Weights
876+
:members:
810877
"""
811878
weights = Raft_Small_Weights.verify(weights)
812879

0 commit comments

Comments
 (0)