Skip to content
17 changes: 13 additions & 4 deletions ray-operator/controllers/ray/common/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,17 +410,26 @@ func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, wo
return podTemplate
}

func initLivenessAndReadinessProbe(rayContainer *corev1.Container, rayNodeType rayv1.RayNodeType, creatorCRDType utils.CRDType) {
func initLivenessAndReadinessProbe(rayContainer *corev1.Container, rayNodeType rayv1.RayNodeType, creatorCRDType utils.CRDType, rayStartParams map[string]string) {
getPort := func(key string, defaultVal int) int {
if portStr, ok := rayStartParams[key]; ok {
if port, err := strconv.Atoi(portStr); err == nil {
return port
}
}
return defaultVal
}

rayAgentRayletHealthCommand := fmt.Sprintf(
utils.BaseWgetHealthCommand,
utils.DefaultReadinessProbeTimeoutSeconds,
utils.DefaultDashboardAgentListenPort,
getPort("dashboard-agent-listen-port", utils.DefaultDashboardAgentListenPort),
utils.RayAgentRayletHealthPath,
)
rayDashboardGCSHealthCommand := fmt.Sprintf(
utils.BaseWgetHealthCommand,
utils.DefaultReadinessProbeFailureThreshold,
utils.DefaultDashboardPort,
getPort("dashboard-port", utils.DefaultDashboardPort),
utils.RayDashboardGCSHealthPath,
)

Expand Down Expand Up @@ -566,7 +575,7 @@ func BuildPod(ctx context.Context, podTemplateSpec corev1.PodTemplateSpec, rayNo
// Configure the readiness and liveness probes for the Ray container. These probes
// play a crucial role in KubeRay health checks. Without them, certain failures,
// such as the Raylet process crashing, may go undetected.
initLivenessAndReadinessProbe(&pod.Spec.Containers[utils.RayContainerIndex], rayNodeType, creatorCRDType)
initLivenessAndReadinessProbe(&pod.Spec.Containers[utils.RayContainerIndex], rayNodeType, creatorCRDType, rayStartParams)
}

return pod
Expand Down
77 changes: 74 additions & 3 deletions ray-operator/controllers/ray/common/pod_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1672,6 +1672,7 @@ func TestInitLivenessAndReadinessProbe(t *testing.T) {
podName := strings.ToLower(cluster.Name + utils.DashSymbol + string(rayv1.HeadNode) + utils.DashSymbol + utils.FormatInt32(0))
podTemplateSpec := DefaultHeadPodTemplate(context.Background(), *cluster, cluster.Spec.HeadGroupSpec, podName, "6379")
rayContainer := &podTemplateSpec.Spec.Containers[utils.RayContainerIndex]
rayStartParams := make(map[string]string)

// Test 1: User defines a custom HTTPGet probe.
httpGetProbe := corev1.Probe{
Expand All @@ -1686,7 +1687,7 @@ func TestInitLivenessAndReadinessProbe(t *testing.T) {

rayContainer.LivenessProbe = &httpGetProbe
rayContainer.ReadinessProbe = &httpGetProbe
initLivenessAndReadinessProbe(rayContainer, rayv1.HeadNode, "")
initLivenessAndReadinessProbe(rayContainer, rayv1.HeadNode, "", rayStartParams)
assert.NotNil(t, rayContainer.LivenessProbe.HTTPGet)
assert.NotNil(t, rayContainer.ReadinessProbe.HTTPGet)
assert.Nil(t, rayContainer.LivenessProbe.Exec)
Expand All @@ -1697,7 +1698,7 @@ func TestInitLivenessAndReadinessProbe(t *testing.T) {
// implying that an additional serve health check will be added to the readiness probe.
rayContainer.LivenessProbe = nil
rayContainer.ReadinessProbe = nil
initLivenessAndReadinessProbe(rayContainer, rayv1.WorkerNode, utils.RayServiceCRD)
initLivenessAndReadinessProbe(rayContainer, rayv1.WorkerNode, utils.RayServiceCRD, rayStartParams)
assert.NotNil(t, rayContainer.LivenessProbe.Exec)
assert.NotNil(t, rayContainer.ReadinessProbe.Exec)
assert.NotContains(t, strings.Join(rayContainer.LivenessProbe.Exec.Command, " "), utils.RayServeProxyHealthPath)
Expand All @@ -1710,14 +1711,84 @@ func TestInitLivenessAndReadinessProbe(t *testing.T) {
// implying that an additional serve health check will be added to the readiness probe.
rayContainer.LivenessProbe = nil
rayContainer.ReadinessProbe = nil
initLivenessAndReadinessProbe(rayContainer, rayv1.HeadNode, utils.RayServiceCRD)
initLivenessAndReadinessProbe(rayContainer, rayv1.HeadNode, utils.RayServiceCRD, rayStartParams)
assert.NotNil(t, rayContainer.LivenessProbe.Exec)
assert.NotNil(t, rayContainer.ReadinessProbe.Exec)
// head pod should not have Ray Serve proxy health probes
assert.NotContains(t, strings.Join(rayContainer.LivenessProbe.Exec.Command, " "), utils.RayServeProxyHealthPath)
assert.NotContains(t, strings.Join(rayContainer.ReadinessProbe.Exec.Command, " "), utils.RayServeProxyHealthPath)
assert.Equal(t, int32(5), rayContainer.LivenessProbe.TimeoutSeconds)
assert.Equal(t, int32(5), rayContainer.ReadinessProbe.TimeoutSeconds)

// Test 4: Test custom ports in rayStartParams for head node.
rayContainer.LivenessProbe = nil
rayContainer.ReadinessProbe = nil
customRayStartParams := map[string]string{
"dashboard-agent-listen-port": "8266",
"dashboard-port": "8365",
}
initLivenessAndReadinessProbe(rayContainer, rayv1.HeadNode, utils.RayClusterCRD, customRayStartParams)
assert.NotNil(t, rayContainer.LivenessProbe.Exec)
assert.NotNil(t, rayContainer.ReadinessProbe.Exec)

livenessCommand := strings.Join(rayContainer.LivenessProbe.Exec.Command, " ")
readinessCommand := strings.Join(rayContainer.ReadinessProbe.Exec.Command, " ")

assert.Contains(t, livenessCommand, ":8266", "Head pod liveness probe should use custom dashboard-agent-listen-port")
assert.Contains(t, livenessCommand, ":8365", "Head pod liveness probe should use custom dashboard-port")
assert.Contains(t, readinessCommand, ":8266", "Head pod readiness probe should use custom dashboard-agent-listen-port")
assert.Contains(t, readinessCommand, ":8365", "Head pod readiness probe should use custom dashboard-port")

// Test 5: Test custom ports in rayStartParams for worker node
rayContainer.LivenessProbe = nil
rayContainer.ReadinessProbe = nil
workerRayStartParams := map[string]string{
"dashboard-agent-listen-port": "9000",
}
initLivenessAndReadinessProbe(rayContainer, rayv1.WorkerNode, utils.RayClusterCRD, workerRayStartParams)
assert.NotNil(t, rayContainer.LivenessProbe.Exec)
assert.NotNil(t, rayContainer.ReadinessProbe.Exec)

workerLivenessCommand := strings.Join(rayContainer.LivenessProbe.Exec.Command, " ")
workerReadinessCommand := strings.Join(rayContainer.ReadinessProbe.Exec.Command, " ")

assert.Contains(t, workerLivenessCommand, ":9000", "Worker pod should use custom dashboard-agent-listen-port")
assert.Contains(t, workerReadinessCommand, ":9000", "Worker pod should use custom dashboard-agent-listen-port")
assert.NotContains(t, workerLivenessCommand, fmt.Sprintf(":%d", utils.DefaultDashboardPort), "Worker pod should not check dashboard-port")
assert.NotContains(t, workerReadinessCommand, fmt.Sprintf(":%d", utils.DefaultDashboardPort), "Worker pod should not check dashboard-port")

// Test 6: Test RayService worker with custom ports and serve proxy health check
rayContainer.LivenessProbe = nil
rayContainer.ReadinessProbe = nil
rayContainer.Ports = []corev1.ContainerPort{
{
Name: utils.ServingPortName,
ContainerPort: int32(utils.DefaultServingPort),
},
}
rayServiceWorkerParams := map[string]string{
"dashboard-agent-listen-port": "8500",
}
initLivenessAndReadinessProbe(rayContainer, rayv1.WorkerNode, utils.RayServiceCRD, rayServiceWorkerParams)
rayServiceReadinessCommand := strings.Join(rayContainer.ReadinessProbe.Exec.Command, " ")
assert.Contains(t, rayServiceReadinessCommand, ":8500", "RayService worker should use custom dashboard-agent-listen-port")
assert.Contains(t, rayServiceReadinessCommand, utils.RayServeProxyHealthPath, "RayService worker should include serve proxy health check")
assert.Equal(t, int32(utils.ServeReadinessProbeFailureThreshold), rayContainer.ReadinessProbe.FailureThreshold, "RayService worker should have correct failure threshold")

// Test 8: Test invalid port values (should fall back to defaults)
rayContainer.LivenessProbe = nil
rayContainer.ReadinessProbe = nil
invalidPortParams := map[string]string{
"dashboard-agent-listen-port": "invalid-port",
"dashboard-port": "not-a-number",
}
initLivenessAndReadinessProbe(rayContainer, rayv1.HeadNode, utils.RayClusterCRD, invalidPortParams)

invalidPortLivenessCommand := strings.Join(rayContainer.LivenessProbe.Exec.Command, " ")

// Should fall back to default ports when invalid values are provided
assert.Contains(t, invalidPortLivenessCommand, fmt.Sprintf(":%d", utils.DefaultDashboardAgentListenPort), "Should fall back to default dashboard-agent-listen-port for invalid input")
assert.Contains(t, invalidPortLivenessCommand, fmt.Sprintf(":%d", utils.DefaultDashboardPort), "Should fall back to default dashboard-port for invalid input")
}

func TestGenerateRayStartCommand(t *testing.T) {
Expand Down
132 changes: 132 additions & 0 deletions ray-operator/controllers/ray/raycluster_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ func rayClusterTemplate(name string, namespace string) *rayv1.RayCluster {
}
}

func getRayContainerFromPod(pod *corev1.Pod) *corev1.Container {
for i := range pod.Spec.Containers {
if pod.Spec.Containers[i].Name == "ray-head" || pod.Spec.Containers[i].Name == "ray-worker" {
return &pod.Spec.Containers[i]
}
}
return nil
}

var _ = Context("Inside the default namespace", func() {
Describe("Static RayCluster", Ordered, func() {
ctx := context.Background()
Expand Down Expand Up @@ -1486,4 +1495,127 @@ var _ = Context("Inside the default namespace", func() {
time.Second*3, time.Millisecond*500).Should(Succeed())
})
})

Describe("RayCluster with custom probe ports", Ordered, func() {
ctx := context.Background()
namespace := "default"
rayClusterName := "raycluster-custom-probe-ports"

rayCluster := rayClusterTemplate(rayClusterName, namespace)

rayCluster.Spec.HeadGroupSpec.RayStartParams = map[string]string{
"dashboard-port": "8365",
"dashboard-agent-listen-port": "8266",
}

rayCluster.Spec.WorkerGroupSpecs[0].RayStartParams = map[string]string{
"dashboard-agent-listen-port": "9000",
}

rayCluster.Spec.WorkerGroupSpecs[0].Replicas = ptr.To[int32](1)
rayCluster.Spec.WorkerGroupSpecs[0].MaxReplicas = ptr.To[int32](1)

headPods := corev1.PodList{}
workerPods := corev1.PodList{}
workerFilters := common.RayClusterGroupPodsAssociationOptions(rayCluster, rayCluster.Spec.WorkerGroupSpecs[0].GroupName).ToListOptions()
headFilters := common.RayClusterHeadPodsAssociationOptions(rayCluster).ToListOptions()

It("Verify RayCluster spec", func() {
Expect(rayCluster.Spec.HeadGroupSpec.RayStartParams["dashboard-port"]).To(Equal("8365"))
Expect(rayCluster.Spec.HeadGroupSpec.RayStartParams["dashboard-agent-listen-port"]).To(Equal("8266"))
Expect(rayCluster.Spec.WorkerGroupSpecs[0].RayStartParams["dashboard-agent-listen-port"]).To(Equal("9000"))
Expect(rayCluster.Spec.WorkerGroupSpecs).To(HaveLen(1))
Expect(rayCluster.Spec.WorkerGroupSpecs[0].Replicas).To(Equal(ptr.To[int32](1)))
})

It("Create a RayCluster custom resource", func() {
err := k8sClient.Create(ctx, rayCluster)
Expect(err).NotTo(HaveOccurred(), "Failed to create RayCluster")
Eventually(
getResourceFunc(ctx, client.ObjectKey{Name: rayCluster.Name, Namespace: namespace}, rayCluster),
time.Second*3, time.Millisecond*500).Should(Succeed(), "Should be able to see RayCluster: %v", rayCluster.Name)
})

It("Check the number of head Pods", func() {
numHeadPods := 1
Eventually(
listResourceFunc(ctx, &headPods, headFilters...),
time.Second*3, time.Millisecond*500).Should(Equal(numHeadPods), fmt.Sprintf("headGroup %v", headPods.Items))
})

It("Check the number of worker Pods", func() {
numWorkerPods := 1
Eventually(
listResourceFunc(ctx, &workerPods, workerFilters...),
time.Second*3, time.Millisecond*500).Should(Equal(numWorkerPods), fmt.Sprintf("workerGroup %v", workerPods.Items))
})

It("Update all Pods to Running", func() {
for _, headPod := range headPods.Items {
headPod.Status.Phase = corev1.PodRunning
Expect(k8sClient.Status().Update(ctx, &headPod)).Should(Succeed())
}

Eventually(
isAllPodsRunningByFilters).WithContext(ctx).WithArguments(headPods, headFilters).WithTimeout(time.Second*3).WithPolling(time.Millisecond*500).Should(BeTrue(), "Head Pod should be running.")

for _, workerPod := range workerPods.Items {
workerPod.Status.Phase = corev1.PodRunning
Expect(k8sClient.Status().Update(ctx, &workerPod)).Should(Succeed())
}

Eventually(
isAllPodsRunningByFilters).WithContext(ctx).WithArguments(workerPods, workerFilters).WithTimeout(time.Second*3).WithPolling(time.Millisecond*500).Should(BeTrue(), "All worker Pods should be running.")
})

It("Should have head pod with correct probe configuration", func() {
Expect(headPods.Items).Should(HaveLen(1), "Should have exactly one head pod")
headPod := headPods.Items[0]

rayContainer := getRayContainerFromPod(&headPod)
Expect(rayContainer).NotTo(BeNil(), "Ray container should exist")

Expect(rayContainer.LivenessProbe).NotTo(BeNil(), "LivenessProbe should be configured")
Expect(rayContainer.LivenessProbe.Exec).NotTo(BeNil(), "LivenessProbe should use Exec")

livenessCommand := strings.Join(rayContainer.LivenessProbe.Exec.Command, " ")
Expect(livenessCommand).To(ContainSubstring(":8266"), "Head pod liveness probe should use custom dashboard-agent-listen-port")
Expect(livenessCommand).To(ContainSubstring(":8365"), "Head pod liveness probe should use custom dashboard-port")

Expect(rayContainer.ReadinessProbe).NotTo(BeNil(), "ReadinessProbe should be configured")
Expect(rayContainer.ReadinessProbe.Exec).NotTo(BeNil(), "ReadinessProbe should use Exec")

readinessCommand := strings.Join(rayContainer.ReadinessProbe.Exec.Command, " ")
Expect(readinessCommand).To(ContainSubstring(":8266"), "Head pod readiness probe should use custom dashboard-agent-listen-port")
Expect(readinessCommand).To(ContainSubstring(":8365"), "Head pod readiness probe should use custom dashboard-port")
})

It("Should have worker pod with correct probe configuration", func() {
Expect(workerPods.Items).Should(HaveLen(1), "Should have exactly one worker pod")
workerPod := workerPods.Items[0]

rayContainer := getRayContainerFromPod(&workerPod)
Expect(rayContainer).NotTo(BeNil(), "Ray container should exist")

Expect(rayContainer.LivenessProbe).NotTo(BeNil(), "LivenessProbe should be configured")
Expect(rayContainer.LivenessProbe.Exec).NotTo(BeNil(), "LivenessProbe should use Exec")

livenessCommand := strings.Join(rayContainer.LivenessProbe.Exec.Command, " ")
Expect(livenessCommand).To(ContainSubstring(":9000"), "Worker pod should use custom dashboard-agent-listen-port")
Expect(livenessCommand).NotTo(ContainSubstring(":8365"), "Worker pod should not check dashboard-port")

Expect(rayContainer.ReadinessProbe).NotTo(BeNil(), "ReadinessProbe should be configured")
Expect(rayContainer.ReadinessProbe.Exec).NotTo(BeNil(), "ReadinessProbe should use Exec")

readinessCommand := strings.Join(rayContainer.ReadinessProbe.Exec.Command, " ")
Expect(readinessCommand).To(ContainSubstring(":9000"), "Worker pod should use custom dashboard-agent-listen-port")
Expect(readinessCommand).NotTo(ContainSubstring(":8365"), "Worker pod should not check dashboard-port")
})

It("RayCluster's .status.state should be updated to 'ready' shortly after all Pods are Running", func() {
Eventually(
getClusterState(ctx, namespace, rayCluster.Name),
time.Second*3, time.Millisecond*500).Should(Equal(rayv1.Ready))
})
})
})
Loading