Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions kubeflow/trainer/backends/container/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@ def get_runtime_packages(self, runtime: types.Runtime):
"""
Spawn a short-lived container to report Python version, pip list, and nvidia-smi.
"""
image = container_utils.resolve_image(runtime)
container_utils.maybe_pull_image(self._adapter, image, self.cfg.pull_policy)
container_utils.maybe_pull_image(self._adapter, runtime.trainer.image, self.cfg.pull_policy)

command = [
"bash",
Expand All @@ -220,14 +219,17 @@ def get_runtime_packages(self, runtime: types.Runtime):
"(nvidia-smi || echo 'nvidia-smi not found')",
]

logs = self._adapter.run_oneoff_container(image=image, command=command)
logs = self._adapter.run_oneoff_container(image=runtime.trainer.image, command=command)
print(logs)

def train(
self,
runtime: Optional[types.Runtime] = None,
initializer: Optional[types.Initializer] = None,
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
trainer: Optional[
Union[types.CustomTrainer, types.CustomTrainerContainer, types.BuiltinTrainer]
] = None,
options: Optional[list] = None,
Comment on lines +227 to +230
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should fix E2Es in this PR: kubeflow/trainer#2907

cc @Fiona-Waters @astefanutti @kramaranya

) -> str:
if runtime is None:
runtime = self.get_runtime("torch-distributed")
Expand All @@ -249,11 +251,12 @@ def train(
logger.debug("Generated training script code")

# Resolve image and pull if needed
image = container_utils.resolve_image(runtime)
logger.debug(f"Using image: {image}")
logger.debug(f"Using image: {runtime.trainer.image}")

container_utils.maybe_pull_image(self._adapter, image, self.cfg.pull_policy)
logger.debug(f"Image ready: {image}")
container_utils.maybe_pull_image(
self._adapter, runtime.trainer.image, self.cfg.pull_policy
)
logger.debug(f"Image ready: {runtime.trainer.image}")

# Build base environment
env = container_utils.build_environment(trainer)
Expand Down Expand Up @@ -368,7 +371,7 @@ def train(
logger.debug(f"Creating container {rank}/{num_nodes}: {container_name}")

container_id = self._adapter.create_and_start_container(
image=image,
image=runtime.trainer.image,
command=full_cmd,
name=container_name,
network_id=network_id,
Expand Down
6 changes: 2 additions & 4 deletions kubeflow/trainer/backends/container/runtime_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,8 @@ def _create_default_runtimes() -> list[base_types.Runtime]:
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
framework=framework,
num_nodes=1,
image=image,
),
pretrained_model=None,
image=image,
)
default_runtimes.append(runtime)
logger.debug(f"Created default runtime: {runtime.name} with image {image}")
Expand Down Expand Up @@ -414,9 +413,8 @@ def _parse_runtime_yaml(data: dict[str, Any], source: str = "unknown") -> base_t
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
framework=framework,
num_nodes=num_nodes,
image=image,
),
pretrained_model=None,
image=image,
)


Expand Down
71 changes: 7 additions & 64 deletions kubeflow/trainer/backends/container/runtime_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def test_list_training_runtimes_from_sources(test_case):
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
framework="torch",
num_nodes=1,
image="example.com/container",
),
)
deepspeed_runtime = base_types.Runtime(
Expand All @@ -291,6 +292,7 @@ def test_list_training_runtimes_from_sources(test_case):
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
framework="deepspeed",
num_nodes=1,
image="example.com/container",
),
)
mock_github.side_effect = [[torch_runtime], [deepspeed_runtime]]
Expand All @@ -303,6 +305,7 @@ def test_list_training_runtimes_from_sources(test_case):
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
framework="torch",
num_nodes=1,
image="example.com/container",
),
)
torch_runtime_2 = base_types.Runtime(
Expand All @@ -311,6 +314,7 @@ def test_list_training_runtimes_from_sources(test_case):
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
framework="torch",
num_nodes=2,
image="example.com/container",
),
)
mock_github.side_effect = [[torch_runtime_1], [torch_runtime_2]]
Expand All @@ -324,6 +328,7 @@ def test_list_training_runtimes_from_sources(test_case):
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
framework="torch",
num_nodes=1,
image="example.com/container",
),
)
mock_defaults.return_value = [default_runtime]
Expand Down Expand Up @@ -358,7 +363,7 @@ def test_create_default_runtimes():
assert torch_runtimes[0].trainer.trainer_type == base_types.TrainerType.CUSTOM_TRAINER
assert torch_runtimes[0].trainer.num_nodes == 1
# Verify default image is set
assert torch_runtimes[0].image == constants.DEFAULT_FRAMEWORK_IMAGES["torch"]
assert torch_runtimes[0].trainer.image == constants.DEFAULT_FRAMEWORK_IMAGES["torch"]
print("test execution complete")


Expand Down Expand Up @@ -620,72 +625,10 @@ def test_parse_runtime_yaml_extracts_image(test_case):
runtime = runtime_loader._parse_runtime_yaml(runtime_yaml, "test")

# Verify image is extracted and stored
assert runtime.image == test_case.config["custom_image"]
assert runtime.name == test_case.config["runtime_name"]
assert runtime.trainer.framework == test_case.config["framework"]
assert runtime.trainer.num_nodes == test_case.config["num_nodes"]

assert test_case.expected_status == SUCCESS

except Exception as e:
assert type(e) is test_case.expected_error
print("test execution complete")


@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="resolve image uses custom image",
expected_status=SUCCESS,
config={
"custom_image": "my-registry.io/pytorch-custom:arm64",
"framework": "torch",
"expect_custom": True,
},
),
TestCase(
name="resolve image falls back to default when no custom image",
expected_status=SUCCESS,
config={
"custom_image": None,
"framework": "torch",
"expect_custom": False,
},
),
],
)
def test_resolve_image_uses_custom_image(test_case):
"""
Test that resolve_image prioritizes runtime.image over default framework images.
This ensures custom images from ClusterTrainingRuntimes are actually used.
"""
print("Executing test:", test_case.name)
try:
from kubeflow.trainer.backends.container import utils

# Create runtime with or without custom image
runtime = base_types.Runtime(
name="test-runtime",
trainer=base_types.RuntimeTrainer(
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
framework=test_case.config["framework"],
num_nodes=1,
),
image=test_case.config["custom_image"],
)

resolved_image = utils.resolve_image(runtime)

if test_case.config["expect_custom"]:
# Should use custom image
assert resolved_image == test_case.config["custom_image"]
else:
# Should fall back to default
assert (
resolved_image == constants.DEFAULT_FRAMEWORK_IMAGES[test_case.config["framework"]]
)
assert "pytorch/pytorch" in resolved_image
assert runtime.trainer.image == test_case.config["custom_image"]

assert test_case.expected_status == SUCCESS

Expand Down
36 changes: 2 additions & 34 deletions kubeflow/trainer/backends/container/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def container_status_to_trainjob_status(status: str, exit_code: int) -> str:
if status == "exited":
# Exit code 0 -> complete, else failed
return constants.TRAINJOB_COMPLETE if exit_code == 0 else constants.TRAINJOB_FAILED
return constants.UNKNOWN
return UNKNOWN


def aggregate_status_from_containers(container_statuses: list[str]) -> str:
Expand All @@ -150,38 +150,6 @@ def aggregate_status_from_containers(container_statuses: list[str]) -> str:
return UNKNOWN


def resolve_image(runtime: types.Runtime) -> str:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fiona-Waters It looks like we don't need image resolver, since we fallback to the default runtime in case we can't get it online:

def _create_default_runtimes() -> list[base_types.Runtime]:

Is that correct ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes if image is not optional then we don't need this function.

"""
Resolve the container image for a runtime.

Priority:
1. Use runtime.image if specified in the ClusterTrainingRuntime
2. Fall back to DEFAULT_FRAMEWORK_IMAGES based on framework

Args:
runtime: Runtime object.

Returns:
Container image name.

Raises:
ValueError: If no image is found for the runtime's framework.
"""
# Use image from runtime if specified
if runtime.image:
return runtime.image

# Fall back to default framework images
framework = runtime.trainer.framework
if framework in constants.DEFAULT_FRAMEWORK_IMAGES:
return constants.DEFAULT_FRAMEWORK_IMAGES[framework]

raise ValueError(
f"No default image found for framework '{framework}'. "
f"Supported frameworks: {list(constants.DEFAULT_FRAMEWORK_IMAGES.keys())}"
)


def maybe_pull_image(adapter, image: str, pull_policy: str):
"""
Pull image based on pull policy.
Expand Down Expand Up @@ -227,7 +195,7 @@ def get_container_status(adapter, container_id: str) -> str:
status, exit_code = adapter.container_status(container_id)
return container_status_to_trainjob_status(status, exit_code)
except Exception:
return constants.UNKNOWN
return UNKNOWN


def aggregate_container_statuses(adapter, containers: list[dict]) -> str:
Expand Down
7 changes: 4 additions & 3 deletions kubeflow/trainer/backends/kubernetes/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def get_replicated_job() -> models.JobsetV1alpha2ReplicatedJob:
def get_container() -> models.IoK8sApiCoreV1Container:
return models.IoK8sApiCoreV1Container(
name="node",
image="image",
image="example.com/test-runtime",
command=["echo", "Hello World"],
resources=get_resource_requirements(),
)
Expand All @@ -543,11 +543,11 @@ def create_runtime_type(
num_nodes=2,
device="gpu",
device_count=RUNTIME_DEVICES,
image="example.com/test-runtime",
)
trainer.set_command(constants.TORCH_COMMAND)
return types.Runtime(
name=name,
pretrained_model=None,
trainer=trainer,
)

Expand All @@ -564,14 +564,14 @@ def get_train_job_data_type(
device="gpu",
device_count=RUNTIME_DEVICES,
num_nodes=2,
image="example.com/test-runtime",
)
trainer.set_command(constants.TORCH_COMMAND)
return types.TrainJob(
name=train_job_name,
creation_timestamp=datetime.datetime(2025, 6, 1, 10, 30, 0),
runtime=types.Runtime(
name=runtime_name,
pretrained_model=None,
trainer=trainer,
),
steps=[
Expand Down Expand Up @@ -696,6 +696,7 @@ def test_list_runtimes(kubernetes_backend, test_case):
num_nodes=1,
device="cpu",
device_count="1",
image="example.com/image",
),
)
},
Expand Down
1 change: 1 addition & 0 deletions kubeflow/trainer/backends/kubernetes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def get_runtime_trainer(
else types.TrainerType.CUSTOM_TRAINER
),
framework=framework,
image=trainer_container.image,
)

# Get the container devices.
Expand Down
1 change: 1 addition & 0 deletions kubeflow/trainer/backends/kubernetes/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def _build_runtime() -> types.Runtime:
framework="torch",
device="cpu",
device_count="1",
image="example.com/image",
)
runtime_trainer.set_command(constants.DEFAULT_COMMAND)
return types.Runtime(name="test-runtime", trainer=runtime_trainer)
Expand Down
2 changes: 1 addition & 1 deletion kubeflow/trainer/backends/localprocess/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,6 @@ def __convert_local_runtime_to_runtime(self, local_runtime) -> types.Runtime:
num_nodes=local_runtime.trainer.num_nodes,
device_count=local_runtime.trainer.device_count,
device=local_runtime.trainer.device,
image=local_runtime.trainer.image,
),
pretrained_model=local_runtime.pretrained_model,
)
4 changes: 4 additions & 0 deletions kubeflow/trainer/backends/localprocess/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

TORCH_FRAMEWORK_TYPE = "torch"

# Image name for the local runtime.
LOCAL_RUNTIME_IMAGE = "local"

local_runtimes = [
base_types.Runtime(
name=constants.TORCH_RUNTIME,
Expand All @@ -32,6 +35,7 @@
device_count=common_constants.UNKNOWN,
device=common_constants.UNKNOWN,
packages=["torch"],
image=LOCAL_RUNTIME_IMAGE,
),
)
]
Expand Down
1 change: 1 addition & 0 deletions kubeflow/trainer/backends/localprocess/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def get_local_runtime_trainer(
trainer_type=types.TrainerType.CUSTOM_TRAINER,
framework=framework,
packages=local_runtime.trainer.packages,
image=local_exec_constants.LOCAL_RUNTIME_IMAGE,
)

# set command to run from venv
Expand Down
2 changes: 1 addition & 1 deletion kubeflow/trainer/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ class TrainerType(Enum):
class RuntimeTrainer:
trainer_type: TrainerType
framework: str
image: str
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@astefanutti @Fiona-Waters I made image mandatory for the RuntimeTrainer.
The container should always has an image, but for the local subprocess backend, we can populate some const value there.
What do you think about it ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense for the image to be mandatory. Adding a dummy const to local process backend seems fine to me as long as we make sure it's purpose is mentioned clearly in a comment.

num_nodes: int = 1 # The default value is set in the APIs.
device: str = common_constants.UNKNOWN
device_count: str = common_constants.UNKNOWN
Expand All @@ -251,7 +252,6 @@ class Runtime:
name: str
trainer: RuntimeTrainer
pretrained_model: Optional[str] = None
image: Optional[str] = None


# Representation for the TrainJob steps.
Expand Down
Loading