Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,7 @@ def _init_aiter_shared_experts_topK_buffer(
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
* dp_size,
is_EP=self.use_ep,
device=current_platform.device_type,
)
self.local_num_experts += self.num_fused_shared_experts

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def init_aiter_topK_meta_data(
shared_experts_score: float = 1.0,
max_num_tokens: int = 32768,
is_EP: bool = False,
device: str = "cuda",
):
global aiter_topK_meta_data
fake_expertid = n_routed_experts + n_shared_experts
Expand All @@ -64,7 +65,7 @@ def init_aiter_topK_meta_data(
total_topk_ids = torch.empty(
(max_num_tokens, top_k + n_shared_experts + is_EP),
dtype=torch.int32,
device="cuda",
device=device,
)
ns_topk_ids, s_topk_ids = total_topk_ids.split(
[top_k, n_shared_experts + is_EP], dim=1
Expand All @@ -80,12 +81,12 @@ def init_aiter_topK_meta_data(
s_topk_ids_list = [
list(range(n_routed_experts, fake_expertid))
] * max_num_tokens
s_topk_ids[:] = torch.tensor(s_topk_ids_list, dtype=torch.int32, device="cuda")
s_topk_ids[:] = torch.tensor(s_topk_ids_list, dtype=torch.int32, device=device)

total_topk_weights = torch.empty(
(max_num_tokens, top_k + n_shared_experts + is_EP),
dtype=torch.float32,
device="cuda",
device=device,
)
ns_topk_weights, s_topk_weights = total_topk_weights.split(
[top_k, n_shared_experts + is_EP], dim=1
Expand Down
10 changes: 5 additions & 5 deletions vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def fp8_mqa_logits_torch(
q = q.to(torch.bfloat16)

mask_lo = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
torch.arange(0, seq_len_kv, device=q.device)[None, :] >= cu_seqlen_ks[:, None]
)
mask_hi = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
torch.arange(0, seq_len_kv, device=q.device)[None, :] < cu_seqlen_ke[:, None]
)
mask = mask_lo & mask_hi

Expand Down Expand Up @@ -124,15 +124,15 @@ def fp8_paged_mqa_logits_torch(
context_lens = context_lens.tolist()
for i in range(batch_size):
context_len = context_lens[i]
q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
q_offsets = torch.arange(context_len - next_n, context_len, device=q.device)
weight_slice = (
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
)
for block_rk in range(cdiv(context_len, block_size)):
block_idx = block_tables[i][block_rk]
qx, kx = q[i], kv_cache[block_idx]
k_offsets = torch.arange(
block_rk * block_size, (block_rk + 1) * block_size, device="cuda"
block_rk * block_size, (block_rk + 1) * block_size, device=q.device
)
mask = (k_offsets[None, :] < context_len) & (
k_offsets[None, :] <= q_offsets[:, None]
Expand Down Expand Up @@ -191,7 +191,7 @@ def rocm_fp8_paged_mqa_logits(
out_qk = torch.full(
(heads, batch_size * next_n, max_model_len),
float("-inf"),
device="cuda",
device=q_fp8.device,
dtype=torch.float32,
)
deepgemm_fp8_paged_mqa_logits_stage1(
Expand Down