diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_sigmoid_gating_delta_rule.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_sigmoid_gating_delta_rule.py index fd469ef3704..abfbcc20367 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_sigmoid_gating_delta_rule.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_sigmoid_gating_delta_rule.py @@ -16,7 +16,6 @@ def test_triton_fusion_ops(): b = torch.tensor( [[0.4277, 0.8906, 1.6875, 2.3750, 4.1562, 0.3809, 1.0625, 3.6719]]).bfloat16().npu() - ssm_state = torch.randn(1, 8, 128, 128, dtype=torch.bfloat16).npu() non_spec_state_indices_tensor = torch.tensor([2]).int().npu() non_spec_query_start_loc = torch.tensor([0, 1]).int().npu() a_log = torch.tensor([ @@ -25,6 +24,7 @@ def test_triton_fusion_ops(): dt_bias = torch.tensor( [-4.7812, -5.0938, -5.5000, 9.4375, 7.6250, -4.3750, -3.0938, 0.9688]).bfloat16().npu() + ssm_state1 = torch.ones(1, 8, 128, 128, dtype=torch.bfloat16).npu() core_attn_out_non_spec_fused = fused_sigmoid_gating_delta_rule_update( A_log=a_log.contiguous(), @@ -34,7 +34,7 @@ def test_triton_fusion_ops(): v=v.contiguous(), a=a.contiguous(), b=b.contiguous(), - initial_state_source=ssm_state, + initial_state_source=ssm_state1, initial_state_indices=non_spec_state_indices_tensor, cu_seqlens=non_spec_query_start_loc, use_qk_l2norm_in_kernel=True, @@ -42,6 +42,7 @@ def test_triton_fusion_ops(): softplus_threshold=20.0, ) + ssm_state2 = torch.ones(1, 8, 128, 128, dtype=torch.bfloat16).npu() g, beta = fused_gdn_gating(a_log, a, b, dt_bias) g_non_spec = g beta_non_spec = beta @@ -52,7 +53,7 @@ def test_triton_fusion_ops(): v=v, g=g_non_spec, beta=beta_non_spec, - initial_state=ssm_state, + initial_state=ssm_state2, inplace_final_state=True, cu_seqlens=non_spec_query_start_loc, ssm_state_indices=non_spec_state_indices_tensor,