Skip to content

Commit 806dba6

Browse files
Fix int/float typing in video_utils.py (#8234)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 0be6c7e commit 806dba6

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

torchvision/datasets/video_utils.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class VideoClips:
8989
video_paths (List[str]): paths to the video files
9090
clip_length_in_frames (int): size of a clip in number of frames
9191
frames_between_clips (int): step (in frames) between each clip
92-
frame_rate (int, optional): if specified, it will resample the video
92+
frame_rate (float, optional): if specified, it will resample the video
9393
so that it has `frame_rate`, and then the clips will be defined
9494
on the resampled video
9595
num_workers (int): how many subprocesses to use for data loading.
@@ -102,7 +102,7 @@ def __init__(
102102
video_paths: List[str],
103103
clip_length_in_frames: int = 16,
104104
frames_between_clips: int = 1,
105-
frame_rate: Optional[int] = None,
105+
frame_rate: Optional[float] = None,
106106
_precomputed_metadata: Optional[Dict[str, Any]] = None,
107107
num_workers: int = 0,
108108
_video_width: int = 0,
@@ -136,7 +136,7 @@ def __init__(
136136

137137
def _compute_frame_pts(self) -> None:
138138
self.video_pts = [] # len = num_videos. Each entry is a tensor of shape (num_frames_in_video,)
139-
self.video_fps: List[int] = [] # len = num_videos
139+
self.video_fps: List[float] = [] # len = num_videos
140140

141141
# strategy: use a DataLoader to parallelize read_video_timestamps
142142
# so need to create a dummy dataset first
@@ -203,15 +203,15 @@ def subset(self, indices: List[int]) -> "VideoClips":
203203

204204
@staticmethod
205205
def compute_clips_for_video(
206-
video_pts: torch.Tensor, num_frames: int, step: int, fps: int, frame_rate: Optional[int] = None
206+
video_pts: torch.Tensor, num_frames: int, step: int, fps: Optional[float], frame_rate: Optional[float] = None
207207
) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]:
208208
if fps is None:
209209
# if for some reason the video doesn't have fps (because doesn't have a video stream)
210210
# set the fps to 1. The value doesn't matter, because video_pts is empty anyway
211211
fps = 1
212212
if frame_rate is None:
213213
frame_rate = fps
214-
total_frames = len(video_pts) * (float(frame_rate) / fps)
214+
total_frames = len(video_pts) * frame_rate / fps
215215
_idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
216216
video_pts = video_pts[_idxs]
217217
clips = unfold(video_pts, num_frames, step)
@@ -227,7 +227,7 @@ def compute_clips_for_video(
227227
idxs = unfold(_idxs, num_frames, step)
228228
return clips, idxs
229229

230-
def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[int] = None) -> None:
230+
def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[float] = None) -> None:
231231
"""
232232
Compute all consecutive sequences of clips from video_pts.
233233
Always returns clips of size `num_frames`, meaning that the
@@ -275,8 +275,8 @@ def get_clip_location(self, idx: int) -> Tuple[int, int]:
275275
return video_idx, clip_idx
276276

277277
@staticmethod
278-
def _resample_video_idx(num_frames: int, original_fps: int, new_fps: int) -> Union[slice, torch.Tensor]:
279-
step = float(original_fps) / new_fps
278+
def _resample_video_idx(num_frames: int, original_fps: float, new_fps: float) -> Union[slice, torch.Tensor]:
279+
step = original_fps / new_fps
280280
if step.is_integer():
281281
# optimization: if step is integer, don't need to perform
282282
# advanced indexing

0 commit comments

Comments
 (0)