Skip to content

Commit dceb419

Browse files
author
zhaorong.bd
committed
add npu modify
1 parent 8cddc2e commit dceb419

File tree

2 files changed

+236
-13
lines changed

2 files changed

+236
-13
lines changed

native_sparse_attention/ops/parallel.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def parallel_nsa_compression_bwd_kernel_dq(
162162
BS: tl.constexpr, # 序列块大小
163163
BK: tl.constexpr, # Key块大小
164164
BV: tl.constexpr, # Value块大小
165-
USE_OFFSETS: tl.constexpr, # 是否使用偏移量的标志
165+
USE_OFFSETS: tl.constexpr # 是否使用偏移量的标志
166166
):
167167
# 获取程序ID,用于并行计算
168168
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
@@ -881,12 +881,12 @@ def parallel_nsa_compression_fwd(
881881
H = k.shape[2]
882882
G = HQ // H
883883
BC = BS = block_size
884-
# if torch.cuda.get_device_capability()[0] >= 9:
885-
# BK = min(256, triton.next_power_of_2(K))
886-
# BV = min(256, triton.next_power_of_2(V))
887-
# else:
888-
BK = min(128, triton.next_power_of_2(K))
889-
BV = min(128, triton.next_power_of_2(V))
884+
BKV_VALUE=256
885+
if torch.cuda.is_available():
886+
if torch.cuda.get_device_capability()[0] < 9:
887+
BKV_VALUE=128
888+
BK = min(BKV_VALUE, triton.next_power_of_2(K))
889+
BV = min(BKV_VALUE, triton.next_power_of_2(V))
890890
NK = triton.cdiv(K, BK)
891891
NV = triton.cdiv(V, BV)
892892
assert NK == 1, "The key dimension can not be larger than 256"
@@ -1130,12 +1130,12 @@ def parallel_nsa_fwd(
11301130
HQ = q.shape[2]
11311131
G = HQ // H
11321132
BS = block_size
1133-
if torch.cuda.get_device_capability()[0] >= 9:
1134-
BK = min(256, triton.next_power_of_2(K))
1135-
BV = min(256, triton.next_power_of_2(V))
1136-
else:
1137-
BK = min(128, triton.next_power_of_2(K))
1138-
BV = min(128, triton.next_power_of_2(V))
1133+
BKV_VALUE=256
1134+
if torch.cuda.is_available():
1135+
if torch.cuda.get_device_capability()[0] < 9:
1136+
BKV_VALUE=128
1137+
BK = min(BKV_VALUE, triton.next_power_of_2(K))
1138+
BV = min(BKV_VALUE, triton.next_power_of_2(V))
11391139
NK = triton.cdiv(K, BK)
11401140
NV = triton.cdiv(V, BV)
11411141
assert NK == 1, "The key dimension can not be larger than 256"

tests/test_nsa_npu.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import os
4+
5+
import pytest
6+
import torch
7+
import triton
8+
import torch_npu
9+
from fla.ops.common.utils import prepare_token_indices
10+
from native_sparse_attention.ops.naive import naive_nsa
11+
from native_sparse_attention.ops.parallel import parallel_nsa
12+
13+
14+
def get_abs_err(x, y):
15+
return (x-y).flatten().abs().max().item()
16+
17+
18+
def get_err_ratio(x, y):
19+
err = (x-y).flatten().square().mean().sqrt().item()
20+
base = (x).flatten().square().mean().sqrt().item()
21+
return err / base
22+
23+
24+
def assert_close(prefix, ref, tri, ratio):
25+
msg = f"{prefix} diff: {get_abs_err(ref, tri):.6f} ratio: {get_err_ratio(ref, tri):.6f}"
26+
print(msg)
27+
assert get_err_ratio(ref, tri) < ratio, msg
28+
29+
30+
@pytest.mark.parametrize("B", [1])
31+
@pytest.mark.parametrize("T", [256, 1024, 2000])
32+
@pytest.mark.parametrize("H", [4])
33+
@pytest.mark.parametrize("HQ", [64])
34+
@pytest.mark.parametrize("D", [100, 64])
35+
@pytest.mark.parametrize("S", [16])
36+
@pytest.mark.parametrize("block_size", [32])
37+
@pytest.mark.parametrize("window_size", [0, 32])
38+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
39+
@pytest.mark.parametrize("scale", [0.1])
40+
def test_parallel(
41+
B: int,
42+
H: int,
43+
HQ: int,
44+
T: int,
45+
D: int,
46+
S: int,
47+
block_size: int,
48+
window_size: int,
49+
dtype: torch.dtype,
50+
scale: float
51+
):
52+
torch.manual_seed(42)
53+
os.environ['TRITON_F32_DEFAULT'] = 'ieee'
54+
55+
perm_q = torch.randperm(T, device='npu')
56+
perm_k = torch.randperm(T, device='npu')
57+
perm_v = torch.randperm(T, device='npu')
58+
q = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_q].view(1, T, 1, 1).expand(B, T, HQ, D).clone().requires_grad_(True)
59+
k = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_k].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True)
60+
v = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_v].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True)
61+
g_cmp = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True)
62+
g_slc = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True)
63+
g_swa = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True)
64+
do = torch.randn((B, T, HQ, D), dtype=dtype, device='npu')
65+
66+
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='npu')
67+
for b in range(B):
68+
for t in range(T):
69+
for h in range(H):
70+
i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]
71+
block_indices[b, t, h, :len(i_i)] = i_i
72+
block_indices = block_indices.sort(-1)[0]
73+
block_counts = torch.randint(1, S + 1, (B, T, H), dtype=torch.long, device='npu')
74+
75+
ref = naive_nsa(
76+
q=q,
77+
k=k,
78+
v=v,
79+
g_slc=g_slc,
80+
g_swa=g_swa,
81+
block_indices=block_indices,
82+
block_counts=block_counts,
83+
block_size=block_size,
84+
window_size=window_size,
85+
scale=scale
86+
)
87+
ref.backward(do)
88+
ref_dq, q.grad = q.grad.clone(), None
89+
ref_dk, k.grad = k.grad.clone(), None
90+
ref_dv, v.grad = v.grad.clone(), None
91+
ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None
92+
if window_size > 0:
93+
ref_dg_swa, g_swa.grad = g_swa.grad.clone(), None
94+
95+
tri = parallel_nsa(
96+
q=q,
97+
k=k,
98+
v=v,
99+
g_cmp=g_cmp,
100+
g_slc=g_slc,
101+
g_swa=g_swa,
102+
block_indices=block_indices,
103+
block_counts=block_counts,
104+
block_size=block_size,
105+
window_size=window_size,
106+
scale=scale
107+
)
108+
tri.backward(do)
109+
tri_dq, q.grad = q.grad.clone(), None
110+
tri_dk, k.grad = k.grad.clone(), None
111+
tri_dv, v.grad = v.grad.clone(), None
112+
tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None
113+
if window_size > 0:
114+
tri_dg_swa, g_swa.grad = g_swa.grad.clone(), None
115+
116+
assert_close(" o", ref, tri, 0.005)
117+
assert_close("dq", ref_dq, tri_dq, 0.005)
118+
assert_close("dk", ref_dk, tri_dk, 0.005)
119+
assert_close("dv", ref_dv, tri_dv, 0.005)
120+
assert_close("dg_slc", ref_dg_slc, tri_dg_slc, 0.005)
121+
if window_size > 0:
122+
assert_close("dg_swa", ref_dg_swa, tri_dg_swa, 0.005)
123+
124+
125+
@pytest.mark.parametrize("N", [4])
126+
@pytest.mark.parametrize("T", [64, 128, 200, 250, 256, 300, 400, 512, 1000, 2048])
127+
@pytest.mark.parametrize("H", [4])
128+
@pytest.mark.parametrize("HQ", [64])
129+
@pytest.mark.parametrize("D", [100, 64])
130+
@pytest.mark.parametrize("S", [16])
131+
@pytest.mark.parametrize("block_size", [32])
132+
@pytest.mark.parametrize("window_size", [0, 32])
133+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
134+
def test_parallel_varlen(
135+
N: int,
136+
T: int,
137+
H: int,
138+
HQ: int,
139+
D: int,
140+
S: int,
141+
block_size: int,
142+
window_size: int,
143+
dtype: torch.dtype,
144+
):
145+
torch.manual_seed(42)
146+
os.environ['TRITON_F32_DEFAULT'] = 'ieee'
147+
148+
# randomly split the sequence into N segments
149+
offsets = torch.cat([
150+
torch.tensor([0], dtype=torch.long),
151+
torch.arange(16, T)[torch.randperm(T - 1)[:N-1]],
152+
torch.tensor([T], dtype=torch.long)
153+
], 0).npu().sort()[0]
154+
# seq-first required for inputs with variable lengths
155+
perm_q = torch.randperm(T, device='npu')
156+
perm_k = torch.randperm(T, device='npu')
157+
perm_v = torch.randperm(T, device='npu')
158+
q = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True)
159+
k = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
160+
v = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
161+
g_cmp = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True)
162+
g_slc = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True)
163+
g_swa = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True)
164+
do = torch.randn((1, T, HQ, D), dtype=dtype, device='npu')
165+
166+
token_indices = prepare_token_indices(offsets).tolist()
167+
block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device='npu')
168+
for i in range(T):
169+
_, t = token_indices[i]
170+
for h in range(H):
171+
i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]
172+
block_indices[0, i, h, :len(i_i)] = i_i
173+
block_indices = block_indices.sort(-1)[0]
174+
block_counts = torch.randint(1, S + 1, (1, T, H), device='npu')
175+
176+
ref = naive_nsa(
177+
q=q,
178+
k=k,
179+
v=v,
180+
g_slc=g_slc,
181+
g_swa=g_swa,
182+
block_indices=block_indices,
183+
block_counts=block_counts,
184+
block_size=block_size,
185+
window_size=window_size,
186+
cu_seqlens=offsets
187+
)
188+
ref.backward(do)
189+
ref_dq, q.grad = q.grad.clone(), None
190+
ref_dk, k.grad = k.grad.clone(), None
191+
ref_dv, v.grad = v.grad.clone(), None
192+
ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None
193+
if window_size > 0:
194+
ref_dg_swa, g_swa.grad = g_swa.grad.clone(), None
195+
196+
tri = parallel_nsa(
197+
q=q,
198+
k=k,
199+
v=v,
200+
g_cmp=g_cmp,
201+
g_slc=g_slc,
202+
g_swa=g_swa,
203+
block_indices=block_indices,
204+
block_counts=block_counts,
205+
block_size=block_size,
206+
window_size=window_size,
207+
cu_seqlens=offsets
208+
)
209+
tri.backward(do)
210+
tri_dq, q.grad = q.grad.clone(), None
211+
tri_dk, k.grad = k.grad.clone(), None
212+
tri_dv, v.grad = v.grad.clone(), None
213+
tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None
214+
if window_size > 0:
215+
tri_dg_swa, g_swa.grad = g_swa.grad.clone(), None
216+
217+
assert_close(" o", ref, tri, 0.004)
218+
assert_close("dq", ref_dq, tri_dq, 0.005)
219+
assert_close("dk", ref_dk, tri_dk, 0.005)
220+
assert_close("dv", ref_dv, tri_dv, 0.005)
221+
assert_close("dg_slc", ref_dg_slc, tri_dg_slc, 0.005)
222+
if window_size > 0:
223+
assert_close("dg_swa", ref_dg_swa, tri_dg_swa, 0.005)

0 commit comments

Comments
 (0)