|
40 | 40 | PARQUET = 'parquet' |
41 | 41 | RECORDIO_PROTOBUF = 'recordio-protobuf' |
42 | 42 |
|
43 | | -VALID_CONTENT_TYPES = [CSV, LIBSVM, PARQUET, RECORDIO_PROTOBUF, |
44 | | - _content_types.CSV, xgb_content_types.LIBSVM, |
45 | | - xgb_content_types.X_LIBSVM, xgb_content_types.X_PARQUET, |
46 | | - xgb_content_types.X_RECORDIO_PROTOBUF] |
| 43 | +MAX_FOLDER_DEPTH = 3 |
| 44 | + |
| 45 | +VALID_CONTENT_TYPES = [ |
| 46 | + CSV, |
| 47 | + LIBSVM, |
| 48 | + PARQUET, |
| 49 | + RECORDIO_PROTOBUF, |
| 50 | + _content_types.CSV, |
| 51 | + xgb_content_types.LIBSVM, |
| 52 | + xgb_content_types.X_LIBSVM, |
| 53 | + xgb_content_types.X_PARQUET, |
| 54 | + xgb_content_types.X_RECORDIO_PROTOBUF, |
| 55 | +] |
47 | 56 |
|
48 | 57 | VALID_PIPED_CONTENT_TYPES = [CSV, PARQUET, RECORDIO_PROTOBUF, |
49 | 58 | _content_types.CSV, xgb_content_types.X_PARQUET, |
@@ -501,33 +510,74 @@ def _get_pipe_mode_files_path(data_path: Union[List[str], str]) -> List[str]: |
501 | 510 | return files_path |
502 | 511 |
|
503 | 512 |
|
| 513 | +def _make_symlinks_from_a_folder(dest_path: str, data_path: str, depth: int): |
| 514 | + if (depth > MAX_FOLDER_DEPTH): |
| 515 | + raise exc.UserError(f"Folder depth exceed the limit: {MAX_FOLDER_DEPTH}.") |
| 516 | + |
| 517 | + if os.path.isfile(data_path): |
| 518 | + _make_symlink(data_path, dest_path, os.path.basename(data_path)) |
| 519 | + return |
| 520 | + else: |
| 521 | + logging.info(f"Making smlinks from folder {data_path} to folder {dest_path}") |
| 522 | + for item in os.scandir(data_path): |
| 523 | + if item.is_file(): |
| 524 | + _make_symlink(item.path, dest_path, item.name) |
| 525 | + elif item.is_dir(): |
| 526 | + _make_symlinks_from_a_folder(dest_path, item.path, depth + 1) |
| 527 | + |
| 528 | + |
| 529 | +def _make_symlinks_from_a_folder_with_warning(dest_path: str, data_path: str): |
| 530 | + """ |
| 531 | + :param dest_path: A dir |
| 532 | + :param data_path: Either dir or file |
| 533 | + :param depth: current folder depth, Integer |
| 534 | + """ |
| 535 | + |
| 536 | + # If data_path is a single file A, create smylink A -> dest_path/A |
| 537 | + # If data_path is a dir, create symlinks for files located within depth of MAX_FOLDER_DEPTH |
| 538 | + # under this dir. Ignore the files in deeper sub dirs and log a warning if they exist. |
| 539 | + |
| 540 | + if (not os.path.exists(dest_path)) or (not os.path.exists(data_path)): |
| 541 | + raise exc.AlgorithmError(f"Unable to create symlinks as {data_path} or {dest_path} doesn't exist ") |
| 542 | + |
| 543 | + if (not os.path.isdir(dest_path)): |
| 544 | + raise exc.AlgorithmError(f"Unable to create symlinks as dest_path {dest_path} is not a dir") |
| 545 | + |
| 546 | + try: |
| 547 | + _make_symlinks_from_a_folder(dest_path, data_path, 1) |
| 548 | + except exc.UserError as e: |
| 549 | + if e.message == f"Folder depth exceed the limit: {MAX_FOLDER_DEPTH}.": |
| 550 | + logging.warning( |
| 551 | + f"The depth of folder {data_path} exceed the limit {MAX_FOLDER_DEPTH}." |
| 552 | + f" Files in deeper sub dirs won't be loaded." |
| 553 | + f" Please adjust the folder structure accordingly." |
| 554 | + ) |
| 555 | + |
| 556 | + |
504 | 557 | def _get_file_mode_files_path(data_path: Union[List[str], str]) -> List[str]: |
505 | 558 | """ |
506 | 559 | :param data_path: Either directory or file |
507 | 560 | """ |
508 | 561 | # In file mode, we create a temp directory with symlink to all input files or |
509 | 562 | # directories to meet XGB's assumption that all files are in the same directory. |
510 | 563 |
|
| 564 | + logging.info("File path {} of input files".format(data_path)) |
| 565 | + # Create a directory with symlinks to input files. |
| 566 | + files_path = "/tmp/sagemaker_xgboost_input_data" |
| 567 | + shutil.rmtree(files_path, ignore_errors=True) |
| 568 | + os.mkdir(files_path) |
511 | 569 | if isinstance(data_path, list): |
512 | | - logging.info('File path {} of input files'.format(data_path)) |
513 | | - # Create a directory with symlinks to input files. |
514 | | - files_path = "/tmp/sagemaker_xgboost_input_data" |
515 | | - shutil.rmtree(files_path, ignore_errors=True) |
516 | | - os.mkdir(files_path) |
517 | | - for index, path in enumerate(data_path): |
518 | | - if not os.path.exists(path): |
519 | | - return None |
520 | | - if os.path.isfile(path): |
521 | | - _make_symlink(path, files_path, os.path.basename(path), index) |
522 | | - else: |
523 | | - for file in os.scandir(path): |
524 | | - _make_symlink(file, files_path, file.name, index) |
525 | | - |
| 570 | + for path in data_path: |
| 571 | + _make_symlinks_from_a_folder_with_warning(files_path, path) |
526 | 572 | else: |
527 | 573 | if not os.path.exists(data_path): |
528 | 574 | logging.info('File path {} does not exist!'.format(data_path)) |
529 | 575 | return None |
530 | | - files_path = get_files_path_from_string(data_path) |
| 576 | + elif os.path.isdir(data_path) or os.path.isfile(data_path): |
| 577 | + # traverse all sub-dirs to load all training data |
| 578 | + _make_symlinks_from_a_folder_with_warning(files_path, data_path) |
| 579 | + else: |
| 580 | + exc.UserError("Unknown input files path: {}".format(data_path)) |
531 | 581 |
|
532 | 582 | return files_path |
533 | 583 |
|
@@ -607,22 +657,11 @@ def get_size(data_path, is_pipe=False): |
607 | 657 | return total_size |
608 | 658 |
|
609 | 659 |
|
610 | | -def get_files_path_from_string(data_path: Union[List[str], str]) -> List[str]: |
611 | | - if os.path.isfile(data_path): |
612 | | - files_path = data_path |
613 | | - else: |
614 | | - for root, dirs, files in os.walk(data_path): |
615 | | - if dirs == []: |
616 | | - files_path = root |
617 | | - break |
618 | | - |
619 | | - return files_path |
620 | | - |
621 | | - |
622 | | -def _make_symlink(path, source_path, name, index): |
623 | | - base_name = os.path.join(source_path, f"{name}_{str(index)}") |
624 | | - logging.info(f'creating symlink between Path {source_path} and destination {base_name}') |
625 | | - os.symlink(path, base_name) |
| 660 | +def _make_symlink(path, source_path, name): |
| 661 | + base_name = os.path.join(source_path, name) |
| 662 | + file_name = base_name + str(hash(path)) |
| 663 | + logging.info(f"creating symlink between Path {path} and destination {file_name}") |
| 664 | + os.symlink(path, file_name) |
626 | 665 |
|
627 | 666 |
|
628 | 667 | def check_data_redundancy(train_path, validate_path): |
|
0 commit comments