Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

IS_NEOX_STYLE = [True, False]
DTYPES = [torch.bfloat16, torch.float16]
MAX_POSITION_EMBEDDINGS = [262144]
HEAD_SIZES = [64, 128]
ROTARY_DIMS = [32, 64]
NUM_Q_HEADS = [64]
Expand Down Expand Up @@ -139,3 +140,83 @@ def test_rotary_embedding_triton_kernel(
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()



@pytest.mark.parametrize("max_position_embeddings", MAX_POSITION_EMBEDDINGS)
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads", NUM_Q_HEADS)
@pytest.mark.parametrize("num_k_heads", NUM_K_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_rotary_embedding_triton_kernel_with_cos_sin_cache(
max_position_embeddings: int,
is_neox_style: bool,
num_tokens: int,
num_q_heads: int,
num_k_heads: int,
head_size: int,
rotary_dim: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()
if rotary_dim == -1:
rotary_dim = head_size
cos_sin_cache = torch.randn(max_position_embeddings, rotary_dim, dtype=dtype, device=device)
positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device)
q_trt = torch.randn(num_tokens,
num_q_heads,
head_size,
dtype=dtype,
device=device)
k_trt = torch.randn(num_tokens,
num_k_heads,
head_size,
dtype=dtype,
device=device)
q_gold = torch.randn(num_tokens,
num_q_heads,
head_size,
dtype=dtype,
device=device)
k_gold = torch.randn(num_tokens,
num_k_heads,
head_size,
dtype=dtype,
device=device)
q_trt.copy_(q_gold)
k_trt.copy_(k_gold)
q_trt, k_trt = rope_forward_triton(q_trt,
k_trt,
cos_sin_cache=cos_sin_cache,
positions=positions,
rope_dim=rotary_dim,
is_neox_style=is_neox_style)
cos, sin = cos_sin_cache.index_select(0, positions).chunk(2, dim=-1)
q_gold, k_gold = _rope_pytorch_native(q_gold,
k_gold,
cos,
sin,
rope_dim=rotary_dim,
is_neox_style=is_neox_style)
# Compare the results.
torch.testing.assert_close(q_trt.view(q_gold.size()),
q_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(k_trt.view(k_gold.size()),
k_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import vllm_ascend.ops.register_custom_ops # noqa
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton

MAX_POSITION_EMBEDDINGS = [262144]
NUM_TOKENS = [1, 4, 8, 16, 1024]
NUM_QKV_HEADS = [(12, 1), (16, 1), (32, 4), (64, 4)]
HEAD_SIZES = [128]
Expand Down Expand Up @@ -55,6 +56,7 @@ def rms_norm(
return out


@pytest.mark.parametrize("max_position_embeddings", MAX_POSITION_EMBEDDINGS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
Expand All @@ -63,7 +65,7 @@ def rms_norm(
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads,
def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads, num_kv_heads,
head_size, eps, dtype, seed, device):
torch.manual_seed(seed)
torch.set_default_device(device)
Expand All @@ -77,12 +79,10 @@ def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads,
device=device)
q_weight = torch.randn(head_size, dtype=dtype, device=device)
k_weight = torch.randn(head_size, dtype=dtype, device=device)
sin = torch.from_numpy(
cos_sin_cache = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
cos = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
[max_position_embeddings, head_size])).to(dtype).npu()
positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device)
# fused kernel
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
q_weight=q_weight,
Expand All @@ -91,8 +91,12 @@ def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads,
kv_hidden_size=kv_hidden_size,
head_dim=head_size,
eps=eps,
cos=cos,
sin=sin)
cos_sin_cache=cos_sin_cache,
positions=positions)

cos, sin = cos_sin_cache.index_select(0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)

# split
_q, _k, v_gold = qkv.cpu().split(
Expand Down Expand Up @@ -129,6 +133,7 @@ def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads,
torch.npu.reset_peak_memory_stats()


@pytest.mark.parametrize("max_position_embeddings", MAX_POSITION_EMBEDDINGS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
Expand All @@ -137,7 +142,7 @@ def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads,
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_split_qkv_rmsnorm_rope_with_bias(num_tokens, num_q_heads,
def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, num_q_heads,
num_kv_heads, head_size, eps, dtype,
seed, device):
torch.manual_seed(seed)
Expand All @@ -154,12 +159,10 @@ def test_split_qkv_rmsnorm_rope_with_bias(num_tokens, num_q_heads,
k_weight = torch.randn(head_size, dtype=dtype, device=device)
q_bias = torch.randn(head_size, dtype=dtype, device=device)
k_bias = torch.randn(head_size, dtype=dtype, device=device)
sin = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
cos = torch.from_numpy(
cos_sin_cache = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
[max_position_embeddings, head_size])).to(dtype).npu()
positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device)
# fused kernel
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
q_weight=q_weight,
Expand All @@ -170,8 +173,12 @@ def test_split_qkv_rmsnorm_rope_with_bias(num_tokens, num_q_heads,
eps=eps,
q_bias=q_bias,
k_bias=k_bias,
cos=cos,
sin=sin)
cos_sin_cache=cos_sin_cache,
positions=positions)

cos, sin = cos_sin_cache.index_select(0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)

# split
_q, _k, v_gold = qkv.cpu().split(
Expand Down
20 changes: 9 additions & 11 deletions tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self.q_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
self.k_weight = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))

def forward(self, qkv, cos, sin):
def forward(self, qkv, cos_sin_cache, positions):
"""
Args:
qkv: [T, q_size + 2*kv_size]
Expand All @@ -82,13 +82,12 @@ def forward(self, qkv, cos, sin):

# Reshape for RoPE: [T, num_heads, head_dim] -> [1, T, num_heads, head_dim]
q_flat = q_norm_out.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)

k_flat = k_norm_out.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)

# Apply RoPE
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
)

return q_rope, k_rope, v

Expand Down Expand Up @@ -116,7 +115,7 @@ def __init__(
self.q_bias = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))
self.k_bias = nn.Parameter(torch.randn(head_dim, dtype=dtype, device=device))

def forward(self, qkv, cos, sin):
def forward(self, qkv, cos_sin_cache, positions):
# Split QKV
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

Expand All @@ -132,13 +131,12 @@ def forward(self, qkv, cos, sin):

# Reshape for RoPE
q_flat = q_normed.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)

k_flat = k_normed.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)

# Apply RoPE
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
)

return q_rope, k_rope, v

Expand All @@ -147,7 +145,7 @@ def assert_qknorm_rope_fusion(after_gm, expect_fused=True, use_bias=False):
check_rules = [
(torch.ops.vllm.qkv_rmsnorm_rope.default, expect_fused),
(torch.ops.npu.npu_rms_norm.default, not expect_fused),
(torch.ops.npu.npu_apply_rotary_pos_emb.default, not expect_fused),
(torch.ops.vllm.npu_rotary_embedding.default, not expect_fused),
]
if use_bias:
check_rules.append((torch.ops.aten.add.Tensor, not expect_fused))
Expand Down
16 changes: 8 additions & 8 deletions tests/e2e/singlecard/test_aclgraph_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
model="Qwen/Qwen3-0.6B",
prompts=PROMPTS_SHORT,
golden_answers=[
" Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I want to know if there are any",
" Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I'm looking for a job in the",
' the same as the president of the United Nations. This is because the president of the United States is the same as the president of the United Nations. The president',
' Paris. The capital of France is also the capital of the Republic of France. The capital of France is also the capital of the European Union. The capital of',
' not just a technological frontier but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and'
' not just a technological challenge but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and'
],
)

Expand All @@ -48,10 +48,10 @@
model="Qwen/Qwen3-0.6B",
prompts=PROMPTS_SHORT,
golden_answers=[
" Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I want to know if there are any",
" Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I'm looking for a job in the",
' the same as the president of the United Nations. This is because the president of the United States is the same as the president of the United Nations. The president',
' Paris. The capital of France is also the capital of the Republic of France. The capital of France is also the capital of the European Union. The capital of',
' not just a technological frontier but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and'
' not just a technological challenge but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and'
],
)

Expand All @@ -72,8 +72,8 @@
prompts=PROMPTS_LONG,
golden_answers=[
' \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the',
" \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area",
' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations $x^2 +'
" \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over",
' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can'
])

CASE_DS_FULL_DECODE_ONLY = LLMTestCase(
Expand All @@ -91,8 +91,8 @@
prompts=PROMPTS_LONG,
golden_answers=[
' \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the',
" \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area",
' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations $x^2 +'
" \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over",
' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can'
])

CASE_DS_EX = LLMTestCase(model="vllm-ascend/DeepSeek-V2-Lite-W8A8",
Expand Down
2 changes: 0 additions & 2 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,6 @@ def set_ascend_forward_context(

_mc2_tokens_capacity: int | None = None
_reserved_mc2_mask: torch.Tensor | None = None
_sin: torch.Tensor | None = None
_cos: torch.Tensor | None = None


def set_mc2_tokens_capacity(vllm_config, max_num_reqs, uniform_decode_query_len):
Expand Down
Loading