Skip to content

Commit e96e53a

Browse files
committed
Allow CustomTrainer to run a Python script directly
CustomTrainer now supports a python_file argument. If set, the job will run the specified script as the main process (python myscript.py) instead of requiring a function. This is mutually exclusive with func. This change makes it easier to migrate script-based workflows and matches user expectations for direct script execution. Existing function-based usage is unchanged. Validation is added to ensure only one of func or python_file is set. Signed-off-by: Krishnaswamy Subramanian <[email protected]>
1 parent d90dbce commit e96e53a

File tree

4 files changed

+195
-3
lines changed

4 files changed

+195
-3
lines changed

python/kubeflow/trainer/types/types.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from dataclasses import dataclass, field
1717
from datetime import datetime
1818
from enum import Enum
19-
from typing import Callable, Dict, Optional
19+
from typing import Callable, Dict, List, Optional
2020

2121
from kubeflow.trainer.constants import constants
2222

@@ -25,10 +25,13 @@
2525
@dataclass
2626
class CustomTrainer:
2727
"""Custom Trainer configuration. Configure the self-contained function
28-
that encapsulates the entire model training process.
28+
that encapsulates the entire model training process, or run a Python script directly.
2929
3030
Args:
3131
func (`Callable`): The function that encapsulates the entire model training process.
32+
python_file (`Optional[str]`): Path to a Python script to run directly (e.g., 'train.py').
33+
python_args (`Optional[List[str]]`): Arguments to pass to the Python script.
34+
Only one of func or python_file should be set.
3235
func_args (`Optional[Dict]`): The arguments to pass to the function.
3336
packages_to_install (`Optional[List[str]]`):
3437
A list of Python packages to install before running the function.
@@ -38,7 +41,9 @@ class CustomTrainer:
3841
env (`Optional[Dict[str, str]]`): The environment variables to set in the training nodes.
3942
"""
4043

41-
func: Callable
44+
func: Optional[Callable] = None
45+
python_file: Optional[str] = None
46+
python_args: Optional[List[str]] = None
4247
func_args: Optional[Dict] = None
4348
packages_to_install: Optional[list[str]] = None
4449
pip_index_url: str = constants.DEFAULT_PIP_INDEX_URL
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from kubeflow.trainer.types import types
2+
3+
4+
class TestTrainerConfigurations:
5+
"""Test cases for trainer configurations and types."""
6+
7+
def test_centralized_trainer_configs(self):
8+
"""Test that centralized trainer configurations are properly defined."""
9+
# Verify all trainer frameworks have configurations
10+
for framework in types.Framework:
11+
assert framework in types.TRAINER_CONFIGS
12+
trainer = types.TRAINER_CONFIGS[framework]
13+
assert trainer.framework == framework
14+
15+
def test_default_trainer_uses_centralized_config(self):
16+
"""Test that DEFAULT_TRAINER uses centralized configuration."""
17+
assert types.DEFAULT_TRAINER == types.TRAINER_CONFIGS[types.Framework.TORCH]
18+
assert types.DEFAULT_TRAINER.framework == types.Framework.TORCH
19+
20+
def test_custom_trainer_python_file_with_args(self):
21+
"""Test CustomTrainer with python_file and python_args."""
22+
# Test basic python_file without args
23+
trainer = types.CustomTrainer(python_file="train.py")
24+
assert trainer.python_file == "train.py"
25+
assert trainer.python_args is None
26+
27+
# Test python_file with args
28+
trainer = types.CustomTrainer(
29+
python_file="train.py",
30+
python_args=["--epochs", "100", "--batch-size", "32"]
31+
)
32+
assert trainer.python_file == "train.py"
33+
assert trainer.python_args == ["--epochs", "100", "--batch-size", "32"]
34+
35+
# Test python_file with complex args
36+
trainer = types.CustomTrainer(
37+
python_file="train.py",
38+
python_args=["--epochs", "100", "--batch-size", "32", "--lr", "0.001", "--model-path", "/workspace/model"]
39+
)
40+
assert trainer.python_file == "train.py"
41+
assert trainer.python_args == ["--epochs", "100", "--batch-size", "32", "--lr", "0.001", "--model-path", "/workspace/model"]
42+
43+
def test_custom_trainer_mutual_exclusivity(self):
44+
"""Test that func and python_file are mutually exclusive."""
45+
# This should work
46+
trainer = types.CustomTrainer(python_file="train.py")
47+
assert trainer.func is None
48+
assert trainer.python_file == "train.py"
49+
50+
# This should work
51+
def dummy_func():
52+
pass
53+
trainer = types.CustomTrainer(func=dummy_func)
54+
assert trainer.func == dummy_func
55+
assert trainer.python_file is None
56+
57+
def test_custom_trainer_python_args_only(self):
58+
"""Test CustomTrainer with python_args but no python_file (should be None)."""
59+
trainer = types.CustomTrainer(python_args=["--epochs", "100"])
60+
assert trainer.python_file is None
61+
assert trainer.python_args == ["--epochs", "100"]
62+
63+
def test_custom_trainer_python_args_with_func(self):
64+
"""Test CustomTrainer with func and python_args (should be allowed)."""
65+
def dummy_func():
66+
pass
67+
68+
trainer = types.CustomTrainer(
69+
func=dummy_func,
70+
python_args=["--epochs", "100"]
71+
)
72+
assert trainer.func == dummy_func
73+
assert trainer.python_file is None
74+
assert trainer.python_args == ["--epochs", "100"]

python/kubeflow/trainer/utils/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,20 @@ def get_trainer_crd_from_custom_trainer(
372372
trainer.resources_per_node
373373
)
374374

375+
if trainer.python_file:
376+
if trainer.func:
377+
raise ValueError("Specify only one of func or python_file in CustomTrainer.")
378+
trainer_crd.command = ["python"]
379+
# Combine python_file with python_args
380+
args = [trainer.python_file]
381+
if trainer.python_args:
382+
args.extend(trainer.python_args)
383+
trainer_crd.args = args
384+
return trainer_crd
385+
386+
if not trainer.func:
387+
raise ValueError("You must specify either func or python_file in CustomTrainer.")
388+
375389
# Add command to the Trainer.
376390
# TODO: Support train function parameters.
377391
trainer_crd.command = get_command_using_train_func(
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2024 The Kubeflow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from unittest.mock import Mock
17+
18+
from kubeflow.trainer.utils import utils
19+
from kubeflow.trainer.types import types
20+
21+
22+
class TestCustomTrainerPythonFileSupport(unittest.TestCase):
23+
"""Test cases for the new python_file and python_args functionality in CustomTrainer."""
24+
25+
def test_get_trainer_crd_from_custom_trainer_python_file_with_args(self):
26+
"""Test get_trainer_crd_from_custom_trainer with python_file and python_args."""
27+
runtime = Mock()
28+
trainer = types.CustomTrainer(
29+
python_file="train.py",
30+
python_args=["--epochs", "100", "--batch-size", "32"],
31+
num_nodes=2,
32+
resources_per_node={"gpu": "4"},
33+
)
34+
35+
result = utils.get_trainer_crd_from_custom_trainer(runtime, trainer)
36+
37+
self.assertEqual(result.num_nodes, 2)
38+
self.assertEqual(result.command, ["python"])
39+
self.assertEqual(result.args, ["train.py", "--epochs", "100", "--batch-size", "32"])
40+
41+
def test_get_trainer_crd_from_custom_trainer_python_file_no_args(self):
42+
"""Test get_trainer_crd_from_custom_trainer with python_file but no args."""
43+
runtime = Mock()
44+
trainer = types.CustomTrainer(
45+
python_file="train.py", num_nodes=2, resources_per_node={"gpu": "4"}
46+
)
47+
48+
result = utils.get_trainer_crd_from_custom_trainer(runtime, trainer)
49+
50+
self.assertEqual(result.num_nodes, 2)
51+
self.assertEqual(result.command, ["python"])
52+
self.assertEqual(result.args, ["train.py"])
53+
54+
def test_get_trainer_crd_from_custom_trainer_mutual_exclusivity_both_specified(self):
55+
"""Test that func and python_file cannot be specified together."""
56+
runtime = Mock()
57+
trainer = types.CustomTrainer(func=lambda: None, python_file="train.py")
58+
59+
with self.assertRaises(ValueError) as context:
60+
utils.get_trainer_crd_from_custom_trainer(runtime, trainer)
61+
62+
self.assertIn(
63+
"Specify only one of func or python_file in CustomTrainer", str(context.exception)
64+
)
65+
66+
def test_get_trainer_crd_from_custom_trainer_mutual_exclusivity_neither_specified(self):
67+
"""Test that either func or python_file must be specified."""
68+
runtime = Mock()
69+
trainer = types.CustomTrainer()
70+
71+
with self.assertRaises(ValueError) as context:
72+
utils.get_trainer_crd_from_custom_trainer(runtime, trainer)
73+
74+
self.assertIn(
75+
"You must specify either func or python_file in CustomTrainer", str(context.exception)
76+
)
77+
78+
def test_get_trainer_crd_from_custom_trainer_with_func_unchanged(self):
79+
"""Test that existing func functionality remains unchanged."""
80+
runtime = Mock()
81+
runtime.trainer = Mock()
82+
runtime.trainer.command = ["python", "script.py"]
83+
84+
def dummy_func():
85+
pass
86+
87+
trainer = types.CustomTrainer(
88+
func=dummy_func, func_args={"lr": 0.001}, num_nodes=2, resources_per_node={"gpu": "4"}
89+
)
90+
91+
with unittest.mock.patch(
92+
"kubeflow.trainer.utils.utils.get_command_using_train_func"
93+
) as mock_get_command:
94+
mock_get_command.return_value = ["python", "script.py"]
95+
result = utils.get_trainer_crd_from_custom_trainer(runtime, trainer)
96+
97+
self.assertEqual(result.num_nodes, 2)
98+
# Verify that the existing func path still works
99+
mock_get_command.assert_called_once()

0 commit comments

Comments
 (0)