diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index bcd313d22d7..4fbfadce903 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -285,7 +285,8 @@ def fused_experts(hidden_states: torch.Tensor, valid_token_mask = torch.arange( 0, sorted_token_indices.shape[0], device=device).unsqueeze(1) < num_valid_tokens - down_out_list.mul_(valid_token_mask) + down_out_list = down_out_list.masked_fill_(~valid_token_mask, + 0).to(dtype) final_hidden_states.index_add_(0, sorted_token_indices, down_out_list) else: # TODO: Reorder device memory 2 times here, replace the current