Skip to content

Commit 1e304d8

Browse files
authored
Merge pull request vllm-project#20 from vllm-model-0920/wye-deepgemm-integration
[Feature] DeepGEMM integration
2 parents fa13a8b + 93eade0 commit 1e304d8

File tree

2 files changed

+419
-2
lines changed

2 files changed

+419
-2
lines changed
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
import random
2+
import pytest
3+
import torch
4+
5+
from vllm.platforms import current_platform
6+
from vllm.utils import has_deep_gemm, cdiv
7+
from vllm.utils.deep_gemm import (
8+
_ceil_to_ue8m0,
9+
fp8_mqa_logits,
10+
calc_diff,
11+
get_paged_mqa_logits_metadata,
12+
fp8_paged_mqa_logits,
13+
)
14+
15+
16+
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
17+
# x: (num_blocks, block_size, 1, head_dim)
18+
num_blocks, block_size, num_heads, head_dim = x.shape
19+
assert num_heads == 1
20+
x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4)
21+
sf = x_amax / 448.0
22+
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
23+
x_fp8 = torch.empty(
24+
(num_blocks, block_size * (head_dim + 4)),
25+
device=x.device,
26+
dtype=torch.uint8,
27+
)
28+
x_fp8[:, : block_size * head_dim] = x_scaled.view(
29+
num_blocks, block_size * head_dim
30+
).view(dtype=torch.uint8)
31+
x_fp8[:, block_size * head_dim :] = sf.view(num_blocks, block_size).view(
32+
dtype=torch.uint8
33+
)
34+
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
35+
36+
37+
def per_custom_dims_cast_to_fp8(
38+
x: torch.Tensor, dims: tuple, use_ue8m0: bool
39+
) -> tuple[torch.Tensor, torch.Tensor]:
40+
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
41+
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
42+
sf = x_amax / 448.0
43+
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
44+
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
45+
return x_scaled, sf.squeeze()
46+
47+
48+
def _generate_cp_test_data(seq_len: int, seq_len_kv: int):
49+
assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0
50+
chunk_size = seq_len // 2
51+
cp_size = seq_len_kv // seq_len
52+
cp_id = cp_size // 3
53+
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
54+
ke = torch.zeros(seq_len, dtype=torch.int, device="cuda")
55+
for i in range(chunk_size):
56+
ke[i] = cp_id * chunk_size + i
57+
ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i
58+
return ks, ke
59+
60+
61+
def _ref_fp8_mqa_logits(
62+
q: torch.Tensor,
63+
kv: torch.Tensor,
64+
weights: torch.Tensor,
65+
cu_seqlen_ks: torch.Tensor,
66+
cu_seqlen_ke: torch.Tensor,
67+
):
68+
seq_len_kv = kv.shape[0]
69+
70+
k = kv
71+
q = q.float()
72+
k = k.float()
73+
74+
mask_lo = (
75+
torch.arange(0, seq_len_kv, device="cuda")[None, :]
76+
>= cu_seqlen_ks[:, None]
77+
)
78+
mask_hi = (
79+
torch.arange(0, seq_len_kv, device="cuda")[None, :]
80+
< cu_seqlen_ke[:, None]
81+
)
82+
mask = mask_lo & mask_hi
83+
84+
score = torch.einsum("mhd,nd->hmn", q, k)
85+
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
86+
logits = logits.masked_fill(~mask, float("-inf"))
87+
88+
return logits
89+
90+
91+
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
92+
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
93+
def test_deepgemm_fp8_mqa_logits():
94+
torch.manual_seed(0)
95+
random.seed(0)
96+
num_heads, head_dim = 32, 128
97+
for seq_len in (512,):
98+
for seq_len_kv in (1024,):
99+
for disable_cp in (False, True):
100+
q = torch.randn(
101+
seq_len,
102+
num_heads,
103+
head_dim,
104+
device="cuda",
105+
dtype=torch.bfloat16,
106+
)
107+
kv = torch.randn(
108+
seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16
109+
)
110+
weights = torch.randn(
111+
seq_len, num_heads, device="cuda", dtype=torch.float32
112+
)
113+
114+
if disable_cp:
115+
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
116+
ke = torch.arange(
117+
seq_len, dtype=torch.int, device="cuda"
118+
) + (seq_len_kv - seq_len)
119+
else:
120+
ks, ke = _generate_cp_test_data(seq_len, seq_len_kv)
121+
122+
q_fp8 = q.to(torch.float8_e4m3fn)
123+
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False)
124+
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
125+
126+
ref_logits = _ref_fp8_mqa_logits(
127+
q=q,
128+
kv=kv,
129+
weights=weights,
130+
cu_seqlen_ks=ks,
131+
cu_seqlen_ke=ke,
132+
)
133+
134+
ref_neginf_mask = ref_logits == float("-inf")
135+
neginf_mask = logits == float("-inf")
136+
assert torch.equal(neginf_mask, ref_neginf_mask)
137+
138+
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
139+
logits = logits.masked_fill(neginf_mask, 0)
140+
diff = calc_diff(logits, ref_logits)
141+
assert diff < 1e-3, f"{diff=}"
142+
143+
144+
def _ref_fp8_paged_mqa_logits(
145+
q: torch.Tensor,
146+
kv_cache: torch.Tensor,
147+
weights: torch.Tensor,
148+
context_lens: torch.Tensor,
149+
block_tables: torch.Tensor,
150+
max_model_len: int,
151+
):
152+
batch_size, next_n, _, _ = q.size()
153+
_, block_size, _, _ = kv_cache.size()
154+
logits = torch.full(
155+
[batch_size * next_n, max_model_len],
156+
float("-inf"),
157+
device=q.device,
158+
dtype=torch.float32,
159+
)
160+
context_lens_list = context_lens.tolist()
161+
for i in range(batch_size):
162+
context_len = context_lens_list[i]
163+
q_offsets = torch.arange(
164+
context_len - next_n, context_len, device="cuda"
165+
)
166+
weight_slice = (
167+
weights[i * next_n : (i + 1) * next_n, :]
168+
.transpose(0, 1)
169+
.contiguous()
170+
)
171+
for block_rk in range(cdiv(context_len, block_size)):
172+
block_idx = block_tables[i][block_rk]
173+
qx, kx = q[i], kv_cache[block_idx]
174+
k_offsets = torch.arange(
175+
block_rk * block_size,
176+
(block_rk + 1) * block_size,
177+
device="cuda",
178+
)
179+
mask = (k_offsets[None, :] < context_len) & (
180+
k_offsets[None, :] <= q_offsets[:, None]
181+
)
182+
s = torch.where(
183+
mask[None, :, :],
184+
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
185+
logits.dtype
186+
),
187+
float("-inf"),
188+
)
189+
s = torch.relu(s) * weight_slice[..., None]
190+
s = s.sum(dim=0)
191+
logits[
192+
i * next_n : (i + 1) * next_n,
193+
block_rk * block_size : (block_rk + 1) * block_size,
194+
] = torch.where(
195+
k_offsets[None, :] <= q_offsets[:, None], s, float("-inf")
196+
)
197+
return logits
198+
199+
200+
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
201+
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
202+
def test_deepgemm_fp8_paged_mqa_logits():
203+
torch.manual_seed(0)
204+
random.seed(0)
205+
206+
max_model_len = 4096
207+
for batch_size, next_n in [(4, 1), (2, 2)]:
208+
for heads, index_dim in [(16, 128)]:
209+
for avg_kv in (2048,):
210+
num_blocks, blocksize = max_model_len * 2, 64
211+
212+
q = torch.randn(
213+
(batch_size, next_n, heads, index_dim),
214+
device="cuda",
215+
dtype=torch.bfloat16,
216+
)
217+
kv_cache = torch.randn(
218+
(num_blocks, blocksize, 1, index_dim),
219+
device="cuda",
220+
dtype=torch.bfloat16,
221+
)
222+
weights = torch.randn(
223+
(batch_size * next_n, heads),
224+
device="cuda",
225+
dtype=torch.float32,
226+
)
227+
228+
context_lens = (
229+
torch.randint(
230+
int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,)
231+
)
232+
.cuda()
233+
.to(torch.int32)
234+
)
235+
max_block_len = (
236+
(context_lens.max().item() + blocksize - 1)
237+
// blocksize
238+
* blocksize
239+
)
240+
block_tables = torch.zeros(
241+
(batch_size, max_block_len),
242+
device="cuda",
243+
dtype=torch.int32,
244+
)
245+
246+
counter = 0
247+
block_idx_pool = list(range(num_blocks))
248+
random.shuffle(block_idx_pool)
249+
for i in range(batch_size):
250+
ctx_len = int(context_lens[i].item())
251+
for j in range((ctx_len + blocksize - 1) // blocksize):
252+
block_tables[i][j] = block_idx_pool[counter]
253+
counter += 1
254+
255+
q_fp8 = q.to(torch.float8_e4m3fn)
256+
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
257+
258+
schedule_metadata = get_paged_mqa_logits_metadata(
259+
context_lens, blocksize, 132
260+
)
261+
logits = fp8_paged_mqa_logits(
262+
q_fp8,
263+
kv_cache_fp8,
264+
weights,
265+
context_lens,
266+
block_tables,
267+
schedule_metadata,
268+
max_model_len,
269+
)
270+
271+
ref_logits = _ref_fp8_paged_mqa_logits(
272+
q,
273+
kv_cache,
274+
weights,
275+
context_lens,
276+
block_tables,
277+
max_model_len,
278+
)
279+
280+
positions = (
281+
torch.arange(max_model_len, device="cuda")
282+
.unsqueeze(0)
283+
.expand(batch_size * next_n, -1)
284+
)
285+
row_indices = (
286+
torch.arange(batch_size * next_n, device="cuda") // next_n
287+
)
288+
next_n_offset = (
289+
torch.arange(batch_size * next_n, device="cuda") % next_n
290+
)
291+
mask = positions <= (
292+
context_lens[row_indices] - next_n + next_n_offset
293+
).unsqueeze(1)
294+
295+
logits = logits.masked_fill(~mask, 0)
296+
ref_logits = ref_logits.masked_fill(~mask, 0)
297+
diff = calc_diff(logits, ref_logits)
298+
assert diff < 1e-3, f"{diff=}"

0 commit comments

Comments
 (0)