Skip to content

Commit

Permalink
[aDAG] Rename variables in CompiledDAG (#47290)
Browse files Browse the repository at this point in the history
Rename the variable based on what it represents. This makes the code easier to follow.
  • Loading branch information
ruisearch42 authored Aug 23, 2024
1 parent c8baeb2 commit fd84b9d
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 30 deletions.
47 changes: 24 additions & 23 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def __init__(self, idx: int, dag_node: "ray.dag.DAGNode"):
self.idx = idx
self.dag_node = dag_node

self.downstream_node_idxs: Dict[int, "ray.actor.ActorHandle"] = {}
# Dict from task index to actor handle for immediate downstream tasks.
self.downstream_actor_idxs: Dict[int, "ray.actor.ActorHandle"] = {}
self.output_channel = None
self.arg_type_hints: List["ChannelOutputType"] = []

Expand All @@ -166,7 +167,7 @@ def kwargs(self) -> Dict[str, Any]:

@property
def num_readers(self) -> int:
return len(self.downstream_node_idxs)
return len(self.downstream_actor_idxs)

def __str__(self) -> str:
return f"""
Expand Down Expand Up @@ -731,7 +732,7 @@ def _preprocess(self) -> None:

# For each task node, set its upstream and downstream task nodes.
# Also collect the set of tasks that produce torch.tensors.
for node_idx, task in self.idx_to_task.items():
for task_idx, task in self.idx_to_task.items():
dag_node = task.dag_node
if not (
isinstance(dag_node, InputNode)
Expand Down Expand Up @@ -790,19 +791,19 @@ def _preprocess(self) -> None:
continue

upstream_node_idx = self.dag_node_to_idx[arg]
upstream_node = self.idx_to_task[upstream_node_idx]
upstream_task = self.idx_to_task[upstream_node_idx]
downstream_actor_handle = None
if isinstance(dag_node, ClassMethodNode):
downstream_actor_handle = dag_node._get_actor_handle()

if isinstance(upstream_node.dag_node, InputAttributeNode):
if isinstance(upstream_task.dag_node, InputAttributeNode):
# Record all of the keys used to index the InputNode.
# During execution, we will check that the user provides
# the same args and kwargs.
if isinstance(upstream_node.dag_node.key, int):
input_positional_args.add(upstream_node.dag_node.key)
elif isinstance(upstream_node.dag_node.key, str):
input_kwargs.add(upstream_node.dag_node.key)
if isinstance(upstream_task.dag_node.key, int):
input_positional_args.add(upstream_task.dag_node.key)
elif isinstance(upstream_task.dag_node.key, str):
input_kwargs.add(upstream_task.dag_node.key)
else:
raise ValueError(
"InputNode() can only be indexed using int "
Expand All @@ -819,17 +820,17 @@ def _preprocess(self) -> None:

# If the upstream node is an InputAttributeNode, treat the
# DAG's input node as the actual upstream node
upstream_node = self.idx_to_task[self.input_task_idx]
upstream_task = self.idx_to_task[self.input_task_idx]

elif isinstance(upstream_node.dag_node, InputNode):
elif isinstance(upstream_task.dag_node, InputNode):
if direct_input is not None and not direct_input:
raise ValueError(
"All tasks must either use InputNode() directly, "
"or they must index to specific args or kwargs."
)
direct_input = True

elif isinstance(upstream_node.dag_node, ClassMethodNode):
elif isinstance(upstream_task.dag_node, ClassMethodNode):
from ray.dag.constants import RAY_ADAG_ENABLE_DETECT_DEADLOCK

if (
Expand All @@ -841,23 +842,23 @@ def _preprocess(self) -> None:
not RAY_ADAG_ENABLE_DETECT_DEADLOCK
and downstream_actor_handle is not None
and downstream_actor_handle
== upstream_node.dag_node._get_actor_handle()
and upstream_node.dag_node.type_hint.requires_nccl()
== upstream_task.dag_node._get_actor_handle()
and upstream_task.dag_node.type_hint.requires_nccl()
):
raise ValueError(
"Compiled DAG does not support NCCL communication between "
"methods on the same actor. NCCL type hint is specified "
"for the channel from method "
f"{upstream_node.dag_node.get_method_name()} to method "
f"{upstream_task.dag_node.get_method_name()} to method "
f"{dag_node.get_method_name()} on actor "
f"{downstream_actor_handle}. Please remove the NCCL "
"type hint between these methods."
)

upstream_node.downstream_node_idxs[node_idx] = downstream_actor_handle
task.arg_type_hints.append(upstream_node.dag_node.type_hint)
upstream_task.downstream_actor_idxs[task_idx] = downstream_actor_handle
task.arg_type_hints.append(upstream_task.dag_node.type_hint)

if upstream_node.dag_node.type_hint.requires_nccl():
if upstream_task.dag_node.type_hint.requires_nccl():
# Add all readers to the NCCL group.
nccl_actors.add(downstream_actor_handle)

Expand All @@ -870,7 +871,7 @@ def _preprocess(self) -> None:
task.dag_node, InputAttributeNode
):
continue
if len(task.downstream_node_idxs) == 0:
if len(task.downstream_actor_idxs) == 0:
assert self.output_task_idx is None, "More than one output node found"
self.output_task_idx = idx

Expand Down Expand Up @@ -974,7 +975,7 @@ def _get_or_compile(
if isinstance(task.dag_node, ClassMethodNode):
# `readers` is the nodes that are ordered after the current one (`task`)
# in the DAG.
readers = [self.idx_to_task[idx] for idx in task.downstream_node_idxs]
readers = [self.idx_to_task[idx] for idx in task.downstream_actor_idxs]
reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]] = []
dag_nodes = [reader.dag_node for reader in readers]
read_by_multi_output_node = False
Expand Down Expand Up @@ -1046,7 +1047,7 @@ def _get_or_compile(
# when we support multiple readers for both shared memory channel
# and IntraProcessChannel.
reader_handles_set = set()
for idx in task.downstream_node_idxs:
for idx in task.downstream_actor_idxs:
reader_task = self.idx_to_task[idx]
assert isinstance(reader_task.dag_node, ClassMethodNode)
reader_handle = reader_task.dag_node._get_actor_handle()
Expand All @@ -1065,7 +1066,7 @@ def _get_or_compile(
task.dag_node, MultiOutputNode
)

for idx in task.downstream_node_idxs:
for idx in task.downstream_actor_idxs:
frontier.append(idx)

# Validate input channels for tasks that have not been visited
Expand Down Expand Up @@ -1416,7 +1417,7 @@ def _is_same_actor(idx1: int, idx2: int) -> bool:
# on the same actor.
next_task_idx = _get_next_task_idx(task)
_add_edge(graph, idx, next_task_idx)
for downstream_idx in task.downstream_node_idxs:
for downstream_idx in task.downstream_actor_idxs:
# Add an edge from the writer to the reader.
_add_edge(graph, idx, downstream_idx)
if task.dag_node.type_hint.requires_nccl():
Expand Down
6 changes: 3 additions & 3 deletions python/ray/dag/dag_node_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,13 @@ def _build_dag_node_operation_graph(
# The edge from the InputNode has no impact on the final execution
# schedule.
continue
for downstream_dag_idx in task.downstream_node_idxs:
downstream_dag_node = idx_to_task[downstream_dag_idx].dag_node
for downstream_actor_idx in task.downstream_actor_idxs:
downstream_dag_node = idx_to_task[downstream_actor_idx].dag_node
if isinstance(downstream_dag_node, MultiOutputNode):
continue
_add_edge(
graph[dag_idx][_DAGNodeOperationType.WRITE],
graph[downstream_dag_idx][_DAGNodeOperationType.READ],
graph[downstream_actor_idx][_DAGNodeOperationType.READ],
)
return graph

Expand Down
8 changes: 4 additions & 4 deletions python/ray/dag/tests/experimental/test_execution_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def test_edge_between_writer_and_reader(self, monkeypatch):
2: CompiledTask(2, ClassMethodNode()),
3: CompiledTask(3, MultiOutputNode()),
}
idx_to_task[1].downstream_node_idxs = {2: fake_actor_2}
idx_to_task[1].downstream_actor_idxs = {2: fake_actor_2}

actor_to_operation_nodes = {
fake_actor_1: [
Expand Down Expand Up @@ -402,7 +402,7 @@ def test_edge_between_compute_nodes(self, monkeypatch):
dag_idx_2: CompiledTask(dag_idx_2, ClassMethodNode()),
3: CompiledTask(3, MultiOutputNode()),
}
idx_to_task[dag_idx_1].downstream_node_idxs = {dag_idx_2: fake_actor}
idx_to_task[dag_idx_1].downstream_actor_idxs = {dag_idx_2: fake_actor}

actor_to_operation_nodes = {
fake_actor: [
Expand Down Expand Up @@ -450,8 +450,8 @@ def test_two_actors(self, monkeypatch):
dag_idx_4: CompiledTask(dag_idx_4, ClassMethodNode()),
5: CompiledTask(5, MultiOutputNode()),
}
idx_to_task[dag_idx_1].downstream_node_idxs = {dag_idx_4: fake_actor_2}
idx_to_task[dag_idx_2].downstream_node_idxs = {dag_idx_3: fake_actor_1}
idx_to_task[dag_idx_1].downstream_actor_idxs = {dag_idx_4: fake_actor_2}
idx_to_task[dag_idx_2].downstream_actor_idxs = {dag_idx_3: fake_actor_1}

actor_to_operation_nodes = {
fake_actor_1: [
Expand Down

0 comments on commit fd84b9d

Please sign in to comment.