diff --git a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py index 68398d2813..53fe44ce55 100644 --- a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py +++ b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py @@ -3333,8 +3333,7 @@ def compute_group_1( gate_handle = load_gate_consumer.wait_and_advance() - max_coord = tTR_tCcShared[cute.size(tTR_tCcShared) - 1] - cumprod_total = sCumprod[max_coord[1], 0, gate_handle.index] + cumprod_total = sCumprod[sCumprod.shape[0] - 1, 0, gate_handle.index] valid_state = not is_first_chunk or self.use_initial_state if cutlass.const_expr(valid_state): diff --git a/tests/gdn/test_prefill_delta_rule.py b/tests/gdn/test_prefill_delta_rule.py index 85dac83efd..eab2a6314d 100644 --- a/tests/gdn/test_prefill_delta_rule.py +++ b/tests/gdn/test_prefill_delta_rule.py @@ -145,7 +145,7 @@ def _test_prefill_kernel( atol_kv = 5e-3 rtol_kv = 1e-3 else: - atol_o = 2e-3 + atol_o = 1e-3 rtol_o = 1e-3 atol_kv = 1e-3 rtol_kv = 1e-4