diff --git a/python/ray/dashboard/modules/aggregator/aggregator_agent.py b/python/ray/dashboard/modules/aggregator/aggregator_agent.py index 9e662c0ab516..cf6fb2dfe242 100644 --- a/python/ray/dashboard/modules/aggregator/aggregator_agent.py +++ b/python/ray/dashboard/modules/aggregator/aggregator_agent.py @@ -66,6 +66,12 @@ PUBLISH_EVENTS_TO_EXTERNAL_HTTP_SERVICE = ray_constants.env_bool( f"{env_var_prefix}_PUBLISH_EVENTS_TO_EXTERNAL_HTTP_SERVICE", True ) +# flag to control whether preserve the proto field name when converting the events to +# JSON. If True, the proto field name will be preserved. If False, the proto field name +# will be converted to camel case. +PRESERVE_PROTO_FIELD_NAME = ray_constants.env_bool( + f"{env_var_prefix}_PRESERVE_PROTO_FIELD_NAME", False +) class AggregatorAgent( @@ -124,6 +130,7 @@ def __init__(self, dashboard_agent) -> None: endpoint=self._events_export_addr, executor=self._executor, events_filter_fn=self._can_expose_event, + preserve_proto_field_name=PRESERVE_PROTO_FIELD_NAME, ), event_buffer=self._event_buffer, common_metric_tags=self._common_tags, diff --git a/python/ray/dashboard/modules/aggregator/publisher/async_publisher_client.py b/python/ray/dashboard/modules/aggregator/publisher/async_publisher_client.py index 15fd7382d4f1..0b9a447a62e6 100644 --- a/python/ray/dashboard/modules/aggregator/publisher/async_publisher_client.py +++ b/python/ray/dashboard/modules/aggregator/publisher/async_publisher_client.py @@ -66,12 +66,14 @@ def __init__( executor: ThreadPoolExecutor, events_filter_fn: Callable[[object], bool], timeout: float = PUBLISHER_TIMEOUT_SECONDS, + preserve_proto_field_name: bool = False, ) -> None: self._endpoint = endpoint self._executor = executor self._events_filter_fn = events_filter_fn self._timeout = aiohttp.ClientTimeout(total=timeout) self._session = None + self._preserve_proto_field_name = preserve_proto_field_name async def publish(self, batch: PublishBatch) -> PublishStats: events_batch: list[events_base_event_pb2.RayEvent] = batch.events @@ -89,7 +91,11 @@ async def publish(self, batch: PublishBatch) -> PublishStats: self._executor, lambda: [ json.loads( - message_to_json(e, always_print_fields_with_no_presence=True) + message_to_json( + e, + always_print_fields_with_no_presence=True, + preserving_proto_field_name=self._preserve_proto_field_name, + ) ) for e in filtered ], diff --git a/python/ray/dashboard/modules/aggregator/tests/test_aggregator_agent.py b/python/ray/dashboard/modules/aggregator/tests/test_aggregator_agent.py index 61e793fadd77..81f7bf43aab4 100644 --- a/python/ray/dashboard/modules/aggregator/tests/test_aggregator_agent.py +++ b/python/ray/dashboard/modules/aggregator/tests/test_aggregator_agent.py @@ -1,6 +1,7 @@ import base64 import json import sys +from typing import Optional from unittest.mock import MagicMock import pytest @@ -74,16 +75,42 @@ def fake_timestamp(): return Timestamp(seconds=seconds, nanos=nanos), "2025-06-30T16:50:30.130457542Z" -_with_aggregator_port = pytest.mark.parametrize( - "ray_start_cluster_head_with_env_vars", - [ - { - "env_vars": { - "RAY_DASHBOARD_AGGREGATOR_AGENT_EVENTS_EXPORT_ADDR": _EVENT_AGGREGATOR_AGENT_TARGET_ADDR, +def generate_event_export_env_vars( + preserve_proto_field_name: Optional[bool] = None, additional_env_vars: dict = None +) -> dict: + if additional_env_vars is None: + additional_env_vars = {} + + event_export_env_vars = { + "RAY_DASHBOARD_AGGREGATOR_AGENT_EVENTS_EXPORT_ADDR": _EVENT_AGGREGATOR_AGENT_TARGET_ADDR, + } | additional_env_vars + + if preserve_proto_field_name is not None: + event_export_env_vars[ + "RAY_DASHBOARD_AGGREGATOR_AGENT_PRESERVE_PROTO_FIELD_NAME" + ] = ("1" if preserve_proto_field_name is True else "0") + + return event_export_env_vars + + +def build_export_env_vars_param_list(additional_env_vars: dict = None) -> list: + return [ + pytest.param( + preserve_proto_field_name, + { + "env_vars": generate_event_export_env_vars( + preserve_proto_field_name, additional_env_vars + ) }, - }, - ], - indirect=True, + ) + for preserve_proto_field_name in [True, False] + ] + + +_with_preserve_proto_field_name_flag = pytest.mark.parametrize( + ("preserve_proto_field_name", "ray_start_cluster_head_with_env_vars"), + build_export_env_vars_param_list(), + indirect=["ray_start_cluster_head_with_env_vars"], ) @@ -175,9 +202,12 @@ def test_aggregator_agent_event_processing_disabled( stub.AddEvents(request) -@_with_aggregator_port +@_with_preserve_proto_field_name_flag def test_aggregator_agent_receive_publish_events_normally( - ray_start_cluster_head_with_env_vars, httpserver, fake_timestamp + ray_start_cluster_head_with_env_vars, + httpserver, + fake_timestamp, + preserve_proto_field_name, ): cluster = ray_start_cluster_head_with_env_vars stub = get_event_aggregator_grpc_stub( @@ -211,28 +241,34 @@ def test_aggregator_agent_receive_publish_events_normally( req_json = json.loads(req.data) assert len(req_json) == 1 - assert req_json[0]["eventId"] == base64.b64encode(b"1").decode() - assert req_json[0]["sourceType"] == "CORE_WORKER" - assert req_json[0]["eventType"] == "TASK_DEFINITION_EVENT" + if preserve_proto_field_name: + assert req_json[0]["event_id"] == base64.b64encode(b"1").decode() + assert req_json[0]["source_type"] == "CORE_WORKER" + assert req_json[0]["event_type"] == "TASK_DEFINITION_EVENT" + else: + assert req_json[0]["eventId"] == base64.b64encode(b"1").decode() + assert req_json[0]["sourceType"] == "CORE_WORKER" + assert req_json[0]["eventType"] == "TASK_DEFINITION_EVENT" + assert req_json[0]["severity"] == "INFO" assert req_json[0]["message"] == "hello" assert req_json[0]["timestamp"] == fake_timestamp[1] @pytest.mark.parametrize( - "ray_start_cluster_head_with_env_vars", - [ - { - "env_vars": { - "RAY_DASHBOARD_AGGREGATOR_AGENT_MAX_EVENT_BUFFER_SIZE": 1, - "RAY_DASHBOARD_AGGREGATOR_AGENT_EVENTS_EXPORT_ADDR": _EVENT_AGGREGATOR_AGENT_TARGET_ADDR, - }, - }, - ], - indirect=True, + ("preserve_proto_field_name", "ray_start_cluster_head_with_env_vars"), + build_export_env_vars_param_list( + additional_env_vars={ + "RAY_DASHBOARD_AGGREGATOR_AGENT_MAX_EVENT_BUFFER_SIZE": 1, + } + ), + indirect=["ray_start_cluster_head_with_env_vars"], ) def test_aggregator_agent_receive_event_full( - ray_start_cluster_head_with_env_vars, httpserver, fake_timestamp + ray_start_cluster_head_with_env_vars, + httpserver, + fake_timestamp, + preserve_proto_field_name, ): cluster = ray_start_cluster_head_with_env_vars stub = get_event_aggregator_grpc_stub( @@ -274,12 +310,18 @@ def test_aggregator_agent_receive_event_full( req_json = json.loads(req.data) assert len(req_json) == 1 - assert req_json[0]["eventId"] == base64.b64encode(b"3").decode() + if preserve_proto_field_name: + assert req_json[0]["event_id"] == base64.b64encode(b"3").decode() + else: + assert req_json[0]["eventId"] == base64.b64encode(b"3").decode() -@_with_aggregator_port +@_with_preserve_proto_field_name_flag def test_aggregator_agent_receive_multiple_events( - ray_start_cluster_head_with_env_vars, httpserver, fake_timestamp + ray_start_cluster_head_with_env_vars, + httpserver, + fake_timestamp, + preserve_proto_field_name, ): cluster = ray_start_cluster_head_with_env_vars stub = get_event_aggregator_grpc_stub( @@ -317,26 +359,31 @@ def test_aggregator_agent_receive_multiple_events( req, _ = httpserver.log[0] req_json = json.loads(req.data) assert len(req_json) == 2 - assert req_json[0]["eventId"] == base64.b64encode(b"4").decode() + if preserve_proto_field_name: + assert req_json[0]["event_id"] == base64.b64encode(b"4").decode() + assert req_json[1]["event_id"] == base64.b64encode(b"5").decode() + else: + assert req_json[0]["eventId"] == base64.b64encode(b"4").decode() + assert req_json[1]["eventId"] == base64.b64encode(b"5").decode() + assert req_json[0]["message"] == "event1" - assert req_json[1]["eventId"] == base64.b64encode(b"5").decode() assert req_json[1]["message"] == "event2" @pytest.mark.parametrize( - "ray_start_cluster_head_with_env_vars", - [ - { - "env_vars": { - "RAY_DASHBOARD_AGGREGATOR_AGENT_MAX_EVENT_BUFFER_SIZE": 1, - "RAY_DASHBOARD_AGGREGATOR_AGENT_EVENTS_EXPORT_ADDR": _EVENT_AGGREGATOR_AGENT_TARGET_ADDR, - }, - }, - ], - indirect=True, + ("preserve_proto_field_name", "ray_start_cluster_head_with_env_vars"), + build_export_env_vars_param_list( + additional_env_vars={ + "RAY_DASHBOARD_AGGREGATOR_AGENT_MAX_EVENT_BUFFER_SIZE": 1, + } + ), + indirect=["ray_start_cluster_head_with_env_vars"], ) def test_aggregator_agent_receive_multiple_events_failures( - ray_start_cluster_head_with_env_vars, httpserver, fake_timestamp + ray_start_cluster_head_with_env_vars, + httpserver, + fake_timestamp, + preserve_proto_field_name, ): cluster = ray_start_cluster_head_with_env_vars stub = get_event_aggregator_grpc_stub( @@ -378,12 +425,20 @@ def test_aggregator_agent_receive_multiple_events_failures( req, _ = httpserver.log[0] req_json = json.loads(req.data) assert len(req_json) == 1 - assert req_json[0]["eventId"] == base64.b64encode(b"3").decode() + if preserve_proto_field_name: + assert req_json[0]["event_id"] == base64.b64encode(b"3").decode() + else: + assert req_json[0]["eventId"] == base64.b64encode(b"3").decode() -@_with_aggregator_port +@pytest.mark.parametrize( + "ray_start_cluster_head_with_env_vars", + [{"env_vars": generate_event_export_env_vars()}], + indirect=True, +) def test_aggregator_agent_receive_empty_events( - ray_start_cluster_head_with_env_vars, httpserver + ray_start_cluster_head_with_env_vars, + httpserver, ): cluster = ray_start_cluster_head_with_env_vars stub = get_event_aggregator_grpc_stub( @@ -401,9 +456,12 @@ def test_aggregator_agent_receive_empty_events( stub.AddEvents(request) -@_with_aggregator_port +@_with_preserve_proto_field_name_flag def test_aggregator_agent_profile_events_not_exposed( - ray_start_cluster_head_with_env_vars, httpserver, fake_timestamp + ray_start_cluster_head_with_env_vars, + httpserver, + fake_timestamp, + preserve_proto_field_name, ): """Test that profile events are not sent when not in exposable event types.""" cluster = ray_start_cluster_head_with_env_vars @@ -442,7 +500,10 @@ def test_aggregator_agent_profile_events_not_exposed( assert len(req_json) == 1 assert req_json[0]["message"] == "event1" - assert req_json[0]["eventType"] == "TASK_DEFINITION_EVENT" + if preserve_proto_field_name: + assert req_json[0]["event_type"] == "TASK_DEFINITION_EVENT" + else: + assert req_json[0]["eventType"] == "TASK_DEFINITION_EVENT" def _create_task_definition_event_proto(timestamp): @@ -483,72 +544,140 @@ def _create_task_definition_event_proto(timestamp): ) -def _verify_task_definition_event_json(req_json, expected_timestamp): +def _verify_task_definition_event_json( + req_json, expected_timestamp, preserve_proto_field_name +): assert len(req_json) == 1 - # Verify the base event fields - assert req_json[0]["eventId"] == base64.b64encode(b"1").decode() - assert req_json[0]["sourceType"] == "CORE_WORKER" - assert req_json[0]["eventType"] == "TASK_DEFINITION_EVENT" - assert req_json[0]["timestamp"] == expected_timestamp - assert req_json[0]["severity"] == "INFO" - assert ( - req_json[0]["message"] == "" - ) # Make sure the default value is included when it is not set - assert req_json[0]["sessionName"] == "test_session" - - # Verify the task definition event specific fields - assert ( - req_json[0]["taskDefinitionEvent"]["taskId"] == base64.b64encode(b"1").decode() - ) - assert req_json[0]["taskDefinitionEvent"]["taskAttempt"] == 1 - assert req_json[0]["taskDefinitionEvent"]["taskType"] == "NORMAL_TASK" - assert req_json[0]["taskDefinitionEvent"]["language"] == "PYTHON" - assert ( - req_json[0]["taskDefinitionEvent"]["taskFunc"]["pythonFunctionDescriptor"][ - "moduleName" - ] - == "test_module" - ) - assert ( - req_json[0]["taskDefinitionEvent"]["taskFunc"]["pythonFunctionDescriptor"][ - "className" - ] - == "test_class" - ) - assert ( - req_json[0]["taskDefinitionEvent"]["taskFunc"]["pythonFunctionDescriptor"][ - "functionName" - ] - == "test_function" - ) - assert ( - req_json[0]["taskDefinitionEvent"]["taskFunc"]["pythonFunctionDescriptor"][ - "functionHash" - ] - == "test_hash" - ) - assert req_json[0]["taskDefinitionEvent"]["taskName"] == "test_task" - assert req_json[0]["taskDefinitionEvent"]["requiredResources"] == { - "CPU": 1.0, - "GPU": 0.0, - } - assert req_json[0]["taskDefinitionEvent"]["serializedRuntimeEnv"] == "{}" - assert ( - req_json[0]["taskDefinitionEvent"]["jobId"] == base64.b64encode(b"1").decode() - ) - assert ( - req_json[0]["taskDefinitionEvent"]["parentTaskId"] - == base64.b64encode(b"1").decode() - ) - assert ( - req_json[0]["taskDefinitionEvent"]["placementGroupId"] - == base64.b64encode(b"1").decode() - ) - assert req_json[0]["taskDefinitionEvent"]["refIds"] == { - "key1": base64.b64encode(b"value1").decode(), - "key2": base64.b64encode(b"value2").decode(), - } + if preserve_proto_field_name: + assert req_json[0]["event_id"] == base64.b64encode(b"1").decode() + assert req_json[0]["source_type"] == "CORE_WORKER" + assert req_json[0]["event_type"] == "TASK_DEFINITION_EVENT" + assert req_json[0]["timestamp"] == expected_timestamp + assert req_json[0]["severity"] == "INFO" + assert ( + req_json[0]["message"] == "" + ) # Make sure the default value is included when it is not set + assert req_json[0]["session_name"] == "test_session" + assert ( + req_json[0]["task_definition_event"]["task_id"] + == base64.b64encode(b"1").decode() + ) + assert req_json[0]["task_definition_event"]["task_attempt"] == 1 + assert req_json[0]["task_definition_event"]["task_type"] == "NORMAL_TASK" + assert req_json[0]["task_definition_event"]["language"] == "PYTHON" + assert ( + req_json[0]["task_definition_event"]["task_func"][ + "python_function_descriptor" + ]["module_name"] + == "test_module" + ) + assert ( + req_json[0]["task_definition_event"]["task_func"][ + "python_function_descriptor" + ]["class_name"] + == "test_class" + ) + assert ( + req_json[0]["task_definition_event"]["task_func"][ + "python_function_descriptor" + ]["function_name"] + == "test_function" + ) + assert ( + req_json[0]["task_definition_event"]["task_func"][ + "python_function_descriptor" + ]["function_hash"] + == "test_hash" + ) + assert req_json[0]["task_definition_event"]["task_name"] == "test_task" + assert req_json[0]["task_definition_event"]["required_resources"] == { + "CPU": 1.0, + "GPU": 0.0, + } + assert req_json[0]["task_definition_event"]["serialized_runtime_env"] == "{}" + assert ( + req_json[0]["task_definition_event"]["job_id"] + == base64.b64encode(b"1").decode() + ) + assert ( + req_json[0]["task_definition_event"]["parent_task_id"] + == base64.b64encode(b"1").decode() + ) + assert ( + req_json[0]["task_definition_event"]["placement_group_id"] + == base64.b64encode(b"1").decode() + ) + assert req_json[0]["task_definition_event"]["ref_ids"] == { + "key1": base64.b64encode(b"value1").decode(), + "key2": base64.b64encode(b"value2").decode(), + } + else: + # Verify the base event fields + assert req_json[0]["eventId"] == base64.b64encode(b"1").decode() + assert req_json[0]["sourceType"] == "CORE_WORKER" + assert req_json[0]["eventType"] == "TASK_DEFINITION_EVENT" + assert req_json[0]["timestamp"] == expected_timestamp + assert req_json[0]["severity"] == "INFO" + assert ( + req_json[0]["message"] == "" + ) # Make sure the default value is included when it is not set + assert req_json[0]["sessionName"] == "test_session" + + # Verify the task definition event specific fields + assert ( + req_json[0]["taskDefinitionEvent"]["taskId"] + == base64.b64encode(b"1").decode() + ) + assert req_json[0]["taskDefinitionEvent"]["taskAttempt"] == 1 + assert req_json[0]["taskDefinitionEvent"]["taskType"] == "NORMAL_TASK" + assert req_json[0]["taskDefinitionEvent"]["language"] == "PYTHON" + assert ( + req_json[0]["taskDefinitionEvent"]["taskFunc"]["pythonFunctionDescriptor"][ + "moduleName" + ] + == "test_module" + ) + assert ( + req_json[0]["taskDefinitionEvent"]["taskFunc"]["pythonFunctionDescriptor"][ + "className" + ] + == "test_class" + ) + assert ( + req_json[0]["taskDefinitionEvent"]["taskFunc"]["pythonFunctionDescriptor"][ + "functionName" + ] + == "test_function" + ) + assert ( + req_json[0]["taskDefinitionEvent"]["taskFunc"]["pythonFunctionDescriptor"][ + "functionHash" + ] + == "test_hash" + ) + assert req_json[0]["taskDefinitionEvent"]["taskName"] == "test_task" + assert req_json[0]["taskDefinitionEvent"]["requiredResources"] == { + "CPU": 1.0, + "GPU": 0.0, + } + assert req_json[0]["taskDefinitionEvent"]["serializedRuntimeEnv"] == "{}" + assert ( + req_json[0]["taskDefinitionEvent"]["jobId"] + == base64.b64encode(b"1").decode() + ) + assert ( + req_json[0]["taskDefinitionEvent"]["parentTaskId"] + == base64.b64encode(b"1").decode() + ) + assert ( + req_json[0]["taskDefinitionEvent"]["placementGroupId"] + == base64.b64encode(b"1").decode() + ) + assert req_json[0]["taskDefinitionEvent"]["refIds"] == { + "key1": base64.b64encode(b"value1").decode(), + "key2": base64.b64encode(b"value2").decode(), + } def _create_task_lifecycle_event_proto(timestamp): @@ -578,42 +707,82 @@ def _create_task_lifecycle_event_proto(timestamp): ) -def _verify_task_lifecycle_event_json(req_json, expected_timestamp): +def _verify_task_lifecycle_event_json( + req_json, expected_timestamp, preserve_proto_field_name +): assert len(req_json) == 1 - # Verify the base event fields - assert req_json[0]["eventId"] == base64.b64encode(b"1").decode() - assert req_json[0]["sourceType"] == "CORE_WORKER" - assert req_json[0]["eventType"] == "TASK_LIFECYCLE_EVENT" - assert req_json[0]["timestamp"] == expected_timestamp - assert req_json[0]["severity"] == "INFO" - assert ( - req_json[0]["message"] == "" - ) # Make sure the default value is included when it is not set - assert req_json[0]["sessionName"] == "test_session" - - # Verify the task execution event specific fields - assert ( - req_json[0]["taskLifecycleEvent"]["taskId"] == base64.b64encode(b"1").decode() - ) - assert req_json[0]["taskLifecycleEvent"]["taskAttempt"] == 1 - assert req_json[0]["taskLifecycleEvent"]["stateTransitions"] == [ - { - "state": "RUNNING", - "timestamp": expected_timestamp, - } - ] - assert ( - req_json[0]["taskLifecycleEvent"]["rayErrorInfo"]["errorType"] - == "TASK_EXECUTION_EXCEPTION" - ) - assert ( - req_json[0]["taskLifecycleEvent"]["nodeId"] == base64.b64encode(b"1").decode() - ) - assert ( - req_json[0]["taskLifecycleEvent"]["workerId"] == base64.b64encode(b"1").decode() - ) - assert req_json[0]["taskLifecycleEvent"]["workerPid"] == 1 + if preserve_proto_field_name: + assert req_json[0]["event_id"] == base64.b64encode(b"1").decode() + assert req_json[0]["source_type"] == "CORE_WORKER" + assert req_json[0]["event_type"] == "TASK_LIFECYCLE_EVENT" + assert req_json[0]["timestamp"] == expected_timestamp + assert req_json[0]["severity"] == "INFO" + assert ( + req_json[0]["message"] == "" + ) # Make sure the default value is included when it is not set + assert req_json[0]["session_name"] == "test_session" + assert ( + req_json[0]["task_lifecycle_event"]["task_id"] + == base64.b64encode(b"1").decode() + ) + assert req_json[0]["task_lifecycle_event"]["task_attempt"] == 1 + assert req_json[0]["task_lifecycle_event"]["state_transitions"] == [ + { + "state": "RUNNING", + "timestamp": expected_timestamp, + } + ] + assert ( + req_json[0]["task_lifecycle_event"]["ray_error_info"]["error_type"] + == "TASK_EXECUTION_EXCEPTION" + ) + assert ( + req_json[0]["task_lifecycle_event"]["node_id"] + == base64.b64encode(b"1").decode() + ) + assert ( + req_json[0]["task_lifecycle_event"]["worker_id"] + == base64.b64encode(b"1").decode() + ) + assert req_json[0]["task_lifecycle_event"]["worker_pid"] == 1 + else: + # Verify the base event fields + assert req_json[0]["eventId"] == base64.b64encode(b"1").decode() + assert req_json[0]["sourceType"] == "CORE_WORKER" + assert req_json[0]["eventType"] == "TASK_LIFECYCLE_EVENT" + assert req_json[0]["timestamp"] == expected_timestamp + assert req_json[0]["severity"] == "INFO" + assert ( + req_json[0]["message"] == "" + ) # Make sure the default value is included when it is not set + assert req_json[0]["sessionName"] == "test_session" + + # Verify the task execution event specific fields + assert ( + req_json[0]["taskLifecycleEvent"]["taskId"] + == base64.b64encode(b"1").decode() + ) + assert req_json[0]["taskLifecycleEvent"]["taskAttempt"] == 1 + assert req_json[0]["taskLifecycleEvent"]["stateTransitions"] == [ + { + "state": "RUNNING", + "timestamp": expected_timestamp, + } + ] + assert ( + req_json[0]["taskLifecycleEvent"]["rayErrorInfo"]["errorType"] + == "TASK_EXECUTION_EXCEPTION" + ) + assert ( + req_json[0]["taskLifecycleEvent"]["nodeId"] + == base64.b64encode(b"1").decode() + ) + assert ( + req_json[0]["taskLifecycleEvent"]["workerId"] + == base64.b64encode(b"1").decode() + ) + assert req_json[0]["taskLifecycleEvent"]["workerPid"] == 1 def _create_profile_event_request(timestamp): @@ -647,35 +816,91 @@ def _create_profile_event_request(timestamp): ) -def _verify_profile_event_json(req_json, expected_timestamp): +def _verify_profile_event_json(req_json, expected_timestamp, preserve_proto_field_name): """Helper function to verify profile event JSON structure.""" - assert len(req_json) == 1 - assert req_json[0]["eventId"] == base64.b64encode(b"1").decode() - assert req_json[0]["sourceType"] == "CORE_WORKER" - assert req_json[0]["eventType"] == "TASK_PROFILE_EVENT" - assert req_json[0]["severity"] == "INFO" - assert req_json[0]["message"] == "profile event test" - assert req_json[0]["timestamp"] == expected_timestamp - - # Verify task profile event specific fields - assert "taskProfileEvents" in req_json[0] - task_profile_events = req_json[0]["taskProfileEvents"] - assert task_profile_events["taskId"] == base64.b64encode(b"100").decode() - assert task_profile_events["attemptNumber"] == 3 - assert task_profile_events["jobId"] == base64.b64encode(b"200").decode() - - # Verify profile event specific fields - profile_event = task_profile_events["profileEvents"] - assert profile_event["componentType"] == "worker" - assert profile_event["componentId"] == base64.b64encode(b"worker_123").decode() - assert profile_event["nodeIpAddress"] == "127.0.0.1" - assert len(profile_event["events"]) == 1 - - event_entry = profile_event["events"][0] - assert event_entry["eventName"] == "task_execution" - assert event_entry["startTime"] == "1751302230130000000" - assert event_entry["endTime"] == "1751302230131000000" - assert event_entry["extraData"] == '{"cpu_usage": 0.8}' + + if preserve_proto_field_name: + assert len(req_json) == 1 + assert req_json[0]["event_id"] == base64.b64encode(b"1").decode() + assert req_json[0]["source_type"] == "CORE_WORKER" + assert req_json[0]["event_type"] == "TASK_PROFILE_EVENT" + assert req_json[0]["timestamp"] == expected_timestamp + assert req_json[0]["severity"] == "INFO" + assert req_json[0]["message"] == "profile event test" + assert ( + req_json[0]["task_profile_events"]["task_id"] + == base64.b64encode(b"100").decode() + ) + assert req_json[0]["task_profile_events"]["attempt_number"] == 3 + assert ( + req_json[0]["task_profile_events"]["job_id"] + == base64.b64encode(b"200").decode() + ) + assert ( + req_json[0]["task_profile_events"]["profile_events"]["component_type"] + == "worker" + ) + assert ( + req_json[0]["task_profile_events"]["profile_events"]["component_id"] + == base64.b64encode(b"worker_123").decode() + ) + assert ( + req_json[0]["task_profile_events"]["profile_events"]["node_ip_address"] + == "127.0.0.1" + ) + assert len(req_json[0]["task_profile_events"]["profile_events"]["events"]) == 1 + assert ( + req_json[0]["task_profile_events"]["profile_events"]["events"][0][ + "start_time" + ] + == "1751302230130000000" + ) + assert ( + req_json[0]["task_profile_events"]["profile_events"]["events"][0][ + "end_time" + ] + == "1751302230131000000" + ) + assert ( + req_json[0]["task_profile_events"]["profile_events"]["events"][0][ + "extra_data" + ] + == '{"cpu_usage": 0.8}' + ) + assert ( + req_json[0]["task_profile_events"]["profile_events"]["events"][0][ + "event_name" + ] + == "task_execution" + ) + else: + assert len(req_json) == 1 + assert req_json[0]["eventId"] == base64.b64encode(b"1").decode() + assert req_json[0]["sourceType"] == "CORE_WORKER" + assert req_json[0]["eventType"] == "TASK_PROFILE_EVENT" + assert req_json[0]["severity"] == "INFO" + assert req_json[0]["message"] == "profile event test" + assert req_json[0]["timestamp"] == expected_timestamp + + # Verify task profile event specific fields + assert "taskProfileEvents" in req_json[0] + task_profile_events = req_json[0]["taskProfileEvents"] + assert task_profile_events["taskId"] == base64.b64encode(b"100").decode() + assert task_profile_events["attemptNumber"] == 3 + assert task_profile_events["jobId"] == base64.b64encode(b"200").decode() + + # Verify profile event specific fields + profile_event = task_profile_events["profileEvents"] + assert profile_event["componentType"] == "worker" + assert profile_event["componentId"] == base64.b64encode(b"worker_123").decode() + assert profile_event["nodeIpAddress"] == "127.0.0.1" + assert len(profile_event["events"]) == 1 + + event_entry = profile_event["events"][0] + assert event_entry["eventName"] == "task_execution" + assert event_entry["startTime"] == "1751302230130000000" + assert event_entry["endTime"] == "1751302230131000000" + assert event_entry["extraData"] == '{"cpu_usage": 0.8}' # tuple: (create_event, verify) @@ -698,16 +923,13 @@ def _verify_profile_event_json(req_json, expected_timestamp): @pytest.mark.parametrize("create_event, verify_event", EVENT_TYPES_TO_TEST) @pytest.mark.parametrize( - "ray_start_cluster_head_with_env_vars", - [ - { - "env_vars": { - "RAY_DASHBOARD_AGGREGATOR_AGENT_EVENTS_EXPORT_ADDR": _EVENT_AGGREGATOR_AGENT_TARGET_ADDR, - "RAY_DASHBOARD_AGGREGATOR_AGENT_EXPOSABLE_EVENT_TYPES": "TASK_DEFINITION_EVENT,TASK_LIFECYCLE_EVENT,ACTOR_TASK_DEFINITION_EVENT,TASK_PROFILE_EVENT", - }, - }, - ], - indirect=True, + ("preserve_proto_field_name", "ray_start_cluster_head_with_env_vars"), + build_export_env_vars_param_list( + additional_env_vars={ + "RAY_DASHBOARD_AGGREGATOR_AGENT_EXPOSABLE_EVENT_TYPES": "TASK_DEFINITION_EVENT,TASK_LIFECYCLE_EVENT,ACTOR_TASK_DEFINITION_EVENT,TASK_PROFILE_EVENT", + } + ), + indirect=["ray_start_cluster_head_with_env_vars"], ) def test_aggregator_agent_receive_events( create_event, @@ -715,6 +937,7 @@ def test_aggregator_agent_receive_events( ray_start_cluster_head_with_env_vars, httpserver, fake_timestamp, + preserve_proto_field_name, ): cluster = ray_start_cluster_head_with_env_vars stub = get_event_aggregator_grpc_stub( @@ -734,12 +957,14 @@ def test_aggregator_agent_receive_events( wait_for_condition(lambda: len(httpserver.log) == 1) req, _ = httpserver.log[0] req_json = json.loads(req.data) - verify_event(req_json, fake_timestamp[1]) + verify_event(req_json, fake_timestamp[1], preserve_proto_field_name) -@_with_aggregator_port +@_with_preserve_proto_field_name_flag def test_aggregator_agent_receive_driver_job_definition_event( - ray_start_cluster_head_with_env_vars, httpserver + ray_start_cluster_head_with_env_vars, + httpserver, + preserve_proto_field_name, ): cluster = ray_start_cluster_head_with_env_vars stub = get_event_aggregator_grpc_stub( @@ -778,15 +1003,25 @@ def test_aggregator_agent_receive_driver_job_definition_event( req, _ = httpserver.log[0] req_json = json.loads(req.data) assert req_json[0]["message"] == "driver job event" - assert ( - req_json[0]["driverJobDefinitionEvent"]["config"]["serializedRuntimeEnv"] - == "{}" - ) + if preserve_proto_field_name: + assert ( + req_json[0]["driver_job_definition_event"]["config"][ + "serialized_runtime_env" + ] + == "{}" + ) + else: + assert ( + req_json[0]["driverJobDefinitionEvent"]["config"]["serializedRuntimeEnv"] + == "{}" + ) -@_with_aggregator_port +@_with_preserve_proto_field_name_flag def test_aggregator_agent_receive_driver_job_lifecycle_event( - ray_start_cluster_head_with_env_vars, httpserver + ray_start_cluster_head_with_env_vars, + httpserver, + preserve_proto_field_name, ): cluster = ray_start_cluster_head_with_env_vars stub = get_event_aggregator_grpc_stub( @@ -831,29 +1066,45 @@ def test_aggregator_agent_receive_driver_job_lifecycle_event( req, _ = httpserver.log[0] req_json = json.loads(req.data) assert req_json[0]["message"] == "driver job lifecycle event" - assert ( - req_json[0]["driverJobLifecycleEvent"]["jobId"] - == base64.b64encode(b"1").decode() - ) - assert len(req_json[0]["driverJobLifecycleEvent"]["stateTransitions"]) == 2 - assert ( - req_json[0]["driverJobLifecycleEvent"]["stateTransitions"][0]["state"] - == "CREATED" - ) - assert ( - req_json[0]["driverJobLifecycleEvent"]["stateTransitions"][1]["state"] - == "FINISHED" - ) + if preserve_proto_field_name: + assert ( + req_json[0]["driver_job_lifecycle_event"]["job_id"] + == base64.b64encode(b"1").decode() + ) + assert len(req_json[0]["driver_job_lifecycle_event"]["state_transitions"]) == 2 + assert ( + req_json[0]["driver_job_lifecycle_event"]["state_transitions"][0]["state"] + == "CREATED" + ) + assert ( + req_json[0]["driver_job_lifecycle_event"]["state_transitions"][1]["state"] + == "FINISHED" + ) + else: + assert ( + req_json[0]["driverJobLifecycleEvent"]["jobId"] + == base64.b64encode(b"1").decode() + ) + assert len(req_json[0]["driverJobLifecycleEvent"]["stateTransitions"]) == 2 + assert ( + req_json[0]["driverJobLifecycleEvent"]["stateTransitions"][0]["state"] + == "CREATED" + ) + assert ( + req_json[0]["driverJobLifecycleEvent"]["stateTransitions"][1]["state"] + == "FINISHED" + ) @pytest.mark.parametrize( "ray_start_cluster_head_with_env_vars", [ { - "env_vars": { - "RAY_DASHBOARD_AGGREGATOR_AGENT_PUBLISH_EVENTS_TO_EXTERNAL_HTTP_SERVICE": "False", - "RAY_DASHBOARD_AGGREGATOR_AGENT_EVENTS_EXPORT_ADDR": _EVENT_AGGREGATOR_AGENT_TARGET_ADDR, - }, + "env_vars": generate_event_export_env_vars( + additional_env_vars={ + "RAY_DASHBOARD_AGGREGATOR_AGENT_PUBLISH_EVENTS_TO_EXTERNAL_HTTP_SERVICE": "False", + } + ) }, ], indirect=True,