Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ci/jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,16 @@ run_test_config_mgpu() {
*0.4.35*)
# Workaround for distributed tests hang with xla_flag
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn'
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will run it with AOTriton too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated with a guard in the JAX ci script

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With those changes env variables are not seen by run method - they are applied to test call only.
Using run_default_fa_lbl. All V3 calls should be labelled with "v3" to distinct them from regular test_distributed_fused_attn call


# Test ring attention with xla_flag --xla_experimental_ignore_channel_id only
XLA_FLAGS="--xla_experimental_ignore_channel_id" run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn
XLA_FLAGS="--xla_experimental_ignore_channel_id" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn
;;
*)
# Workaround for distributed tests hang with xla_flag
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run 3 test_distributed_fused_attn.py
;;
esac

Expand Down
1 change: 1 addition & 0 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ run_test_config_mgpu(){
run 3 distributed/test_numerics.py
run 3 distributed/test_torch_fsdp2.py
run 3 fused_attn/test_fused_attn_with_cp.py
NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run 3 fused_attn/test_fused_attn_with_cp.py
fi
}

Expand Down