Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed May 10, 2021
1 parent 5f254b6 commit 8745191
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions flash/vision/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,20 @@ def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]]) -
input_data, target_data = data

if self.isdir(input_data) and self.isdir(target_data):
files = os.listdir(input_data)
input_files = [os.path.join(input_data, file) for file in files]
target_files = [os.path.join(target_data, file) for file in files]
input_files = os.listdir(input_data)
target_files = os.listdir(target_data)

target_files = list(filter(os.path.isfile, target_files))
all_files = set(input_files).intersection(set(target_files))

if len(input_files) != len(target_files):
if len(all_files) != len(input_files) or len(all_files) != len(target_files):
rank_zero_warn(
f"Found inconsistent files in input_dir: {input_data} and target_dir: {target_data}. Some files"
" have been dropped.",
UserWarning,
)

input_data = input_files
target_data = target_files
input_data = [os.path.join(input_data, file) for file in all_files]
target_data = [os.path.join(target_data, file) for file in all_files]

if not isinstance(input_data, list) and not isinstance(target_data, list):
input_data = [input_data]
Expand Down

0 comments on commit 8745191

Please sign in to comment.