Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions vllm_ascend/eplb/eplb_updator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
from vllm_ascend.eplb.utils import moe_load_async_stream
from vllm_ascend.utils import npu_stream_switch


class EplbUpdator:
Expand Down Expand Up @@ -153,21 +155,22 @@ def compute_and_set_moe_load(self, is_clear=False):

self._gather_buffer = None
if dist.is_initialized():
self.world_size = dist.get_world_size()
self.device = local_load.device
if self._gather_buffer is None:
shape = (self.world_size, *local_load.shape)
self._gather_buffer = torch.empty(shape,
dtype=local_load.dtype,
device=self.device)

dist.all_gather_into_tensor(self._gather_buffer, local_load)

moe_load = self._gather_buffer.permute(1, 0, 2)
self.shared_dict["moe_load"] = moe_load.cpu()
logger.debug(
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
)
with npu_stream_switch(moe_load_async_stream()):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe better to set moe_load_async_stream as class attribute of EplbUpdator

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already move this function to eplb module, since other file would also call this stream, so move to eplb utils is better

self.world_size = dist.get_world_size()
self.device = local_load.device
if self._gather_buffer is None:
shape = (self.world_size, *local_load.shape)
self._gather_buffer = torch.empty(shape,
dtype=local_load.dtype,
device=self.device)

dist.all_gather_into_tensor(self._gather_buffer, local_load)

moe_load = self._gather_buffer.permute(1, 0, 2)
self.shared_dict["moe_load"] = moe_load.cpu()
logger.debug(
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
)
else:
moe_load = local_load.unsqueeze(1)
self.shared_dict["moe_load"] = moe_load.cpu()
Expand Down
12 changes: 12 additions & 0 deletions vllm_ascend/eplb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import types

import torch
import torch_npu

_MOE_LOAD_ASYNC_STREAM = None


def get_expert_map(self, layer_id):
Expand Down Expand Up @@ -75,3 +78,12 @@ def model_register(model, model_config):
model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers
else:
raise NotImplementedError("EPLB is not supported.")


def moe_load_async_stream() -> torch_npu.npu.Stream:
global _MOE_LOAD_ASYNC_STREAM
if _MOE_LOAD_ASYNC_STREAM is None:
# when this function is called before any stream is set,
# we return the default stream.
_MOE_LOAD_ASYNC_STREAM = torch_npu.npu.Stream()
return _MOE_LOAD_ASYNC_STREAM
12 changes: 10 additions & 2 deletions vllm_ascend/ops/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map
from vllm_ascend.eplb.utils import moe_load_async_stream
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
Expand Down Expand Up @@ -368,8 +369,15 @@ def forward_impl(self, hidden_states: torch.Tensor,
if isinstance(final_hidden_states, tuple):
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
if self.dynamic_eplb:
self.moe_load += expert_tokens if group_list_type == 1 else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])

moe_load_stream = moe_load_async_stream()
cur_stream = torch.npu.current_stream()

moe_load_stream.wait_stream(cur_stream)
with npu_stream_switch(moe_load_stream):
self.moe_load += expert_tokens if group_list_type == 1 else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
cur_stream.wait_stream(moe_load_stream)

final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=final_hidden_states,
Expand Down
Loading