Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def sampling_config():
def model_name():
return "wemaster/deepseek_mtp_main_random_bf16"


@pytest.skip("exist OOM error")
def mtp_torchair_correctness(
sampling_config: SamplingParams,
model_name: str,
Expand Down
14 changes: 11 additions & 3 deletions vllm_ascend/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from functools import partial
from typing import Callable, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -143,14 +144,21 @@ def cal_cos_sin(self, rotary_pos_emb):
def forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
grid_thw: torch.Tensor | list[list[int]],
) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype)
hidden_states = self.patch_embed(hidden_states)

pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
if isinstance(grid_thw, list):
grid_thw_list = grid_thw
grid_thw = np.array(grid_thw, dtype=np.int32)
else:
grid_thw_list = grid_thw.tolist()
grid_thw = grid_thw.numpy()

pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
rotary_pos_emb = self.rot_pos_emb(grid_thw_list)
grid_thw_tensor = torch.tensor(grid_thw,
device=self.device,
dtype=torch.int32)
Comment on lines +152 to 164
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for handling grid_thw when it is a torch.Tensor is inefficient. It converts the tensor to a NumPy array using .numpy() (which can cause a GPU-to-CPU data transfer) and then converts it back to a tensor using torch.tensor(). This can be optimized by handling the list and torch.Tensor cases separately to avoid unnecessary conversions.

        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
            grid_thw_tensor = torch.tensor(grid_thw,
                                           device=self.device,
                                           dtype=torch.int32)
        else:
            grid_thw_list = grid_thw.tolist()
            grid_thw_tensor = grid_thw.to(device=self.device, dtype=torch.int32)

        pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
        hidden_states = hidden_states + pos_embeds
        rotary_pos_emb = self.rot_pos_emb(grid_thw_list)

Expand Down
Loading