Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,9 +545,12 @@ def hyperparameters(self):
mpi_dict = self.distributions["mpi"]
mpi_enabled = mpi_dict.get("enabled", False)
additional_hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled
additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get(
"processes_per_host", 1
)

if mpi_dict.get("processes_per_host"):
additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get(
"processes_per_host"
)

additional_hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get(
"custom_mpi_options", ""
)
Expand Down