[CP] fuse fwd/bwd kernels and fix IMA in long context#733
Conversation
WalkthroughAdds two new benchmark scripts, fuses CP pre-processing into merged Triton kernels for forward/backward, widens kernel index arithmetic to 64-bit in several Triton kernels, and simplifies the test memory-guard by replacing canary/padding logic with direct NaN-filling for floating/complex tensors. Changes
Sequence Diagram(s)sequenceDiagram
participant Driver as Driver (launch)
participant Rank as Worker Rank
participant CP as CP Group
participant TP as TP Group
participant GPU as GPU / Triton Kernel
Driver->>Rank: start distributed benchmark (args)
Rank->>CP: join CP group
Rank->>TP: join TP group
Rank->>GPU: run local fused pre-process kernel
GPU-->>Rank: return local shard outputs
Rank->>CP: intra-CP all-gather / all-reduce (assemble/reduce shards)
CP->>TP: optional inter-group coordination
TP->>GPU: run TP-side kernels / execute remainder
GPU-->>Rank: return final results
Rank->>Driver: report timings & profiler data
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @zhiyuan1i, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance and stability of the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
d85e659 to
f49979b
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces two significant improvements: the fusion of forward and backward pre-processing kernels for better performance in context parallelism, and a fix for potential Illegal Memory Access (IMA) errors in long context scenarios. The kernel fusion is well-implemented through new merged Triton kernels, which are correctly integrated and benchmarked, demonstrating a clear performance benefit. The IMA fixes are applied consistently across multiple kernels by ensuring pointer arithmetic is performed using 64-bit integers, which is a critical correction for handling long sequences. Furthermore, the inclusion of new, comprehensive benchmark scripts and detailed documentation for the context parallel implementation are excellent additions that enhance the library's testability and maintainability. Overall, the changes are of high quality and significantly improve the robustness and performance of the codebase.
e4d1d49 to
da166ba
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py`:
- Around line 259-349: The custom-CP benchmark (kda_with_custom_cp) calls
chunk_kda with different kernel/config params than the simple-CP and baseline
variants, making timings incomparable; update the chunk_kda call inside
kda_with_custom_cp to pass the same kernel/config arguments used by
kda_with_simple_cp/kda_baseline_single_gpu (A_log, dt_bias,
use_gate_in_kernel=True, safe_gate=True, lower_bound=-5,
use_qk_l2norm_in_kernel=True) while keeping its cu_seqlens=cp_cu_seqlens and
existing TP all-reduce/backward handling so the computation is functionally
equivalent across kda_with_custom_cp, kda_with_simple_cp, and
kda_baseline_single_gpu.
🧹 Nitpick comments (7)
fla/ops/cp/chunk_delta_h.py (1)
838-843: Redundant if-else branches - both paths execute identical code.The
if USE_EXP2/elsebranches perform the same tensor load operation. This pattern repeats at lines 850-853, 862-865, and 872-875.♻️ Simplify by removing redundant conditional
if USE_GK: o_k1 = tl.arange(0, 64) - if USE_EXP2: - b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.).to(tl.float32) - else: - b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.).to(tl.float32) + b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.).to(tl.float32) b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype))Apply the same simplification to the other redundant branches (o_k2, o_k3, o_k4).
benchmarks/cp/benchmark_chunk_delta_h_kernels.py (2)
70-71: Consider usingtorch.int64for cu_seqlens to match kernel expectations.The kernels load
cu_seqlensvalues with.to(tl.int64)(e.g., line 57 in chunk_delta_h.py). While the currentint32works due to the explicit cast in the kernel, usingint64here would be more consistent with the 64-bit indexing changes in this PR.♻️ Proposed change
# cu_seqlens for 1 chunk - cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32) + cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int64)
425-428: Minor: Guard against division by zero in FWD/BWD ratio.While unlikely in practice,
bwd_timecould theoretically be zero.♻️ Proposed change
print(f"Forward pass total: {fwd_time * 1000:.2f} us ({fwd_time/total_time*100:.1f}%)") print(f"Backward pass total: {bwd_time * 1000:.2f} us ({bwd_time/total_time*100:.1f}%)") - print(f"FWD/BWD ratio: {fwd_time/bwd_time:.2f}") + print(f"FWD/BWD ratio: {fwd_time/bwd_time:.2f}" if bwd_time > 0 else "FWD/BWD ratio: N/A") print()benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py (4)
48-48: Redundant argument configuration.Using
action="store_true"withdefault=Truemakes this flag always True. If you want to allow disabling the benchmark, consider using--no-benchwithaction="store_false".♻️ Suggested fix
- parser.add_argument("--bench", action="store_true", default=True, help="Run benchmark") + parser.add_argument("--no-bench", action="store_true", default=False, help="Skip benchmark")Then update line 365:
- if args.bench: + if not args.no_bench:
77-77: Rename unused loop variable to_.The loop variable
iis not used within the loop body. Per Python convention, rename to_to indicate it's intentionally unused.♻️ Suggested fix
- for i in range(warm_up): + for _ in range(warm_up): fn() if grad_to_none is not None: for x in grad_to_none: x.grad = None # Benchmark torch.cuda.synchronize() start_event.record() - for i in range(step): + for _ in range(step): fn()Also applies to: 86-86
95-95: Remove unusedrankparameter.The
rankparameter is passed toprofile_kernels()but never used within the function.♻️ Suggested fix
-def profile_kernels(fn, rank, steps=5, warmup=2): +def profile_kernels(fn, steps=5, warmup=2):And update the call site at line 360:
- kernel_stats = profile_kernels(kda_with_custom_cp, rank, steps=5, warmup=2) + kernel_stats = profile_kernels(kda_with_custom_cp, steps=5, warmup=2)
237-243: Consider definingdo_baseunconditionally to avoid potentialNameError.When
run_backward=Trueandtest_baseline=False,do_baseis never defined. While the current control flow prevents accessing it, this structure is fragile. Defining it unconditionally improves clarity.♻️ Suggested fix
# Output gradient for backward + do_base = None if run_backward: do = torch.randn(B, T_local, H_local, V, dtype=DTYPE, device=device) if test_baseline: do_base = torch.randn(B, T_baseline, H_local, V, dtype=DTYPE, device=device) else: do = None - do_base = None
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py`:
- Around line 188-198: Add explicit divisibility checks immediately after
computing T_total, T_local, T_baseline and H_local: verify T_total % cp_size ==
0 and H % tp_size == 0 (and that cp_size and tp_size are > 0), and raise or
assert with a clear explanatory message if they are not divisible so the user
sees which dimension is invalid; update the logic around variables T_local,
T_baseline and H_local to rely on these validated values (variables referenced:
T_total, T_local, T_baseline, H_local, cp_size, tp_size, H).
- Around line 187-225: The CP paths assume B==1 which breaks when batch size >1;
fix by making cu_seqlens and the all-gather logic batch-aware and removing any
unconditional squeeze(0) that collapses the batch dim. Specifically, ensure
cu_seqlens passed to build_cp_context is shaped per-batch (one cu_seqlens
sequence per sample) instead of a single [0, T_total], update the
all-gather/concat that previously used squeeze(0) to preserve the B dimension
when gathering q/k/v across CP ranks, and adjust chunk_kda inputs/reshaping to
accept (B, T_local, H_local, ...) tensors rather than relying on B==1. Update
references: cu_seqlens, build_cp_context, chunk_kda, and the squeeze(0) usage to
handle arbitrary B.
- Around line 418-435: The script binds CUDA using the global process rank which
breaks multi-node torchrun; in main() replace uses of rank for device-selection
and device-name queries with the LOCAL_RANK provided by torchrun: read
local_rank = int(os.environ.get("LOCAL_RANK", rank)) (or fallback to rank), call
torch.cuda.set_device(local_rank), and use
torch.cuda.get_device_name(local_rank); also use local_rank when seeding CUDA
(torch.cuda.manual_seed) so per-process seeds are per-local-device while keeping
distributed init via dist.init_process_group() and world_size/rank logic
unchanged.
🧹 Nitpick comments (2)
benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py (2)
71-88: Avoid grad accumulation during timed iterations (and fix unused loop vars).Grad accumulation can skew timing/memory, and Ruff flagged unused loop vars. Clear grads inside the timed loop and rename the index to
_.♻️ Proposed refactor
- for i in range(warm_up): + for _ in range(warm_up): fn() if grad_to_none is not None: for x in grad_to_none: x.grad = None @@ - for i in range(step): + for _ in range(step): fn() + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None
95-98: Remove unusedrankargument fromprofile_kernels.
rankisn’t used; simplify the signature and call site.♻️ Proposed refactor
-def profile_kernels(fn, rank, steps=5, warmup=2): +def profile_kernels(fn, steps=5, warmup=2): @@ - kernel_stats = profile_kernels(kda_with_cp, rank, steps=5, warmup=2) + kernel_stats = profile_kernels(kda_with_cp, steps=5, warmup=2)Also applies to: 366-367
| def main(): | ||
| dist.init_process_group() | ||
| world_size = dist.get_world_size() | ||
| rank = dist.get_rank() | ||
|
|
||
| torch.manual_seed(rank + 42) | ||
| torch.cuda.manual_seed(rank + 42) | ||
| random.seed(42) | ||
| torch.cuda.set_device(rank) | ||
|
|
||
| args = get_args() | ||
|
|
||
| # Print environment info | ||
| if rank == 0: | ||
| print(f"\nPyTorch version: {torch.__version__}") | ||
| print(f"CUDA available: {torch.cuda.is_available()}") | ||
| print(f"CUDA device: {torch.cuda.get_device_name(rank)}") | ||
| print(f"World size: {world_size}\n") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Check if the file exists and examine the code at the specified lines
fd "benchmark_kda_cp8_vs_cp2tp.py" --type fRepository: fla-org/flash-linear-attention
Length of output: 117
🏁 Script executed:
# Also check the full imports and context of the file
fd "benchmark_kda_cp8_vs_cp2tp.py" -x wc -l {}Repository: fla-org/flash-linear-attention
Length of output: 123
🌐 Web query:
torchrun LOCAL_RANK environment variable torch.cuda.set_device best practice multi-node
💡 Result:
Best practice (multi-node, torchrun)
- Read ranks from env (don’t invent them):
LOCAL_RANK= GPU index on this node for this process (0…LOCAL_WORLD_SIZE-1)RANK/WORLD_SIZE= global process id / total processes across all nodes (docs.pytorch.org)
- Set the CUDA device early (before creating CUDA tensors/modules):
- Recommended pattern is to bind each process to exactly one GPU using
LOCAL_RANK, e.g.:torch.cuda.set_device(local_rank)(or awith torch.cuda.device(local_rank): ...scope) (docs.pytorch.org)
- This avoids every process defaulting to
cuda:0unless you explicitly change it. (discuss.pytorch.org)
- Initialize distributed with
env://and wrap model with per-process device ids:
- Use
init_method="env://"(officially supported approach) (docs.pytorch.org) - For
DistributedDataParallel, ensure each process uses only its GPU:DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)(docs.pytorch.org)
- Don’t treat
LOCAL_RANKas globally unique:
- It’s only unique within a node—don’t use it for “rank-0 does logging/checkpointing” decisions; use global
RANK==0for that. (docs.pytorch.org)
Minimal, solid template
import os
import torch
import torch.distributed as dist
def main():
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.cuda.set_device(local_rank) # do this before any CUDA allocations
dist.init_process_group(backend="nccl", init_method="env://")
device = torch.device("cuda", local_rank)
model = build_model().to(device)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank], output_device=local_rank
)torch.cuda.set_device(local_rank) vs CUDA_VISIBLE_DEVICES
Either works as long as each process exclusively operates on its assigned GPU:
set_device(local_rank): process can still “see” all GPUs, but default ops go to the selected device. (discuss.pytorch.org)CUDA_VISIBLE_DEVICESper process: each process only sees one GPU (so it often reportscuda:0, mapped to a different physical GPU per process). (discuss.pytorch.org)
Citations:
- 1: https://docs.pytorch.org/docs/stable/elastic/run.html?utm_source=openai
- 2: https://docs.pytorch.org/docs/stable/distributed?utm_source=openai
- 3: https://discuss.pytorch.org/t/should-local-rank-be-equal-to-torch-cuda-current-device/150873?utm_source=openai
- 4: https://docs.pytorch.org/docs/stable/distributed?utm_source=openai
- 5: https://docs.pytorch.org/docs/2.8/distributed.html?utm_source=openai
- 6: https://docs.pytorch.org/docs/2.8/distributed.html?utm_source=openai
- 7: https://discuss.pytorch.org/t/should-local-rank-be-equal-to-torch-cuda-current-device/150873?utm_source=openai
- 8: https://discuss.pytorch.org/t/should-local-rank-be-equal-to-torch-cuda-current-device/150873?utm_source=openai
🏁 Script executed:
# Read the file to see the actual implementation at lines 418-435
sed -n '410,443p' benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.pyRepository: fla-org/flash-linear-attention
Length of output: 1075
🏁 Script executed:
# Check imports at the top of the file
head -30 benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.pyRepository: fla-org/flash-linear-attention
Length of output: 1143
🏁 Script executed:
# Check if os is imported anywhere in the file
rg "^import os|^from os" benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.pyRepository: fla-org/flash-linear-attention
Length of output: 56
Use LOCAL_RANK for device selection (multi-node safe).
torch.cuda.set_device(rank) uses global rank and breaks multi-node torchrun. This script explicitly uses torchrun --nproc_per_node=8, which sets LOCAL_RANK for each process. Use LOCAL_RANK instead of global rank for device binding, and apply it to both device-setting and device-name queries.
Proposed fix
+import os
import argparse
import random
@@
def main():
dist.init_process_group()
world_size = dist.get_world_size()
rank = dist.get_rank()
@@
- torch.cuda.set_device(rank)
+ local_rank = int(os.environ.get("LOCAL_RANK", rank))
+ torch.cuda.set_device(local_rank)
@@
- print(f"CUDA device: {torch.cuda.get_device_name(rank)}")
+ print(f"CUDA device: {torch.cuda.get_device_name(local_rank)}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def main(): | |
| dist.init_process_group() | |
| world_size = dist.get_world_size() | |
| rank = dist.get_rank() | |
| torch.manual_seed(rank + 42) | |
| torch.cuda.manual_seed(rank + 42) | |
| random.seed(42) | |
| torch.cuda.set_device(rank) | |
| args = get_args() | |
| # Print environment info | |
| if rank == 0: | |
| print(f"\nPyTorch version: {torch.__version__}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| print(f"CUDA device: {torch.cuda.get_device_name(rank)}") | |
| print(f"World size: {world_size}\n") | |
| def main(): | |
| dist.init_process_group() | |
| world_size = dist.get_world_size() | |
| rank = dist.get_rank() | |
| torch.manual_seed(rank + 42) | |
| torch.cuda.manual_seed(rank + 42) | |
| random.seed(42) | |
| local_rank = int(os.environ.get("LOCAL_RANK", rank)) | |
| torch.cuda.set_device(local_rank) | |
| args = get_args() | |
| # Print environment info | |
| if rank == 0: | |
| print(f"\nPyTorch version: {torch.__version__}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| print(f"CUDA device: {torch.cuda.get_device_name(local_rank)}") | |
| print(f"World size: {world_size}\n") |
🤖 Prompt for AI Agents
In `@benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py` around lines 418 - 435, The
script binds CUDA using the global process rank which breaks multi-node
torchrun; in main() replace uses of rank for device-selection and device-name
queries with the LOCAL_RANK provided by torchrun: read local_rank =
int(os.environ.get("LOCAL_RANK", rank)) (or fallback to rank), call
torch.cuda.set_device(local_rank), and use
torch.cuda.get_device_name(local_rank); also use local_rank when seeding CUDA
(torch.cuda.manual_seed) so per-process seeds are per-local-device while keeping
distributed init via dist.init_process_group() and world_size/rank logic
unchanged.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@tests/conftest.py`:
- Around line 90-97: Remove the trailing-whitespace-only lines in the zeros-like
replacement block to satisfy lint: edit the function that calls
_ORIGINAL_ZEROS_LIKE so the lines around the result assignment and the
subsequent if/elif (where result.is_floating_point(), result.fill_(...),
result.is_complex(), result.fill_(...)) contain no trailing spaces or
blank-space-only lines (or run the repo's pre-commit hook) so the file has no
trailing whitespace.
🧹 Nitpick comments (1)
tests/conftest.py (1)
55-75: Align_guarded_zeroswith compile bypass used elsewhere.
_guarded_zeros_likeand_guarded_new_zerosskip poisoning whentorch.compiler.is_compiling()is true, but_guarded_zerosdoes not. Iftorch.zerosis hit during compilation, NaN poisoning could leak into compiled graphs. Consider short‑circuiting onis_compiling()here as well for consistency.♻️ Suggested tweak
def _guarded_zeros(*args, **kwargs): """Create a tensor filled with NaN instead of zeros to detect incorrect dependency on zero initialization.""" - dtype = kwargs.get('dtype') or torch.get_default_dtype() + if is_compiling() or not _is_called_from_fla(): + return _ORIGINAL_ZEROS(*args, **kwargs) + + dtype = kwargs.get('dtype') or torch.get_default_dtype() if not (dtype.is_floating_point or dtype.is_complex): return _ORIGINAL_ZEROS(*args, **kwargs) - - # Check if this call is from fla package or from test directly - if not _is_called_from_fla(): - # Direct call from test file - don't guard - return _ORIGINAL_ZEROS(*args, **kwargs)
f82f78f to
7c66d96
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py`:
- Around line 174-184: The divisibility checks use undefined variables T_total
and H (they are defined later), causing a NameError; move the validation block
that checks T_total % cp_size and H % tp_size so it runs after T_total and H are
assigned (after the assignments currently at lines where T_total and H are set
inside run_benchmark or the surrounding scope), or alternatively compute/assign
T_total and H before performing the checks; ensure the checks reference the same
cp_size and tp_size variables used elsewhere (e.g., in run_benchmark) and remove
the original validation lines from their current location.
- Around line 257-259: The dt_bias tensor is created with shape [H_local] but
KDA expects a per-head, per-head-dimension bias of shape (H_local, K); change
the dt_bias construction so it uses both H_local and K (e.g.,
torch.randn(H_local, K, device=device, dtype=torch.float32)) to match KDA
requirements and avoid shape mismatches when used alongside A_log and in
chunk_kda computations.
🧹 Nitpick comments (2)
benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py (2)
95-95: Unused function parameterrank.The
rankparameter is never used withinprofile_kernels. Consider removing it or prefixing with underscore if reserved for future use.♻️ Proposed fix
-def profile_kernels(fn, rank, steps=5, warmup=2): +def profile_kernels(fn, steps=5, warmup=2):And update the call site at line 378:
- kernel_stats = profile_kernels(kda_with_cp, rank, steps=5, warmup=2) + kernel_stats = profile_kernels(kda_with_cp, steps=5, warmup=2)
71-92: Unused loop variablesi.Loop control variables at lines 77 and 86 are not used within the loop bodies. Rename to
_per Python convention.♻️ Proposed fix
# Warmup - for i in range(warm_up): + for _ in range(warm_up): fn() if grad_to_none is not None: for x in grad_to_none: x.grad = None # Benchmark torch.cuda.synchronize() start_event.record() - for i in range(step): + for _ in range(step): fn()
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py`:
- Around line 172-173: The runtime check using assert on cp_size * tp_size is
unsafe because asserts can be disabled; replace it with an explicit validation
that checks if cp_size * tp_size != world_size and then raise a clear exception
(e.g., ValueError) or call sys.exit with the same descriptive message
referencing cp_size, tp_size, and world_size so invalid CP/TP sizing fails
immediately and deterministically.
In `@tests/conftest.py`:
- Around line 55-75: _guarded_zeros is missing the is_compiling() bypass present
in _guarded_zeros_like and _guarded_new_zeros; add an early guard that calls
is_compiling() and if True returns _ORIGINAL_ZEROS(*args, **kwargs) (place this
check after determining dtype and before calling _is_called_from_fla()) so the
function short-circuits during compilation and preserves the existing
inspect/NaN behavior only at runtime.
| assert cp_size * tp_size == world_size, \ | ||
| f"CP size ({cp_size}) * TP size ({tp_size}) must equal world size ({world_size})" |
There was a problem hiding this comment.
Avoid assert for runtime config validation.
assert can be stripped with -O, allowing invalid CP/TP sizing to proceed and fail later. Prefer an explicit check.
🔧 Proposed fix
- assert cp_size * tp_size == world_size, \
- f"CP size ({cp_size}) * TP size ({tp_size}) must equal world size ({world_size})"
+ if cp_size * tp_size != world_size:
+ raise ValueError(
+ f"CP size ({cp_size}) * TP size ({tp_size}) must equal world size ({world_size})"
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| assert cp_size * tp_size == world_size, \ | |
| f"CP size ({cp_size}) * TP size ({tp_size}) must equal world size ({world_size})" | |
| if cp_size * tp_size != world_size: | |
| raise ValueError( | |
| f"CP size ({cp_size}) * TP size ({tp_size}) must equal world size ({world_size})" | |
| ) |
🤖 Prompt for AI Agents
In `@benchmarks/cp/benchmark_kda_cp8_vs_cp2tp.py` around lines 172 - 173, The
runtime check using assert on cp_size * tp_size is unsafe because asserts can be
disabled; replace it with an explicit validation that checks if cp_size *
tp_size != world_size and then raise a clear exception (e.g., ValueError) or
call sys.exit with the same descriptive message referencing cp_size, tp_size,
and world_size so invalid CP/TP sizing fails immediately and deterministically.
- Fuse forward and backward kernels for better performance - Fix IMA (Illegal Memory Access) issue in long context scenarios
7c66d96 to
9b2af83
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
fla/ops/cp/chunk_delta_h.py (1)
1051-1059:⚠️ Potential issue | 🟠 MajorGuard
cu_seqlensslicing when it can be None.The fixed-length CP path allows
cu_seqlens=None; slicing will now raise. Preserve previous behavior by passingNonethrough.🛠️ Suggested fix
- pre_process_fwd_kernel_merged[grid]( + local_cu_seqlens = cu_seqlens[-2:] if cu_seqlens is not None else None + pre_process_fwd_kernel_merged[grid]( k=k, v=u, w=w, g=g, gk=gk, hm=hm, - cu_seqlens=cu_seqlens[-2:], + cu_seqlens=local_cu_seqlens, T=T, H=H, K=K, V=V, BT=BT, BK1=BK, USE_EXP2=use_exp2, BLOCK_SIZE=BLOCK_SIZE, )- pre_process_bwd_kernel_merged[grid]( + local_cu_seqlens = cu_seqlens[:2] if cu_seqlens is not None else None + pre_process_bwd_kernel_merged[grid]( q=q, k=k, w=w, g=g, gk=gk, do=do, dhm=dhm, dv=dv, - cu_seqlens=cu_seqlens[:2], + cu_seqlens=local_cu_seqlens, scale=scale, T=T, H=H, K=K, V=V, BT=BT, BK1=BK, USE_EXP2=use_exp2, BLOCK_SIZE=BLOCK_SIZE, )Also applies to: 1120-1129
🤖 Fix all issues with AI agents
In `@benchmarks/cp/benchmark_chunk_delta_h_kernels.py`:
- Around line 110-127: The benchmark_all_kernels function runs kernels that
hardcode i_n = 0 (see generate_chunk_delta_h_cache kernels), so if B > 1 the
benchmark only measures the first batch and yields misleading results; add an
explicit guard at the start of benchmark_all_kernels (before create_tensors)
that checks B and either assert/raise if B != 1 or normalize B to 1 for these
single-batch kernels (or skip running these kernels with a clear warning),
referencing benchmark_all_kernels, generate_chunk_delta_h_cache, create_tensors
and the i_n usage so reviewers can locate the change.
- Around line 203-205: The grid sizing uses triton.cdiv(V + K, BLOCK_SIZE) which
undercounts blocks because the kernel expects separate block counts; change the
calculation for grid_merged (and the other occurrences at lines ~333-335) to sum
triton.cdiv(V, BLOCK_SIZE) + triton.cdiv(K, BLOCK_SIZE) for the first dimension
so the kernel's mapping (i_col, i_k_col = i_col - tl.cdiv(V, BLOCK_SIZE)) is
valid; update any related variables that build grid_merged accordingly
(referencing BLOCK_SIZE, grid_merged, V, K, i_col, i_k_col, tl.cdiv) so the K
portion is fully scheduled.
In `@fla/ops/cp/chunk_delta_h.py`:
- Around line 1049-1050: The grid sizing is under-allocating K tiles: change the
merged-kernel grid computation from using triton.cdiv(V + K, BLOCK_SIZE) to
summing separate cdivs so both V and K remainders are covered (e.g., grid =
(triton.cdiv(V, BLOCK_SIZE) + triton.cdiv(K, BLOCK_SIZE), H)); update the same
pattern at the other occurrence around lines 1118-1119, keeping BLOCK_SIZE, grid
and the kernel logic that uses is_h_part = i_col * BLOCK_SIZE < V intact so the
kernel can correctly partition V and K blocks.
- Around line 893-897: The merged backward kernel pre_process_bwd_kernel_merged
differs from pre_process_bwd_kernel_stage1 by applying exp2 when USE_EXP2 is
true to GK (b_gk_last1 and other GK uses); if backward parity with
pre_process_bwd_kernel_stage1 is required, remove the USE_EXP2 conditional for
GK operations in pre_process_bwd_kernel_merged so GK always uses exp (i.e.,
replace conditional exp2/exp usages on b_gk_* with a single exp call); otherwise
confirm intent to mirror the forward merged kernel and leave USE_EXP2 in
place—update tests or a comment clarifying which parity (stage1 vs
forward-merged) is intended.
| def benchmark_all_kernels(B, T, H, K, V): | ||
| """Benchmark all 6 kernels from generate_chunk_delta_h_cache().""" | ||
| print(f"\n{'=' * 80}") | ||
| print("Benchmarking chunk_delta_h kernels (using triton.testing.do_bench)") | ||
| print(f"{'=' * 80}") | ||
| print("Configuration:") | ||
| print(f" Batch size: {B}") | ||
| print(f" Heads: {H}") | ||
| print(f" Head dim (K): {K}") | ||
| print(f" Head dim (V): {V}") | ||
| print(f" SeqLen per rank: {T}") | ||
| print(f" Dtype: {DTYPE}") | ||
| print(f" Device: {device}") | ||
| print(f"{'=' * 80}\n") | ||
|
|
||
| # Create tensors | ||
| tensors = create_tensors(B, T, H, K, V, device, DTYPE) | ||
| BT = 64 |
There was a problem hiding this comment.
Guard batch size for single-batch kernels.
The kernels hardcode i_n = 0, so B > 1 benchmarks only the first batch. Add an explicit guard to avoid misleading results.
🛠️ Suggested fix
def benchmark_all_kernels(B, T, H, K, V):
"""Benchmark all 6 kernels from generate_chunk_delta_h_cache()."""
+ if B != 1:
+ raise ValueError("This benchmark assumes batch size = 1 (kernels use i_n=0).")🤖 Prompt for AI Agents
In `@benchmarks/cp/benchmark_chunk_delta_h_kernels.py` around lines 110 - 127, The
benchmark_all_kernels function runs kernels that hardcode i_n = 0 (see
generate_chunk_delta_h_cache kernels), so if B > 1 the benchmark only measures
the first batch and yields misleading results; add an explicit guard at the
start of benchmark_all_kernels (before create_tensors) that checks B and either
assert/raise if B != 1 or normalize B to 1 for these single-batch kernels (or
skip running these kernels with a clear warning), referencing
benchmark_all_kernels, generate_chunk_delta_h_cache, create_tensors and the i_n
usage so reviewers can locate the change.
Fuse forward and backward kernels for better performance
Fix IMA (Illegal Memory Access) issue in long context scenarios
Summary by CodeRabbit
New Features
Performance Improvements
Stability / Bug Fixes
Tests