Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Feb 15, 2022
1 parent 2047b35 commit 8cc1bc4
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions flash/video/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def from_files(
>>> _ = [os.remove(f"predict_video_{i}.mp4") for i in range(1, 4)]
"""
ds_kw = dict(
target_formatter=target_formatter,
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
clip_sampler=clip_sampler,
Expand All @@ -197,6 +196,7 @@ def from_files(
train_targets,
transform=train_transform,
video_sampler=video_sampler,
target_formatter=target_formatter,
**ds_kw,
)
target_formatter = getattr(train_input, "target_formatter", None)
Expand Down Expand Up @@ -357,7 +357,6 @@ def from_folders(
>>> shutil.rmtree("predict_folder")
"""
ds_kw = dict(
target_formatter=target_formatter,
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
clip_sampler=clip_sampler,
Expand All @@ -368,7 +367,12 @@ def from_folders(
)

train_input = input_cls(
RunningStage.TRAINING, train_folder, transform=train_transform, video_sampler=video_sampler, **ds_kw
RunningStage.TRAINING,
train_folder,
transform=train_transform,
video_sampler=video_sampler,
target_formatter=target_formatter,
**ds_kw,
)
target_formatter = getattr(train_input, "target_formatter", None)

Expand Down Expand Up @@ -546,7 +550,6 @@ def from_data_frame(
>>> del predict_data_frame
"""
ds_kw = dict(
target_formatter=target_formatter,
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
clip_sampler=clip_sampler,
Expand All @@ -562,7 +565,12 @@ def from_data_frame(
predict_data = (predict_data_frame, input_field, predict_videos_root, predict_resolver)

train_input = input_cls(
RunningStage.TRAINING, *train_data, transform=train_transform, video_sampler=video_sampler, **ds_kw
RunningStage.TRAINING,
*train_data,
transform=train_transform,
video_sampler=video_sampler,
target_formatter=target_formatter,
**ds_kw,
)
target_formatter = getattr(train_input, "target_formatter", None)

Expand Down Expand Up @@ -754,7 +762,6 @@ def from_csv(
>>> os.remove("predict_data.csv")
"""
ds_kw = dict(
target_formatter=target_formatter,
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
clip_sampler=clip_sampler,
Expand All @@ -770,7 +777,12 @@ def from_csv(
predict_data = (predict_file, input_field, predict_videos_root, predict_resolver)

train_input = input_cls(
RunningStage.TRAINING, *train_data, transform=train_transform, video_sampler=video_sampler, **ds_kw
RunningStage.TRAINING,
*train_data,
transform=train_transform,
video_sampler=video_sampler,
target_formatter=target_formatter,
**ds_kw,
)
target_formatter = getattr(train_input, "target_formatter", None)

Expand Down Expand Up @@ -917,7 +929,6 @@ def from_fiftyone(
>>> del predict_dataset
"""
ds_kw = dict(
target_formatter=target_formatter,
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
clip_sampler=clip_sampler,
Expand All @@ -933,6 +944,7 @@ def from_fiftyone(
transform=train_transform,
video_sampler=video_sampler,
label_field=label_field,
target_formatter=target_formatter,
**ds_kw,
)
target_formatter = getattr(train_input, "target_formatter", None)
Expand Down

0 comments on commit 8cc1bc4

Please sign in to comment.