-
Couldn't load subscription status.
- Fork 90
Fix: load training data from all sub dirs #396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -528,33 +528,50 @@ def _get_pipe_mode_files_path(data_path: Union[List[str], str]) -> List[str]: | |
| return files_path | ||
|
|
||
|
|
||
| def _make_symlinks_for_files_under_a_folder(dest_path: str, data_path: str): | ||
| if (not os.path.exists(dest_path)) or (not os.path.exists(data_path)): | ||
| raise exc.AlgorithmError("Unable to create symlinks as {data_path} or {dest_path} doesn't exist ") | ||
|
|
||
| logging.info("Making smlinks from folder {} to folder {}".format(data_path, dest_path)) | ||
|
||
|
|
||
| if os.path.isfile(data_path): | ||
| _make_symlink(data_path, dest_path, os.path.basename(data_path)) | ||
| return | ||
|
|
||
| else: | ||
| for item in os.scandir(data_path): | ||
| if item.is_file(): | ||
| _make_symlink(item.path, dest_path, item.name) | ||
| elif item.is_dir(): | ||
| _make_symlinks_for_files_under_a_folder(dest_path, item.path) | ||
|
||
|
|
||
|
|
||
| def _get_file_mode_files_path(data_path: Union[List[str], str]) -> List[str]: | ||
| """ | ||
| :param data_path: Either directory or file | ||
| """ | ||
| # In file mode, we create a temp directory with symlink to all input files or | ||
| # directories to meet XGB's assumption that all files are in the same directory. | ||
|
|
||
| logging.info("File path {} of input files".format(data_path)) | ||
| # Create a directory with symlinks to input files. | ||
| files_path = "/tmp/sagemaker_xgboost_input_data" | ||
| shutil.rmtree(files_path, ignore_errors=True) | ||
| os.mkdir(files_path) | ||
| if isinstance(data_path, list): | ||
| logging.info("File path {} of input files".format(data_path)) | ||
| # Create a directory with symlinks to input files. | ||
| files_path = "/tmp/sagemaker_xgboost_input_data" | ||
| shutil.rmtree(files_path, ignore_errors=True) | ||
| os.mkdir(files_path) | ||
| for index, path in enumerate(data_path): | ||
| if not os.path.exists(path): | ||
| return None | ||
| if os.path.isfile(path): | ||
| _make_symlink(path, files_path, os.path.basename(path), index) | ||
| else: | ||
| for file in os.scandir(path): | ||
| _make_symlink(file, files_path, file.name, index) | ||
|
|
||
| for path in data_path: | ||
| _make_symlinks_for_files_under_a_folder(files_path, path) | ||
| else: | ||
| if not os.path.exists(data_path): | ||
| logging.info("File path {} does not exist!".format(data_path)) | ||
| return None | ||
| files_path = get_files_path_from_string(data_path) | ||
| elif os.path.isdir(data_path): | ||
| # traverse all sub-dirs to load all training data | ||
| _make_symlinks_for_files_under_a_folder(files_path, data_path) | ||
| elif os.path.isfile(data_path): | ||
| files_path = data_path | ||
|
||
| else: | ||
| exc.UserError("Unknown input files path: {}".format(data_path)) | ||
|
|
||
| return files_path | ||
|
|
||
|
|
@@ -635,22 +652,11 @@ def get_size(data_path, is_pipe=False): | |
| return total_size | ||
|
|
||
|
|
||
| def get_files_path_from_string(data_path: Union[List[str], str]) -> List[str]: | ||
| if os.path.isfile(data_path): | ||
| files_path = data_path | ||
| else: | ||
| for root, dirs, files in os.walk(data_path): | ||
| if dirs == []: | ||
| files_path = root | ||
| break | ||
|
|
||
| return files_path | ||
|
|
||
|
|
||
| def _make_symlink(path, source_path, name, index): | ||
| base_name = os.path.join(source_path, f"{name}_{str(index)}") | ||
| logging.info(f"creating symlink between Path {source_path} and destination {base_name}") | ||
| os.symlink(path, base_name) | ||
| def _make_symlink(path, source_path, name): | ||
| base_name = os.path.join(source_path, name) | ||
| file_name = base_name + str(hash(path)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens in case of hash collisions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An error will be thrown out saying there is an existing symlink between A and B. The reason for |
||
| logging.info(f"creating symlink between Path {path} and destination {file_name}") | ||
| os.symlink(path, file_name) | ||
|
|
||
|
|
||
| def check_data_redundancy(train_path, validate_path): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this meant to be an f-string?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch, yeah, I will update