From 8c68a4795e1e1f2c1cede4f57ef4cfd3b5f9e093 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 23 Oct 2025 21:15:45 +0000 Subject: [PATCH 1/4] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/jax/flax/transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 1eafed4131..42c9451245 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -197,6 +197,7 @@ def __call__( fused_scale_factor = scale_factor if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: attn_weights += bias + bias = None def apply_swa_mask(original_mask: Array) -> Array: """Apply the sliding window mask to a given mask""" From 108ba4c6df1e88c3b5ace9893cac66656aae4b43 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 23 Oct 2025 21:45:06 +0000 Subject: [PATCH 2/4] fix Signed-off-by: Pawel Gadzinski --- tests/jax/test_layer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index d1b2535c4c..614ea0d74b 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -311,6 +311,10 @@ def _loss_fn(self, diff_xs, no_diff_xs, params, others, model): variables = {"params": params, **others} output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng) return jnp.mean(output, dtype=jnp.float32).astype(output.dtype) + + def _output_fn(self, params, others, model, diff_xs, no_diff_xs): + variables = {"params": params, **others} + return model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng) def _sync_params(self, ref, target): """Copy the reference params to target""" @@ -334,11 +338,14 @@ def test_forward( test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks) ref_params, test_params = self._sync_params(ref_params, test_params) - ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer) - test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer) + ref_out = self._output_fn(ref_params, ref_others, ref_layer, inputs, ref_masks) + test_out = self._output_fn(test_params, test_others, test_layer, inputs, test_masks) tols = dtype_tols(dtype, rtol=rtol, atol=atol) - assert_allclose(ref_out, test_out, **tols) + if not get_quantize_config().is_fp8_enabled(): + assert_allclose(ref_out, test_out, **tols) + else: + assert_allclose(ref_out.mean(), test_out.mean(), **tols) def test_backward( self, From c0a0947b35d362bdb38821fbb648247f396cd8aa Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 31 Oct 2025 09:49:09 +0000 Subject: [PATCH 3/4] fix Signed-off-by: Pawel Gadzinski --- tests/jax/test_layer.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 614ea0d74b..6a4eb6a7e0 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -311,10 +311,6 @@ def _loss_fn(self, diff_xs, no_diff_xs, params, others, model): variables = {"params": params, **others} output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng) return jnp.mean(output, dtype=jnp.float32).astype(output.dtype) - - def _output_fn(self, params, others, model, diff_xs, no_diff_xs): - variables = {"params": params, **others} - return model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng) def _sync_params(self, ref, target): """Copy the reference params to target""" @@ -338,14 +334,11 @@ def test_forward( test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks) ref_params, test_params = self._sync_params(ref_params, test_params) - ref_out = self._output_fn(ref_params, ref_others, ref_layer, inputs, ref_masks) - test_out = self._output_fn(test_params, test_others, test_layer, inputs, test_masks) + ref_out = self._loss_fn(ref_params, ref_others, ref_layer, inputs, ref_masks) + test_out = self._loss_fn(test_params, test_others, test_layer, inputs, test_masks) tols = dtype_tols(dtype, rtol=rtol, atol=atol) - if not get_quantize_config().is_fp8_enabled(): - assert_allclose(ref_out, test_out, **tols) - else: - assert_allclose(ref_out.mean(), test_out.mean(), **tols) + assert_allclose(ref_out.mean(), test_out.mean(), **tols) def test_backward( self, From b80b7d6c5a2d949ccef89505dc75e758adc33aae Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 31 Oct 2025 09:50:36 +0000 Subject: [PATCH 4/4] fix Signed-off-by: Pawel Gadzinski --- tests/jax/test_layer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 6a4eb6a7e0..d1b2535c4c 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -334,11 +334,11 @@ def test_forward( test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks) ref_params, test_params = self._sync_params(ref_params, test_params) - ref_out = self._loss_fn(ref_params, ref_others, ref_layer, inputs, ref_masks) - test_out = self._loss_fn(test_params, test_others, test_layer, inputs, test_masks) + ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer) + test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer) tols = dtype_tols(dtype, rtol=rtol, atol=atol) - assert_allclose(ref_out.mean(), test_out.mean(), **tols) + assert_allclose(ref_out, test_out, **tols) def test_backward( self,