Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 82 additions & 80 deletions python/ray/_private/state.py

Large diffs are not rendered by default.

10 changes: 4 additions & 6 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2157,7 +2157,7 @@ def custom_excepthook(type, value, tb):
worker_type = common_pb2.DRIVER
worker_info = {"exception": error_message}

ray._private.state.state._check_connected()
ray._private.state.state._connect_and_get_accessor()
ray._private.state.state.add_worker(worker_id, worker_type, worker_info)
# Call the normal excepthook.
normal_excepthook(type, value, tb)
Expand Down Expand Up @@ -3042,13 +3042,11 @@ def put(
if _owner is None:
serialize_owner_address = None
elif isinstance(_owner, ray.actor.ActorHandle):
# Ensure `ray._private.state.state.global_state_accessor` is not None
ray._private.state.state._check_connected()
# Ensure GlobalState is connected
ray._private.state.state._connect_and_get_accessor()
serialize_owner_address = (
ray._raylet._get_actor_serialized_owner_address_or_none(
ray._private.state.state.global_state_accessor.get_actor_info(
_owner._actor_id
)
ray._private.state.state.get_actor_info(_owner._actor_id)
)
)
if not serialize_owner_address:
Expand Down
4 changes: 0 additions & 4 deletions python/ray/includes/global_state_accessor.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ cdef class GlobalStateAccessor:
result = self.inner.get().Connect()
return result

def disconnect(self):
with nogil:
self.inner.get().Disconnect()

def get_job_table(
self, *, skip_submission_job_info_field=False, skip_is_running_tasks_field=False
):
Expand Down
4 changes: 0 additions & 4 deletions python/ray/tests/test_actor_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,8 +933,6 @@ class Actor2:
== 2
)

global_state_accessor.disconnect()


def test_kill_pending_actor_with_no_restart_true():
cluster = ray.init()
Expand Down Expand Up @@ -964,7 +962,6 @@ def condition1():
# Actor is dead, so the infeasible task queue length is 0.
wait_for_condition(condition1, timeout=10)

global_state_accessor.disconnect()
ray.shutdown()


Expand Down Expand Up @@ -1072,7 +1069,6 @@ def condition2():

wait_for_condition(condition2, timeout=10)

global_state_accessor.disconnect()
ray.shutdown()


Expand Down
12 changes: 3 additions & 9 deletions python/ray/tests/test_actor_lineage_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def pid(self):

def verify1():
gc.collect()
actor_info = ray._private.state.state.global_state_accessor.get_actor_info(
actor_id
)
actor_info = ray._private.state.state.get_actor_info(actor_id)
assert actor_info is not None
actor_info = gcs_pb2.ActorTableData.FromString(actor_info)
assert actor_info.state == gcs_pb2.ActorTableData.ActorState.DEAD
Expand All @@ -81,9 +79,7 @@ def verify1():
assert ray.get(obj2) == [1] * 1024 * 1024

def verify2():
actor_info = ray._private.state.state.global_state_accessor.get_actor_info(
actor_id
)
actor_info = ray._private.state.state.get_actor_info(actor_id)
assert actor_info is not None
actor_info = gcs_pb2.ActorTableData.FromString(actor_info)
assert actor_info.state == gcs_pb2.ActorTableData.ActorState.DEAD
Expand All @@ -102,9 +98,7 @@ def verify2():
del obj2

def verify3():
actor_info = ray._private.state.state.global_state_accessor.get_actor_info(
actor_id
)
actor_info = ray._private.state.state.get_actor_info(actor_id)
assert actor_info is not None
actor_info = gcs_pb2.ActorTableData.FromString(actor_info)
assert actor_info.state == gcs_pb2.ActorTableData.ActorState.DEAD
Expand Down
4 changes: 1 addition & 3 deletions python/ray/tests/test_core_worker_fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,7 @@ def ping(self):
del actor

def verify_actor_ref_deleted():
actor_info = ray._private.state.state.global_state_accessor.get_actor_info(
actor_id
)
actor_info = ray._private.state.state.get_actor_info(actor_id)
if actor_info is None:
return False
actor_info = gcs_pb2.ActorTableData.FromString(actor_info)
Expand Down
6 changes: 0 additions & 6 deletions python/ray/tests/test_global_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def test_node_name_cluster(ray_start_cluster):
else:
assert node["NodeName"] == "worker_node"

global_state_accessor.disconnect()
ray.shutdown()
cluster.shutdown()

Expand Down Expand Up @@ -317,7 +316,6 @@ def check_load_report(self):
else:
assert demand.num_ready_requests_queued > 0
assert demand.num_infeasible_requests_queued == 0
global_state_accessor.disconnect()


def test_placement_group_load_report(ray_start_cluster):
Expand Down Expand Up @@ -383,7 +381,6 @@ def _read_resource_usage(self):
_, unready = ray.wait([pg_infeasible_second.ready()], timeout=0)
assert len(unready) == 1
wait_for_condition(checker.two_infeasible_pg)
global_state_accessor.disconnect()


def test_backlog_report(shutdown_only):
Expand Down Expand Up @@ -428,7 +425,6 @@ def backlog_size_set():
# request is sent to the raylet with backlog=7

wait_for_condition(backlog_size_set, timeout=2)
global_state_accessor.disconnect()


def test_default_load_reports(shutdown_only):
Expand Down Expand Up @@ -472,7 +468,6 @@ def actor_and_task_queued_together():
ref = foo.remote()

wait_for_condition(actor_and_task_queued_together, timeout=2)
global_state_accessor.disconnect()

# Do something with the variables so lint is happy.
del handle
Expand All @@ -494,7 +489,6 @@ def self_ip_is_set():
return resources_data.node_manager_address == self_ip

wait_for_condition(self_ip_is_set, timeout=2)
global_state_accessor.disconnect()


def test_next_job_id(ray_start_regular):
Expand Down
32 changes: 32 additions & 0 deletions python/ray/tests/test_state_api_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import tempfile
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

import pytest
Expand Down Expand Up @@ -375,5 +376,36 @@ def f():
assert len(dumped) > 0


def test_state_init_multiple_threads(shutdown_only):
ray.init()
global_state = ray._private.state.state
global_state._connect_and_get_accessor()
gcs_options = global_state.gcs_options

def disconnect():
global_state.disconnect()
global_state._initialize_global_state(gcs_options)
return True

def get_nodes_from_state_api():
try:
return len(global_state.node_table()) == 1
except ray.exceptions.RaySystemError:
# There's a gap between disconnect and _initialize_global_state
# and this will be raised if we try to connect during that gap
return True

disconnect()
with ThreadPoolExecutor(max_workers=50) as executor:
futures = [executor.submit(get_nodes_from_state_api) for _ in range(50)]
futures.extend([executor.submit(disconnect) for _ in range(50)])
futures.extend([executor.submit(get_nodes_from_state_api) for _ in range(50)])
results = [future.result() for future in futures]

# Assert that all calls returned True
assert all(results)
assert len(results) == 150


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))