-
Notifications
You must be signed in to change notification settings - Fork 540
[JAX] Add support for sink attention in JAX #2225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 34 commits
1bd7fd6
5cbfb7d
a1e0c51
dcb23b3
90e2d1a
58c443a
426ac6a
06cac3f
df623e5
5efb88c
b5ffcda
bd53e36
438953e
970cf80
216b46b
f04be9d
e0c7f6b
176e2e9
cab9659
2cd1727
99d67c1
4c79c34
1845710
7ca2377
6a45be7
b05e8ef
3abec0b
a5730df
952d558
c5332a4
1a7d61b
32ca266
6fc0e0b
553c0bc
bc43dc8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| from distributed_test_base import compare_ops | ||
| from utils import make_causal_mask, make_self_mask | ||
| from transformer_engine.jax import autocast | ||
| from transformer_engine.jax.softmax import SoftmaxType, softmax | ||
| from transformer_engine.jax.softmax import SoftmaxFusion, softmax | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: For this test, you may see a failure during rendezvous of the parent and child threads when running the CI at test However, looks like you are not adding any tests, but because you are going to run the softmax tests I wanted to keep you posted |
||
|
|
||
| DTYPES = [jnp.float16, jnp.bfloat16] | ||
|
|
||
|
|
@@ -29,12 +29,12 @@ def generate_collectives_count_ref(self): | |
| return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) | ||
|
|
||
| def generate_inputs( | ||
| self, shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask | ||
| self, shape, mesh_resource, softmax_fusion, dtype, bad_sharding, broadcast_batch_mask | ||
| ): | ||
| batch, _, sqelen, _ = shape | ||
|
|
||
| x = random.normal(random.PRNGKey(1124), shape, dtype=dtype) | ||
| if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED: | ||
| if softmax_fusion == SoftmaxFusion.SCALED_UPPER_TRIANG_MASKED: | ||
| mask = make_causal_mask(batch, sqelen) | ||
| else: | ||
| mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen) | ||
|
|
@@ -56,8 +56,8 @@ def generate_inputs( | |
| return (x, mask), (x_pspec, mask_pspec) | ||
|
|
||
| @staticmethod | ||
| def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED): | ||
| return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type)) | ||
| def target_func(x, mask, scale_factor=1.0, softmax_fusion=SoftmaxFusion.SCALED): | ||
| return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_fusion=softmax_fusion)) | ||
|
|
||
| @staticmethod | ||
| def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16): | ||
|
|
@@ -80,24 +80,24 @@ def impl_test_softmax( | |
| mesh_axes, | ||
| mesh_resource, | ||
| data_shape, | ||
| softmax_type, | ||
| softmax_fusion, | ||
| scale_factor, | ||
| dtype, | ||
| bad_sharding, | ||
| broadcast_batch_mask, | ||
| use_shardy, | ||
| ): | ||
| if broadcast_batch_mask and softmax_type != SoftmaxType.SCALED_MASKED: | ||
| if broadcast_batch_mask and softmax_fusion != SoftmaxFusion.SCALED_MASKED: | ||
| pytest.skip("Softmax type has no mask.") | ||
|
|
||
| jax.config.update("jax_use_shardy_partitioner", use_shardy) | ||
| target_func = partial( | ||
| self.target_func, scale_factor=scale_factor, softmax_type=softmax_type | ||
| self.target_func, scale_factor=scale_factor, softmax_fusion=softmax_fusion | ||
| ) | ||
| ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype) | ||
|
|
||
| (x, mask), (x_pspec, mask_pspec) = self.generate_inputs( | ||
| data_shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask | ||
| data_shape, mesh_resource, softmax_fusion, dtype, bad_sharding, broadcast_batch_mask | ||
| ) | ||
| collective_count_ref = self.generate_collectives_count_ref() | ||
| devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) | ||
|
|
@@ -139,8 +139,12 @@ def impl_test_softmax( | |
| @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) | ||
| @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]]) | ||
| @pytest.mark.parametrize( | ||
| "softmax_type", | ||
| [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED], | ||
| "softmax_fusion", | ||
| [ | ||
| SoftmaxFusion.SCALED, | ||
| SoftmaxFusion.SCALED_MASKED, | ||
| SoftmaxFusion.SCALED_UPPER_TRIANG_MASKED, | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("scale_factor", [1.0, 3.0]) | ||
| @pytest.mark.parametrize("dtype", DTYPES) | ||
|
|
@@ -153,7 +157,7 @@ def test_softmax( | |
| mesh_axes, | ||
| mesh_resource, | ||
| data_shape, | ||
| softmax_type, | ||
| softmax_fusion, | ||
| scale_factor, | ||
| dtype, | ||
| bad_sharding, | ||
|
|
@@ -165,7 +169,7 @@ def test_softmax( | |
| mesh_axes, | ||
| mesh_resource, | ||
| data_shape, | ||
| softmax_type, | ||
| softmax_fusion, | ||
| scale_factor, | ||
| dtype, | ||
| bad_sharding, | ||
|
|
@@ -174,7 +178,7 @@ def test_softmax( | |
| ) | ||
|
|
||
| @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) | ||
| @pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED]) | ||
| @pytest.mark.parametrize("softmax_fusion", [SoftmaxFusion.SCALED, SoftmaxFusion.SCALED_MASKED]) | ||
| @pytest.mark.parametrize("bad_sharding", [False, True]) | ||
| @pytest.mark.parametrize("broadcast_batch_mask", [False, True]) | ||
| def test_softmax_gspmd( | ||
|
|
@@ -183,7 +187,7 @@ def test_softmax_gspmd( | |
| mesh_shape, | ||
| mesh_axes, | ||
| mesh_resource, | ||
| softmax_type, | ||
| softmax_fusion, | ||
| bad_sharding, | ||
| broadcast_batch_mask, | ||
| ): | ||
|
|
@@ -193,7 +197,7 @@ def test_softmax_gspmd( | |
| mesh_axes, | ||
| mesh_resource, | ||
| data_shape=[32, 12, 128, 128], | ||
| softmax_type=softmax_type, | ||
| softmax_fusion=softmax_fusion, | ||
| scale_factor=1.0, | ||
| dtype=DTYPES[0], | ||
| bad_sharding=bad_sharding, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this comment here