- 
                Notifications
    
You must be signed in to change notification settings  - Fork 1.2k
 
Open
Labels
component: trainingRelates to the SageMaker Training PlatformRelates to the SageMaker Training Platformtype: bug
Description
Describe the bug
PyTorch estimator doesn't allow to setup a checkpoint_s3_uri when I'm working with an heterogeneous cluster, by returning the following error:
│ /Users/bpistone/miniforge3/envs/ray-env/lib/python3.12/site-packages/sagemaker/estimator.py:3646 │
│ in _validate_and_set_debugger_configs                                                            │
│                                                                                                  │
│   3643 │   │   │   │   │   │   "the debugger_hook_config is disabled."                           │
│   3644 │   │   │   │   │   )                                                                     │
│   3645 │   │   │   │   │   self.debugger_hook_config = False                                     │
│ ❱ 3646 │   │   │   │   elif self.instance_count > 1 or (                                         │
│   3647 │   │   │   │   │   hasattr(self, "distribution")                                         │
│   3648 │   │   │   │   │   and self.distribution is not None  # pylint: disable=no-member        │
│   3649 │   │   │   │   ):                                                                        │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: '>' not supported between instances of 'NoneType' and 'int'
To reproduce
from sagemaker.instance_group import InstanceGroup
from sagemaker.pytorch import PyTorch
instance_groups = [
    InstanceGroup(
        instance_group_name="head-instance-group",
        instance_type="ml.t3.xlarge",
        instance_count=1,
    ),
    InstanceGroup(
        instance_group_name="worker-instance-group",
        instance_type="ml.g5.xlarge",
        instance_count=4,
    ),
]
# define Training Job Name
job_name = f"train-{model_id.split('/')[-1].replace('.', '-')}-sft"
output_path = f"s3://{bucket_name}/{job_name}"
estimator = PyTorch(
    source_dir="./scripts",
    entry_point="launcher.py",
    output_path=output_path,
    base_job_name=job_name,
    role=role,
    instance_groups=instance_groups,
    max_run=432000,
    image_uri=image_uri,
    environment={
        "head_instance_group": "head-instance-group",
        "head_num_cpus": "0",
        "head_num_gpus": "0",
    },
    hyperparameters={
        "entrypoint": "train_ray.py",
        "config": "/opt/ml/input/data/config/args.yaml",  # path to TRL config which was uploaded to s3
    },
    enable_remote_debug=True,
    checkpoint_local_path="/opt/ml/checkpoints", 
    checkpoint_s3_uri=output_path + "/checkpoint", 
)
This error cannot be reproduced with ModelTrainer, due to an existing bug for Heterogeneous clusters and ModelTrainer reported in this issue #5225
Expected behavior
The estimator should be created and the training job should start with estimator.fit(inputs=data, wait=False)
Screenshots or logs
If applicable, add screenshots or logs to help explain your problem.
System information
A description of your system. Please provide:
- SageMaker Python SDK version: 2.271.
 - Framework name (eg. PyTorch) or algorithm (eg. KMeans): PyTorch
 - Framework version: 2.6.0
 - Python version: 3.12
 - CPU or GPU: CPU and GPU
 - Custom Docker image (Y/N): N
 
Additional context
Add any other context about the problem here.
Metadata
Metadata
Assignees
Labels
component: trainingRelates to the SageMaker Training PlatformRelates to the SageMaker Training Platformtype: bug