Skip to content

Commit 906ab83

Browse files
committed
Refactor RayJob submitter template handling
Signed-off-by: You-Cheng Lin (Owen) <[email protected]>
1 parent 3066c41 commit 906ab83

File tree

6 files changed

+20
-31
lines changed

6 files changed

+20
-31
lines changed

ray-operator/controllers/ray/batchscheduler/yunikorn/yunikorn_task_groups.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,7 @@ func newTaskGroupsFromRayClusterSpec(rayClusterSpec *v1.RayClusterSpec) *TaskGro
7171
func newTaskGroupsFromRayJobSpec(rayJobSpec *v1.RayJobSpec) *TaskGroups {
7272
taskGroups := newTaskGroupsFromRayClusterSpec(rayJobSpec.RayClusterSpec)
7373

74-
// submitter group
75-
var submitterGroupSpec corev1.PodSpec
76-
if rayJobSpec.SubmitterPodTemplate != nil {
77-
submitterGroupSpec = rayJobSpec.SubmitterPodTemplate.Spec
78-
} else {
79-
submitterGroupSpec = common.GetDefaultSubmitterTemplate(rayJobSpec.RayClusterSpec).Spec
80-
}
74+
submitterGroupSpec := common.GetSubmitterTemplate(rayJobSpec, rayJobSpec.RayClusterSpec).Spec
8175

8276
submitterPodMinResource := utils.CalculatePodResource(submitterGroupSpec)
8377
taskGroups.addTaskGroup(

ray-operator/controllers/ray/common/job.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,11 @@ func BuildJobSubmitCommand(rayJobInstance *rayv1.RayJob, submissionMode rayv1.Jo
198198
return cmd, nil
199199
}
200200

201-
// GetDefaultSubmitterTemplate creates a default submitter template for the Ray job.
202-
func GetDefaultSubmitterTemplate(rayClusterSpec *rayv1.RayClusterSpec) corev1.PodTemplateSpec {
201+
// GetSubmitterTemplate creates a default submitter template for the Ray job.
202+
func GetSubmitterTemplate(rayJobSpec *rayv1.RayJobSpec, rayClusterSpec *rayv1.RayClusterSpec) corev1.PodTemplateSpec {
203+
if rayJobSpec.SubmitterPodTemplate != nil {
204+
return *rayJobSpec.SubmitterPodTemplate.DeepCopy()
205+
}
203206
return corev1.PodTemplateSpec{
204207
Spec: corev1.PodSpec{
205208
Containers: []corev1.Container{

ray-operator/controllers/ray/common/job_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,10 @@ func TestMetadataRaisesErrorBeforeRay26(t *testing.T) {
218218
require.Error(t, err)
219219
}
220220

221-
func TestGetDefaultSubmitterTemplate(t *testing.T) {
221+
func TestGetSubmitterTemplate(t *testing.T) {
222+
rayJob := &rayv1.RayJob{
223+
Spec: rayv1.RayJobSpec{},
224+
}
222225
rayCluster := &rayv1.RayCluster{
223226
Spec: rayv1.RayClusterSpec{
224227
HeadGroupSpec: rayv1.HeadGroupSpec{
@@ -234,6 +237,6 @@ func TestGetDefaultSubmitterTemplate(t *testing.T) {
234237
},
235238
},
236239
}
237-
template := GetDefaultSubmitterTemplate(&rayCluster.Spec)
240+
template := GetSubmitterTemplate(&rayJob.Spec, &rayCluster.Spec)
238241
assert.Equal(t, template.Spec.Containers[0].Image, rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[utils.RayContainerIndex].Image)
239242
}

ray-operator/controllers/ray/rayjob_controller.go

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ func (r *RayJobReconciler) createK8sJobIfNeed(ctx context.Context, rayJobInstanc
576576
namespacedName := common.RayJobK8sJobNamespacedName(rayJobInstance)
577577
if err := r.Client.Get(ctx, namespacedName, job); err != nil {
578578
if errors.IsNotFound(err) {
579-
submitterTemplate, err := getSubmitterTemplate(ctx, rayJobInstance, rayClusterInstance)
579+
submitterTemplate, err := getSubmitterTemplate(rayJobInstance, rayClusterInstance)
580580
if err != nil {
581581
return err
582582
}
@@ -597,18 +597,9 @@ func (r *RayJobReconciler) createK8sJobIfNeed(ctx context.Context, rayJobInstanc
597597
}
598598

599599
// getSubmitterTemplate builds the submitter pod template for the Ray job.
600-
func getSubmitterTemplate(ctx context.Context, rayJobInstance *rayv1.RayJob, rayClusterInstance *rayv1.RayCluster) (corev1.PodTemplateSpec, error) {
601-
logger := ctrl.LoggerFrom(ctx)
602-
var submitterTemplate corev1.PodTemplateSpec
603-
600+
func getSubmitterTemplate(rayJobInstance *rayv1.RayJob, rayClusterInstance *rayv1.RayCluster) (corev1.PodTemplateSpec, error) {
604601
// Set the default value for the optional field SubmitterPodTemplate if not provided.
605-
if rayJobInstance.Spec.SubmitterPodTemplate == nil {
606-
submitterTemplate = common.GetDefaultSubmitterTemplate(&rayClusterInstance.Spec)
607-
logger.Info("default submitter template is used")
608-
} else {
609-
submitterTemplate = *rayJobInstance.Spec.SubmitterPodTemplate.DeepCopy()
610-
logger.Info("user-provided submitter template is used; the first container is assumed to be the submitter")
611-
}
602+
submitterTemplate := common.GetSubmitterTemplate(&rayJobInstance.Spec, &rayClusterInstance.Spec)
612603

613604
if err := configureSubmitterContainer(&submitterTemplate.Spec.Containers[utils.RayContainerIndex], rayJobInstance, rayv1.K8sJobMode); err != nil {
614605
return corev1.PodTemplateSpec{}, err

ray-operator/controllers/ray/rayjob_controller_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ var _ = Context("RayJob with different submission modes", func() {
357357
namespace := "default"
358358
rayJob := rayJobTemplate("rayjob-invalid-test", namespace)
359359
rayCluster := &rayv1.RayCluster{Spec: *rayJob.Spec.RayClusterSpec}
360-
template := common.GetDefaultSubmitterTemplate(&rayCluster.Spec)
360+
template := common.GetSubmitterTemplate(&rayJob.Spec, &rayCluster.Spec)
361361
template.Spec.RestartPolicy = "" // Make it invalid to create a submitter. Ref: https://github.com/ray-project/kuberay/pull/2389#issuecomment-2359564334
362362
rayJob.Spec.SubmitterPodTemplate = &template
363363

ray-operator/controllers/ray/rayjob_controller_unit_test.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,37 +160,35 @@ func TestGetSubmitterTemplate(t *testing.T) {
160160
},
161161
}
162162

163-
ctx := context.Background()
164-
165163
// Test 1: User provided template with command
166-
submitterTemplate, err := getSubmitterTemplate(ctx, rayJobInstanceWithTemplate, nil)
164+
submitterTemplate, err := getSubmitterTemplate(rayJobInstanceWithTemplate, rayClusterInstance)
167165
require.NoError(t, err)
168166
assert.Equal(t, "user-command", submitterTemplate.Spec.Containers[utils.RayContainerIndex].Command[0])
169167

170168
// Test 2: User provided template without command
171169
rayJobInstanceWithTemplate.Spec.SubmitterPodTemplate.Spec.Containers[utils.RayContainerIndex].Command = []string{}
172-
submitterTemplate, err = getSubmitterTemplate(ctx, rayJobInstanceWithTemplate, nil)
170+
submitterTemplate, err = getSubmitterTemplate(rayJobInstanceWithTemplate, rayClusterInstance)
173171
require.NoError(t, err)
174172
assert.Equal(t, []string{"/bin/bash", "-ce", "--"}, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Command)
175173
assert.Equal(t, []string{"if ! ray job status --address http://test-url test-job-id >/dev/null 2>&1 ; then ray job submit --address http://test-url --no-wait --submission-id test-job-id -- echo no quote 'single quote' \"double quote\" ; fi ; ray job logs --address http://test-url --follow test-job-id"}, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Args)
176174

177175
// Test 3: User did not provide template, should use the image of the Ray Head
178-
submitterTemplate, err = getSubmitterTemplate(ctx, rayJobInstanceWithoutTemplate, rayClusterInstance)
176+
submitterTemplate, err = getSubmitterTemplate(rayJobInstanceWithoutTemplate, rayClusterInstance)
179177
require.NoError(t, err)
180178
assert.Equal(t, []string{"/bin/bash", "-ce", "--"}, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Command)
181179
assert.Equal(t, []string{"if ! ray job status --address http://test-url test-job-id >/dev/null 2>&1 ; then ray job submit --address http://test-url --no-wait --submission-id test-job-id -- echo no quote 'single quote' \"double quote\" ; fi ; ray job logs --address http://test-url --follow test-job-id"}, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Args)
182180
assert.Equal(t, "rayproject/ray:custom-version", submitterTemplate.Spec.Containers[utils.RayContainerIndex].Image)
183181

184182
// Test 4: Check default PYTHONUNBUFFERED setting
185-
submitterTemplate, err = getSubmitterTemplate(ctx, rayJobInstanceWithoutTemplate, rayClusterInstance)
183+
submitterTemplate, err = getSubmitterTemplate(rayJobInstanceWithoutTemplate, rayClusterInstance)
186184
require.NoError(t, err)
187185

188186
envVar, found := utils.EnvVarByName(PythonUnbufferedEnvVarName, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Env)
189187
assert.True(t, found)
190188
assert.Equal(t, "1", envVar.Value)
191189

192190
// Test 5: Check default RAY_DASHBOARD_ADDRESS env var
193-
submitterTemplate, err = getSubmitterTemplate(ctx, rayJobInstanceWithTemplate, nil)
191+
submitterTemplate, err = getSubmitterTemplate(rayJobInstanceWithTemplate, rayClusterInstance)
194192
require.NoError(t, err)
195193

196194
envVar, found = utils.EnvVarByName(utils.RAY_DASHBOARD_ADDRESS, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Env)

0 commit comments

Comments
 (0)