@@ -517,6 +517,19 @@ def forward(self, image1, image2, num_flow_updates: int = 12):
517517
518518
519519class Raft_Large_Weights (WeightsEnum ):
520+ """The metrics reported here are as follows.
521+
522+ ``epe`` is the "end-point-error" and indicates how far (in pixels) the
523+ predicted flow is from its true value. This is averaged over all pixels
524+ of all images. ``per_image_epe`` is similar, but the average is different:
525+ the epe is first computed on each image independently, and then averaged
526+ over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe")
527+ in the original paper, and it's only used on Kitti. ``fl-all`` is also a
528+ Kitti-specific metric, defined by the author of the dataset and used for the
529+ Kitti leaderboard. It corresponds to the average of pixels whose epe is
530+ either <3px, or <5% of flow's 2-norm.
531+ """
532+
520533 C_T_V1 = Weights (
521534 # Weights ported from https://github.com/princeton-vl/RAFT
522535 url = "https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth" ,
@@ -530,7 +543,9 @@ class Raft_Large_Weights(WeightsEnum):
530543 "Sintel-Train-Finalpass" : {"epe" : 2.7894 },
531544 "Kitti-Train" : {"per_image_epe" : 5.0172 , "fl_all" : 17.4506 },
532545 },
533- "_docs" : """These weights were ported from the original paper. They are trained on Chairs + Things.""" ,
546+ "_docs" : """These weights were ported from the original paper. They
547+ are trained on :class:`~torchvision.datasets.FlyingChairs` +
548+ :class:`~torchvision.datasets.FlyingThings3D`.""" ,
534549 },
535550 )
536551
@@ -546,7 +561,9 @@ class Raft_Large_Weights(WeightsEnum):
546561 "Sintel-Train-Finalpass" : {"epe" : 2.7161 },
547562 "Kitti-Train" : {"per_image_epe" : 4.5118 , "fl_all" : 16.0679 },
548563 },
549- "_docs" : """These weights were trained from scratch on Chairs + Things.""" ,
564+ "_docs" : """These weights were trained from scratch on
565+ :class:`~torchvision.datasets.FlyingChairs` +
566+ :class:`~torchvision.datasets.FlyingThings3D`.""" ,
550567 },
551568 )
552569
@@ -563,8 +580,14 @@ class Raft_Large_Weights(WeightsEnum):
563580 "Sintel-Test-Finalpass" : {"epe" : 3.18 },
564581 },
565582 "_docs" : """
566- These weights were ported from the original paper. They are trained on Chairs + Things and fine-tuned on
567- Sintel (C+T+S+K+H).
583+ These weights were ported from the original paper. They are
584+ trained on :class:`~torchvision.datasets.FlyingChairs` +
585+ :class:`~torchvision.datasets.FlyingThings3D` and fine-tuned on
586+ Sintel. The Sintel fine-tuning step is a combination of
587+ :class:`~torchvision.datasets.Sintel`,
588+ :class:`~torchvision.datasets.KittiFlow`,
589+ :class:`~torchvision.datasets.HD1K`, and
590+ :class:`~torchvision.datasets.FlyingThings3D` (clean pass).
568591 """ ,
569592 },
570593 )
@@ -581,7 +604,14 @@ class Raft_Large_Weights(WeightsEnum):
581604 "Sintel-Test-Finalpass" : {"epe" : 3.067 },
582605 },
583606 "_docs" : """
584- These weights were trained from scratch on Chairs + Things and fine-tuned on Sintel (C+T+S+K+H).
607+ These weights were trained from scratch. They are
608+ pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
609+ :class:`~torchvision.datasets.FlyingThings3D` and then
610+ fine-tuned on Sintel. The Sintel fine-tuning step is a
611+ combination of :class:`~torchvision.datasets.Sintel`,
612+ :class:`~torchvision.datasets.KittiFlow`,
613+ :class:`~torchvision.datasets.HD1K`, and
614+ :class:`~torchvision.datasets.FlyingThings3D` (clean pass).
585615 """ ,
586616 },
587617 )
@@ -598,8 +628,12 @@ class Raft_Large_Weights(WeightsEnum):
598628 "Kitti-Test" : {"fl_all" : 5.10 },
599629 },
600630 "_docs" : """
601- These weights were ported from the original paper. They are trained on Chairs + Things, fine-tuned on
602- Sintel and then on Kitti.
631+ These weights were ported from the original paper. They are
632+ pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
633+ :class:`~torchvision.datasets.FlyingThings3D`,
634+ fine-tuned on Sintel, and then fine-tuned on
635+ :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
636+ step was described above.
603637 """ ,
604638 },
605639 )
@@ -615,7 +649,12 @@ class Raft_Large_Weights(WeightsEnum):
615649 "Kitti-Test" : {"fl_all" : 5.19 },
616650 },
617651 "_docs" : """
618- These weights were trained from scratch on Chairs + Things, fine-tuned on Sintel and then on Kitti.
652+ These weights were trained from scratch. They are
653+ pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
654+ :class:`~torchvision.datasets.FlyingThings3D`,
655+ fine-tuned on Sintel, and then fine-tuned on
656+ :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
657+ step was described above.
619658 """ ,
620659 },
621660 )
@@ -624,6 +663,19 @@ class Raft_Large_Weights(WeightsEnum):
624663
625664
626665class Raft_Small_Weights (WeightsEnum ):
666+ """The metrics reported here are as follows.
667+
668+ ``epe`` is the "end-point-error" and indicates how far (in pixels) the
669+ predicted flow is from its true value. This is averaged over all pixels
670+ of all images. ``per_image_epe`` is similar, but the average is different:
671+ the epe is first computed on each image independently, and then averaged
672+ over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe")
673+ in the original paper, and it's only used on Kitti. ``fl-all`` is also a
674+ Kitti-specific metric, defined by the author of the dataset and used for the
675+ Kitti leaderboard. It corresponds to the average of pixels whose epe is
676+ either <3px, or <5% of flow's 2-norm.
677+ """
678+
627679 C_T_V1 = Weights (
628680 # Weights ported from https://github.com/princeton-vl/RAFT
629681 url = "https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth" ,
@@ -637,7 +689,9 @@ class Raft_Small_Weights(WeightsEnum):
637689 "Sintel-Train-Finalpass" : {"epe" : 3.2790 },
638690 "Kitti-Train" : {"per_image_epe" : 7.6557 , "fl_all" : 25.2801 },
639691 },
640- "_docs" : """These weights were ported from the original paper. They are trained on Chairs + Things.""" ,
692+ "_docs" : """These weights were ported from the original paper. They
693+ are trained on :class:`~torchvision.datasets.FlyingChairs` +
694+ :class:`~torchvision.datasets.FlyingThings3D`.""" ,
641695 },
642696 )
643697 C_T_V2 = Weights (
@@ -652,7 +706,9 @@ class Raft_Small_Weights(WeightsEnum):
652706 "Sintel-Train-Finalpass" : {"epe" : 3.2831 },
653707 "Kitti-Train" : {"per_image_epe" : 7.5978 , "fl_all" : 25.2369 },
654708 },
655- "_docs" : """These weights were trained from scratch on Chairs + Things.""" ,
709+ "_docs" : """These weights were trained from scratch on
710+ :class:`~torchvision.datasets.FlyingChairs` +
711+ :class:`~torchvision.datasets.FlyingThings3D`.""" ,
656712 },
657713 )
658714
@@ -750,13 +806,19 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
750806 Please see the example below for a tutorial on how to use this model.
751807
752808 Args:
753- weights(Raft_Large_weights, optional): The pretrained weights for the model
754- progress (bool): If True, displays a progress bar of the download to stderr
755- kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
756- to override any default.
757-
758- Returns:
759- RAFT: The model.
809+ weights(:class:`~torchvision.models.optical_flow.Raft_Large_Weights`, optional): The
810+ pretrained weights to use. See
811+ :class:`~torchvision.models.optical_flow.Raft_Large_Weights`
812+ below for more details, and possible values. By default, no
813+ pre-trained weights are used.
814+ progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
815+ **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
816+ base class. Please refer to the `source code
817+ <https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
818+ for more details about this class.
819+
820+ .. autoclass:: torchvision.models.optical_flow.Raft_Large_Weights
821+ :members:
760822 """
761823
762824 weights = Raft_Large_Weights .verify (weights )
@@ -794,19 +856,24 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
794856@handle_legacy_interface (weights = ("pretrained" , Raft_Small_Weights .C_T_V2 ))
795857def raft_small (* , weights : Optional [Raft_Small_Weights ] = None , progress = True , ** kwargs ) -> RAFT :
796858 """RAFT "small" model from
797- `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_ .
859+ `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`__ .
798860
799861 Please see the example below for a tutorial on how to use this model.
800862
801863 Args:
802- weights(Raft_Small_weights, optional): The pretrained weights for the model
803- progress (bool): If True, displays a progress bar of the download to stderr
804- kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
805- to override any default.
806-
807- Returns:
808- RAFT: The model.
809-
864+ weights(:class:`~torchvision.models.optical_flow.Raft_Small_Weights`, optional): The
865+ pretrained weights to use. See
866+ :class:`~torchvision.models.optical_flow.Raft_Small_Weights`
867+ below for more details, and possible values. By default, no
868+ pre-trained weights are used.
869+ progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
870+ **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
871+ base class. Please refer to the `source code
872+ <https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
873+ for more details about this class.
874+
875+ .. autoclass:: torchvision.models.optical_flow.Raft_Small_Weights
876+ :members:
810877 """
811878 weights = Raft_Small_Weights .verify (weights )
812879
0 commit comments