diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index a24bf75dec..cdac66661f 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -545,9 +545,12 @@ def hyperparameters(self): mpi_dict = self.distributions["mpi"] mpi_enabled = mpi_dict.get("enabled", False) additional_hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled - additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get( - "processes_per_host", 1 - ) + + if mpi_dict.get("processes_per_host"): + additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get( + "processes_per_host" + ) + additional_hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get( "custom_mpi_options", "" )