Skip to content

Commit b89d568

Browse files
committed
Merge remote-tracking branch 'upstream/main' into jax_attn_use_nonpacked
2 parents 78a340d + d0d4063 commit b89d568

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+3499
-3435
lines changed

3rdparty/cudnn-frontend

Submodule cudnn-frontend updated 87 files

examples/jax/encoder/test_multiprocessing_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ def test_te_mxfp8(self):
672672
def test_te_nvfp4(self):
673673
"""Test Transformer Engine with NVFP4"""
674674
result = self.exec(True, "NVFP4BlockScaling")
675-
assert result[0] < 0.451 and result[1] > 0.788
675+
assert result[0] < 0.451 and result[1] > 0.787
676676

677677
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
678678
def test_te_bf16_shardy(self):
@@ -710,7 +710,7 @@ def test_te_mxfp8_shardy(self):
710710
def test_te_nvfp4_shardy(self):
711711
"""Test Transformer Engine with NVFP4"""
712712
result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True)
713-
assert result[0] < 0.451 and result[1] > 0.788
713+
assert result[0] < 0.451 and result[1] > 0.787
714714

715715

716716
if __name__ == "__main__":

tests/cpp/operator/test_normalization.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,18 @@ void compute_ref_output(NormType norm_type,
114114
tmp = current * rsigma[i] * g;
115115
}
116116

117+
// Write output (scaled only for fp8 paths)
117118
output[i * H + j] = static_cast<OutputType>(tmp * scale);
118-
current_max = fmaxf(current_max, fabsf(tmp));
119+
120+
// amax semantics:
121+
// - fp8_out (scale != 1): amax on pre-scale compute value 'tmp'
122+
// - non-fp8_out (scale == 1): amax on value converted to OutputType (e.g., bf16)
123+
if (scale != 1.f) {
124+
current_max = fmaxf(current_max, fabsf(tmp));
125+
} else {
126+
OutputType out_t_val = static_cast<OutputType>(tmp);
127+
current_max = fmaxf(current_max, fabsf(static_cast<compute_t>(out_t_val)));
128+
}
119129
}
120130
}
121131

tests/jax/test_custom_call_compute.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
from transformer_engine.jax.activation import activation
4646
from transformer_engine.jax.dense import dense, grouped_dense
4747
from transformer_engine.jax.layernorm_dense import layernorm_dense
48-
from transformer_engine.common import recipe
4948

5049
GEMM_CASES = [
5150
(256, 256, 512),

tests/jax/test_distributed_layernorm_mlp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def _test_layernorm_mlp(
389389
intermediate_dim=INTERMEDIATE,
390390
activations=activation_type,
391391
use_bias=use_bias,
392+
return_layernorm_output=True,
392393
)
393394
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
394395
mlp_out_single, ln_out_single = ln_mlp_single.apply(
@@ -417,6 +418,7 @@ def _test_layernorm_mlp(
417418
dot_1_input_axes=DOT_1_INPUT_AXES,
418419
dot_2_input_axes=DOT_2_INPUT_AXES,
419420
name="mlp",
421+
return_layernorm_output=True,
420422
)
421423
params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
422424
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(

tests/jax/test_helper.py renamed to tests/jax/test_recipe_characteristics.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
from flax import linen as nn
1313

14-
from utils import assert_allclose
14+
from utils import assert_allclose, pytest_parametrize_wrapper
1515
from transformer_engine.common.recipe import (
1616
DelayedScaling,
1717
MXFP8BlockScaling,
@@ -22,6 +22,7 @@
2222
from transformer_engine.jax import autocast
2323
from transformer_engine.jax.quantize import (
2424
get_quantize_config,
25+
get_supported_quantization_recipes,
2526
is_scaling_mode_supported,
2627
ScalingMode,
2728
update_collections,
@@ -32,11 +33,15 @@
3233
from transformer_engine.jax.quantize.helper import _format2dtypes
3334
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
3435
from transformer_engine.jax.flax.module import TransformerEngineBase
36+
from transformer_engine.jax import flax as te_flax
37+
import transformer_engine.jax as te
3538

3639
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
3740
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
3841
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
3942

43+
SUPPORTED_RECIPES = get_supported_quantization_recipes()
44+
4045

4146
def quantizer_check_vjp(outer_quantizer_set, assertion_func, x):
4247
"""Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""
@@ -253,3 +258,63 @@ def test_autocast_nvfp4_block_scaling(self):
253258
self._compare_nvfp4_scaling_quantizers(bs)
254259

255260
self._check_default_state()
261+
262+
263+
class TestJaxprAndHlo:
264+
"""Tests to verify Jaxpr and/or HLO of compiled modules apply expected recipe functionality and optimizations."""
265+
266+
@pytest_parametrize_wrapper(
267+
"quantization_recipe",
268+
[
269+
quantization_recipe
270+
for quantization_recipe in SUPPORTED_RECIPES
271+
if isinstance(quantization_recipe, NVFP4BlockScaling)
272+
],
273+
)
274+
def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe):
275+
"""Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton."""
276+
277+
with te.autocast(enabled=True, recipe=quantization_recipe, mesh_resource=te.MeshResource()):
278+
model = te_flax.LayerNormMLP(
279+
layernorm_type="rmsnorm",
280+
return_layernorm_output=False,
281+
intermediate_dropout_rate=0.0,
282+
dtype=jnp.bfloat16,
283+
)
284+
285+
var_collect = model.init(
286+
jax.random.PRNGKey(0),
287+
jnp.ones((128, 128), dtype=jnp.bfloat16),
288+
)
289+
290+
def loss_fn(x, rngs):
291+
return jnp.mean(model.apply(var_collect, x, rngs=rngs)[0])
292+
293+
x = jax.random.normal(jax.random.PRNGKey(0), (128, 128), dtype=jnp.bfloat16)
294+
rngs = {"sr_rng": jax.random.PRNGKey(1), "dropout": jax.random.PRNGKey(2)}
295+
jaxpr = jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs)
296+
297+
rht_amax_eqns = [
298+
eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper"
299+
]
300+
301+
assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}"
302+
303+
def assert_param(index, tensor_name, expected_value: bool):
304+
if expected_value:
305+
assert rht_amax_eqns[index].params["produce_regular_amax"] == True, (
306+
f"Expected produce_regular_amax for {tensor_name} to be True, indicating no"
307+
" reuse of amax as this tensor does not have a previous operation to fuse"
308+
" with"
309+
)
310+
else:
311+
assert rht_amax_eqns[index].params["produce_regular_amax"] == False, (
312+
f"Expected produce_regular_amax for {tensor_name} to be False, indicating"
313+
" reuse of amax"
314+
)
315+
316+
assert_param(0, "fwd ln+q", False)
317+
assert_param(1, "fwd act+q", False)
318+
# No previous op before incoming dgrad in the backward so amax is not reused
319+
assert_param(2, "bwd dgrad", True)
320+
assert_param(3, "bwd dact+q", False)

tests/jax/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,9 @@ class MlpBlock(nn.Module):
364364

365365
transpose_batch_sequence: bool
366366
intermediate_dim: int = 2048
367-
activations: Sequence[Union[str, Callable]] = ("relu",)
367+
activations: Sequence[Union[str, Callable]] = ("gelu",)
368368
kernel_init: Initializer = None
369-
intermediate_dropout_rate: float = 0.1
369+
intermediate_dropout_rate: float = 0.0
370370
intermediate_dropout_dims: Sequence[int] = ()
371371
use_bias: bool = False
372372
dtype: Any = jnp.float32
@@ -1035,14 +1035,14 @@ class EncoderLayer(nn.Module):
10351035
hidden_dropout: float = 0.1
10361036
hidden_dropout_dims: Sequence[int] = ()
10371037
attention_dropout: float = 0.1
1038-
intermediate_dropout: float = 0.1
1038+
intermediate_dropout: float = 0.0
10391039
intermediate_dropout_dims: Sequence[int] = ()
10401040
transpose_batch_sequence: bool = True
10411041
float32_attention_logits: bool = False
10421042
scale_attn_logits: bool = False
10431043
scaled_query_init: bool = True
10441044
mlp_dim: int = 2048
1045-
mlp_activations: Sequence[str] = ("relu",)
1045+
mlp_activations: Sequence[str] = ("gelu",)
10461046
use_bias: bool = False
10471047
dtype: Any = jnp.float32
10481048
apply_residual_connection_post_layernorm: bool = False
@@ -1199,14 +1199,14 @@ class DecoderLayer(nn.Module):
11991199
hidden_dropout: float = 0.1
12001200
hidden_dropout_dims: Sequence[int] = ()
12011201
attention_dropout: float = 0.1
1202-
intermediate_dropout: float = 0.1
1202+
intermediate_dropout: float = 0.0
12031203
intermediate_dropout_dims: Sequence[int] = ()
12041204
transpose_batch_sequence: bool = True
12051205
float32_attention_logits: bool = False
12061206
scale_attn_logits: bool = False
12071207
scaled_query_init: bool = True
12081208
mlp_dim: int = 2048
1209-
mlp_activations: Sequence[str] = ("relu",)
1209+
mlp_activations: Sequence[str] = ("gelu",)
12101210
use_bias: bool = False
12111211
dtype: Any = jnp.float32
12121212
apply_residual_connection_post_layernorm: bool = False

0 commit comments

Comments
 (0)