diff --git a/.github/workflows/integration_test_4gpu_rl.yaml b/.github/workflows/integration_test_4gpu_rl.yaml index 93f5c81b0a..da4a9b390d 100644 --- a/.github/workflows/integration_test_4gpu_rl.yaml +++ b/.github/workflows/integration_test_4gpu_rl.yaml @@ -66,7 +66,7 @@ jobs: pip install uv # 1. Install Monarch and TorchStore - uv pip install torchmonarch==0.3.0 + uv pip install torchmonarch==0.4.1 uv pip install --no-deps "git+https://github.com/meta-pytorch/torchstore.git@main" uv pip install pygtrie portpicker @@ -91,6 +91,11 @@ jobs: sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded" sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded" + # Install nvcc so FlashInfer can JIT compile its CUDA kernels. + # vLLM uses FlashInfer sampler by default (vllm-project/vllm#40376) + # but the CI docker image only has CUDA runtime, not the toolkit. + uv pip install nvidia-cuda-nvcc + # Run E2E RL integration tests (TP=2 on 4 GPUs) python -m torchtitan.experiments.rl.tests.integration_tests \ $RUNNER_TEMP/artifacts-to-be-uploaded --ngpu 4 \ diff --git a/.github/workflows/integration_test_8gpu_rl_h100.yaml b/.github/workflows/integration_test_8gpu_rl_h100.yaml index 2fe0240110..0dc854e587 100644 --- a/.github/workflows/integration_test_8gpu_rl_h100.yaml +++ b/.github/workflows/integration_test_8gpu_rl_h100.yaml @@ -84,6 +84,11 @@ jobs: sudo mkdir -p "$RUNNER_TEMP/artifacts-to-be-uploaded" sudo chown -R $(id -u):$(id -g) "$RUNNER_TEMP/artifacts-to-be-uploaded" + # Install nvcc so FlashInfer can JIT compile its CUDA kernels. + # vLLM uses FlashInfer sampler by default (vllm-project/vllm#40376) + # but the CI docker image only has CUDA runtime, not the toolkit. + uv pip install nvidia-cuda-nvcc + # Run bitwise parity tests (TP=2, batch-invariant mode) HF_ASSETS_PATH="$MODEL_PATH" torchrun --nproc-per-node=2 -m pytest \ torchtitan/experiments/rl/tests/test_bitwise_parity.py -v diff --git a/torchtitan/experiments/rl/models/attention.py b/torchtitan/experiments/rl/models/attention.py index cc1e9c1267..5aae2c5391 100644 --- a/torchtitan/experiments/rl/models/attention.py +++ b/torchtitan/experiments/rl/models/attention.py @@ -168,14 +168,16 @@ def forward( num_seqs + 1, dtype=torch.int32, device=query.device ) cu_seqlens_k[1:] = torch.cumsum(seqused_k, dim=0) - # FA3 + batch-invariant: fix num_splits=1 to prevent non-deterministic - # split-k reductions. FA2 is automatically batch-invariant and does - # not accept num_splits. extra_kwargs = {} - # Disable split_kv in Flash Attention to ensure bitwise identical output. - # see https://github.com/pytorch/pytorch/pull/176905 - if is_in_batch_invariant_mode() and current_flash_attention_impl() == "FA3": + # TODO(pytorch/pytorch#179760): FA2's auto num_splits heuristic + # produces NaN intermittently with paged KV (block_table). Force + # num_splits=1 as a workaround until the root cause is fixed + # upstream. current_flash_attention_impl() returns None when FA2 + # is the implicit default (SM < 9.0). For FA3, only force + # num_splits=1 in batch-invariant mode (determinism). + fa_impl = current_flash_attention_impl() + if fa_impl in (None, "FA2") or is_in_batch_invariant_mode(): extra_kwargs["num_splits"] = 1 if self.enable_gqa: diff --git a/torchtitan/models/common/attention.py b/torchtitan/models/common/attention.py index 6ce30b88a6..4e2037ab95 100644 --- a/torchtitan/models/common/attention.py +++ b/torchtitan/models/common/attention.py @@ -132,12 +132,15 @@ def forward( varlen_kwargs = dict() - if is_in_batch_invariant_mode(): - if current_flash_attention_impl() == "FA3": - # Fix split count to 1 to prevent non-deterministic split-k - # reductions that vary with batch composition. - # Only needed for FA3; FA2 is automatically batch-invariant. - varlen_kwargs["num_splits"] = 1 + # TODO(pytorch/pytorch#179760): FA2's auto num_splits heuristic + # produces NaN intermittently with paged KV (block_table). Force + # num_splits=1 as a workaround. current_flash_attention_impl() + # returns None when FA2 is the implicit default (SM < 9.0). + # For FA3, only force num_splits=1 in batch-invariant mode + # to prevent non-deterministic split-k reductions. + fa_impl = current_flash_attention_impl() + if fa_impl in (None, "FA2") or is_in_batch_invariant_mode(): + varlen_kwargs["num_splits"] = 1 # Forward enable_gqa from GQAttention when Q and KV head counts differ if kwargs.get("enable_gqa", False):