Skip to content
Merged
26 changes: 19 additions & 7 deletions python/ray/_private/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ def __init__(
self._raylet_socket_name = self._prepare_socket_file(
self._ray_params.raylet_socket_name, default_prefix="raylet"
)
# Set node labels from RayParams or environment override variables.
self._node_labels = self._get_node_labels()
if (
self._ray_params.env_vars is not None
and "RAY_OVERRIDE_NODE_ID_FOR_TESTING" in self._ray_params.env_vars
Expand Down Expand Up @@ -375,12 +377,17 @@ def __init__(
"could happen because some of the Ray processes "
"failed to startup."
) from te
node_info = ray._private.services.get_node(
self.gcs_address,
self._node_id,
)
if self._ray_params.node_manager_port == 0:
self._ray_params.node_manager_port = node_info["node_manager_port"]

# Fetch node info to update port or get labels.
node_info = ray._private.services.get_node(
self.gcs_address,
self._node_id,
)
if not connect_only and self._ray_params.node_manager_port == 0:
self._ray_params.node_manager_port = node_info["node_manager_port"]
elif connect_only:
# Set node labels from GCS if provided at node init.
self._node_labels = node_info.get("labels", {})

# Makes sure the Node object has valid addresses after setup.
self.validate_ip_port(self.address)
Expand Down Expand Up @@ -741,6 +748,11 @@ def address_info(self):
"dashboard_agent_listen_port": self.dashboard_agent_listen_port,
}

@property
def node_labels(self):
"""Get the node labels."""
return self._node_labels

def is_head(self):
return self.head

Expand Down Expand Up @@ -1323,7 +1335,7 @@ def start_raylet(
env_updates=self._ray_params.env_vars,
node_name=self._ray_params.node_name,
webui=self._webui_url,
labels=self._get_node_labels(),
labels=self.node_labels,
resource_isolation_config=self.resource_isolation_config,
)
assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes
Expand Down
5 changes: 5 additions & 0 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,11 @@ def job_logging_config(self):
logging_config = pickle.loads(job_config.serialized_py_logging_config)
return logging_config

@property
def current_node_labels(self):
# Return the node labels of this worker's current node.
return self.node.node_labels

def set_debugger_port(self, port):
worker_id = self.core_worker.get_worker_id()
ray._private.state.update_worker_debugger_port(worker_id, port)
Expand Down
2 changes: 2 additions & 0 deletions python/ray/includes/global_state_accessor.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,11 @@ cdef class GlobalStateAccessor:
if not status.ok():
raise RuntimeError(status.message())
c_node_info.ParseFromString(cnode_info_str)
c_labels = PythonGetNodeLabels(c_node_info)
return {
"object_store_socket_name": c_node_info.object_store_socket_name().decode(),
"raylet_socket_name": c_node_info.raylet_socket_name().decode(),
"node_manager_port": c_node_info.node_manager_port(),
"node_id": c_node_info.node_id().hex(),
"labels": {key.decode(): value.decode() for key, value in c_labels},
}
12 changes: 12 additions & 0 deletions python/ray/runtime_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,18 @@ def get_accelerator_ids(self) -> Dict[str, List[str]]:
ids_dict[accelerator_resource_name] = [str(id) for id in accelerator_ids]
return ids_dict

def get_node_labels(self) -> Dict[str, List[str]]:
"""
Get the node labels of the current worker.

Returns:
A dictionary of label key-value pairs.
"""
worker = self.worker
worker.check_connected()

return worker.current_node_labels


_runtime_context = None
_runtime_context_lock = threading.Lock()
Expand Down
38 changes: 38 additions & 0 deletions python/ray/tests/test_runtime_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,5 +416,43 @@ def verify():
wait_for_condition(verify)


def test_get_node_labels(ray_start_cluster_head):
cluster = ray_start_cluster_head
cluster.add_node(
resources={"worker1": 1},
num_cpus=1,
labels={
"accelerator-type": "A100",
"region": "us-west4",
"market-type": "spot",
},
)
# ray.init(address=cluster.address)

@ray.remote
class Actor:
def get_node_id(self):
return ray.get_runtime_context().get_node_id()

def get_node_labels(self):
return ray.get_runtime_context().get_node_labels()

expected_node_labels = {
"accelerator-type": "A100",
"region": "us-west4",
"market-type": "spot",
}

# Check node labels from Actor runtime context
a = Actor.options(label_selector={"accelerator-type": "A100"}).remote()
node_labels = ray.get(a.get_node_labels.remote())
expected_node_labels["ray.io/node_id"] = ray.get(a.get_node_id.remote())
assert expected_node_labels == node_labels

# Check node labels from driver runtime context (none are set except default)
driver_labels = ray.get_runtime_context().get_node_labels()
assert {"ray.io/node_id": ray.get_runtime_context().get_node_id()} == driver_labels


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))