Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## v.3.5.0 (2025-12-03)

### Features
* Elastic training support for HyperPodTrainingOperator that is released in Reinvent 2025 keynote 3. This is a method that dynamically scales distributed machine learning operations.


## v.3.4.0 (2025-11-20)

### Features
Expand Down
6 changes: 6 additions & 0 deletions hyperpod-pytorch-job-template/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## v1.3.0 (2025-12-03)

### Features

* Support for elastic training

## v1.2.0 (2025-11-20)

### Features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
Metadata,
Volumes,
HostPath,
PersistentVolumeClaim
PersistentVolumeClaim,
ElasticPolicy
)
from sagemaker.hyperpod.training.hyperpod_pytorch_job import HyperPodPytorchJob
import yaml
Expand Down Expand Up @@ -239,6 +240,38 @@ class PyTorchJobConfig(BaseModel):
alias="required_topology",
description="Required topology annotation for scheduling",
)
elastic_replica_increment_step: Optional[int] = Field(
default=None,
alias="elastic_replica_increment_step",
description="Scaling step size for elastic training",
ge=1,
)
max_node_count: Optional[int] = Field(
default=None,
alias="max_node_count",
description="Maximum number of nodes for elastic training",
ge=1,
)
elastic_graceful_shutdown_timeout_in_seconds: Optional[int] = Field(
default=None,
alias="elastic_graceful_shutdown_timeout_in_seconds",
description="Graceful shutdown timeout in seconds for elastic scaling operations"
)
elastic_scaling_timeout_in_seconds: Optional[int] = Field(
default=None,
alias="elastic_scaling_timeout_in_seconds",
description="Scaling timeout for elastic training"
)
elastic_scale_up_snooze_time_in_seconds: Optional[int] = Field(
default=None,
alias="elastic_scale_up_snooze_time_in_seconds",
description="Timeout period after job restart during which no scale up/workload admission is allowed"
)
elastic_replica_discrete_values: Optional[List[int]] = Field(
default=None,
alias="elastic_replica_discrete_values",
description="Alternative to replica increment step. Provides exact values for total replicas count"
)

@field_validator('tasks_per_node', mode='before')
@classmethod
Expand Down Expand Up @@ -363,6 +396,45 @@ def validate_accelerator_partition_options(self):
)
if not valid:
raise ValueError(error)

return self

@model_validator(mode='after')
def validate_elastic_replica_config(self):
"""Validate elastic replica configuration."""
has_increment_step = self.elastic_replica_increment_step is not None
has_discrete_values = self.elastic_replica_discrete_values is not None

# Check mutual exclusivity
if has_increment_step and has_discrete_values:
raise ValueError(
"Only one of 'elastic_replica_increment_step' or 'elastic_replica_discrete_values' "
"can be specified, not both. Please use either:\n"
" - elastic_replica_increment_step for uniform scaling steps, or\n"
" - elastic_replica_discrete_values for specific replica counts"
)

# Validate discrete values are within valid range
if has_discrete_values:
discrete_values = self.elastic_replica_discrete_values

# Check that all values are positive
if any(val <= 0 for val in discrete_values):
raise ValueError(
f"All values in 'elastic_replica_discrete_values' must be positive integers. "
f"Got: {discrete_values}"
)

# Check against max_node_count if specified
if self.max_node_count is not None:
invalid_values = [val for val in discrete_values if val > self.max_node_count]
if invalid_values:
raise ValueError(
f"All values in 'elastic_replica_discrete_values' must be ≤ max_node_count ({self.max_node_count}). "
f"Invalid values: {invalid_values}. "
f"Please either increase max_node_count or remove values exceeding it."
)

return self

def to_domain(self) -> Dict:
Expand Down Expand Up @@ -467,15 +539,61 @@ def build_dict(**kwargs):
replica_kwargs = build_dict(
name="pod",
template=Template(metadata=Metadata(**metadata_kwargs), spec=Spec(**spec_kwargs)),
replicas=self.node_count
replicas=self.node_count,
max_replicas=self.max_node_count
)

# Build elastic policy
elastic_policy = None
if any([
self.elastic_replica_increment_step is not None,
self.max_node_count is not None,
self.elastic_graceful_shutdown_timeout_in_seconds is not None,
self.elastic_scaling_timeout_in_seconds is not None,
self.elastic_replica_discrete_values is not None
]):
# Build base elastic policy kwargs
elastic_policy_kwargs = build_dict(
min_replicas=self.node_count,
max_replicas=self.max_node_count,
graceful_shutdown_timeout_in_seconds=self.elastic_graceful_shutdown_timeout_in_seconds,
scaling_timeout_in_seconds=self.elastic_scaling_timeout_in_seconds
)

if self.elastic_replica_discrete_values is not None:
elastic_policy_kwargs['replica_discrete_values'] = self.elastic_replica_discrete_values
elif self.elastic_replica_increment_step is not None:
elastic_policy_kwargs['replica_increment_step'] = self.elastic_replica_increment_step

elastic_policy = ElasticPolicy(**elastic_policy_kwargs)

# Build run policy
run_policy = None
if self.max_retry is not None or self.elastic_scale_up_snooze_time_in_seconds is not None:
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_unified_config import RestartPolicy

run_policy_kwargs = build_dict(
clean_pod_policy="None",
job_max_retry_count=self.max_retry
)

# Add restart policy if scale_up_snooze_interval is provided
if self.elastic_scale_up_snooze_time_in_seconds is not None:
restart_policy = RestartPolicy(
eval_period_seconds=3600,
scale_up_snooze_time_in_seconds=self.elastic_scale_up_snooze_time_in_seconds
)
run_policy_kwargs['restart_policy'] = restart_policy

run_policy = RunPolicy(**run_policy_kwargs)

# Build job
job_kwargs = build_dict(
metadata=metadata_kwargs,
replica_specs=[ReplicaSpec(**replica_kwargs)],
nproc_per_node=str(self.tasks_per_node) if self.tasks_per_node else None,
run_policy=RunPolicy(clean_pod_policy="None", job_max_retry_count=self.max_retry) if self.max_retry else None
run_policy=run_policy,
elastic_policy=elastic_policy
)

result = HyperPodPytorchJob(**job_kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,94 @@
"type": "string",
"description": "Required topology annotation for scheduling",
"$ref": "#/$defs/topologyLabels"
},
"elastic_replica_increment_step": {
"anyOf": [
{
"minimum": 1,
"type": "integer"
},
{
"type": "null"
}
],
"default": null,
"description": "Scaling step size for elastic training",
"title": "Elastic Training Replica Increment Step"
},
"max_node_count": {
"anyOf": [
{
"minimum": 1,
"type": "integer"
},
{
"type": "null"
}
],
"default": null,
"description": "Maximum number of nodes for elastic training",
"title": "Max Node Count"
},
"elastic_graceful_shutdown_timeout_in_seconds": {
"anyOf": [
{
"minimum": 0,
"type": "integer"
},
{
"type": "null"
}
],
"default": null,
"description": "Graceful shutdown timeout in seconds for elastic scaling operations",
"title": "Elastic Graceful Shutdown Timeout In Seconds"
},
"elastic_scaling_timeout_in_seconds": {
"anyOf": [
{
"minimum": 0,
"type": "integer"
},
{
"type": "null"
}
],
"default": null,
"description": "Scaling timeout for elastic training",
"title": "Elastic Scaling Timeout In Seconds"
},
"elastic_scale_up_snooze_time_in_seconds": {
"anyOf": [
{
"minimum": 0,
"type": "integer"
},
{
"type": "null"
}
],
"default": null,
"description": "Timeout period after job restart during which no scale up/workload admission is allowed",
"title": "Elastic Scale Up Snooze Time In Seconds"
},
"elastic_replica_discrete_values": {
"anyOf": [
{
"items": {
"type": "integer"
},
"type": "array"
},
{
"type": "null"
}
],
"default": null,
"description": "Alternative to replica increment step. Provides exact values for total replicas count",
"title": "Elastic Replica Discrete Values"
}

},
"required": [
"job_name",
Expand Down
2 changes: 1 addition & 1 deletion hyperpod-pytorch-job-template/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "hyperpod-pytorch-job-template"
version = "1.2.0"
version = "1.3.0"
readme = "README.md"
authors = [{name = "Amazon Web Services"}]
license = {text = "Apache-2.0"}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
dynamic = ["dependencies"]
name = "sagemaker-hyperpod"
version = "3.4.0"
version = "3.5.0"
description = "Amazon SageMaker HyperPod SDK and CLI"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
setup(
data_files=sagemaker_hyperpod_recipes,
name="sagemaker-hyperpod",
version="3.4.0",
version="3.5.0",
description="Amazon SageMaker HyperPod SDK and CLI",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
Expand Down
15 changes: 13 additions & 2 deletions src/sagemaker/hyperpod/cli/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,16 @@ def _parse_list_flag(ctx, param, value):
return None
# Remove brackets and split by comma
value = value.strip("[]")
return [item.strip() for item in value.split(",") if item.strip()]
items = [item.strip() for item in value.split(",") if item.strip()]

# Convert to integers for elastic_replica_discrete_values
if param and hasattr(param, 'name') and param.name == 'elastic_replica_discrete_values':
try:
return [int(item) for item in items]
except ValueError as e:
raise click.BadParameter(f"elastic-replica-discrete-values must contain only integers: {e}")

return items

def _parse_volume_param(ctx, param, value):
"""Parse volume parameters from command line format to dictionary format."""
Expand Down Expand Up @@ -134,11 +143,12 @@ def wrapped_func(*args, **kwargs):
list_params = {
"command": "List of command arguments",
"args": "List of script arguments, e.g. '[--batch-size, 32, --learning-rate, 0.001]'",
"elastic_replica_discrete_values": "List of discrete replica values for elastic training, e.g. '[2, 4, 8, 16]'",
}

for param_name, help_text in list_params.items():
wrapped_func = click.option(
f"--{param_name}",
f"--{param_name.replace('_', '-')}",
callback=_parse_list_flag,
type=str,
default=None,
Expand All @@ -154,6 +164,7 @@ def wrapped_func(*args, **kwargs):
"command",
"args",
"volume",
"elastic_replica_discrete_values"
]
)

Expand Down
Loading
Loading