diff --git a/tests/kernels/core/test_layernorm.py b/tests/kernels/core/test_layernorm.py index f8f9660942af..6d799fbc1072 100644 --- a/tests/kernels/core/test_layernorm.py +++ b/tests/kernels/core/test_layernorm.py @@ -14,9 +14,11 @@ HIDDEN_SIZES = [8, 768, 769, 5120, 5125, 8192] # Arbitrary values for testing ADD_RESIDUAL = [False, True] SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2) -] +DEVICES = ["cpu"] +if torch.cuda.is_available(): + DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2) + ] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -24,7 +26,7 @@ @pytest.mark.parametrize("add_residual", ADD_RESIDUAL) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("strided_input", [False, True]) @torch.inference_mode() def test_rms_norm( @@ -46,12 +48,19 @@ def test_rms_norm( x = torch.randn(num_tokens, last_dim, dtype=dtype) x = x[..., :hidden_size] assert x.is_contiguous() != strided_input - x *= scale + x = x * scale if add_residual else x residual = torch.randn_like(x) * scale if add_residual else None # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. - ref_out = layer.forward_native(x, residual) + ref_out = layer.forward_static( + x, + variance_epsilon=layer.variance_epsilon, + orig_dtype=x.dtype, + weight=layer.weight, + hidden_size=layer.hidden_size, + residual=residual, + ) out = layer(x, residual) # NOTE(woosuk): LayerNorm operators (including RMS) typically have larger # numerical errors than other operators because they involve reductions. @@ -79,7 +88,7 @@ def test_rms_norm( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("quant_scale", [0.01, 1.0, 10.0]) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("strided_input", [False, True]) def test_fused_rms_norm_quant( num_tokens: int,