Conversation
Summary of ChangesHello @xueweilnvidia, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances FlashInfer's capabilities by integrating parallel attention mechanisms. It provides a robust framework for distributing attention computations across multiple devices using Ulysses and Ring parallelism, which is crucial for scaling up large language models. The changes introduce core components for managing attention operations, configuring parallel strategies, and handling the necessary inter-device communication and data transformations, thereby improving efficiency and enabling larger model training. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughIntroduces a distributed parallel attention framework with a registry-driven backend architecture. Adds AttentionOpManager for pluggable attention implementations, ParallelAttention wrapper applying Ulysses and Ring distributed strategies, supporting FlashAttn3 and CutlassFmha backends. Includes configuration classes for uneven and variable-length parallelism, wrapper decorators for data distribution and output aggregation, utility functions for tensor layout conversion and sequence splitting, and comprehensive distributed tests. Changes
Sequence DiagramsequenceDiagram
actor User
participant PA as ParallelAttention<br/>(Orchestrator)
participant UW as ulysses_wrapper<br/>(Decorator)
participant RW as ring_wrapper<br/>(Decorator)
participant AOM as AttentionOpManager<br/>(Backend Registry)
participant Backend as FlashAttn3/<br/>CutlassFmha
participant Dist as torch.distributed<br/>(Communication)
User->>PA: run(query, key, value,<br/>tensor_layout, ...)
PA->>UW: call wrapped function
UW->>UW: validate inputs & layout
UW->>Dist: ulysses_a2a_in()<br/>(all-to-all exchange)
Dist-->>UW: distributed Q/K/V
UW->>RW: call wrapped function
RW->>RW: prepare ring topology
loop For each ring iteration
RW->>Dist: ring P2P communicate()<br/>(KV exchange)
Dist-->>RW: received K/V
RW->>AOM: get_impl(attn_type)
AOM-->>RW: backend instance
RW->>Backend: __call__(Q, K, V)
Backend-->>RW: local attention output
RW->>RW: aggregate outputs &<br/>softmax_lse corrections
end
RW-->>UW: aggregated output
UW->>Dist: ulysses_a2a_out()<br/>(reorder to original)
Dist-->>UW: reordered output
UW-->>PA: final output
PA-->>User: result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can scan for known vulnerabilities in your dependencies using OSV Scanner.OSV Scanner will automatically detect and report security vulnerabilities in your project's dependencies. No additional configuration is required. |
There was a problem hiding this comment.
Code Review
This pull request introduces a parallel attention wrapper for flashinfer, implementing both Ulysses and Ring parallelism. The overall structure is well-designed, with clear separation of concerns for configuration, operations, and parallel wrappers. The addition of comprehensive tests covering various scenarios is also a great strength. I've identified a critical issue regarding the use of __del__ for resource cleanup, which could lead to instability. I've also made several suggestions to improve code clarity and robustness, particularly around floating-point arithmetic for integer-based calculations and simplifying some convoluted logic. Addressing these points will make the implementation more robust and maintainable.
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/parallel_attention/attention_ops.py`:
- Around line 62-84: The __call__ method currently accepts attn_mask but ignores
it; update the implementation in attention_ops.AttentionOp.__call__ to
explicitly reject non-null attn_mask by raising a clear error (e.g., raise
NotImplementedError or ValueError) when attn_mask is not None so callers aren't
silently given unmasked attention; locate the check near the existing
tensor_layout and flash_attn_interface guards and add the attn_mask validation
there (or alternatively wire attn_mask through to the flash_attn_interface call
if you intend to support it).
In `@flashinfer/parallel_attention/parallel_config.py`:
- Around line 382-392: from_dict currently constructs a new AttnParallelConfig
via cls(**config_dict) but __init__ ignores kwargs due to the singleton pattern,
so the supplied sizes are never applied; instead retrieve the singleton instance
(e.g., via the class's singleton getter such as
AttnParallelConfig.get_instance()), call its set_config(...) with config_dict to
apply the provided values, and return that singleton instance; update from_dict
to use the singleton getter and set_config rather than instantiating with
cls(**config_dict).
In `@flashinfer/parallel_attention/parallel_wrapper.py`:
- Around line 393-548: The wrapper always forces kwargs["return_lse"]=True but
discards that result; capture the original return_lse flag before overwriting
(e.g., orig_return_lse = kwargs.get("return_lse", False)), then continue to set
kwargs["return_lse"]=True for internal calls, and at the end of
ring_wrapper.wrapper return (out, softmax_lse) when orig_return_lse is True
otherwise return out; ensure you reference the local symbols wrapper, func,
kwargs, return_lse, out, and softmax_lse when making this change so callers
requesting LSE get the tuple back.
In `@flashinfer/parallel_attention/utils.py`:
- Around line 69-104: The code currently treats seq_len_list like a tensor which
breaks when a Python list is passed and yields tensor-valued pad sizes;
convert/normalize seq_len_list to integer scalars before doing divisions and
padding. Specifically, in the block computing seq_len_padded,
total_seq_len_padded and seq_len_padded_cur_rank (symbols: seq_len_padded,
total_seq_len_padded, seq_len_padded_cur_rank) coerce seq_len_list into a CPU
integer tensor or compute per-element integer math (e.g. integer ceil via (n +
world_size - 1) // world_size) so seq_len_padded_cur_rank and pad_len are plain
Python ints, and ensure pad_shape uses int values before calling torch.zeros;
update downstream uses (the for-loop and final padding) to rely on those integer
values.
In `@tests/attention/test_parallel_attention.py`:
- Around line 26-33: In the dist_setup fixture, before calling
dist.init_process_group, import and call the flashinfer.utils helper(s) (e.g.,
get_compute_capability and is_sm90a_supported or the appropriate support-check
helper) and call pytest.skip with a clear message if the current GPU arch does
not meet FlashAttn3/BF16 requirements; ensure this skip check runs before any
distributed initialization so dist.init_process_group("nccl") is only called on
supported architectures and keep the fixture name dist_setup unchanged.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
flashinfer/parallel_attention/attention_ops.pyflashinfer/parallel_attention/parallel_attention.pyflashinfer/parallel_attention/parallel_config.pyflashinfer/parallel_attention/parallel_wrapper.pyflashinfer/parallel_attention/utils.pytests/attention/test_parallel_attention.py
There was a problem hiding this comment.
Actionable comments posted: 8
♻️ Duplicate comments (2)
flashinfer/parallel_attention/parallel_wrapper.py (1)
398-551:return_lseis now consistently rejected — past concern is addressed.The
ulysses_wrapper(line 263-266) explicitly raises ifreturn_lse=True, and thering_wrapperinternally forcesreturn_lse=Truefor softmax merging but only returnsout. This is internally consistent. The previous review concern about discarding LSE is resolved by the user-facing rejection.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/parallel_wrapper.py` around lines 398 - 551, Summary: ring_wrapper forces kwargs["return_lse"]=True which collides with ulysses_wrapper's public rejection of return_lse; fix by using a private internal flag instead. Change ring_wrapper to set a private key (e.g., kwargs["_internal_return_lse"]=True) instead of "return_lse", and update ulysses_wrapper to check for and consume kwargs.pop("_internal_return_lse", False) when deciding to compute/return LSE (keep public rejection for "return_lse" intact); reference ring_wrapper and ulysses_wrapper and the "return_lse" flag when making this change.flashinfer/parallel_attention/utils.py (1)
73-106:⚠️ Potential issue | 🟡 MinorTensor values used where
torch.zerosexpects plain ints — fragile shape construction.
seq_len_padded_cur_rank(line 75-77) andpad_len(line 100) are 0-dtorch.Tensors. Inserting them intopad_shape(line 102) and comparing withres.shape[chunk_dim](line 99) works in some PyTorch versions but is fragile and can break with stricter type checking. Convert to Python ints.🛠️ Suggested fix
seq_len_padded_cur_rank = ( (total_seq_len_padded + world_size - 1) // world_size - ).to(torch.int32) + ).to(torch.int32).item() ... - if res.shape[chunk_dim] < seq_len_padded_cur_rank: - pad_len = seq_len_padded_cur_rank - res.shape[chunk_dim] + if res.shape[chunk_dim] < int(seq_len_padded_cur_rank): + pad_len = int(seq_len_padded_cur_rank) - res.shape[chunk_dim] pad_shape = list(res.shape) pad_shape[chunk_dim] = pad_len🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/utils.py` around lines 73 - 106, seq_len_padded_cur_rank and pad_len are 0-d torch.Tensors and are being used as Python ints (in the comparison res.shape[chunk_dim] < seq_len_padded_cur_rank and when placed into pad_shape), which is fragile; convert seq_len_padded_cur_rank and pad_len to Python ints (e.g., int(seq_len_padded_cur_rank) or seq_len_padded_cur_rank.item()) before comparisons and before inserting into pad_shape, then use those ints when constructing the zero padding passed to torch.zeros so pad_shape contains plain ints and the comparison against res.shape[chunk_dim] is between ints.
🧹 Nitpick comments (5)
tests/attention/test_parallel_attention.py (2)
135-142: SingletonAttnParallelConfigshared across tests — consider adding cleanup.
AttnParallelConfig()returns the same singleton instance for all tests. Since different tests callset_configwith differentulysses_size/ring_size, test order matters. If one test fails mid-way, the config remains in an unexpected state for subsequent tests. Consider adding a fixture that callsclear_instance()or resets config after each test to improve isolation.♻️ Example cleanup fixture
`@pytest.fixture`(autouse=True) def reset_attn_config(): yield AttnParallelConfig.clear_instance()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_parallel_attention.py` around lines 135 - 142, Tests instantiate AttnParallelConfig() which is a singleton and tests call set_config(...) causing shared mutable state across tests; update the test suite to reset the singleton after each test by adding an autouse fixture that calls AttnParallelConfig.clear_instance() (or call clear_instance() in teardown) so each test gets a fresh config; reference AttnParallelConfig, set_config, and clear_instance to locate the change and ensure the fixture runs after each test to avoid order-dependent failures.
51-57: Good addition of the hardware-skip fixture, but fragilecallspecaccess.The
skip_if_unsupportedfixture correctly satisfies the coding guideline for skipping unsupported GPU architectures. However,request.node.callspecraisesAttributeErrorif a non-parametrized test is ever added to this file. A safer access pattern:♻️ Safer pattern
`@pytest.fixture`(autouse=True) def skip_if_unsupported(request): """Skip test if the attention backend requires unsupported hardware.""" - attn_type = request.node.callspec.params.get("attn_type", None) + callspec = getattr(request.node, "callspec", None) + attn_type = callspec.params.get("attn_type", None) if callspec else None if attn_type == "flash-attn3" and not is_sm90a_supported(torch.device("cuda")): cc = get_compute_capability(torch.device("cuda")) pytest.skip(f"flash-attn3 requires SM90a+, got {cc}")As per coding guidelines:
tests/**/*.py: Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_parallel_attention.py` around lines 51 - 57, The fixture skip_if_unsupported currently assumes request.node.callspec exists and will raise AttributeError for non-parametrized tests; update it to safely retrieve the callspec (e.g., callspec = getattr(request.node, "callspec", None) or similar) and only proceed to read attn_type from callspec.params if callspec is not None; keep the existing checks that use get_compute_capability and is_sm90a_supported and the attn_type == "flash-attn3" condition, and call pytest.skip with the same message when appropriate.flashinfer/parallel_attention/parallel_config.py (1)
84-99: Cache check is placed after partialmesh_dims/mesh_sizesconstruction.The cache lookup at line 97 runs after the "redundant" dimension is conditionally appended (lines 87-95). On cache hit, the partially built
mesh_dims/mesh_sizesare wasted, and on cache miss, the dims built before the check are then extended further. Moving the cache check earlier (before building any dims) would be cleaner and avoid potential inconsistencies if the caching key doesn't capture all mesh topology aspects.♻️ Suggested reorder
+ if str(self) in self._cached_device_mesh: + self._device_mesh = self._cached_device_mesh[str(self)] + return self._device_mesh + mesh_dims = [] mesh_sizes = [] if world_size != total_parallel_size: mesh_dims.append("redundant") mesh_sizes.append(world_size // total_parallel_size) ... - if str(self) in self._cached_device_mesh: - self._device_mesh = self._cached_device_mesh[str(self)] - return self._device_mesh - if self._ring_size > 1:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/parallel_config.py` around lines 84 - 99, The cache lookup happens after partially mutating mesh_dims/mesh_sizes; move the check so self._cached_device_mesh lookup (using str(self) as key) occurs before any mutation to mesh_dims or mesh_sizes to avoid wasting work or inconsistent state. Specifically, in parallel_config.py ensure you test "if str(self) in self._cached_device_mesh: self._device_mesh = self._cached_device_mesh[str(self)]; return self._device_mesh" at the start of the method (or before modifying mesh_dims/mesh_sizes and before adding the "redundant" branch), so the functions and attributes mesh_dims, mesh_sizes, self._cached_device_mesh, __str__, and _device_mesh are used in that order to short-circuit when a cached device mesh exists.flashinfer/parallel_attention/parallel_wrapper.py (2)
10-10:tensor_layoutparameter is accepted but never used inall_to_all.The
tensor_layoutparameter is declared in the signature and passed by all callers but never referenced in the function body. Either remove it or document why it's reserved for future use.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/parallel_wrapper.py` at line 10, The function all_to_all currently accepts a tensor_layout parameter that is never used; either remove tensor_layout from the all_to_all signature (and update all callers that pass it) or explicitly document and/or implement its intended behavior. Locate the all_to_all definition and every call site that passes tensor_layout, then either (A) delete tensor_layout from the signature and remove the argument from callers, or (B) add a short explanatory comment in parallel_wrapper.all_to_all and/or implement handling for tensor_layout (e.g., selecting layout-specific scatter/gather behavior) so the parameter is justified.
430-517: Ring loop modifieskwargsdict in-place across iterations — potential stale key accumulation.Lines 475-476, 484-489, and 491 set keys directly on
kwargs(e.g.,cur_rank_cu_seqlens_q,cur_rank_cu_seqlens_k,return_lse). Sincekwargsis the same dict object across allring_sizeiterations, keys set in one iteration persist into the next. This is mostly harmless here because the same keys are overwritten each iteration, but if the varlen branch is entered in one iteration and not another (e.g. due to conditional logic changes), stalecu_seqlensfrom a previous iteration could leak through.Consider making a shallow copy of
kwargsat the start of each iteration for safety.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/parallel_wrapper.py` around lines 430 - 517, The ring loop mutates the shared kwargs dict across iterations (keys like cur_rank_cu_seqlens_q, cur_rank_cu_seqlens_k, cur_rank_max_seqlen_q, cur_rank_max_seqlen_k, return_lse), which can leak stale values between iterations; fix by creating a shallow copy of kwargs at the top of each loop iteration (e.g., local_kwargs = kwargs.copy()), set/modify keys on that copy (local_kwargs["..."] = ...), and pass local_kwargs to the call to func(self, query, kv_inputs[0], kv_inputs[1], tensor_layout, attn_mask, **local_kwargs) so the original kwargs is not mutated across iterations.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/parallel_attention/__init__.py`:
- Around line 1-13: Top-level package __init__ is missing re-exports for the
parallel_attention API; add explicit imports for ParallelAttention,
AttnParallelConfig, UnevenCPConfig, VarlenCPConfig, and split_varlen_input into
the top-level module (using the same "X as X" pattern) so users can access
flashinfer.ParallelAttention, flashinfer.AttnParallelConfig,
flashinfer.UnevenCPConfig, flashinfer.VarlenCPConfig, and
flashinfer.split_varlen_input; import these names from the parallel_attention
submodule and include them in the top-level __all__ list.
In `@flashinfer/parallel_attention/attention_ops.py`:
- Around line 44-53: The get_impl method assumes a class attribute cls.attn_type
that doesn't exist on AttentionOpManager; update get_impl to safely resolve a
default by using getattr(cls, "attn_type", None) when name is None and if that
still yields None raise a clear ValueError asking the caller to supply a name;
keep the rest of the lookup logic (use cls._attn_registry.get(name), raise if
attn_class is None, and return attn_class()) so callers that pass name (e.g.,
ParallelAttention) continue to work and the latent AttributeError is avoided.
- Around line 119-129: Remove the non-existent return_attn_probs argument when
calling flash_attn_func and flash_attn_varlen_func (they always return (out,
softmax_lse)); call those functions without that kwarg and treat their return as
a 2-tuple. Also simplify the post-call handling: do not pre-wrap output into a
tuple before the isinstance(output, tuple) check—after calling
flash_attn_func/flash_attn_varlen_func, squeeze the returned tensors (use
torch.squeeze on output[0] and output[1] when the function returns a tuple) and
then construct the final (output, lse) tuple once; update places referencing
output and lse accordingly to remove the redundant pack/unpack cycle.
In `@flashinfer/parallel_attention/parallel_config.py`:
- Around line 211-235: ring_ranks() and ulysses_ranks() unconditionally index
instance._device_mesh and crash when that parallelism is disabled; update each
to mirror the guard used in ring_rank()/ulysses_rank(): fetch instance =
cls.get_instance(), if instance._ring_size <= 1 return [0] for ring_ranks (and
if instance._ulysses_size <= 1 return [0] for ulysses_ranks), otherwise return
instance._device_mesh["ring"].mesh.flatten().tolist() (and
instance._device_mesh["ulysses"].mesh.flatten().tolist() respectively).
- Around line 280-285: The validation in set_config uses
torch.cuda.device_count() to compare total_size (ulysses_size * ring_size),
which is wrong for multi-node setups; replace that local-GPU check by using
torch.distributed.get_world_size() (or remove the check entirely) so total_size
is compared against the global world size, or simply rely on the existing
check_parallel_size() call; update the logic in set_config (referencing
total_size, set_config, torch.cuda.device_count(), and check_parallel_size()) to
use dist.get_world_size() or drop the local device_count check.
- Around line 406-408: The code uses torch.distributed.get_rank() as a CUDA
device index (e.g., rank = torch.distributed.get_rank(); device =
torch.device(f"cuda:{rank}")), which fails on multi-node setups; change these
sites (the occurrence in ParallelConfig and the ones in
VarlenCPConfig.set_ulysses_varlen_config and set_ring_varlen_config) to use the
local device instead — e.g., obtain the already-set CUDA device via
torch.cuda.current_device() (or use a provided local_rank) and construct the
device from that value so the CUDA index maps to the local GPU rather than the
global rank.
In `@flashinfer/parallel_attention/parallel_wrapper.py`:
- Around line 222-252: The code uses torch.distributed.P2POp with the group_peer
argument and torch.distributed.batch_isend_irecv which require PyTorch >= 2.6.0;
update the project dependency (pyproject.toml or requirements.txt) to require
torch>=2.6.0 and add a runtime guard in parallel_wrapper.py that checks
torch.__version__ (or tries to access torch.distributed.P2POp and
batch_isend_irecv) and raises a clear error if the installed PyTorch is older;
reference the use sites torch.distributed.P2POp and
torch.distributed.batch_isend_irecv (the block building send_op/recv_op based on
rank) so the failure happens early with a helpful message.
In `@flashinfer/parallel_attention/utils.py`:
- Around line 24-36: The comments in convert_output_layout are swapped: for the
branch where src_layout == "HND" and dst_layout == "NHD" update the comment to
"[H, S, D] -> [S, H, D]" and for the branch where src_layout == "NHD" and
dst_layout == "HND" update the comment to "[S, H, D] -> [H, S, D] (the
permute(1, 0, 2) calls are correct so only change the comment text to match the
actual layout conversion) in the convert_output_layout function.
---
Duplicate comments:
In `@flashinfer/parallel_attention/parallel_wrapper.py`:
- Around line 398-551: Summary: ring_wrapper forces kwargs["return_lse"]=True
which collides with ulysses_wrapper's public rejection of return_lse; fix by
using a private internal flag instead. Change ring_wrapper to set a private key
(e.g., kwargs["_internal_return_lse"]=True) instead of "return_lse", and update
ulysses_wrapper to check for and consume kwargs.pop("_internal_return_lse",
False) when deciding to compute/return LSE (keep public rejection for
"return_lse" intact); reference ring_wrapper and ulysses_wrapper and the
"return_lse" flag when making this change.
In `@flashinfer/parallel_attention/utils.py`:
- Around line 73-106: seq_len_padded_cur_rank and pad_len are 0-d torch.Tensors
and are being used as Python ints (in the comparison res.shape[chunk_dim] <
seq_len_padded_cur_rank and when placed into pad_shape), which is fragile;
convert seq_len_padded_cur_rank and pad_len to Python ints (e.g.,
int(seq_len_padded_cur_rank) or seq_len_padded_cur_rank.item()) before
comparisons and before inserting into pad_shape, then use those ints when
constructing the zero padding passed to torch.zeros so pad_shape contains plain
ints and the comparison against res.shape[chunk_dim] is between ints.
---
Nitpick comments:
In `@flashinfer/parallel_attention/parallel_config.py`:
- Around line 84-99: The cache lookup happens after partially mutating
mesh_dims/mesh_sizes; move the check so self._cached_device_mesh lookup (using
str(self) as key) occurs before any mutation to mesh_dims or mesh_sizes to avoid
wasting work or inconsistent state. Specifically, in parallel_config.py ensure
you test "if str(self) in self._cached_device_mesh: self._device_mesh =
self._cached_device_mesh[str(self)]; return self._device_mesh" at the start of
the method (or before modifying mesh_dims/mesh_sizes and before adding the
"redundant" branch), so the functions and attributes mesh_dims, mesh_sizes,
self._cached_device_mesh, __str__, and _device_mesh are used in that order to
short-circuit when a cached device mesh exists.
In `@flashinfer/parallel_attention/parallel_wrapper.py`:
- Line 10: The function all_to_all currently accepts a tensor_layout parameter
that is never used; either remove tensor_layout from the all_to_all signature
(and update all callers that pass it) or explicitly document and/or implement
its intended behavior. Locate the all_to_all definition and every call site that
passes tensor_layout, then either (A) delete tensor_layout from the signature
and remove the argument from callers, or (B) add a short explanatory comment in
parallel_wrapper.all_to_all and/or implement handling for tensor_layout (e.g.,
selecting layout-specific scatter/gather behavior) so the parameter is
justified.
- Around line 430-517: The ring loop mutates the shared kwargs dict across
iterations (keys like cur_rank_cu_seqlens_q, cur_rank_cu_seqlens_k,
cur_rank_max_seqlen_q, cur_rank_max_seqlen_k, return_lse), which can leak stale
values between iterations; fix by creating a shallow copy of kwargs at the top
of each loop iteration (e.g., local_kwargs = kwargs.copy()), set/modify keys on
that copy (local_kwargs["..."] = ...), and pass local_kwargs to the call to
func(self, query, kv_inputs[0], kv_inputs[1], tensor_layout, attn_mask,
**local_kwargs) so the original kwargs is not mutated across iterations.
In `@tests/attention/test_parallel_attention.py`:
- Around line 135-142: Tests instantiate AttnParallelConfig() which is a
singleton and tests call set_config(...) causing shared mutable state across
tests; update the test suite to reset the singleton after each test by adding an
autouse fixture that calls AttnParallelConfig.clear_instance() (or call
clear_instance() in teardown) so each test gets a fresh config; reference
AttnParallelConfig, set_config, and clear_instance to locate the change and
ensure the fixture runs after each test to avoid order-dependent failures.
- Around line 51-57: The fixture skip_if_unsupported currently assumes
request.node.callspec exists and will raise AttributeError for non-parametrized
tests; update it to safely retrieve the callspec (e.g., callspec =
getattr(request.node, "callspec", None) or similar) and only proceed to read
attn_type from callspec.params if callspec is not None; keep the existing checks
that use get_compute_capability and is_sm90a_supported and the attn_type ==
"flash-attn3" condition, and call pytest.skip with the same message when
appropriate.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
flashinfer/parallel_attention/__init__.pyflashinfer/parallel_attention/attention_ops.pyflashinfer/parallel_attention/parallel_attention.pyflashinfer/parallel_attention/parallel_config.pyflashinfer/parallel_attention/parallel_wrapper.pyflashinfer/parallel_attention/utils.pytests/attention/test_parallel_attention.py
| from .parallel_attention import ParallelAttention as ParallelAttention | ||
| from .parallel_config import AttnParallelConfig as AttnParallelConfig | ||
| from .parallel_config import UnevenCPConfig as UnevenCPConfig | ||
| from .parallel_config import VarlenCPConfig as VarlenCPConfig | ||
| from .utils import split_varlen_input as split_varlen_input | ||
|
|
||
| __all__ = [ | ||
| "AttnParallelConfig", | ||
| "ParallelAttention", | ||
| "UnevenCPConfig", | ||
| "VarlenCPConfig", | ||
| "split_varlen_input", | ||
| ] |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if parallel_attention is exported from the top-level flashinfer __init__
echo "=== Searching for parallel_attention in top-level flashinfer/__init__.py ==="
rg -n 'parallel_attention' flashinfer/__init__.py
echo ""
echo "=== Showing full content of flashinfer/__init__.py ==="
cat -n flashinfer/__init__.pyRepository: flashinfer-ai/flashinfer
Length of output: 8794
Add parallel_attention re-exports to top-level flashinfer/__init__.py.
The subpackage __init__.py is well-structured with explicit re-exports using the X as X pattern. However, parallel_attention classes and utilities are not re-exported from the top-level flashinfer/__init__.py, making them inaccessible at package level (e.g., flashinfer.ParallelAttention). Add the following to the top-level module to align with the pattern used for other subpackages:
from .parallel_attention import ParallelAttention as ParallelAttention
from .parallel_attention import AttnParallelConfig as AttnParallelConfig
from .parallel_attention import UnevenCPConfig as UnevenCPConfig
from .parallel_attention import VarlenCPConfig as VarlenCPConfig
from .parallel_attention import split_varlen_input as split_varlen_input🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/parallel_attention/__init__.py` around lines 1 - 13, Top-level
package __init__ is missing re-exports for the parallel_attention API; add
explicit imports for ParallelAttention, AttnParallelConfig, UnevenCPConfig,
VarlenCPConfig, and split_varlen_input into the top-level module (using the same
"X as X" pattern) so users can access flashinfer.ParallelAttention,
flashinfer.AttnParallelConfig, flashinfer.UnevenCPConfig,
flashinfer.VarlenCPConfig, and flashinfer.split_varlen_input; import these names
from the parallel_attention submodule and include them in the top-level __all__
list.
| @classmethod | ||
| def get_impl(cls, name=None): | ||
| if name is None: | ||
| name = cls.attn_type | ||
| attn_class = cls._attn_registry.get(name) | ||
| if attn_class is None: | ||
| raise ValueError( | ||
| f"Attention function {name} not found in registry" | ||
| ) | ||
| return attn_class() # Create and return an instance |
There was a problem hiding this comment.
get_impl fallback references undefined cls.attn_type.
When name is None, line 47 falls back to cls.attn_type, but AttentionOpManager never defines an attn_type class attribute. This would raise AttributeError at runtime. Currently ParallelAttention always passes a name, but this is a latent bug.
🛠️ Suggested fix
`@classmethod`
def get_impl(cls, name=None):
if name is None:
- name = cls.attn_type
+ raise ValueError("Attention type name must be provided")
attn_class = cls._attn_registry.get(name)🧰 Tools
🪛 Ruff (0.15.2)
[warning] 50-52: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/parallel_attention/attention_ops.py` around lines 44 - 53, The
get_impl method assumes a class attribute cls.attn_type that doesn't exist on
AttentionOpManager; update get_impl to safely resolve a default by using
getattr(cls, "attn_type", None) when name is None and if that still yields None
raise a clear ValueError asking the caller to supply a name; keep the rest of
the lookup logic (use cls._attn_registry.get(name), raise if attn_class is None,
and return attn_class()) so callers that pass name (e.g., ParallelAttention)
continue to work and the latent AttributeError is avoided.
| @classmethod | ||
| def ring_ranks(cls) -> List[int]: | ||
| """Get all the local ranks in the ring attention group.""" | ||
| instance = cls.get_instance() | ||
| return instance._device_mesh["ring"].mesh.flatten().tolist() | ||
|
|
||
| @classmethod | ||
| def ulysses_rank(cls) -> int: | ||
| """Get the local rank in the Ulysses group. | ||
|
|
||
| Returns: | ||
| int: Local rank in the Ulysses group (0 to ulysses_size-1) | ||
| """ | ||
| instance = cls.get_instance() | ||
| if instance._ulysses_size <= 1: | ||
| return 0 | ||
| rank = cls.get_local_rank("ulysses") | ||
| logger.debug(f"[{cls.__name__}] Ulysses rank: {rank}") | ||
| return rank | ||
|
|
||
| @classmethod | ||
| def ulysses_ranks(cls) -> List[int]: | ||
| """Get all the local ranks in the Ulysses group.""" | ||
| instance = cls.get_instance() | ||
| return instance._device_mesh["ulysses"].mesh.flatten().tolist() |
There was a problem hiding this comment.
ring_ranks() and ulysses_ranks() crash when the corresponding parallelism is disabled.
ring_ranks() unconditionally accesses instance._device_mesh["ring"], which raises KeyError if ring_size <= 1 (mesh won't have a "ring" dimension). Same for ulysses_ranks() with "ulysses". Add a guard matching the pattern used in ring_rank() / ulysses_rank().
🐛 Proposed fix
`@classmethod`
def ring_ranks(cls) -> List[int]:
"""Get all the local ranks in the ring attention group."""
instance = cls.get_instance()
+ if instance._ring_size <= 1 or instance._device_mesh is None:
+ return [0]
return instance._device_mesh["ring"].mesh.flatten().tolist()
...
`@classmethod`
def ulysses_ranks(cls) -> List[int]:
"""Get all the local ranks in the Ulysses group."""
instance = cls.get_instance()
+ if instance._ulysses_size <= 1 or instance._device_mesh is None:
+ return [0]
return instance._device_mesh["ulysses"].mesh.flatten().tolist()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/parallel_attention/parallel_config.py` around lines 211 - 235,
ring_ranks() and ulysses_ranks() unconditionally index instance._device_mesh and
crash when that parallelism is disabled; update each to mirror the guard used in
ring_rank()/ulysses_rank(): fetch instance = cls.get_instance(), if
instance._ring_size <= 1 return [0] for ring_ranks (and if
instance._ulysses_size <= 1 return [0] for ulysses_ranks), otherwise return
instance._device_mesh["ring"].mesh.flatten().tolist() (and
instance._device_mesh["ulysses"].mesh.flatten().tolist() respectively).
| if rank % 2 == 0: | ||
| send_op = torch.distributed.P2POp( | ||
| torch.distributed.isend, | ||
| send_tensor, | ||
| group_peer=send_dst, | ||
| group=ring_group, | ||
| ) | ||
| recv_op = torch.distributed.P2POp( | ||
| torch.distributed.irecv, | ||
| recv_tensor, | ||
| group_peer=recv_src, | ||
| group=ring_group, | ||
| ) | ||
| send_recv_ops.append(send_op) | ||
| send_recv_ops.append(recv_op) | ||
| else: | ||
| recv_op = torch.distributed.P2POp( | ||
| torch.distributed.irecv, | ||
| recv_tensor, | ||
| group_peer=recv_src, | ||
| group=ring_group, | ||
| ) | ||
| send_op = torch.distributed.P2POp( | ||
| torch.distributed.isend, | ||
| send_tensor, | ||
| group_peer=send_dst, | ||
| group=ring_group, | ||
| ) | ||
| send_recv_ops.append(recv_op) | ||
| send_recv_ops.append(send_op) | ||
| send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check minimum required PyTorch version in project config
echo "=== Checking PyTorch version requirements ==="
for file in pyproject.toml setup.py setup.cfg requirements.txt requirements-dev.txt; do
if [ -f "$file" ]; then
echo "--- $file ---"
cat "$file" | grep -i torch || echo "No torch found"
fi
done
echo ""
echo "=== Searching for group_peer usage in Python files ==="
rg -n 'group_peer' --type py
echo ""
echo "=== Checking torch.distributed usage patterns ==="
rg -n 'torch\.distributed\.' --type py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 3847
🌐 Web query:
PyTorch 2.4 release notes group_peer P2POp parameter introduction
💡 Result:
group_peer was not introduced in the PyTorch 2.4 release notes / API.
- In PyTorch 2.4,
torch.distributed.P2POpis documented asP2POp(op, tensor, peer, group=None, tag=0)(nogroup_peer).[1] - Support for group-relative ranks in
P2POp/batch_isend_irecv(which is where thegroup_peerparameter comes in) is called out in the PyTorch 2.6.0 release notes as: “Supported group ranks inP2POpandbatch_isend_irecv”.[2] This is tied to PR #141054 (“Support group ranks in P2POp and batch_isend_irecv”).[3]
Sources:
[1] PyTorch 2.4 distributed docs (P2POp signature) https://pytorch.com.tw/docs/2.4/distributed.html (pytorch.com.tw)
[2] PyTorch 2.6.0 release notes (Distributed / c10d bullet) (github.com)
[3] PR #141054 (feature implementation) (github.com)
Citations:
- 1: https://pytorch.com.tw/docs/2.4/distributed.html?utm_source=openai
- 2: https://github.com/pytorch/pytorch/releases/tag/v2.6.0
- 3: [C10D] Support group ranks in P2POp and batch_isend_irecv pytorch/pytorch#141054
Ensure PyTorch 2.6.0+ is required — group_peer parameter support added in 2.6.0, not 2.4.
The group_peer parameter in torch.distributed.P2POp was introduced in PyTorch 2.6.0 (not 2.4), as part of support for group-relative ranks in batch_isend_irecv. This code will fail on any PyTorch version earlier than 2.6.0. The project does not specify a minimum PyTorch version in pyproject.toml or requirements.txt, leaving this dependency implicit.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/parallel_attention/parallel_wrapper.py` around lines 222 - 252,
The code uses torch.distributed.P2POp with the group_peer argument and
torch.distributed.batch_isend_irecv which require PyTorch >= 2.6.0; update the
project dependency (pyproject.toml or requirements.txt) to require torch>=2.6.0
and add a runtime guard in parallel_wrapper.py that checks torch.__version__ (or
tries to access torch.distributed.P2POp and batch_isend_irecv) and raises a
clear error if the installed PyTorch is older; reference the use sites
torch.distributed.P2POp and torch.distributed.batch_isend_irecv (the block
building send_op/recv_op based on rank) so the failure happens early with a
helpful message.
| def convert_output_layout(out, src_layout, dst_layout): | ||
| if src_layout == "HND" and dst_layout == "NHD": | ||
| # [S, H, D] -> [H, S, D] | ||
| out = out.permute(1, 0, 2).contiguous() | ||
| elif src_layout == "NHD" and dst_layout == "HND": | ||
| # [H, S, D] -> [S, H, D] | ||
| out = out.permute(1, 0, 2).contiguous() | ||
| else: | ||
| raise NotImplementedError( | ||
| f"Unsupported tensor layout conversion: " | ||
| f"{src_layout} -> {dst_layout}" | ||
| ) | ||
| return out |
There was a problem hiding this comment.
Comments in convert_output_layout are swapped.
Line 26 says [S, H, D] -> [H, S, D] for the HND → NHD branch, but HND is [H, S, D] so the conversion is [H, S, D] → [S, H, D]. Line 29 has the mirror mistake. The permute(1, 0, 2) code is correct — only the comments are wrong.
📝 Fix comments
if src_layout == "HND" and dst_layout == "NHD":
- # [S, H, D] -> [H, S, D]
+ # [H, S, D] -> [S, H, D]
out = out.permute(1, 0, 2).contiguous()
elif src_layout == "NHD" and dst_layout == "HND":
- # [H, S, D] -> [S, H, D]
+ # [S, H, D] -> [H, S, D]
out = out.permute(1, 0, 2).contiguous()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/parallel_attention/utils.py` around lines 24 - 36, The comments in
convert_output_layout are swapped: for the branch where src_layout == "HND" and
dst_layout == "NHD" update the comment to "[H, S, D] -> [S, H, D]" and for the
branch where src_layout == "NHD" and dst_layout == "HND" update the comment to
"[S, H, D] -> [H, S, D] (the permute(1, 0, 2) calls are correct so only change
the comment text to match the actual layout conversion) in the
convert_output_layout function.
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (9)
flashinfer/parallel_attention/utils.py (1)
22-33: Comments inconvert_output_layoutare still swapped.Line 24 says
[S, H, D] -> [H, S, D]for theHND → NHDbranch, but the conversion is actually[H, S, D] → [S, H, D]. Line 27 has the mirror error. Thepermutecalls are correct — only the comments are wrong.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/utils.py` around lines 22 - 33, The comments in convert_output_layout are reversed: for the branch src_layout == "HND" and dst_layout == "NHD" the code permutes with out.permute(1, 0, 2) which converts [H, S, D] -> [S, H, D], and for the branch src_layout == "NHD" and dst_layout == "HND" the permute converts [S, H, D] -> [H, S, D]; update the inline comments to match those actual conversions (refer to function convert_output_layout and the out.permute(1, 0, 2) calls) so the comments correctly describe the source and destination layouts.flashinfer/parallel_attention/parallel_config.py (3)
261-266:set_configstill validates againstdevice_count()instead ofworld_size.In multi-node setups,
total_sizecan legitimately exceed the local GPU count. Usedist.get_world_size()(when initialized) or remove this check sincecheck_parallel_size()already validates compatibility.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/parallel_config.py` around lines 261 - 266, The set_config validation currently compares total_size (computed from ulysses_size and ring_size) against torch.cuda.device_count(), which fails for multi-node runs; change the check to use torch.distributed.get_world_size() when the distributed process group is initialized (or remove the local-GPU check entirely since check_parallel_size() already enforces compatibility). Locate the check in set_config where total_size, ulysses_size and ring_size are computed and replace the device_count() comparison with a get_world_size() lookup guarded by torch.distributed.is_initialized(), or drop this block and rely on check_parallel_size() to validate sizes.
385-386: Global rank used as CUDA device index — breaks multi-node.
torch.device(f"cuda:{rank}")whererank = torch.distributed.get_rank()will exceed local GPU count on multi-node clusters. Usetorch.cuda.current_device()instead at all three sites (lines 385–386, 506–507, 549–550).Also applies to: 506-507, 549-550
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/parallel_config.py` around lines 385 - 386, The code uses global rank to pick a CUDA device (rank = torch.distributed.get_rank(); device = torch.device(f"cuda:{rank}")), which fails on multi-node setups; replace uses of torch.distributed.get_rank() with the local-device API and derive device from torch.cuda.current_device() (e.g., device = torch.device(f"cuda:{torch.cuda.current_device()}")) at each site mentioning rank/device (the occurrences that currently call torch.distributed.get_rank() and set device), ensuring you remove or stop using the global rank variable and use torch.cuda.current_device() (or torch.distributed.get_local_rank() if you prefer) consistently in those code paths.
202-206:ring_ranks()andulysses_ranks()still crash when parallelism is disabled.
ring_ranks()unconditionally accesses_device_mesh["ring"], which raisesKeyErrorwhenring_size <= 1. Same forulysses_ranks()with"ulysses". Add a guard matchingring_rank()/ulysses_rank().Also applies to: 222-226
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/parallel_config.py` around lines 202 - 206, ring_ranks() and ulysses_ranks() access instance._device_mesh["ring"] / ["ulysses"] unconditionally which raises KeyError when parallelism is disabled; mirror the guard used in ring_rank()/ulysses_rank() by checking the corresponding size or presence first (e.g., instance.ring_size <= 1 or "ring" not in instance._device_mesh) and return an empty list (or appropriate default) when disabled; update both ring_ranks() and ulysses_ranks() (and the duplicate block around the other function at the 222-226 region) to perform this guard before accessing _device_mesh.flashinfer/parallel_attention/attention_ops.py (3)
118-123: Redundant tuple pack/unpack in the non-varlen path.Lines 118–121 squeeze and re-wrap the FA3 output as
(output, lse), only for lines 152–155 to immediately unpack it again. Consider squeezing and assigninglsedirectly without the intermediate re-wrapping.Also applies to: 152-155
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/attention_ops.py` around lines 118 - 123, The code is unnecessarily packing the squeezed FA3 results into a tuple and immediately unpacking them later; instead of creating (output, lse) and re-unpacking, call torch.squeeze on the returned tensors and assign lse directly (e.g., lse = torch.squeeze(output[1], dim=0); output = torch.squeeze(output[0], dim=0)) so you avoid the intermediate tuple; apply the same change in the later identical block that currently repacks/unpacks output and lse to remove redundant tuple creation.
42-49:get_implfallback to undefinedcls.attn_typeis still present.When
name is None, line 45 referencescls.attn_typewhichAttentionOpManagernever defines. This would raiseAttributeErrorat runtime. Currently mitigated becauseParallelAttentionalways passes a name, but it's a latent bug.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/attention_ops.py` around lines 42 - 49, get_impl currently references an undefined class attribute (cls.attn_type) when name is None, causing an AttributeError; change the fallback to use getattr(cls, "attn_type", None) and if that returns None raise a clear ValueError instructing callers to pass a name or define AttentionOpManager.attn_type, then look up the registry and return an instance as before (i.e., update get_impl to safely obtain a name via getattr and validate it before using cls._attn_registry and instantiating the class).
94-150:return_attn_probsand several other parameters may not exist in the FA3 API.Lines 115 and 149 pass
return_attn_probs=return_lse, and lines 104/134/138 passqv,seqused_q,seqused_k,attention_chunk,sm_margin,pack_gqa— none of which are in the standard FA3flash_attn_interfacesignature. If these parameters don't exist in the installed version, this will raise aTypeError. If the project uses a custom/internal FA3 fork that accepts these, please document or pin that dependency.#!/bin/bash # Check if flash_attn_interface is importable and inspect its signature rg -n "flash_attn_interface" --type py -C2 # Check requirements/dependencies for flash-attn version fd -e txt -e toml -e cfg -e yaml "requirements\|pyproject\|setup" --exec cat {}flashinfer/parallel_attention/parallel_wrapper.py (2)
388-530: Ring wrapper's double-buffered P2P and output merging are correct;return_lseis still discarded.The double-buffering scheme is safe: the send buffer from iteration
iis waited on at iterationi+2before being overwritten, and concurrent reads (send + compute) on the same buffer are race-free. The output accumulation and padding zero-out logic are sound.However, the wrapper still always returns only
out, discardingsoftmax_lseeven though it's fully computed and merged. This is mitigated byulysses_wrapperrejectingreturn_lse=Trueat the API level, but it limits future extensibility.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/parallel_wrapper.py` around lines 388 - 530, The wrapper currently computes and merges softmax_lse but always returns only out; modify ring_wrapper so it preserves the func API by returning (out, softmax_lse) when return_lse=True (i.e. when kwargs.get("return_lse", False) is set) and returning out otherwise; locate the return near the end of ring_wrapper and change it to conditionally return softmax_lse alongside out (use the existing softmax_lse variable) so callers that expect LSEs receive them.
217-254:group_peerinP2POprequires PyTorch ≥ 2.6.0.The
group_peerparameter was introduced in PyTorch 2.6.0. The project does not specify a minimum PyTorch version in its dependencies, so this will fail on older installations with a confusingTypeError.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/parallel_wrapper.py` around lines 217 - 254, The P2POp usage in ring_attn_p2p_communicate uses the group_peer keyword which only exists in PyTorch ≥2.6; update the function to detect/support older PyTorch by constructing P2POp with group_peer when available and falling back to the older API (pass send_dst/recv_src as the positional dst argument or use the older keyword name) when it is not—wrap creation of the P2POp objects (the send_op/recv_op construction in ring_attn_p2p_communicate) in a small helper or try/except that attempts P2POp(..., group_peer=...) and on TypeError recreates the op as P2POp(..., send_dst) / P2POp(..., recv_src) (preserving group=ring_group) before calling torch.distributed.batch_isend_irecv so the code works across PyTorch versions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/parallel_attention/parallel_config.py`:
- Around line 542-606: In set_ring_varlen_config, fix incorrect use of global
world_size and rank: obtain ring_size = attn_parallel_config.ring_size() and
replace uses of world_size when computing padded_seq_lens_q/padded_seq_lens_kv,
padded_seq_len_q_cur_rank/padded_seq_len_kv_cur_rank and the per-rank chunking
logic (the values used to build padded_seq_lens_q, padded_seq_lens_kv,
padded_seq_len_q_cur_rank, padded_seq_len_kv_cur_rank and cu_seqlens_* lists) so
sequences are chunked by ring_size instead of world_size; also use device =
torch.device(f"cuda:{torch.cuda.current_device()}") (or
torch.cuda.current_device() to get the CUDA index) rather than
torch.distributed.get_rank() to choose the device before creating
cu_seqlens_q_all_ranks / cu_seqlens_kv_all_ranks, then call
self.set_varlen_cp_config with the stacked cu_seqlens and the same
max_seq_len_q/max_seq_len_kv.
In `@flashinfer/parallel_attention/utils.py`:
- Around line 68-101: The variable seq_len_padded_cur_rank (and any upstream
sums like total_seq_len_padded/seq_len_padded) is a 0-d tensor but is later used
to compute pad_len and to build pad_shape for torch.zeros; convert it to a
Python int before using in shape arithmetic and list construction (e.g., use
int(...) or .item()) so pad_len and pad_shape contain plain ints; update the
code around seq_len_padded_cur_rank, the pad_len calculation, and the pad_shape
assignment in the block that pads res (where res.shape[chunk_dim] is compared
and torch.zeros is called) to ensure torch.zeros receives integer dimensions
rather than tensors.
In `@tests/attention/test_parallel_attention.py`:
- Around line 51-57: The fixture skip_if_unsupported accesses
request.node.callspec which raises AttributeError for non-parametrized tests;
change the guard to safely handle missing callspec (e.g., check
hasattr(request.node, "callspec") or use getattr(request.node, "callspec",
None)) and only attempt to read attn_type when callspec is present, then proceed
with the existing flash-attn3 / is_sm90a_supported / get_compute_capability
logic; update references to attn_type lookup to tolerate None so
non-parametrized tests do not error.
- Around line 164-235: The test constructs seq_len_cur_rank as a 1-element CUDA
tensor but UnevenCPConfig.set_uneven_cp_config expects a scalar; convert
seq_len_cur_rank to a Python int before calling set_uneven_cp_config to avoid
creating a (1,1) tensor inside set_uneven_cp_config—e.g. extract the scalar from
seq_len_cur_rank (using int(...) or .item()) and pass that to
UnevenCPConfig.set_uneven_cp_config so seq_len_all_ranks and downstream
ring_wrapper comparisons receive the correct shape.
---
Duplicate comments:
In `@flashinfer/parallel_attention/attention_ops.py`:
- Around line 118-123: The code is unnecessarily packing the squeezed FA3
results into a tuple and immediately unpacking them later; instead of creating
(output, lse) and re-unpacking, call torch.squeeze on the returned tensors and
assign lse directly (e.g., lse = torch.squeeze(output[1], dim=0); output =
torch.squeeze(output[0], dim=0)) so you avoid the intermediate tuple; apply the
same change in the later identical block that currently repacks/unpacks output
and lse to remove redundant tuple creation.
- Around line 42-49: get_impl currently references an undefined class attribute
(cls.attn_type) when name is None, causing an AttributeError; change the
fallback to use getattr(cls, "attn_type", None) and if that returns None raise a
clear ValueError instructing callers to pass a name or define
AttentionOpManager.attn_type, then look up the registry and return an instance
as before (i.e., update get_impl to safely obtain a name via getattr and
validate it before using cls._attn_registry and instantiating the class).
In `@flashinfer/parallel_attention/parallel_config.py`:
- Around line 261-266: The set_config validation currently compares total_size
(computed from ulysses_size and ring_size) against torch.cuda.device_count(),
which fails for multi-node runs; change the check to use
torch.distributed.get_world_size() when the distributed process group is
initialized (or remove the local-GPU check entirely since check_parallel_size()
already enforces compatibility). Locate the check in set_config where
total_size, ulysses_size and ring_size are computed and replace the
device_count() comparison with a get_world_size() lookup guarded by
torch.distributed.is_initialized(), or drop this block and rely on
check_parallel_size() to validate sizes.
- Around line 385-386: The code uses global rank to pick a CUDA device (rank =
torch.distributed.get_rank(); device = torch.device(f"cuda:{rank}")), which
fails on multi-node setups; replace uses of torch.distributed.get_rank() with
the local-device API and derive device from torch.cuda.current_device() (e.g.,
device = torch.device(f"cuda:{torch.cuda.current_device()}")) at each site
mentioning rank/device (the occurrences that currently call
torch.distributed.get_rank() and set device), ensuring you remove or stop using
the global rank variable and use torch.cuda.current_device() (or
torch.distributed.get_local_rank() if you prefer) consistently in those code
paths.
- Around line 202-206: ring_ranks() and ulysses_ranks() access
instance._device_mesh["ring"] / ["ulysses"] unconditionally which raises
KeyError when parallelism is disabled; mirror the guard used in
ring_rank()/ulysses_rank() by checking the corresponding size or presence first
(e.g., instance.ring_size <= 1 or "ring" not in instance._device_mesh) and
return an empty list (or appropriate default) when disabled; update both
ring_ranks() and ulysses_ranks() (and the duplicate block around the other
function at the 222-226 region) to perform this guard before accessing
_device_mesh.
In `@flashinfer/parallel_attention/parallel_wrapper.py`:
- Around line 388-530: The wrapper currently computes and merges softmax_lse but
always returns only out; modify ring_wrapper so it preserves the func API by
returning (out, softmax_lse) when return_lse=True (i.e. when
kwargs.get("return_lse", False) is set) and returning out otherwise; locate the
return near the end of ring_wrapper and change it to conditionally return
softmax_lse alongside out (use the existing softmax_lse variable) so callers
that expect LSEs receive them.
- Around line 217-254: The P2POp usage in ring_attn_p2p_communicate uses the
group_peer keyword which only exists in PyTorch ≥2.6; update the function to
detect/support older PyTorch by constructing P2POp with group_peer when
available and falling back to the older API (pass send_dst/recv_src as the
positional dst argument or use the older keyword name) when it is not—wrap
creation of the P2POp objects (the send_op/recv_op construction in
ring_attn_p2p_communicate) in a small helper or try/except that attempts
P2POp(..., group_peer=...) and on TypeError recreates the op as P2POp(...,
send_dst) / P2POp(..., recv_src) (preserving group=ring_group) before calling
torch.distributed.batch_isend_irecv so the code works across PyTorch versions.
In `@flashinfer/parallel_attention/utils.py`:
- Around line 22-33: The comments in convert_output_layout are reversed: for the
branch src_layout == "HND" and dst_layout == "NHD" the code permutes with
out.permute(1, 0, 2) which converts [H, S, D] -> [S, H, D], and for the branch
src_layout == "NHD" and dst_layout == "HND" the permute converts [S, H, D] ->
[H, S, D]; update the inline comments to match those actual conversions (refer
to function convert_output_layout and the out.permute(1, 0, 2) calls) so the
comments correctly describe the source and destination layouts.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
flashinfer/parallel_attention/attention_ops.pyflashinfer/parallel_attention/parallel_attention.pyflashinfer/parallel_attention/parallel_config.pyflashinfer/parallel_attention/parallel_wrapper.pyflashinfer/parallel_attention/utils.pytests/attention/test_parallel_attention.py
| seq_len_padded = (seq_len_list + world_size - 1) // world_size * world_size | ||
| total_seq_len_padded = sum(seq_len_padded) | ||
| seq_len_padded_cur_rank = ( | ||
| (total_seq_len_padded + world_size - 1) // world_size | ||
| ).to(torch.int32) | ||
|
|
||
| chunks = [] | ||
| offset = 0 | ||
| for seq_len in seq_len_list: | ||
| seq_len = int(seq_len) | ||
| # First (world_size - 1) ranks get ceil(seq_len / world_size), | ||
| # last rank gets whatever is left. | ||
| base = (seq_len + world_size - 1) // world_size | ||
| if rank < world_size - 1: | ||
| chunk_len = base | ||
| start = offset + base * rank | ||
| else: | ||
| # Last rank gets the remainder | ||
| start = offset + base * (world_size - 1) | ||
| chunk_len = seq_len - base * (world_size - 1) | ||
|
|
||
| chunks.append(tensor.narrow(chunk_dim, start, chunk_len)) | ||
| offset += seq_len | ||
|
|
||
| res = torch.cat(chunks, dim=chunk_dim) | ||
|
|
||
| if res.shape[chunk_dim] < seq_len_padded_cur_rank: | ||
| pad_len = seq_len_padded_cur_rank - res.shape[chunk_dim] | ||
| pad_shape = list(res.shape) | ||
| pad_shape[chunk_dim] = pad_len | ||
| res = torch.cat( | ||
| [res, torch.zeros(pad_shape, device=res.device, dtype=res.dtype)], | ||
| dim=chunk_dim, | ||
| ) |
There was a problem hiding this comment.
seq_len_padded_cur_rank is a 0-d tensor — convert to int before using in shapes.
sum(seq_len_padded) returns a 0-d tensor, so seq_len_padded_cur_rank is also a 0-d tensor after the floor-division. This propagates into pad_len (line 95) and then into pad_shape (line 97), meaning torch.zeros(pad_shape, ...) receives a list containing a tensor element instead of a plain int. While recent PyTorch may handle this, it's fragile and undocumented.
🛠️ Proposed fix
seq_len_padded_cur_rank = (
(total_seq_len_padded + world_size - 1) // world_size
- ).to(torch.int32)
+ )
+ seq_len_padded_cur_rank = int(seq_len_padded_cur_rank.item())🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/parallel_attention/utils.py` around lines 68 - 101, The variable
seq_len_padded_cur_rank (and any upstream sums like
total_seq_len_padded/seq_len_padded) is a 0-d tensor but is later used to
compute pad_len and to build pad_shape for torch.zeros; convert it to a Python
int before using in shape arithmetic and list construction (e.g., use int(...)
or .item()) so pad_len and pad_shape contain plain ints; update the code around
seq_len_padded_cur_rank, the pad_len calculation, and the pad_shape assignment
in the block that pads res (where res.shape[chunk_dim] is compared and
torch.zeros is called) to ensure torch.zeros receives integer dimensions rather
than tensors.
| @pytest.fixture(autouse=True) | ||
| def skip_if_unsupported(request): | ||
| """Skip test if the attention backend requires unsupported hardware.""" | ||
| attn_type = request.node.callspec.params.get("attn_type", None) | ||
| if attn_type == "flash-attn3" and not is_sm90a_supported(torch.device("cuda")): | ||
| cc = get_compute_capability(torch.device("cuda")) | ||
| pytest.skip(f"flash-attn3 requires SM90a+, got {cc}") |
There was a problem hiding this comment.
Guard against AttributeError on non-parametrized tests.
request.node.callspec raises AttributeError when the test is not parametrized. Since this fixture is autouse=True, any future non-parametrized test in this module would fail with a confusing error.
🛠️ Proposed fix
`@pytest.fixture`(autouse=True)
def skip_if_unsupported(request):
"""Skip test if the attention backend requires unsupported hardware."""
- attn_type = request.node.callspec.params.get("attn_type", None)
+ callspec = getattr(request.node, "callspec", None)
+ attn_type = callspec.params.get("attn_type", None) if callspec else None
if attn_type == "flash-attn3" and not is_sm90a_supported(torch.device("cuda")):
cc = get_compute_capability(torch.device("cuda"))
pytest.skip(f"flash-attn3 requires SM90a+, got {cc}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/attention/test_parallel_attention.py` around lines 51 - 57, The fixture
skip_if_unsupported accesses request.node.callspec which raises AttributeError
for non-parametrized tests; change the guard to safely handle missing callspec
(e.g., check hasattr(request.node, "callspec") or use getattr(request.node,
"callspec", None)) and only attempt to read attn_type when callspec is present,
then proceed with the existing flash-attn3 / is_sm90a_supported /
get_compute_capability logic; update references to attn_type lookup to tolerate
None so non-parametrized tests do not error.
There was a problem hiding this comment.
Actionable comments posted: 7
♻️ Duplicate comments (5)
tests/attention/test_parallel_attention.py (1)
54-60:⚠️ Potential issue | 🟡 MinorGuard against
AttributeErroron non-parametrized tests.
request.node.callspecraisesAttributeErrorwhen the test is not parametrized. Since this fixture isautouse=True, any future non-parametrized test in this module would fail.🛠️ Proposed fix
`@pytest.fixture`(autouse=True) def skip_if_unsupported(request): """Skip test if the attention backend requires unsupported hardware.""" - attn_type = request.node.callspec.params.get("attn_type", None) + callspec = getattr(request.node, "callspec", None) + attn_type = callspec.params.get("attn_type", None) if callspec else None if attn_type == "flash-attn3" and not is_sm90a_supported(torch.device("cuda")): cc = get_compute_capability(torch.device("cuda")) pytest.skip(f"flash-attn3 requires SM90a+, got {cc}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_parallel_attention.py` around lines 54 - 60, The fixture skip_if_unsupported accesses request.node.callspec which raises AttributeError for non-parametrized tests; guard that access and only read attn_type when callspec exists. Update skip_if_unsupported to check for hasattr(request.node, "callspec") or wrap request.node.callspec access in try/except AttributeError, then proceed to evaluate attn_type and call is_sm90a_supported/get_compute_capability/pytest.skip as before if attn_type == "flash-attn3".flashinfer/parallel_attention/utils.py (2)
26-37:⚠️ Potential issue | 🟡 MinorComments in
convert_output_layoutare swapped.The comments describe the opposite direction. HND is
[H, S, D], so converting HND→NHD should be[H, S, D] → [S, H, D], not the reverse.📝 Fix comments
def convert_output_layout(out, src_layout, dst_layout): if src_layout == "HND" and dst_layout == "NHD": - # [S, H, D] -> [H, S, D] + # [H, S, D] -> [S, H, D] out = out.permute(1, 0, 2).contiguous() elif src_layout == "NHD" and dst_layout == "HND": - # [H, S, D] -> [S, H, D] + # [S, H, D] -> [H, S, D] out = out.permute(1, 0, 2).contiguous()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/utils.py` around lines 26 - 37, The comments in convert_output_layout are reversed; update them to accurately describe the tensor shapes for the HND and NHD layouts and their conversions: HND corresponds to [H, S, D] and NHD corresponds to [S, H, D], so the branch for src_layout == "HND" and dst_layout == "NHD" should comment "[H, S, D] -> [S, H, D]" and the branch for src_layout == "NHD" and dst_layout == "HND" should comment "[S, H, D] -> [H, S, D]"; leave the code (permute/contiguous) unchanged and only correct the comment strings in convert_output_layout.
72-106:⚠️ Potential issue | 🟡 MinorConvert
seq_len_padded_cur_rankto int before using in shapes.
sum(seq_len_padded)returns a 0-d tensor, soseq_len_padded_cur_rankremains a tensor. This propagates topad_lenandpad_shape, causingtorch.zerosto receive a tensor element instead of a plain int.🛠️ Proposed fix
seq_len_padded_cur_rank = ( (total_seq_len_padded + world_size - 1) // world_size - ).to(torch.int32) + ) + seq_len_padded_cur_rank = int(seq_len_padded_cur_rank.item())🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/utils.py` around lines 72 - 106, The variable seq_len_padded_cur_rank is a 0-d tensor (from sum(seq_len_padded)) and must be converted to a Python int before using it for shape arithmetic and tensor creation; update the computation of seq_len_padded_cur_rank (used later to compute pad_len and pad_shape) to yield an int (e.g., call int(...) or .item()), and ensure pad_len and any size values passed into torch.zeros are plain ints so torch.zeros receives integer sizes rather than tensor objects.flashinfer/parallel_attention/parallel_wrapper.py (1)
217-254:⚠️ Potential issue | 🟠 MajorEnsure PyTorch 2.6.0+ is required —
group_peerparameter added in 2.6.0.The
group_peerparameter intorch.distributed.P2POpwas introduced in PyTorch 2.6.0. This code will fail on earlier versions. Consider adding a runtime version check or updating the project's PyTorch dependency.PyTorch P2POp group_peer parameter version requirement🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/parallel_wrapper.py` around lines 217 - 254, The P2POp usage in ring_attn_p2p_communicate relies on the group_peer kwarg introduced in PyTorch 2.6.0; add a runtime check of torch.__version__ (or torch.version.cuda / packaging.version.parse) before constructing torch.distributed.P2POp and either (a) raise a clear RuntimeError instructing users to upgrade to PyTorch >=2.6.0, or (b) implement a fallback path that builds P2POp calls without the group_peer argument for older versions, ensuring both send_op and recv_op creation (and subsequent batch_isend_irecv) still work; update the logic that creates send_op/recv_op in ring_attn_p2p_communicate to choose the kwarg-aware or kwarg-less construction accordingly.flashinfer/parallel_attention/__init__.py (1)
1-8:⚠️ Potential issue | 🔴 CriticalAdd parallel_attention symbols to top-level
flashinfer/__init__.pyre-exports.
ParallelAttention,UnevenCPConfig,VarlenCPConfig, and utility functions (split_varlen_input,ulysses_varlen_config,ring_varlen_config,uneven_cp_config,get_parallel_groups) are not currently re-exported from the top-level package, inconsistent with the export pattern for other modules likeattention,cascade,mla, andsparse.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/__init__.py` around lines 1 - 8, Top-level package __init__ is missing re-exports for the parallel_attention API; import and re-export ParallelAttention, UnevenCPConfig, VarlenCPConfig and the utility functions (split_varlen_input, ulysses_varlen_config, ring_varlen_config, uneven_cp_config, get_parallel_groups) from flashinfer.parallel_attention (or the corresponding submodule names used in the diff) into the package __init__ and ensure they are exposed via the package's public API (add them to the module exports / __all__ alongside other exported symbols like attention, cascade, mla, sparse) so consumers can import these symbols directly from the top-level package.
🧹 Nitpick comments (2)
flashinfer/parallel_attention/__init__.py (1)
10-19: Consider sorting__all__alphabetically.Static analysis flags that
__all__is unsorted. Sorting improves readability and reduces merge conflicts.♻️ Proposed fix
__all__ = [ "ParallelAttention", + "get_parallel_groups", + "ring_varlen_config", + "split_varlen_input", + "ulysses_varlen_config", + "uneven_cp_config", "UnevenCPConfig", "VarlenCPConfig", - "split_varlen_input", - "ulysses_varlen_config", - "ring_varlen_config", - "uneven_cp_config", - "get_parallel_groups", ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/__init__.py` around lines 10 - 19, The __all__ list is unsorted; reorder its entries alphabetically to improve readability and reduce merge conflicts by sorting the identifiers ("ParallelAttention", "UnevenCPConfig", "VarlenCPConfig", "split_varlen_input", "ulysses_varlen_config", "ring_varlen_config", "uneven_cp_config", "get_parallel_groups") into ascending lexicographic order and update the __all__ declaration in parallel_attention.__init__ accordingly.flashinfer/parallel_attention/utils.py (1)
284-300: Parametersseq_lenandseq_len_paddedare documented but unused.Static analysis flags these as unused. If they're for future use or documentation purposes, consider prefixing with underscore or adding a TODO.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/utils.py` around lines 284 - 300, The parameters seq_len and seq_len_padded in function uneven_cp_config are declared and documented but never used; either remove them from the signature and docstring if they're dead, or preserve the API by renaming them to _seq_len and _seq_len_padded (or prefix with an underscore) and update the docstring to note they are intentionally unused/TODO for future use so static analysis stops flagging them; make the change only within the uneven_cp_config definition and its docstring to keep callers unaffected.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/parallel_attention/parallel_attention.py`:
- Around line 30-39: The docstring example wrongly references a non-existent
AttnParallelConfig; update the example to call get_parallel_groups() to obtain
ulysses_group and ring_group and pass them into ParallelAttention (the
ParallelAttention class and its run() example should remain), e.g., replace the
AttnParallelConfig setup with a brief call to get_parallel_groups() and use its
returned groups when constructing ParallelAttention so the example matches the
actual implementation.
In `@flashinfer/parallel_attention/utils.py`:
- Around line 225-226: The code uses torch.distributed.get_rank() to pick a CUDA
device (rank = torch.distributed.get_rank(); device =
torch.device(f"cuda:{rank}")), which breaks on multi-node setups; change it to
use the local CUDA device index via torch.cuda.current_device() so device is
constructed from torch.cuda.current_device() instead of the global distributed
rank (update the code around the rank and device variables where they are
defined/used).
- Around line 134-135: The code uses torch.distributed.get_rank() to derive a
CUDA device index (rank = torch.distributed.get_rank(); device =
torch.device(f"cuda:{rank}")), which can be incorrect on multi-node setups;
replace the device selection to use the local CUDA device
(torch.cuda.current_device()) or map global rank to local GPU index (e.g.,
local_rank = torch.cuda.current_device() or local_rank =
torch.distributed.get_rank() % torch.cuda.device_count()) and set device =
torch.device(f"cuda:{local_rank}") so that the device reflects the local GPU
rather than the global distributed rank.
- Around line 302-303: In uneven_cp_config, the code uses
torch.distributed.get_rank() to derive the CUDA device (variables rank and
device); replace that by using the local CUDA index via
torch.cuda.current_device() to support multi-node setups — update where
rank/device are set in uneven_cp_config to call torch.cuda.current_device() (or
use local_rank = torch.cuda.current_device() and set device =
torch.device(f"cuda:{local_rank}")) so the CUDA device reflects the current
process's GPU rather than the global distributed rank.
In `@tests/attention/test_parallel_attention.py`:
- Around line 86-105: The helper _sample_ring_varlen_tensors incorrectly uses
the global process rank to pick a CUDA device (device =
torch.device(f"cuda:{rank}")), which breaks multi-node setups; change it to
compute a local GPU index (e.g., read os.environ["LOCAL_RANK"] when present or
fall back to rank % torch.cuda.device_count()) and use that as the device
(device = torch.device(f"cuda:{local_rank}")). Update the device selection in
_sample_ring_varlen_tensors so q/k/v tensors are created on the correct local
GPU before broadcasts and keep the rest of the function (split_varlen_input,
dist.broadcast calls) unchanged.
- Around line 66-83: The helper _sample_tensors uses dist.get_rank() to build
the CUDA device string which breaks on multi-node setups; change the device
selection to use torch.cuda.current_device() (or torch.cuda.device_of when
appropriate) when creating q/k/v so tensors are allocated on the actual local
CUDA device, keeping dist.get_rank() only for indexing the chunked tensors
(local_q/local_k/local_v).
- Around line 49-51: The device fixture uses the global distributed rank which
can exceed the local GPU count in multi-node runs; update the device fixture
(def device(rank)) to derive the local CUDA index instead of using
f"cuda:{rank}" — for example call torch.cuda.current_device() (or compute
local_rank = rank % torch.cuda.device_count() if current_device() isn't set) and
return torch.device(f"cuda:{local_index}") so the fixture uses the node-local
GPU index.
---
Duplicate comments:
In `@flashinfer/parallel_attention/__init__.py`:
- Around line 1-8: Top-level package __init__ is missing re-exports for the
parallel_attention API; import and re-export ParallelAttention, UnevenCPConfig,
VarlenCPConfig and the utility functions (split_varlen_input,
ulysses_varlen_config, ring_varlen_config, uneven_cp_config,
get_parallel_groups) from flashinfer.parallel_attention (or the corresponding
submodule names used in the diff) into the package __init__ and ensure they are
exposed via the package's public API (add them to the module exports / __all__
alongside other exported symbols like attention, cascade, mla, sparse) so
consumers can import these symbols directly from the top-level package.
In `@flashinfer/parallel_attention/parallel_wrapper.py`:
- Around line 217-254: The P2POp usage in ring_attn_p2p_communicate relies on
the group_peer kwarg introduced in PyTorch 2.6.0; add a runtime check of
torch.__version__ (or torch.version.cuda / packaging.version.parse) before
constructing torch.distributed.P2POp and either (a) raise a clear RuntimeError
instructing users to upgrade to PyTorch >=2.6.0, or (b) implement a fallback
path that builds P2POp calls without the group_peer argument for older versions,
ensuring both send_op and recv_op creation (and subsequent batch_isend_irecv)
still work; update the logic that creates send_op/recv_op in
ring_attn_p2p_communicate to choose the kwarg-aware or kwarg-less construction
accordingly.
In `@flashinfer/parallel_attention/utils.py`:
- Around line 26-37: The comments in convert_output_layout are reversed; update
them to accurately describe the tensor shapes for the HND and NHD layouts and
their conversions: HND corresponds to [H, S, D] and NHD corresponds to [S, H,
D], so the branch for src_layout == "HND" and dst_layout == "NHD" should comment
"[H, S, D] -> [S, H, D]" and the branch for src_layout == "NHD" and dst_layout
== "HND" should comment "[S, H, D] -> [H, S, D]"; leave the code
(permute/contiguous) unchanged and only correct the comment strings in
convert_output_layout.
- Around line 72-106: The variable seq_len_padded_cur_rank is a 0-d tensor (from
sum(seq_len_padded)) and must be converted to a Python int before using it for
shape arithmetic and tensor creation; update the computation of
seq_len_padded_cur_rank (used later to compute pad_len and pad_shape) to yield
an int (e.g., call int(...) or .item()), and ensure pad_len and any size values
passed into torch.zeros are plain ints so torch.zeros receives integer sizes
rather than tensor objects.
In `@tests/attention/test_parallel_attention.py`:
- Around line 54-60: The fixture skip_if_unsupported accesses
request.node.callspec which raises AttributeError for non-parametrized tests;
guard that access and only read attn_type when callspec exists. Update
skip_if_unsupported to check for hasattr(request.node, "callspec") or wrap
request.node.callspec access in try/except AttributeError, then proceed to
evaluate attn_type and call
is_sm90a_supported/get_compute_capability/pytest.skip as before if attn_type ==
"flash-attn3".
---
Nitpick comments:
In `@flashinfer/parallel_attention/__init__.py`:
- Around line 10-19: The __all__ list is unsorted; reorder its entries
alphabetically to improve readability and reduce merge conflicts by sorting the
identifiers ("ParallelAttention", "UnevenCPConfig", "VarlenCPConfig",
"split_varlen_input", "ulysses_varlen_config", "ring_varlen_config",
"uneven_cp_config", "get_parallel_groups") into ascending lexicographic order
and update the __all__ declaration in parallel_attention.__init__ accordingly.
In `@flashinfer/parallel_attention/utils.py`:
- Around line 284-300: The parameters seq_len and seq_len_padded in function
uneven_cp_config are declared and documented but never used; either remove them
from the signature and docstring if they're dead, or preserve the API by
renaming them to _seq_len and _seq_len_padded (or prefix with an underscore) and
update the docstring to note they are intentionally unused/TODO for future use
so static analysis stops flagging them; make the change only within the
uneven_cp_config definition and its docstring to keep callers unaffected.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 18bf785e-0c44-412f-b951-7d72dbf467c3
📒 Files selected for processing (6)
flashinfer/parallel_attention/__init__.pyflashinfer/parallel_attention/parallel_attention.pyflashinfer/parallel_attention/parallel_config.pyflashinfer/parallel_attention/parallel_wrapper.pyflashinfer/parallel_attention/utils.pytests/attention/test_parallel_attention.py
| Example:: | ||
|
|
||
| config = AttnParallelConfig() | ||
| config.set_config(ulysses_size=2, ring_size=2) | ||
| attn = ParallelAttention( | ||
| attn_type="flash-attn3", | ||
| ulysses_group=ulysses_group, | ||
| ring_group=ring_group, | ||
| ) | ||
| output = attn.run(query, key, value, tensor_layout="HND") |
There was a problem hiding this comment.
Fix docstring example: AttnParallelConfig does not exist.
The example references AttnParallelConfig which is not defined in this module. Based on the actual implementation, the example should use get_parallel_groups() to obtain process groups directly.
📝 Proposed fix
Example::
- config = AttnParallelConfig()
- config.set_config(ulysses_size=2, ring_size=2)
+ ring_group, ulysses_group = get_parallel_groups(
+ ulysses_size=2, ring_size=2
+ )
attn = ParallelAttention(
attn_type="flash-attn3",
ulysses_group=ulysses_group,
ring_group=ring_group,
)
output = attn.run(query, key, value, tensor_layout="HND")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/parallel_attention/parallel_attention.py` around lines 30 - 39,
The docstring example wrongly references a non-existent AttnParallelConfig;
update the example to call get_parallel_groups() to obtain ulysses_group and
ring_group and pass them into ParallelAttention (the ParallelAttention class and
its run() example should remain), e.g., replace the AttnParallelConfig setup
with a brief call to get_parallel_groups() and use its returned groups when
constructing ParallelAttention so the example matches the actual implementation.
| rank = torch.distributed.get_rank() | ||
| device = torch.device(f"cuda:{rank}") |
There was a problem hiding this comment.
Global rank as CUDA device index may fail in multi-node setups.
torch.distributed.get_rank() returns the global rank, which can exceed local GPU count in multi-node training. Use torch.cuda.current_device() instead.
🛠️ Proposed fix
- rank = torch.distributed.get_rank()
- device = torch.device(f"cuda:{rank}")
+ device = torch.device(f"cuda:{torch.cuda.current_device()}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| rank = torch.distributed.get_rank() | |
| device = torch.device(f"cuda:{rank}") | |
| device = torch.device(f"cuda:{torch.cuda.current_device()}") |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/parallel_attention/utils.py` around lines 134 - 135, The code uses
torch.distributed.get_rank() to derive a CUDA device index (rank =
torch.distributed.get_rank(); device = torch.device(f"cuda:{rank}")), which can
be incorrect on multi-node setups; replace the device selection to use the local
CUDA device (torch.cuda.current_device()) or map global rank to local GPU index
(e.g., local_rank = torch.cuda.current_device() or local_rank =
torch.distributed.get_rank() % torch.cuda.device_count()) and set device =
torch.device(f"cuda:{local_rank}") so that the device reflects the local GPU
rather than the global distributed rank.
| rank = torch.distributed.get_rank() | ||
| device = torch.device(f"cuda:{rank}") |
There was a problem hiding this comment.
Global rank as CUDA device index — same multi-node issue.
Same issue as ulysses_varlen_config. Use torch.cuda.current_device().
🛠️ Proposed fix
- rank = torch.distributed.get_rank()
- device = torch.device(f"cuda:{rank}")
+ device = torch.device(f"cuda:{torch.cuda.current_device()}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| rank = torch.distributed.get_rank() | |
| device = torch.device(f"cuda:{rank}") | |
| device = torch.device(f"cuda:{torch.cuda.current_device()}") |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/parallel_attention/utils.py` around lines 225 - 226, The code uses
torch.distributed.get_rank() to pick a CUDA device (rank =
torch.distributed.get_rank(); device = torch.device(f"cuda:{rank}")), which
breaks on multi-node setups; change it to use the local CUDA device index via
torch.cuda.current_device() so device is constructed from
torch.cuda.current_device() instead of the global distributed rank (update the
code around the rank and device variables where they are defined/used).
| rank = torch.distributed.get_rank() | ||
| device = torch.device(f"cuda:{rank}") |
There was a problem hiding this comment.
Global rank as CUDA device index in uneven_cp_config.
Same multi-node issue. Use torch.cuda.current_device().
🛠️ Proposed fix
- rank = torch.distributed.get_rank()
- device = torch.device(f"cuda:{rank}")
+ device = torch.device(f"cuda:{torch.cuda.current_device()}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/parallel_attention/utils.py` around lines 302 - 303, In
uneven_cp_config, the code uses torch.distributed.get_rank() to derive the CUDA
device (variables rank and device); replace that by using the local CUDA index
via torch.cuda.current_device() to support multi-node setups — update where
rank/device are set in uneven_cp_config to call torch.cuda.current_device() (or
use local_rank = torch.cuda.current_device() and set device =
torch.device(f"cuda:{local_rank}")) so the CUDA device reflects the current
process's GPU rather than the global distributed rank.
| @pytest.fixture | ||
| def device(rank): | ||
| return torch.device(f"cuda:{rank}") |
There was a problem hiding this comment.
Global rank as CUDA device index may fail in multi-node setups.
Using torch.device(f"cuda:{rank}") where rank is the global distributed rank can fail in multi-node training where global rank exceeds local GPU count. Consider using torch.cuda.current_device() instead.
🛠️ Proposed fix
`@pytest.fixture`
def device(rank):
- return torch.device(f"cuda:{rank}")
+ return torch.device(f"cuda:{torch.cuda.current_device()}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/attention/test_parallel_attention.py` around lines 49 - 51, The device
fixture uses the global distributed rank which can exceed the local GPU count in
multi-node runs; update the device fixture (def device(rank)) to derive the
local CUDA index instead of using f"cuda:{rank}" — for example call
torch.cuda.current_device() (or compute local_rank = rank %
torch.cuda.device_count() if current_device() isn't set) and return
torch.device(f"cuda:{local_index}") so the fixture uses the node-local GPU
index.
| def _sample_tensors(num_heads, seq_len, head_dim, world_size): | ||
| """Create sample tensors for attention testing.""" | ||
| shape = (num_heads, seq_len, head_dim) | ||
| rank = dist.get_rank() | ||
| device = torch.device(f"cuda:{rank}") | ||
|
|
||
| q = torch.randn(shape, device=device, dtype=torch.bfloat16) | ||
| k = torch.randn(shape, device=device, dtype=torch.bfloat16) | ||
| v = torch.randn(shape, device=device, dtype=torch.bfloat16) | ||
|
|
||
| dist.broadcast(q, src=0) | ||
| dist.broadcast(k, src=0) | ||
| dist.broadcast(v, src=0) | ||
|
|
||
| local_q = q.chunk(world_size, dim=1)[rank] | ||
| local_k = k.chunk(world_size, dim=1)[rank] | ||
| local_v = v.chunk(world_size, dim=1)[rank] | ||
| return q, k, v, local_q, local_k, local_v |
There was a problem hiding this comment.
Helper _sample_tensors also uses global rank as device index.
Same issue as the device fixture — in multi-node setups this will fail. Consider using torch.cuda.current_device().
🛠️ Proposed fix
def _sample_tensors(num_heads, seq_len, head_dim, world_size):
"""Create sample tensors for attention testing."""
shape = (num_heads, seq_len, head_dim)
rank = dist.get_rank()
- device = torch.device(f"cuda:{rank}")
+ device = torch.device(f"cuda:{torch.cuda.current_device()}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _sample_tensors(num_heads, seq_len, head_dim, world_size): | |
| """Create sample tensors for attention testing.""" | |
| shape = (num_heads, seq_len, head_dim) | |
| rank = dist.get_rank() | |
| device = torch.device(f"cuda:{rank}") | |
| q = torch.randn(shape, device=device, dtype=torch.bfloat16) | |
| k = torch.randn(shape, device=device, dtype=torch.bfloat16) | |
| v = torch.randn(shape, device=device, dtype=torch.bfloat16) | |
| dist.broadcast(q, src=0) | |
| dist.broadcast(k, src=0) | |
| dist.broadcast(v, src=0) | |
| local_q = q.chunk(world_size, dim=1)[rank] | |
| local_k = k.chunk(world_size, dim=1)[rank] | |
| local_v = v.chunk(world_size, dim=1)[rank] | |
| return q, k, v, local_q, local_k, local_v | |
| def _sample_tensors(num_heads, seq_len, head_dim, world_size): | |
| """Create sample tensors for attention testing.""" | |
| shape = (num_heads, seq_len, head_dim) | |
| rank = dist.get_rank() | |
| device = torch.device(f"cuda:{torch.cuda.current_device()}") | |
| q = torch.randn(shape, device=device, dtype=torch.bfloat16) | |
| k = torch.randn(shape, device=device, dtype=torch.bfloat16) | |
| v = torch.randn(shape, device=device, dtype=torch.bfloat16) | |
| dist.broadcast(q, src=0) | |
| dist.broadcast(k, src=0) | |
| dist.broadcast(v, src=0) | |
| local_q = q.chunk(world_size, dim=1)[rank] | |
| local_k = k.chunk(world_size, dim=1)[rank] | |
| local_v = v.chunk(world_size, dim=1)[rank] | |
| return q, k, v, local_q, local_k, local_v |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/attention/test_parallel_attention.py` around lines 66 - 83, The helper
_sample_tensors uses dist.get_rank() to build the CUDA device string which
breaks on multi-node setups; change the device selection to use
torch.cuda.current_device() (or torch.cuda.device_of when appropriate) when
creating q/k/v so tensors are allocated on the actual local CUDA device, keeping
dist.get_rank() only for indexing the chunked tensors (local_q/local_k/local_v).
| def _sample_ring_varlen_tensors(num_heads, head_dim, world_size, seq_len_list): | ||
| rank = dist.get_rank() | ||
| device = torch.device(f"cuda:{rank}") | ||
|
|
||
| total_seq_len = sum(seq_len_list) | ||
| shape = (num_heads, total_seq_len, head_dim) | ||
|
|
||
| q = torch.randn(shape, device=device, dtype=torch.bfloat16) | ||
| k = torch.randn(shape, device=device, dtype=torch.bfloat16) | ||
| v = torch.randn(shape, device=device, dtype=torch.bfloat16) | ||
|
|
||
| dist.broadcast(q, src=0) | ||
| dist.broadcast(k, src=0) | ||
| dist.broadcast(v, src=0) | ||
|
|
||
| local_q = split_varlen_input(q, seq_len_list, world_size, rank) | ||
| local_k = split_varlen_input(k, seq_len_list, world_size, rank) | ||
| local_v = split_varlen_input(v, seq_len_list, world_size, rank) | ||
|
|
||
| return q, k, v, local_q, local_k, local_v |
There was a problem hiding this comment.
Helper _sample_ring_varlen_tensors also uses global rank as device index.
Same multi-node issue as above.
🛠️ Proposed fix
def _sample_ring_varlen_tensors(num_heads, head_dim, world_size, seq_len_list):
rank = dist.get_rank()
- device = torch.device(f"cuda:{rank}")
+ device = torch.device(f"cuda:{torch.cuda.current_device()}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/attention/test_parallel_attention.py` around lines 86 - 105, The helper
_sample_ring_varlen_tensors incorrectly uses the global process rank to pick a
CUDA device (device = torch.device(f"cuda:{rank}")), which breaks multi-node
setups; change it to compute a local GPU index (e.g., read
os.environ["LOCAL_RANK"] when present or fall back to rank %
torch.cuda.device_count()) and use that as the device (device =
torch.device(f"cuda:{local_rank}")). Update the device selection in
_sample_ring_varlen_tensors so q/k/v tensors are created on the correct local
GPU before broadcasts and keep the rest of the function (split_varlen_input,
dist.broadcast calls) unchanged.
|
/bot run |
|
[FAILED] Pipeline #45775002: 6/20 passed |
|
/bot run |
There was a problem hiding this comment.
♻️ Duplicate comments (10)
tests/comm/test_parallel_attention.py (4)
60-62:⚠️ Potential issue | 🟡 MinorGlobal rank as CUDA device index may fail in multi-node setups.
Using
torch.device(f"cuda:{rank}")whererankis the global distributed rank can fail in multi-node training. Consider usingtorch.cuda.current_device()instead.🛠️ Proposed fix
`@pytest.fixture` def device(rank): - return torch.device(f"cuda:{rank}") + return torch.device(f"cuda:{torch.cuda.current_device()}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/comm/test_parallel_attention.py` around lines 60 - 62, The fixture device uses the global distributed rank which can be incorrect on multi-node setups; update the pytest fixture named device to select the CUDA device via the local/current CUDA index instead of the global rank (e.g., use torch.cuda.current_device() or a local_rank fixture) so the device becomes torch.device(f"cuda:{torch.cuda.current_device()}") (or derive from a provided local_rank) and ensure any tests depending on device use that updated fixture.
101-120:⚠️ Potential issue | 🟡 MinorHelper
_sample_ring_varlen_tensorsalso uses global rank as device index.Same multi-node issue as above.
🛠️ Proposed fix
def _sample_ring_varlen_tensors(num_heads, head_dim, world_size, seq_len_list): rank = dist.get_rank() - device = torch.device(f"cuda:{rank}") + device = torch.device(f"cuda:{torch.cuda.current_device()}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/comm/test_parallel_attention.py` around lines 101 - 120, The test helper _sample_ring_varlen_tensors uses the global MPI rank to pick a CUDA device (device = torch.device(f"cuda:{rank}")), which breaks on multi-node setups; change it to derive a local GPU index (e.g. local_rank) instead of using the global rank. Compute local_rank from an environment variable if present (LOCAL_RANK or CUDA_VISIBLE_DEVICES mapping) or fall back to rank % torch.cuda.device_count(), then set device = torch.device(f"cuda:{local_rank}"); update imports if needed (os) and keep the rest of the function (references: _sample_ring_varlen_tensors, dist.get_rank()) the same.
65-76:⚠️ Potential issue | 🟡 MinorGuard against
AttributeErroron non-parametrized tests.
request.node.callspecraisesAttributeErrorwhen the test is not parametrized. Since this fixture isautouse=True, any future non-parametrized test in this module would fail with a confusing error.🛠️ Proposed fix
`@pytest.fixture`(autouse=True) def skip_if_unsupported(request): """Skip test if the attention backend requires unsupported hardware.""" - attn_type = request.node.callspec.params.get("attn_type", None) + callspec = getattr(request.node, "callspec", None) + attn_type = callspec.params.get("attn_type", None) if callspec else None if attn_type == "flash-attn3" and not is_sm90a_supported(torch.device("cuda")): cc = get_compute_capability(torch.device("cuda")) pytest.skip(f"flash-attn3 requires SM90a+, got {cc}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/comm/test_parallel_attention.py` around lines 65 - 76, The autouse fixture skip_if_unsupported assumes request.node.callspec exists and raises AttributeError for non-parametrized tests; guard access by checking for callspec first (e.g., return early if not hasattr(request.node, "callspec") or use getattr(request.node, "callspec", None) and then safe .params lookup) before extracting attn_type so attn_type = ... .get("attn_type", None) doesn't blow up; retain the existing checks for "flash-attn3" and "cutlass" and their use of is_sm90a_supported/is_sm100a_supported/get_compute_capability.
81-98:⚠️ Potential issue | 🟡 MinorHelper
_sample_tensorsalso uses global rank as device index.Same multi-node issue as the
devicefixture. Consider usingtorch.cuda.current_device().🛠️ Proposed fix
def _sample_tensors(num_heads, seq_len, head_dim, world_size): """Create sample tensors for attention testing.""" shape = (num_heads, seq_len, head_dim) rank = dist.get_rank() - device = torch.device(f"cuda:{rank}") + device = torch.device(f"cuda:{torch.cuda.current_device()}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/comm/test_parallel_attention.py` around lines 81 - 98, The helper _sample_tensors uses dist.get_rank() as the CUDA device index (device = torch.device(f"cuda:{rank}")), which breaks in multi-node setups; change it to use the local CUDA index via torch.cuda.current_device() when creating tensors (e.g., device = torch.device(f"cuda:{torch.cuda.current_device()}")) while still using rank = dist.get_rank() for the chunking/indexing (local_q = q.chunk(world_size, dim=1)[rank], etc.) so tensors are allocated on the correct local GPU but sharding logic still uses the distributed rank.flashinfer/parallel_attention/utils.py (5)
134-135:⚠️ Potential issue | 🟡 MinorGlobal rank as CUDA device index may fail in multi-node setups.
torch.distributed.get_rank()returns the global rank, which can exceed local GPU count in multi-node training. Usetorch.cuda.current_device()instead.🛠️ Proposed fix
- rank = torch.distributed.get_rank() - device = torch.device(f"cuda:{rank}") + device = torch.device(f"cuda:{torch.cuda.current_device()}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/utils.py` around lines 134 - 135, The code uses torch.distributed.get_rank() to pick a CUDA device (rank = torch.distributed.get_rank(); device = torch.device(f"cuda:{rank}")), which fails on multi-node setups; replace use of global rank with the local CUDA device index (e.g., torch.cuda.current_device()) when selecting the device so device = torch.device(f"cuda:{torch.cuda.current_device()}") or otherwise obtain the local GPU index; update any surrounding logic in utils.py that assumes global rank maps 1:1 to local GPU ids.
26-37:⚠️ Potential issue | 🟡 MinorComments in
convert_output_layoutare swapped.Line 28 says
[S, H, D] -> [H, S, D]for theHND → NHDbranch, but HND is[H, S, D]so the conversion is[H, S, D] → [S, H, D]. Line 31 has the mirror mistake. Thepermute(1, 0, 2)code is correct—only the comments are wrong.📝 Fix comments
if src_layout == "HND" and dst_layout == "NHD": - # [S, H, D] -> [H, S, D] + # [H, S, D] -> [S, H, D] out = out.permute(1, 0, 2).contiguous() elif src_layout == "NHD" and dst_layout == "HND": - # [H, S, D] -> [S, H, D] + # [S, H, D] -> [H, S, D] out = out.permute(1, 0, 2).contiguous()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/utils.py` around lines 26 - 37, The comments in convert_output_layout are swapped: update the comment for the branch if src_layout == "HND" and dst_layout == "NHD" to say "[H, S, D] -> [S, H, D]" and update the other branch (src_layout == "NHD" and dst_layout == "HND") to say "[S, H, D] -> [H, S, D]"; leave the implementation (out.permute(1, 0, 2).contiguous()) unchanged and ensure the comments correctly describe the layout change for the function convert_output_layout and its permute(1, 0, 2) usage.
225-226:⚠️ Potential issue | 🟡 MinorGlobal rank as CUDA device index—same multi-node issue.
Same issue as
ulysses_varlen_config. Usetorch.cuda.current_device().🛠️ Proposed fix
- rank = torch.distributed.get_rank() - device = torch.device(f"cuda:{rank}") + device = torch.device(f"cuda:{torch.cuda.current_device()}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/utils.py` around lines 225 - 226, The code sets device using the global distributed rank (rank = torch.distributed.get_rank()) which breaks multi-node setups; change the device selection in utils.py to use the local CUDA index from torch.cuda.current_device() (or torch.cuda.device_count()/local_rank if needed) instead of the global rank so device = torch.device(f"cuda:{torch.cuda.current_device()}") (update any places referencing the old rank-based device variable).
302-303:⚠️ Potential issue | 🟡 MinorGlobal rank as CUDA device index in
uneven_cp_config.Same multi-node issue. Use
torch.cuda.current_device().🛠️ Proposed fix
- rank = torch.distributed.get_rank() - device = torch.device(f"cuda:{rank}") + device = torch.device(f"cuda:{torch.cuda.current_device()}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/utils.py` around lines 302 - 303, The code in uneven_cp_config currently uses torch.distributed.get_rank() to pick a CUDA device (rank = torch.distributed.get_rank(); device = torch.device(f"cuda:{rank}")), which breaks in multi-node setups; change it to use the local GPU index via torch.cuda.current_device() (e.g., obtain current = torch.cuda.current_device() and set device = torch.device(f"cuda:{current}") ) inside the uneven_cp_config logic so the device maps to the local CUDA device rather than the global process rank.
72-105:⚠️ Potential issue | 🟡 Minor
seq_len_padded_cur_rankis a 0-d tensor—convert tointbefore using in shapes.
sum(seq_len_padded)returns a 0-d tensor, soseq_len_padded_cur_rankis also a 0-d tensor. This propagates intopad_len(line 99) andpad_shape(line 101), meaningtorch.zeros(pad_shape, ...)receives a list containing a tensor element. While recent PyTorch may handle this, it's fragile.🛠️ Proposed fix
seq_len_padded_cur_rank = ( (total_seq_len_padded + world_size - 1) // world_size - ).to(torch.int32) + ) + seq_len_padded_cur_rank = int(seq_len_padded_cur_rank.item())🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/utils.py` around lines 72 - 105, seq_len_padded_cur_rank is a 0-d tensor and is later used to compute pad_len and pad_shape for torch.zeros, which can produce a list containing a tensor; convert seq_len_padded_cur_rank to a Python int before any shape arithmetic or comparisons (e.g., use int(seq_len_padded_cur_rank) or .item()) so that pad_len = int(seq_len_padded_cur_rank) - res.shape[chunk_dim] yields an int and pad_shape contains only ints; update references in the function where seq_len_padded_cur_rank, pad_len, and pad_shape are computed (the variables named seq_len_padded_cur_rank, pad_len, pad_shape, and the torch.zeros call) to use the converted int value.flashinfer/parallel_attention/attention_ops.py (1)
45-52:⚠️ Potential issue | 🟡 Minor
get_implfallback references undefinedcls.attn_type.When
name is None, line 48 falls back tocls.attn_type, butAttentionOpManagernever defines this class attribute. This would raiseAttributeErrorat runtime. WhileParallelAttentionalways passes a name, this is a latent bug.🛠️ Suggested fix
`@classmethod` def get_impl(cls, name=None): if name is None: - name = cls.attn_type + raise ValueError("Attention type name must be provided") attn_class = cls._attn_registry.get(name)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/parallel_attention/attention_ops.py` around lines 45 - 52, The get_impl method uses an undefined class attribute cls.attn_type when name is None, which will raise AttributeError; update get_impl in AttentionOpManager to avoid referencing cls.attn_type—either require a name and raise a clear ValueError if name is None, or introduce and use a defined fallback like a class constant DEFAULT_ATTENTION; then look up in cls._attn_registry and instantiate the class (symbols: get_impl, AttentionOpManager, cls.attn_type, cls._attn_registry).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/parallel_attention/attention_ops.py`:
- Around line 45-52: The get_impl method uses an undefined class attribute
cls.attn_type when name is None, which will raise AttributeError; update
get_impl in AttentionOpManager to avoid referencing cls.attn_type—either require
a name and raise a clear ValueError if name is None, or introduce and use a
defined fallback like a class constant DEFAULT_ATTENTION; then look up in
cls._attn_registry and instantiate the class (symbols: get_impl,
AttentionOpManager, cls.attn_type, cls._attn_registry).
In `@flashinfer/parallel_attention/utils.py`:
- Around line 134-135: The code uses torch.distributed.get_rank() to pick a CUDA
device (rank = torch.distributed.get_rank(); device =
torch.device(f"cuda:{rank}")), which fails on multi-node setups; replace use of
global rank with the local CUDA device index (e.g., torch.cuda.current_device())
when selecting the device so device =
torch.device(f"cuda:{torch.cuda.current_device()}") or otherwise obtain the
local GPU index; update any surrounding logic in utils.py that assumes global
rank maps 1:1 to local GPU ids.
- Around line 26-37: The comments in convert_output_layout are swapped: update
the comment for the branch if src_layout == "HND" and dst_layout == "NHD" to say
"[H, S, D] -> [S, H, D]" and update the other branch (src_layout == "NHD" and
dst_layout == "HND") to say "[S, H, D] -> [H, S, D]"; leave the implementation
(out.permute(1, 0, 2).contiguous()) unchanged and ensure the comments correctly
describe the layout change for the function convert_output_layout and its
permute(1, 0, 2) usage.
- Around line 225-226: The code sets device using the global distributed rank
(rank = torch.distributed.get_rank()) which breaks multi-node setups; change the
device selection in utils.py to use the local CUDA index from
torch.cuda.current_device() (or torch.cuda.device_count()/local_rank if needed)
instead of the global rank so device =
torch.device(f"cuda:{torch.cuda.current_device()}") (update any places
referencing the old rank-based device variable).
- Around line 302-303: The code in uneven_cp_config currently uses
torch.distributed.get_rank() to pick a CUDA device (rank =
torch.distributed.get_rank(); device = torch.device(f"cuda:{rank}")), which
breaks in multi-node setups; change it to use the local GPU index via
torch.cuda.current_device() (e.g., obtain current = torch.cuda.current_device()
and set device = torch.device(f"cuda:{current}") ) inside the uneven_cp_config
logic so the device maps to the local CUDA device rather than the global process
rank.
- Around line 72-105: seq_len_padded_cur_rank is a 0-d tensor and is later used
to compute pad_len and pad_shape for torch.zeros, which can produce a list
containing a tensor; convert seq_len_padded_cur_rank to a Python int before any
shape arithmetic or comparisons (e.g., use int(seq_len_padded_cur_rank) or
.item()) so that pad_len = int(seq_len_padded_cur_rank) - res.shape[chunk_dim]
yields an int and pad_shape contains only ints; update references in the
function where seq_len_padded_cur_rank, pad_len, and pad_shape are computed (the
variables named seq_len_padded_cur_rank, pad_len, pad_shape, and the torch.zeros
call) to use the converted int value.
In `@tests/comm/test_parallel_attention.py`:
- Around line 60-62: The fixture device uses the global distributed rank which
can be incorrect on multi-node setups; update the pytest fixture named device to
select the CUDA device via the local/current CUDA index instead of the global
rank (e.g., use torch.cuda.current_device() or a local_rank fixture) so the
device becomes torch.device(f"cuda:{torch.cuda.current_device()}") (or derive
from a provided local_rank) and ensure any tests depending on device use that
updated fixture.
- Around line 101-120: The test helper _sample_ring_varlen_tensors uses the
global MPI rank to pick a CUDA device (device = torch.device(f"cuda:{rank}")),
which breaks on multi-node setups; change it to derive a local GPU index (e.g.
local_rank) instead of using the global rank. Compute local_rank from an
environment variable if present (LOCAL_RANK or CUDA_VISIBLE_DEVICES mapping) or
fall back to rank % torch.cuda.device_count(), then set device =
torch.device(f"cuda:{local_rank}"); update imports if needed (os) and keep the
rest of the function (references: _sample_ring_varlen_tensors, dist.get_rank())
the same.
- Around line 65-76: The autouse fixture skip_if_unsupported assumes
request.node.callspec exists and raises AttributeError for non-parametrized
tests; guard access by checking for callspec first (e.g., return early if not
hasattr(request.node, "callspec") or use getattr(request.node, "callspec", None)
and then safe .params lookup) before extracting attn_type so attn_type = ...
.get("attn_type", None) doesn't blow up; retain the existing checks for
"flash-attn3" and "cutlass" and their use of
is_sm90a_supported/is_sm100a_supported/get_compute_capability.
- Around line 81-98: The helper _sample_tensors uses dist.get_rank() as the CUDA
device index (device = torch.device(f"cuda:{rank}")), which breaks in multi-node
setups; change it to use the local CUDA index via torch.cuda.current_device()
when creating tensors (e.g., device =
torch.device(f"cuda:{torch.cuda.current_device()}")) while still using rank =
dist.get_rank() for the chunking/indexing (local_q = q.chunk(world_size,
dim=1)[rank], etc.) so tensors are allocated on the correct local GPU but
sharding logic still uses the distributed rank.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 53475189-cb1a-4178-9901-7897109f2e62
📒 Files selected for processing (4)
flashinfer/parallel_attention/attention_ops.pyflashinfer/parallel_attention/utils.pyscripts/task_test_multi_gpu_comm_kernels.shtests/comm/test_parallel_attention.py
|
[SUCCESS] Pipeline #46044958: 9/20 passed |
saltyminty
left a comment
There was a problem hiding this comment.
Do the parallel attention tests need to be hooked up into the CI? Doesn't seem like they're running right now.
| - **ulysses_group** (*Optional[ProcessGroup]*): The Ulysses | ||
| process group, or ``None`` if ``ulysses_size == 1``. | ||
|
|
||
| Returns ``None`` (not a tuple) when ``ulysses_size == ring_size == 1``. |
There was a problem hiding this comment.
Is there a reason for this as opposed to just returning (None, None)?
There was a problem hiding this comment.
always return (None, None) is better, I changed it.
yes, this tests need to be hooked up into the CI, I added the tests into CI in scripts/task_test_multi_gpu_comm_kernels.sh |
|
/bot run |
|
[SUCCESS] Pipeline #46623748: 12/20 passed |
|
Seems like there's failures for Probably unrelated/expected, but could you verify? Looks good otherwise. |
This failure is caused by @torch.compile decorator in my code. @torch.compile will run pwd.getpwuid() but USER is not set. I removed @torch.compile and will rerun the tests |
|
/bot run |
|
[SUCCESS] Pipeline #46742992: 14/20 passed |
saltyminty
left a comment
There was a problem hiding this comment.
CI looks good, approved.
📌 Description
Add a
parallel_attentionmodule to FlashInfer that enables distributed attentioncomputation using Ulysses (all-to-all head parallelism) and Ring (P2P KV
exchange with online softmax merging) strategies, or a combination of both.
New files
parallel_attention.py—ParallelAttentionclass: the main entry point thatwraps any registered attention backend and applies Ulysses/Ring parallelism
transparently via decorators.
parallel_config.py— Configuration classes:AttnParallelConfig: singleton that managesulysses_size,ring_size, devicemesh creation, and process group accessors.
UnevenCPConfig: handles uneven context parallelism where the total sequencelength is not divisible by
world_size.VarlenCPConfig: handles variable-length (ragged) batching where multiplesequences of different lengths are packed together.
parallel_wrapper.py— Decorator implementations:ulysses_wrapper: performs all-to-all communication to split heads across ranks,calls the inner function, then reverses the all-to-all.
ring_wrapper: implements ring attention with P2P KV exchange and online softmaxcorrection across ring steps.
all_to_all,ulysses_a2a_in/out,ring_fwd_out_correction,ring_fwd_softmax_lse_correction,ring_attn_p2p_communicate.attention_ops.py—AttentionOpManagerregistry with decorator-basedbackend registration. Includes
FlashAttn3as the first registered backend.utils.py— Utility functions:convert_qkv_layout,convert_output_layout,split_varlen_input.__init__.py— Package API re-exports.Tests
tests/attention/test_parallel_attention.py— Pytest-based test suite covering:test_attn_parallel)test_uneven_attn_parallel)test_ulysses_varlen_attn_parallel)test_ring_varlen_attn_parallel)tensor_layout("HND"/"NHD")Key design decisions
@AttentionOpManager.register_attn("name")and used with parallel wrappers.@ulysses_wrapperand@ring_wrapperarecomposable decorators — they can be stacked or used independently.
is_causal=TrueraisesNotImplementedError.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests