Skip to content

Commit b96551c

Browse files
committed
Merge remote-tracking branch 'origin/main' into bump-version-bot
2 parents cec8c7c + aacc8df commit b96551c

File tree

178 files changed

+20332
-5465
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

178 files changed

+20332
-5465
lines changed

.github/CODEOWNERS

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,41 +3,41 @@
33
# Analysis period: 180 days
44
# Minimum commits threshold: 1
55

6-
benchmarks/ @bkryu @cyx-6 @nv-yunzheq @kahyunnam @nvmbreughe
6+
benchmarks/ @bkryu @cyx-6 @jiahanc @nv-yunzheq @kahyunnam
77
benchmarks/routines/ @bkryu @nv-yunzheq @cyx-6 @nvmbreughe @Anerudhan
88
ci/ @cyx-6 @yzh119 @nvmbreughe
99
ci/scripts/ @cyx-6
1010
ci/scripts/jenkins/ @cyx-6
11-
csrc/ @yzh119 @wenscarl @cyx-6 @yongwww @kahyunnam
12-
csrc/fused_moe/ @yzh119 @yongwww @wenscarl @cyx-6 @yongwww
13-
csrc/fused_moe/cutlass_backend/ @yzh119 @yongwww @wenscarl @cyx-6 @yongwww
14-
csrc/nv_internal/ @wenscarl @yzh119 @cyx-6 @yongwww @aleozlx
15-
csrc/nv_internal/cpp/ @wenscarl @yongwww @joker-eph @ttyio @azhurkevich
11+
csrc/ @wenscarl @yzh119 @cyx-6 @djmmoss @yongwww
12+
csrc/fused_moe/ @yzh119 @yongwww @djmmoss @cyx-6 @wenscarl
13+
csrc/fused_moe/cutlass_backend/ @yzh119 @yongwww @djmmoss @cyx-6 @wenscarl
14+
csrc/nv_internal/ @wenscarl @djmmoss @cyx-6 @yzh119 @yongwww
15+
csrc/nv_internal/cpp/ @wenscarl @yongwww @djmmoss @joker-eph @ttyio
1616
csrc/nv_internal/include/ @wenscarl
17-
csrc/nv_internal/tensorrt_llm/ @wenscarl @yzh119 @cyx-6 @yongwww @aleozlx
18-
csrc/xqa/ @yzh119 @cyx-6
17+
csrc/nv_internal/tensorrt_llm/ @wenscarl @djmmoss @cyx-6 @yzh119 @yongwww
18+
csrc/xqa/ @cyx-6 @yzh119
1919
docs/ @yzh119 @cyx-6 @wenscarl @nv-yunzheq @aleozlx
20-
flashinfer/ @yzh119 @cyx-6 @nvmbreughe @wenscarl @yongwww
20+
flashinfer/ @yzh119 @cyx-6 @wenscarl @nvmbreughe @yongwww
2121
flashinfer-cubin/ @yzh119 @cyx-6
2222
flashinfer-cubin/flashinfer_cubin/ @yzh119
2323
flashinfer-jit-cache/ @yzh119 @cyx-6
2424
flashinfer-jit-cache/flashinfer_jit_cache/ @yzh119
25-
flashinfer/comm/ @yzh119 @cyx-6 @nvmbreughe @wenscarl @aleozlx
25+
flashinfer/comm/ @yzh119 @cyx-6 @nvmbreughe @wenscarl @djmmoss
2626
flashinfer/cudnn/ @Anerudhan @yzh119 @cyx-6 @Anerudhan
2727
flashinfer/cute_dsl/ @yzh119 @kaixih @Amir-19 @aleozlx
28-
flashinfer/fused_moe/ @yzh119 @cyx-6 @wenscarl @IwakuraRein @joker-eph
29-
flashinfer/jit/ @yzh119 @cyx-6 @aleozlx @yongwww @bkryu
30-
flashinfer/jit/attention/ @yzh119 @Anerudhan @joker-eph
28+
flashinfer/fused_moe/ @djmmoss @yzh119 @cyx-6 @wenscarl @IwakuraRein
29+
flashinfer/jit/ @yzh119 @cyx-6 @djmmoss @jiahanc @aleozlx
30+
flashinfer/jit/attention/ @yzh119 @cyx-6 @Anerudhan @joker-eph
3131
flashinfer/jit/gemm/ @yzh119
3232
flashinfer/logits_processor/ @cyx-6 @yzh119
3333
flashinfer/profiler/ @cyx-6
3434
flashinfer/triton/ @cyx-6 @nvmbreughe @yzh119
3535
flashinfer/tuning_configs/ @kaixih
36-
include/ @yzh119 @cyx-6 @kahyunnam @joker-eph @aleozlx
37-
include/flashinfer/ @yzh119 @cyx-6 @kahyunnam @joker-eph @aleozlx
36+
include/ @yzh119 @wenscarl @kahyunnam @joker-eph @cyx-6
37+
include/flashinfer/ @yzh119 @wenscarl @kahyunnam @joker-eph @cyx-6
3838
include/flashinfer/attention/ @yzh119 @kahyunnam @joker-eph
39-
include/flashinfer/comm/ @yongwww @nvmbreughe @yzh119 @cyx-6
40-
include/flashinfer/gemm/ @ttyio @yongwww @aleozlx @cyx-6
41-
include/flashinfer/trtllm/ @joker-eph @aleozlx @yzh119 @cyx-6 @aleozlx
39+
include/flashinfer/comm/ @yongwww @nvmbreughe @djmmoss @yzh119 @cyx-6
40+
include/flashinfer/gemm/ @ttyio @yongwww @aleozlx
41+
include/flashinfer/trtllm/ @joker-eph @aleozlx @yzh119 @cyx-6 @wenscarl
4242
profiler/ @cyx-6
43-
scripts/ @yzh119 @nvmbreughe @yongwww @bkryu @dierksen
43+
scripts/ @yzh119 @nvmbreughe @dierksen @yongwww @bkryu

.github/workflows/nightly-release.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ jobs:
9898
run: |
9999
python -m pip install --upgrade pip
100100
pip install build twine wheel
101-
pip install setuptools>=61.0 requests filelock torch tqdm numpy apache-tvm-ffi==0.1.0b15
102101
103102
- name: Build flashinfer-cubin wheel
104103
env:

.github/workflows/release.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ jobs:
136136
run: |
137137
python -m pip install --upgrade pip
138138
pip install build twine wheel
139-
pip install setuptools>=61.0 requests filelock torch tqdm numpy apache-tvm-ffi==0.1.0b15
140139
141140
- name: Build flashinfer-cubin wheel
142141
run: |

CONTRIBUTING.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,11 @@ Code Contribution Procedure
3636

3737
# Release Versioning
3838

39-
When incrementing a version and creating a release, follow [Semantic Versioning](https://packaging.python.org/en/latest/discussions/versioning/) (`major.minor.patch`) [^1]. In particular:
39+
When incrementing a version and creating a release, follow a "right-shifted" versioning scheme similar to [vLLM Release Versioning](https://github.com/vllm-project/vllm/blob/main/RELEASE.md) (`major.minor.patch[.post1]`) [^1]. In particular:
4040

41-
* major increment signals incompatible API changes
42-
* minor increment signals added functionality that is backwards-compatible (e.g. new kernels, new SM support, etc)
43-
* patch increment signals backwards-compatible bug fixes (both for functional and performance issues)
41+
* _major_ increment signals architectural milestone and/or when incompatible API changes are made, similar to PyTorch 2.0.
42+
* _minor_ increment signals significant backwards-compatible new features
43+
* _patch_ increment signals small backwards-compatible features (e.g. new kernels, new SM support, etc) and backwards-compatible bug fixes
44+
* _post1_ is an optional suffix for a quick follow up release with just backwards-compatible bug fixes
4445

45-
Optionally, use post-releases (e.g., `X.Y.Z.post1`) for minor changes, like a documentation change.
46-
47-
[^1]: We have not followed this strictly through v0.2.14.post1. But after v0.2.14.post1, the versioning should follow SemVer.
46+
[^1]: We have not followed this strictly through v0.4.0. But after v0.4.0, the versioning should follow this "right-shifted" versioning scheme.

benchmarks/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Currently supports testing most attention, gemm, and fused MOE APIs:
1616
- `BatchPrefillWithPagedKVCacheWrapper` - Prefill attention with paged KV cache.
1717
- Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_batch_context_with_kv_cache`.
1818
- `BatchPrefillWithRaggedKVCacheWrapper` - Prefill attention with ragged KV cache.
19+
- Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_ragged_attention_deepseek`.
1920
- `BatchMLAPagedAttentionWrapper` - MLA attention proposed in DeepSeek series of models.
2021
- Also supports computationally similar `trtllm_batch_decode_with_kv_cache_mla`.
2122
- GEMM:

benchmarks/bench_mixed_attention.py

Lines changed: 69 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,24 @@ def run_bench(
7272
measurements = bench_gpu_time(lambda: wrapper_old.run(q, kv_data))
7373
ms_old = np.median(measurements)
7474

75+
wrapper_persistent = flashinfer.BatchAttention(kv_layout="NHD")
76+
wrapper_persistent.plan(
77+
q_indptr.to(device),
78+
kv_indptr.to(device),
79+
torch.arange(num_blocks, dtype=torch.int32, device=device),
80+
seq_lens.to(device),
81+
num_qo_heads,
82+
num_kv_heads,
83+
head_dim,
84+
head_dim,
85+
page_block_size,
86+
causal=causal,
87+
q_data_type=torch.bfloat16,
88+
kv_data_type=torch.bfloat16,
89+
)
90+
o_persistent, _ = wrapper_persistent.run(q, kv_data)
91+
measurements_persistent = bench_gpu_time(lambda: wrapper_persistent.run(q, kv_data))
92+
ms_persistent = np.mean(measurements_persistent)
7593
if len(p_kv_lens) == 1:
7694
q_d = q[: d_q_indptr[-1]]
7795
kv_d = kv_data[: d_kv_indptr[-1]].unbind(1)
@@ -123,9 +141,46 @@ def run_bench(
123141
)
124142
)
125143
ms_pod = np.median(measurements)
144+
145+
# Sequential two kernels: single prefill + batch decode (tensor cores)
146+
# Prefill using single_prefill_with_kv_cache
147+
def _run_single_prefill():
148+
return flashinfer.prefill.single_prefill_with_kv_cache(
149+
q_p,
150+
k_p,
151+
v_p,
152+
causal=causal,
153+
pos_encoding_mode="NONE",
154+
backend="fa2",
155+
)
156+
157+
measurements_prefill = bench_gpu_time(lambda: _run_single_prefill())
158+
ms_prefill = np.median(measurements_prefill)
159+
160+
# Batch decode using tensor cores
161+
wrapper_decode = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
162+
workspace_buffer, kv_layout=kv_layout, use_tensor_cores=True
163+
)
164+
wrapper_decode.plan(
165+
d_kv_indptr.to(device),
166+
kv_indices_d.to(device),
167+
last_page_len_d,
168+
num_qo_heads,
169+
num_kv_heads,
170+
head_dim,
171+
page_block_size,
172+
data_type=torch.bfloat16,
173+
q_data_type=torch.bfloat16,
174+
)
175+
measurements_decode = bench_gpu_time(lambda: wrapper_decode.run(q_d, kv_d))
176+
ms_decode = np.median(measurements_decode)
177+
ms_seq_two_kernels = ms_prefill + ms_decode
178+
126179
print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms")
127180
if len(p_kv_lens) == 1:
128181
print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms")
182+
print(f"Elapsed time (Sequential two kernels): {ms_seq_two_kernels:.2f} ms")
183+
print(f"Elapsed time (Persistent BatchAttention): {ms_persistent:.2f} ms")
129184
total_bytes = (
130185
q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
131186
)
@@ -137,77 +192,29 @@ def run_bench(
137192
if len(p_kv_lens) == 1:
138193
bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3)
139194
print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s")
195+
bandwidth_seq_gb_s = total_bytes / (ms_seq_two_kernels * 1e-3) / (1024**3)
196+
print(
197+
f"Memory bandwidth (Sequential two kernels): {bandwidth_seq_gb_s:.2f} GB/s"
198+
)
199+
bandwidth_persistent_gb_s = total_bytes / (ms_persistent * 1e-3) / (1024**3)
200+
print(
201+
f"Memory bandwidth (Persistent BatchAttention): {bandwidth_persistent_gb_s:.2f} GB/s"
202+
)
140203

141204

142205
if __name__ == "__main__":
143206
np.random.seed(42)
144207
torch.random.manual_seed(42)
145208

146209
# Irregular sequence lengths for prefill and decode
147-
d_q_len_configs = [[1] * 122, [1] * 128, [1] * 242, [1] * 256]
148-
d_kv_len_configs = [[600] * 122, [10000] * 128, [400] * 242, [8192] * 256]
149-
p_q_configs = [[17] * 1, [10000], [17] * 1, []]
150-
p_kv_configs = [[10000] * 1, [10000], [8192] * 1, []]
151-
152-
# construct random length testcases
153-
for _ in range(1):
154-
bsz = 256
155-
stride = 16
156-
sparsity = 0.05
157-
158-
full_kv_len = np.random.randint(1000, 8192, size=bsz)
159-
p_q_lens = []
160-
p_kv_lens = []
161-
d_q_lens = []
162-
d_kv_lens = []
163-
for i in range(bsz):
164-
if i % stride == 0:
165-
kv_len = full_kv_len[i]
166-
qo_len = stride + 1
167-
p_q_lens.append(qo_len)
168-
p_kv_lens.append(kv_len)
169-
else:
170-
kv_len = int(full_kv_len[i] * sparsity)
171-
qo_len = 1
172-
d_q_lens.append(qo_len)
173-
d_kv_lens.append(kv_len)
174-
175-
p_q_configs.append(p_q_lens)
176-
p_kv_configs.append(p_kv_lens)
177-
d_q_len_configs.append(d_q_lens)
178-
d_kv_len_configs.append(d_kv_lens)
179-
180-
for _ in range(1):
181-
bsz = 128
182-
stride = 16
183-
sparsity = 0.05
184-
185-
full_kv_len = np.random.randint(2000, 16000, size=bsz)
186-
p_q_lens = []
187-
p_kv_lens = []
188-
d_q_lens = []
189-
d_kv_lens = []
190-
191-
for i in range(bsz):
192-
if i % stride == 0:
193-
kv_len = full_kv_len[i]
194-
qo_len = stride + 1
195-
p_q_lens.append(qo_len)
196-
p_kv_lens.append(kv_len)
197-
else:
198-
kv_len = int(full_kv_len[i] * sparsity)
199-
qo_len = 1
200-
d_q_lens.append(qo_len)
201-
d_kv_lens.append(kv_len)
202-
203-
p_q_configs.append(p_q_lens)
204-
p_kv_configs.append(p_kv_lens)
205-
d_q_len_configs.append(d_q_lens)
206-
d_kv_len_configs.append(d_kv_lens)
210+
d_q_len_configs = [[1] * 128, [1] * 128, [1] * 128, [1] * 128]
211+
d_kv_len_configs = [[2048] * 128, [4096] * 128, [8192] * 128, [8192] * 128]
212+
p_q_configs = [[2048], [4096], [4096], [6000]]
213+
p_kv_configs = [[2048], [4096], [4096], [7000]]
207214

208215
page_block_size = 1
209-
num_kv_heads = 4
210-
num_qo_heads = 28
216+
num_kv_heads = 8
217+
num_qo_heads = 32
211218
head_dim = 128
212219

213220
for idx, (p_q_lens, p_kv_lens, d_q_len, d_kv_len) in enumerate(

benchmarks/bench_rope_quantize_fp8.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _apply_rotary_emb(
8888
return torch.stack((o1, o2), dim=-1).flatten(-2)
8989

9090

91-
def benchmark_config(config_name, num_tokens, provider):
91+
def benchmark_config(config_name, num_tokens, provider, enable_pdl=False):
9292
"""Benchmark a specific attention configuration."""
9393
input_dtype = torch.bfloat16
9494
device = "cuda"
@@ -177,6 +177,7 @@ def execute():
177177
k_nope_out=k_nope_out,
178178
quant_scale_q=1.0,
179179
quant_scale_kv=1.0,
180+
enable_pdl=enable_pdl,
180181
)
181182

182183
if mode_ncu and run_idx == 20:
@@ -278,6 +279,23 @@ def benchmark_mha(provider, num_tokens):
278279
return benchmark_config("mha", num_tokens, provider)
279280

280281

282+
@triton.testing.perf_report(
283+
triton.testing.Benchmark(
284+
x_names=["num_tokens"],
285+
x_vals=[768] if mode_ncu else [1, 2, 4, 8, 16, 32, 64, 128, 256, 384, 512, 768],
286+
line_arg="enable_pdl",
287+
line_vals=[False, True],
288+
line_names=["enable_pdl=False", "enable_pdl=True"],
289+
styles=[("blue", "-"), ("red", "-")],
290+
ylabel="Latency (ms)",
291+
plot_name="rope-pdl-benchmark",
292+
args={},
293+
)
294+
)
295+
def benchmark_pdl(enable_pdl, num_tokens):
296+
return benchmark_config("mla", num_tokens, "flashinfer", enable_pdl=enable_pdl)
297+
298+
281299
if __name__ == "__main__":
282300
# Run all benchmarks and generate individual plots
283301
print("Running MLA benchmark...")
@@ -289,6 +307,9 @@ def benchmark_mha(provider, num_tokens):
289307
print("Running MHA benchmark...")
290308
benchmark_mha.run(print_data=False, show_plots=True, save_path=".")
291309

310+
print("Running PDL benchmark...")
311+
benchmark_pdl.run(print_data=False, show_plots=True, save_path=".")
312+
292313
# Collect results for summary table
293314
token_counts = (
294315
[1, 2, 4, 8, 16, 32, 64, 128, 256, 384, 512, 768] if not mode_ncu else [768]
@@ -319,3 +340,4 @@ def benchmark_mha(provider, num_tokens):
319340
print(" mla-rope-benchmark.png (FlashInfer vs PyTorch)")
320341
print(" gqa-rope-benchmark.png (FlashInfer vs PyTorch)")
321342
print(" mha-rope-benchmark.png (FlashInfer vs PyTorch)")
343+
print(" rope-pdl-benchmark.png (enable_pdl=False vs enable_pdl=True)")

0 commit comments

Comments
 (0)