Skip to content

Commit 94bb147

Browse files
haixiwhaixiw
authored andcommitted
Fix: load training data from all sub dirs (#396)
* Fix: load training data from all sub dirs * fix flake8 formatting * fix local integ test * resolve some comments
1 parent 07ece4c commit 94bb147

File tree

6 files changed

+5105
-35
lines changed

6 files changed

+5105
-35
lines changed

src/sagemaker_xgboost_container/data_utils.py

Lines changed: 74 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,19 @@
4040
PARQUET = 'parquet'
4141
RECORDIO_PROTOBUF = 'recordio-protobuf'
4242

43-
VALID_CONTENT_TYPES = [CSV, LIBSVM, PARQUET, RECORDIO_PROTOBUF,
44-
_content_types.CSV, xgb_content_types.LIBSVM,
45-
xgb_content_types.X_LIBSVM, xgb_content_types.X_PARQUET,
46-
xgb_content_types.X_RECORDIO_PROTOBUF]
43+
MAX_FOLDER_DEPTH = 3
44+
45+
VALID_CONTENT_TYPES = [
46+
CSV,
47+
LIBSVM,
48+
PARQUET,
49+
RECORDIO_PROTOBUF,
50+
_content_types.CSV,
51+
xgb_content_types.LIBSVM,
52+
xgb_content_types.X_LIBSVM,
53+
xgb_content_types.X_PARQUET,
54+
xgb_content_types.X_RECORDIO_PROTOBUF,
55+
]
4756

4857
VALID_PIPED_CONTENT_TYPES = [CSV, PARQUET, RECORDIO_PROTOBUF,
4958
_content_types.CSV, xgb_content_types.X_PARQUET,
@@ -501,33 +510,74 @@ def _get_pipe_mode_files_path(data_path: Union[List[str], str]) -> List[str]:
501510
return files_path
502511

503512

513+
def _make_symlinks_from_a_folder(dest_path: str, data_path: str, depth: int):
514+
if (depth > MAX_FOLDER_DEPTH):
515+
raise exc.UserError(f"Folder depth exceed the limit: {MAX_FOLDER_DEPTH}.")
516+
517+
if os.path.isfile(data_path):
518+
_make_symlink(data_path, dest_path, os.path.basename(data_path))
519+
return
520+
else:
521+
logging.info(f"Making smlinks from folder {data_path} to folder {dest_path}")
522+
for item in os.scandir(data_path):
523+
if item.is_file():
524+
_make_symlink(item.path, dest_path, item.name)
525+
elif item.is_dir():
526+
_make_symlinks_from_a_folder(dest_path, item.path, depth + 1)
527+
528+
529+
def _make_symlinks_from_a_folder_with_warning(dest_path: str, data_path: str):
530+
"""
531+
:param dest_path: A dir
532+
:param data_path: Either dir or file
533+
:param depth: current folder depth, Integer
534+
"""
535+
536+
# If data_path is a single file A, create smylink A -> dest_path/A
537+
# If data_path is a dir, create symlinks for files located within depth of MAX_FOLDER_DEPTH
538+
# under this dir. Ignore the files in deeper sub dirs and log a warning if they exist.
539+
540+
if (not os.path.exists(dest_path)) or (not os.path.exists(data_path)):
541+
raise exc.AlgorithmError(f"Unable to create symlinks as {data_path} or {dest_path} doesn't exist ")
542+
543+
if (not os.path.isdir(dest_path)):
544+
raise exc.AlgorithmError(f"Unable to create symlinks as dest_path {dest_path} is not a dir")
545+
546+
try:
547+
_make_symlinks_from_a_folder(dest_path, data_path, 1)
548+
except exc.UserError as e:
549+
if e.message == f"Folder depth exceed the limit: {MAX_FOLDER_DEPTH}.":
550+
logging.warning(
551+
f"The depth of folder {data_path} exceed the limit {MAX_FOLDER_DEPTH}."
552+
f" Files in deeper sub dirs won't be loaded."
553+
f" Please adjust the folder structure accordingly."
554+
)
555+
556+
504557
def _get_file_mode_files_path(data_path: Union[List[str], str]) -> List[str]:
505558
"""
506559
:param data_path: Either directory or file
507560
"""
508561
# In file mode, we create a temp directory with symlink to all input files or
509562
# directories to meet XGB's assumption that all files are in the same directory.
510563

564+
logging.info("File path {} of input files".format(data_path))
565+
# Create a directory with symlinks to input files.
566+
files_path = "/tmp/sagemaker_xgboost_input_data"
567+
shutil.rmtree(files_path, ignore_errors=True)
568+
os.mkdir(files_path)
511569
if isinstance(data_path, list):
512-
logging.info('File path {} of input files'.format(data_path))
513-
# Create a directory with symlinks to input files.
514-
files_path = "/tmp/sagemaker_xgboost_input_data"
515-
shutil.rmtree(files_path, ignore_errors=True)
516-
os.mkdir(files_path)
517-
for index, path in enumerate(data_path):
518-
if not os.path.exists(path):
519-
return None
520-
if os.path.isfile(path):
521-
_make_symlink(path, files_path, os.path.basename(path), index)
522-
else:
523-
for file in os.scandir(path):
524-
_make_symlink(file, files_path, file.name, index)
525-
570+
for path in data_path:
571+
_make_symlinks_from_a_folder_with_warning(files_path, path)
526572
else:
527573
if not os.path.exists(data_path):
528574
logging.info('File path {} does not exist!'.format(data_path))
529575
return None
530-
files_path = get_files_path_from_string(data_path)
576+
elif os.path.isdir(data_path) or os.path.isfile(data_path):
577+
# traverse all sub-dirs to load all training data
578+
_make_symlinks_from_a_folder_with_warning(files_path, data_path)
579+
else:
580+
exc.UserError("Unknown input files path: {}".format(data_path))
531581

532582
return files_path
533583

@@ -607,22 +657,11 @@ def get_size(data_path, is_pipe=False):
607657
return total_size
608658

609659

610-
def get_files_path_from_string(data_path: Union[List[str], str]) -> List[str]:
611-
if os.path.isfile(data_path):
612-
files_path = data_path
613-
else:
614-
for root, dirs, files in os.walk(data_path):
615-
if dirs == []:
616-
files_path = root
617-
break
618-
619-
return files_path
620-
621-
622-
def _make_symlink(path, source_path, name, index):
623-
base_name = os.path.join(source_path, f"{name}_{str(index)}")
624-
logging.info(f'creating symlink between Path {source_path} and destination {base_name}')
625-
os.symlink(path, base_name)
660+
def _make_symlink(path, source_path, name):
661+
base_name = os.path.join(source_path, name)
662+
file_name = base_name + str(hash(path))
663+
logging.info(f"creating symlink between Path {path} and destination {file_name}")
664+
os.symlink(path, file_name)
626665

627666

628667
def check_data_redundancy(train_path, validate_path):

0 commit comments

Comments
 (0)