diff --git a/python/sgl_kernel_npu/sgl_kernel_npu/fla/chunk.py b/python/sgl_kernel_npu/sgl_kernel_npu/fla/chunk.py index b7d50b673..220bfcc2e 100644 --- a/python/sgl_kernel_npu/sgl_kernel_npu/fla/chunk.py +++ b/python/sgl_kernel_npu/sgl_kernel_npu/fla/chunk.py @@ -234,7 +234,7 @@ def chunk_gated_delta_rule_npu( beta: torch.Tensor, scale: float = None, initial_state: torch.Tensor = None, - output_final_state: bool = False, + output_final_state: bool = True, cu_seqlens: Optional[torch.LongTensor] = None, head_first: bool = False, use_qk_l2norm_in_kernel: bool = False,