Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/integration_test_4gpu_rl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
pip install uv

# 1. Install Monarch and TorchStore
uv pip install torchmonarch==0.3.0
uv pip install torchmonarch==0.4.1
uv pip install --no-deps "git+https://github.com/meta-pytorch/torchstore.git@main"
uv pip install pygtrie portpicker

Expand All @@ -91,6 +91,11 @@ jobs:
sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded"
sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded"

# Install nvcc so FlashInfer can JIT compile its CUDA kernels.
# vLLM uses FlashInfer sampler by default (vllm-project/vllm#40376)
# but the CI docker image only has CUDA runtime, not the toolkit.
uv pip install nvidia-cuda-nvcc

# Run E2E RL integration tests (TP=2 on 4 GPUs)
python -m torchtitan.experiments.rl.tests.integration_tests \
$RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 4 \
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/integration_test_8gpu_rl_h100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ jobs:
sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded"
sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded"
# Install nvcc so FlashInfer can JIT compile its CUDA kernels.
# vLLM uses FlashInfer sampler by default (vllm-project/vllm#40376)
# but the CI docker image only has CUDA runtime, not the toolkit.
uv pip install nvidia-cuda-nvcc
# Run bitwise parity tests (TP=2, batch-invariant mode)
HF_ASSETS_PATH="$MODEL_PATH" torchrun --nproc-per-node=2 -m pytest \
torchtitan/experiments/rl/tests/test_bitwise_parity.py -v
Expand Down
14 changes: 8 additions & 6 deletions torchtitan/experiments/rl/models/attention.py
Comment thread
wwwjn marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,16 @@ def forward(
num_seqs + 1, dtype=torch.int32, device=query.device
)
cu_seqlens_k[1:] = torch.cumsum(seqused_k, dim=0)
# FA3 + batch-invariant: fix num_splits=1 to prevent non-deterministic
# split-k reductions. FA2 is automatically batch-invariant and does
# not accept num_splits.
extra_kwargs = {}

# Disable split_kv in Flash Attention to ensure bitwise identical output.
# see https://github.com/pytorch/pytorch/pull/176905
if is_in_batch_invariant_mode() and current_flash_attention_impl() == "FA3":
# TODO(pytorch/pytorch#179760): FA2's auto num_splits heuristic
# produces NaN intermittently with paged KV (block_table). Force
# num_splits=1 as a workaround until the root cause is fixed
# upstream. current_flash_attention_impl() returns None when FA2
# is the implicit default (SM < 9.0). For FA3, only force
# num_splits=1 in batch-invariant mode (determinism).
fa_impl = current_flash_attention_impl()
if fa_impl in (None, "FA2") or is_in_batch_invariant_mode():
extra_kwargs["num_splits"] = 1

if self.enable_gqa:
Expand Down
15 changes: 9 additions & 6 deletions torchtitan/models/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,15 @@ def forward(

varlen_kwargs = dict()

if is_in_batch_invariant_mode():
if current_flash_attention_impl() == "FA3":
# Fix split count to 1 to prevent non-deterministic split-k
# reductions that vary with batch composition.
# Only needed for FA3; FA2 is automatically batch-invariant.
varlen_kwargs["num_splits"] = 1
# TODO(pytorch/pytorch#179760): FA2's auto num_splits heuristic
# produces NaN intermittently with paged KV (block_table). Force
# num_splits=1 as a workaround. current_flash_attention_impl()
# returns None when FA2 is the implicit default (SM < 9.0).
# For FA3, only force num_splits=1 in batch-invariant mode
# to prevent non-deterministic split-k reductions.
fa_impl = current_flash_attention_impl()
if fa_impl in (None, "FA2") or is_in_batch_invariant_mode():
varlen_kwargs["num_splits"] = 1

# Forward enable_gqa from GQAttention when Q and KV head counts differ
if kwargs.get("enable_gqa", False):
Expand Down
Loading