Skip to content

Commit 3503dca

Browse files
committed
Refactor RayJob submitter template handling
Signed-off-by: You-Cheng Lin (Owen) <[email protected]>
1 parent 3066c41 commit 3503dca

File tree

6 files changed

+35
-59
lines changed

6 files changed

+35
-59
lines changed

ray-operator/controllers/ray/batchscheduler/yunikorn/yunikorn_task_groups.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,7 @@ func newTaskGroupsFromRayJobSpec(rayJobSpec *v1.RayJobSpec) *TaskGroups {
7272
taskGroups := newTaskGroupsFromRayClusterSpec(rayJobSpec.RayClusterSpec)
7373

7474
// submitter group
75-
var submitterGroupSpec corev1.PodSpec
76-
if rayJobSpec.SubmitterPodTemplate != nil {
77-
submitterGroupSpec = rayJobSpec.SubmitterPodTemplate.Spec
78-
} else {
79-
submitterGroupSpec = common.GetDefaultSubmitterTemplate(rayJobSpec.RayClusterSpec).Spec
80-
}
75+
submitterGroupSpec := common.GetSubmitterTemplateFromRayJobSpec(rayJobSpec).Spec
8176

8277
submitterPodMinResource := utils.CalculatePodResource(submitterGroupSpec)
8378
taskGroups.addTaskGroup(

ray-operator/controllers/ray/common/job.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,15 @@ func BuildJobSubmitCommand(rayJobInstance *rayv1.RayJob, submissionMode rayv1.Jo
198198
return cmd, nil
199199
}
200200

201-
// GetDefaultSubmitterTemplate creates a default submitter template for the Ray job.
202-
func GetDefaultSubmitterTemplate(rayClusterSpec *rayv1.RayClusterSpec) corev1.PodTemplateSpec {
201+
// GetSubmitterTemplate creates a default submitter template for the Ray job.
202+
func GetSubmitterTemplateFromRayJobSpec(rayJobSpec *rayv1.RayJobSpec) corev1.PodTemplateSpec {
203+
if rayJobSpec.SubmitterPodTemplate != nil {
204+
return *rayJobSpec.SubmitterPodTemplate.DeepCopy()
205+
}
203206
return corev1.PodTemplateSpec{
204207
Spec: corev1.PodSpec{
205208
Containers: []corev1.Container{
206-
GetDefaultSubmitterContainer(rayClusterSpec),
209+
GetDefaultSubmitterContainer(rayJobSpec.RayClusterSpec),
207210
},
208211
RestartPolicy: corev1.RestartPolicyNever,
209212
},

ray-operator/controllers/ray/common/job_test.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -219,21 +219,23 @@ func TestMetadataRaisesErrorBeforeRay26(t *testing.T) {
219219
}
220220

221221
func TestGetDefaultSubmitterTemplate(t *testing.T) {
222-
rayCluster := &rayv1.RayCluster{
223-
Spec: rayv1.RayClusterSpec{
224-
HeadGroupSpec: rayv1.HeadGroupSpec{
225-
Template: corev1.PodTemplateSpec{
226-
Spec: corev1.PodSpec{
227-
Containers: []corev1.Container{
228-
{
229-
Image: "rayproject/ray:test-submitter-template",
222+
rayJob := &rayv1.RayJob{
223+
Spec: rayv1.RayJobSpec{
224+
RayClusterSpec: &rayv1.RayClusterSpec{
225+
HeadGroupSpec: rayv1.HeadGroupSpec{
226+
Template: corev1.PodTemplateSpec{
227+
Spec: corev1.PodSpec{
228+
Containers: []corev1.Container{
229+
{
230+
Image: "rayproject/ray:test-submitter-template",
231+
},
230232
},
231233
},
232234
},
233235
},
234236
},
235237
},
236238
}
237-
template := GetDefaultSubmitterTemplate(&rayCluster.Spec)
238-
assert.Equal(t, template.Spec.Containers[0].Image, rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[utils.RayContainerIndex].Image)
239+
template := GetSubmitterTemplateFromRayJobSpec(&rayJob.Spec)
240+
assert.Equal(t, template.Spec.Containers[0].Image, rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Containers[utils.RayContainerIndex].Image)
239241
}

ray-operator/controllers/ray/rayjob_controller.go

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ func (r *RayJobReconciler) Reconcile(ctx context.Context, request ctrl.Request)
219219
}
220220

221221
if rayJobInstance.Spec.SubmissionMode == rayv1.K8sJobMode {
222-
if err := r.createK8sJobIfNeed(ctx, rayJobInstance, rayClusterInstance); err != nil {
222+
if err := r.createK8sJobIfNeed(ctx, rayJobInstance); err != nil {
223223
return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err
224224
}
225225
}
@@ -570,13 +570,13 @@ func checkBackoffLimitAndUpdateStatusIfNeeded(ctx context.Context, rayJob *rayv1
570570
}
571571

572572
// createK8sJobIfNeed creates a Kubernetes Job for the RayJob if it doesn't exist.
573-
func (r *RayJobReconciler) createK8sJobIfNeed(ctx context.Context, rayJobInstance *rayv1.RayJob, rayClusterInstance *rayv1.RayCluster) error {
573+
func (r *RayJobReconciler) createK8sJobIfNeed(ctx context.Context, rayJobInstance *rayv1.RayJob) error {
574574
logger := ctrl.LoggerFrom(ctx)
575575
job := &batchv1.Job{}
576576
namespacedName := common.RayJobK8sJobNamespacedName(rayJobInstance)
577577
if err := r.Client.Get(ctx, namespacedName, job); err != nil {
578578
if errors.IsNotFound(err) {
579-
submitterTemplate, err := getSubmitterTemplate(ctx, rayJobInstance, rayClusterInstance)
579+
submitterTemplate, err := getSubmitterTemplate(rayJobInstance)
580580
if err != nil {
581581
return err
582582
}
@@ -597,18 +597,8 @@ func (r *RayJobReconciler) createK8sJobIfNeed(ctx context.Context, rayJobInstanc
597597
}
598598

599599
// getSubmitterTemplate builds the submitter pod template for the Ray job.
600-
func getSubmitterTemplate(ctx context.Context, rayJobInstance *rayv1.RayJob, rayClusterInstance *rayv1.RayCluster) (corev1.PodTemplateSpec, error) {
601-
logger := ctrl.LoggerFrom(ctx)
602-
var submitterTemplate corev1.PodTemplateSpec
603-
604-
// Set the default value for the optional field SubmitterPodTemplate if not provided.
605-
if rayJobInstance.Spec.SubmitterPodTemplate == nil {
606-
submitterTemplate = common.GetDefaultSubmitterTemplate(&rayClusterInstance.Spec)
607-
logger.Info("default submitter template is used")
608-
} else {
609-
submitterTemplate = *rayJobInstance.Spec.SubmitterPodTemplate.DeepCopy()
610-
logger.Info("user-provided submitter template is used; the first container is assumed to be the submitter")
611-
}
600+
func getSubmitterTemplate(rayJobInstance *rayv1.RayJob) (corev1.PodTemplateSpec, error) {
601+
submitterTemplate := common.GetSubmitterTemplateFromRayJobSpec(&rayJobInstance.Spec)
612602

613603
if err := configureSubmitterContainer(&submitterTemplate.Spec.Containers[utils.RayContainerIndex], rayJobInstance, rayv1.K8sJobMode); err != nil {
614604
return corev1.PodTemplateSpec{}, err

ray-operator/controllers/ray/rayjob_controller_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ var _ = Context("RayJob with different submission modes", func() {
357357
namespace := "default"
358358
rayJob := rayJobTemplate("rayjob-invalid-test", namespace)
359359
rayCluster := &rayv1.RayCluster{Spec: *rayJob.Spec.RayClusterSpec}
360-
template := common.GetDefaultSubmitterTemplate(&rayCluster.Spec)
360+
template := common.GetSubmitterTemplateFromRayJobSpec(&rayJob.Spec)
361361
template.Spec.RestartPolicy = "" // Make it invalid to create a submitter. Ref: https://github.com/ray-project/kuberay/pull/2389#issuecomment-2359564334
362362
rayJob.Spec.SubmitterPodTemplate = &template
363363

ray-operator/controllers/ray/rayjob_controller_unit_test.go

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ func TestCreateRayJobSubmitterIfNeed(t *testing.T) {
6060
Name: "test-rayjob",
6161
Namespace: "default",
6262
},
63+
Spec: rayv1.RayJobSpec{
64+
RayClusterSpec: &rayCluster.Spec,
65+
},
6366
}
6467

6568
k8sJob := &batchv1.Job{
@@ -79,14 +82,14 @@ func TestCreateRayJobSubmitterIfNeed(t *testing.T) {
7982
Recorder: &record.FakeRecorder{},
8083
}
8184

82-
err := rayJobReconciler.createK8sJobIfNeed(ctx, rayJob, rayCluster)
85+
err := rayJobReconciler.createK8sJobIfNeed(ctx, rayJob)
8386
require.NoError(t, err)
8487

8588
// Test 2: Create a new k8s job if it does not already exist
8689
fakeClient = clientFake.NewClientBuilder().WithScheme(newScheme).WithRuntimeObjects(rayCluster, rayJob).Build()
8790
rayJobReconciler.Client = fakeClient
8891

89-
err = rayJobReconciler.createK8sJobIfNeed(ctx, rayJob, rayCluster)
92+
err = rayJobReconciler.createK8sJobIfNeed(ctx, rayJob)
9093
require.NoError(t, err)
9194

9295
err = fakeClient.Get(ctx, types.NamespacedName{
@@ -144,53 +147,36 @@ func TestGetSubmitterTemplate(t *testing.T) {
144147
JobId: "test-job-id",
145148
},
146149
}
147-
rayClusterInstance := &rayv1.RayCluster{
148-
Spec: rayv1.RayClusterSpec{
149-
HeadGroupSpec: rayv1.HeadGroupSpec{
150-
Template: corev1.PodTemplateSpec{
151-
Spec: corev1.PodSpec{
152-
Containers: []corev1.Container{
153-
{
154-
Image: "rayproject/ray:custom-version",
155-
},
156-
},
157-
},
158-
},
159-
},
160-
},
161-
}
162-
163-
ctx := context.Background()
164150

165151
// Test 1: User provided template with command
166-
submitterTemplate, err := getSubmitterTemplate(ctx, rayJobInstanceWithTemplate, nil)
152+
submitterTemplate, err := getSubmitterTemplate(rayJobInstanceWithTemplate)
167153
require.NoError(t, err)
168154
assert.Equal(t, "user-command", submitterTemplate.Spec.Containers[utils.RayContainerIndex].Command[0])
169155

170156
// Test 2: User provided template without command
171157
rayJobInstanceWithTemplate.Spec.SubmitterPodTemplate.Spec.Containers[utils.RayContainerIndex].Command = []string{}
172-
submitterTemplate, err = getSubmitterTemplate(ctx, rayJobInstanceWithTemplate, nil)
158+
submitterTemplate, err = getSubmitterTemplate(rayJobInstanceWithTemplate)
173159
require.NoError(t, err)
174160
assert.Equal(t, []string{"/bin/bash", "-ce", "--"}, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Command)
175161
assert.Equal(t, []string{"if ! ray job status --address http://test-url test-job-id >/dev/null 2>&1 ; then ray job submit --address http://test-url --no-wait --submission-id test-job-id -- echo no quote 'single quote' \"double quote\" ; fi ; ray job logs --address http://test-url --follow test-job-id"}, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Args)
176162

177163
// Test 3: User did not provide template, should use the image of the Ray Head
178-
submitterTemplate, err = getSubmitterTemplate(ctx, rayJobInstanceWithoutTemplate, rayClusterInstance)
164+
submitterTemplate, err = getSubmitterTemplate(rayJobInstanceWithoutTemplate)
179165
require.NoError(t, err)
180166
assert.Equal(t, []string{"/bin/bash", "-ce", "--"}, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Command)
181167
assert.Equal(t, []string{"if ! ray job status --address http://test-url test-job-id >/dev/null 2>&1 ; then ray job submit --address http://test-url --no-wait --submission-id test-job-id -- echo no quote 'single quote' \"double quote\" ; fi ; ray job logs --address http://test-url --follow test-job-id"}, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Args)
182168
assert.Equal(t, "rayproject/ray:custom-version", submitterTemplate.Spec.Containers[utils.RayContainerIndex].Image)
183169

184170
// Test 4: Check default PYTHONUNBUFFERED setting
185-
submitterTemplate, err = getSubmitterTemplate(ctx, rayJobInstanceWithoutTemplate, rayClusterInstance)
171+
submitterTemplate, err = getSubmitterTemplate(rayJobInstanceWithoutTemplate)
186172
require.NoError(t, err)
187173

188174
envVar, found := utils.EnvVarByName(PythonUnbufferedEnvVarName, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Env)
189175
assert.True(t, found)
190176
assert.Equal(t, "1", envVar.Value)
191177

192178
// Test 5: Check default RAY_DASHBOARD_ADDRESS env var
193-
submitterTemplate, err = getSubmitterTemplate(ctx, rayJobInstanceWithTemplate, nil)
179+
submitterTemplate, err = getSubmitterTemplate(rayJobInstanceWithTemplate)
194180
require.NoError(t, err)
195181

196182
envVar, found = utils.EnvVarByName(utils.RAY_DASHBOARD_ADDRESS, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Env)

0 commit comments

Comments
 (0)