Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def test_triton_fusion_ops():
b = torch.tensor(
[[0.4277, 0.8906, 1.6875, 2.3750, 4.1562, 0.3809, 1.0625,
3.6719]]).bfloat16().npu()
ssm_state = torch.randn(1, 8, 128, 128, dtype=torch.bfloat16).npu()
non_spec_state_indices_tensor = torch.tensor([2]).int().npu()
non_spec_query_start_loc = torch.tensor([0, 1]).int().npu()
a_log = torch.tensor([
Expand All @@ -25,6 +24,7 @@ def test_triton_fusion_ops():
dt_bias = torch.tensor(
[-4.7812, -5.0938, -5.5000, 9.4375, 7.6250, -4.3750, -3.0938,
0.9688]).bfloat16().npu()
ssm_state1 = torch.ones(1, 8, 128, 128, dtype=torch.bfloat16).npu()

core_attn_out_non_spec_fused = fused_sigmoid_gating_delta_rule_update(
A_log=a_log.contiguous(),
Expand All @@ -34,14 +34,15 @@ def test_triton_fusion_ops():
v=v.contiguous(),
a=a.contiguous(),
b=b.contiguous(),
initial_state_source=ssm_state,
initial_state_source=ssm_state1,
initial_state_indices=non_spec_state_indices_tensor,
cu_seqlens=non_spec_query_start_loc,
use_qk_l2norm_in_kernel=True,
softplus_beta=1.0,
softplus_threshold=20.0,
)

ssm_state2 = torch.ones(1, 8, 128, 128, dtype=torch.bfloat16).npu()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line duplicates the creation of the state tensor from line 27. This introduces a maintainability risk, as any future changes to the tensor's properties (e.g., shape, dtype) must be manually synchronized in two places, which is error-prone. To improve maintainability and make the intent clearer, you should create the tensor once and clone it for the second use.

Suggested change
ssm_state2 = torch.ones(1, 8, 128, 128, dtype=torch.bfloat16).npu()
ssm_state2 = ssm_state1.clone()

g, beta = fused_gdn_gating(a_log, a, b, dt_bias)
g_non_spec = g
beta_non_spec = beta
Expand All @@ -52,7 +53,7 @@ def test_triton_fusion_ops():
v=v,
g=g_non_spec,
beta=beta_non_spec,
initial_state=ssm_state,
initial_state=ssm_state2,
inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc,
ssm_state_indices=non_spec_state_indices_tensor,
Expand Down
Loading