diff --git a/python/ray/_private/node.py b/python/ray/_private/node.py index 93f30fbc0c6d..65ddb6b73e41 100644 --- a/python/ray/_private/node.py +++ b/python/ray/_private/node.py @@ -300,6 +300,8 @@ def __init__( self._raylet_socket_name = self._prepare_socket_file( self._ray_params.raylet_socket_name, default_prefix="raylet" ) + # Set node labels from RayParams or environment override variables. + self._node_labels = self._get_node_labels() if ( self._ray_params.env_vars is not None and "RAY_OVERRIDE_NODE_ID_FOR_TESTING" in self._ray_params.env_vars @@ -375,12 +377,17 @@ def __init__( "could happen because some of the Ray processes " "failed to startup." ) from te - node_info = ray._private.services.get_node( - self.gcs_address, - self._node_id, - ) - if self._ray_params.node_manager_port == 0: - self._ray_params.node_manager_port = node_info["node_manager_port"] + + # Fetch node info to update port or get labels. + node_info = ray._private.services.get_node( + self.gcs_address, + self._node_id, + ) + if not connect_only and self._ray_params.node_manager_port == 0: + self._ray_params.node_manager_port = node_info["node_manager_port"] + elif connect_only: + # Set node labels from GCS if provided at node init. + self._node_labels = node_info.get("labels", {}) # Makes sure the Node object has valid addresses after setup. self.validate_ip_port(self.address) @@ -741,6 +748,11 @@ def address_info(self): "dashboard_agent_listen_port": self.dashboard_agent_listen_port, } + @property + def node_labels(self): + """Get the node labels.""" + return self._node_labels + def is_head(self): return self.head @@ -1323,7 +1335,7 @@ def start_raylet( env_updates=self._ray_params.env_vars, node_name=self._ray_params.node_name, webui=self._webui_url, - labels=self._get_node_labels(), + labels=self.node_labels, resource_isolation_config=self.resource_isolation_config, ) assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 6b9f2e51f349..f505ae268d96 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -593,6 +593,11 @@ def job_logging_config(self): logging_config = pickle.loads(job_config.serialized_py_logging_config) return logging_config + @property + def current_node_labels(self): + # Return the node labels of this worker's current node. + return self.node.node_labels + def set_debugger_port(self, port): worker_id = self.core_worker.get_worker_id() ray._private.state.update_worker_debugger_port(worker_id, port) diff --git a/python/ray/includes/global_state_accessor.pxi b/python/ray/includes/global_state_accessor.pxi index 1ee717553c9f..1a78529e8bac 100644 --- a/python/ray/includes/global_state_accessor.pxi +++ b/python/ray/includes/global_state_accessor.pxi @@ -295,9 +295,11 @@ cdef class GlobalStateAccessor: if not status.ok(): raise RuntimeError(status.message()) c_node_info.ParseFromString(cnode_info_str) + c_labels = PythonGetNodeLabels(c_node_info) return { "object_store_socket_name": c_node_info.object_store_socket_name().decode(), "raylet_socket_name": c_node_info.raylet_socket_name().decode(), "node_manager_port": c_node_info.node_manager_port(), "node_id": c_node_info.node_id().hex(), + "labels": {key.decode(): value.decode() for key, value in c_labels}, } diff --git a/python/ray/runtime_context.py b/python/ray/runtime_context.py index cba283914e35..567b7ac2c7e7 100644 --- a/python/ray/runtime_context.py +++ b/python/ray/runtime_context.py @@ -533,6 +533,18 @@ def get_accelerator_ids(self) -> Dict[str, List[str]]: ids_dict[accelerator_resource_name] = [str(id) for id in accelerator_ids] return ids_dict + def get_node_labels(self) -> Dict[str, List[str]]: + """ + Get the node labels of the current worker. + + Returns: + A dictionary of label key-value pairs. + """ + worker = self.worker + worker.check_connected() + + return worker.current_node_labels + _runtime_context = None _runtime_context_lock = threading.Lock() diff --git a/python/ray/tests/test_runtime_context.py b/python/ray/tests/test_runtime_context.py index 6c8eb2f7da6a..c6bb29ca80b9 100644 --- a/python/ray/tests/test_runtime_context.py +++ b/python/ray/tests/test_runtime_context.py @@ -416,5 +416,43 @@ def verify(): wait_for_condition(verify) +def test_get_node_labels(ray_start_cluster_head): + cluster = ray_start_cluster_head + cluster.add_node( + resources={"worker1": 1}, + num_cpus=1, + labels={ + "accelerator-type": "A100", + "region": "us-west4", + "market-type": "spot", + }, + ) + # ray.init(address=cluster.address) + + @ray.remote + class Actor: + def get_node_id(self): + return ray.get_runtime_context().get_node_id() + + def get_node_labels(self): + return ray.get_runtime_context().get_node_labels() + + expected_node_labels = { + "accelerator-type": "A100", + "region": "us-west4", + "market-type": "spot", + } + + # Check node labels from Actor runtime context + a = Actor.options(label_selector={"accelerator-type": "A100"}).remote() + node_labels = ray.get(a.get_node_labels.remote()) + expected_node_labels["ray.io/node_id"] = ray.get(a.get_node_id.remote()) + assert expected_node_labels == node_labels + + # Check node labels from driver runtime context (none are set except default) + driver_labels = ray.get_runtime_context().get_node_labels() + assert {"ray.io/node_id": ray.get_runtime_context().get_node_id()} == driver_labels + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__]))