Skip to content

Commit 393a24e

Browse files
authored
Fix: load training data from all sub dirs (#396)
* Fix: load training data from all sub dirs * fix flake8 formatting * fix local integ test * resolve some comments
1 parent 90a17b8 commit 393a24e

File tree

6 files changed

+5094
-31
lines changed

6 files changed

+5094
-31
lines changed

src/sagemaker_xgboost_container/data_utils.py

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
PARQUET = "parquet"
3939
RECORDIO_PROTOBUF = "recordio-protobuf"
4040

41+
MAX_FOLDER_DEPTH = 3
42+
4143
VALID_CONTENT_TYPES = [
4244
CSV,
4345
LIBSVM,
@@ -528,33 +530,74 @@ def _get_pipe_mode_files_path(data_path: Union[List[str], str]) -> List[str]:
528530
return files_path
529531

530532

533+
def _make_symlinks_from_a_folder(dest_path: str, data_path: str, depth: int):
534+
if (depth > MAX_FOLDER_DEPTH):
535+
raise exc.UserError(f"Folder depth exceed the limit: {MAX_FOLDER_DEPTH}.")
536+
537+
if os.path.isfile(data_path):
538+
_make_symlink(data_path, dest_path, os.path.basename(data_path))
539+
return
540+
else:
541+
logging.info(f"Making smlinks from folder {data_path} to folder {dest_path}")
542+
for item in os.scandir(data_path):
543+
if item.is_file():
544+
_make_symlink(item.path, dest_path, item.name)
545+
elif item.is_dir():
546+
_make_symlinks_from_a_folder(dest_path, item.path, depth + 1)
547+
548+
549+
def _make_symlinks_from_a_folder_with_warning(dest_path: str, data_path: str):
550+
"""
551+
:param dest_path: A dir
552+
:param data_path: Either dir or file
553+
:param depth: current folder depth, Integer
554+
"""
555+
556+
# If data_path is a single file A, create smylink A -> dest_path/A
557+
# If data_path is a dir, create symlinks for files located within depth of MAX_FOLDER_DEPTH
558+
# under this dir. Ignore the files in deeper sub dirs and log a warning if they exist.
559+
560+
if (not os.path.exists(dest_path)) or (not os.path.exists(data_path)):
561+
raise exc.AlgorithmError(f"Unable to create symlinks as {data_path} or {dest_path} doesn't exist ")
562+
563+
if (not os.path.isdir(dest_path)):
564+
raise exc.AlgorithmError(f"Unable to create symlinks as dest_path {dest_path} is not a dir")
565+
566+
try:
567+
_make_symlinks_from_a_folder(dest_path, data_path, 1)
568+
except exc.UserError as e:
569+
if e.message == f"Folder depth exceed the limit: {MAX_FOLDER_DEPTH}.":
570+
logging.warning(
571+
f"The depth of folder {data_path} exceed the limit {MAX_FOLDER_DEPTH}."
572+
f" Files in deeper sub dirs won't be loaded."
573+
f" Please adjust the folder structure accordingly."
574+
)
575+
576+
531577
def _get_file_mode_files_path(data_path: Union[List[str], str]) -> List[str]:
532578
"""
533579
:param data_path: Either directory or file
534580
"""
535581
# In file mode, we create a temp directory with symlink to all input files or
536582
# directories to meet XGB's assumption that all files are in the same directory.
537583

584+
logging.info("File path {} of input files".format(data_path))
585+
# Create a directory with symlinks to input files.
586+
files_path = "/tmp/sagemaker_xgboost_input_data"
587+
shutil.rmtree(files_path, ignore_errors=True)
588+
os.mkdir(files_path)
538589
if isinstance(data_path, list):
539-
logging.info("File path {} of input files".format(data_path))
540-
# Create a directory with symlinks to input files.
541-
files_path = "/tmp/sagemaker_xgboost_input_data"
542-
shutil.rmtree(files_path, ignore_errors=True)
543-
os.mkdir(files_path)
544-
for index, path in enumerate(data_path):
545-
if not os.path.exists(path):
546-
return None
547-
if os.path.isfile(path):
548-
_make_symlink(path, files_path, os.path.basename(path), index)
549-
else:
550-
for file in os.scandir(path):
551-
_make_symlink(file, files_path, file.name, index)
552-
590+
for path in data_path:
591+
_make_symlinks_from_a_folder_with_warning(files_path, path)
553592
else:
554593
if not os.path.exists(data_path):
555594
logging.info("File path {} does not exist!".format(data_path))
556595
return None
557-
files_path = get_files_path_from_string(data_path)
596+
elif os.path.isdir(data_path) or os.path.isfile(data_path):
597+
# traverse all sub-dirs to load all training data
598+
_make_symlinks_from_a_folder_with_warning(files_path, data_path)
599+
else:
600+
exc.UserError("Unknown input files path: {}".format(data_path))
558601

559602
return files_path
560603

@@ -635,22 +678,11 @@ def get_size(data_path, is_pipe=False):
635678
return total_size
636679

637680

638-
def get_files_path_from_string(data_path: Union[List[str], str]) -> List[str]:
639-
if os.path.isfile(data_path):
640-
files_path = data_path
641-
else:
642-
for root, dirs, files in os.walk(data_path):
643-
if dirs == []:
644-
files_path = root
645-
break
646-
647-
return files_path
648-
649-
650-
def _make_symlink(path, source_path, name, index):
651-
base_name = os.path.join(source_path, f"{name}_{str(index)}")
652-
logging.info(f"creating symlink between Path {source_path} and destination {base_name}")
653-
os.symlink(path, base_name)
681+
def _make_symlink(path, source_path, name):
682+
base_name = os.path.join(source_path, name)
683+
file_name = base_name + str(hash(path))
684+
logging.info(f"creating symlink between Path {path} and destination {file_name}")
685+
os.symlink(path, file_name)
654686

655687

656688
def check_data_redundancy(train_path, validate_path):

0 commit comments

Comments
 (0)