Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, npu_prefetch,
npu_stream_switch, npu_wait_tensor)
from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -711,12 +710,6 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
self.W_UV = W_UV.transpose(0, 1).contiguous()
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
if get_ascend_config().enable_weight_nz_layout:
# cast quantized weight tensors in NZ layout for higher inference speed
self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data,
ACL_FORMAT_FRACTAL_NZ)
self.W_UK_T.data = torch_npu.npu_format_cast(
self.W_UK_T.data, ACL_FORMAT_FRACTAL_NZ)

def _compute_prefill_context(
self,
Expand Down