diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 641adef4846..deb40f7939d 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -295,7 +295,14 @@ def __call__( softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, + mdQ_semaphore: Optional[cute.Tensor] = None, + mdK_semaphore: Optional[cute.Tensor] = None, + mdV_semaphore: Optional[cute.Tensor] = None, ): + assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, ( + "determinism not supported yet for Sm90" + ) + self._check_type( *( t.element_type if t is not None else None