diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index f4e5f7fc2..2e9c1ad4c 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -425,8 +425,7 @@ def forward( assert mask.dtype != torch.bool mask = mask.unsqueeze(1) # matrix_bd as a mask bias - mask = torch.where(mask == get_dtype_min(mask.dtype), mask, - matrix_bd / math.sqrt(self.d_k)) + mask = (matrix_bd + mask) / math.sqrt(self.d_k) output = torch.nn.functional.scaled_dot_product_attention( q_with_bias_u, k, diff --git a/wenet/utils/common.py b/wenet/utils/common.py index a6607c57f..7e01f2c7d 100644 --- a/wenet/utils/common.py +++ b/wenet/utils/common.py @@ -310,21 +310,9 @@ def log_add(*args) -> float: return a_max + lsp -def get_dtype_min( - dtype: torch.dtype, - eps16: float = torch.finfo(torch.float16).min, - eps32: float = torch.finfo(torch.float32).min, - eps64: float = torch.finfo(torch.float64).min, - epsbf16: float = torch.finfo(torch.bfloat16).min, -): - if dtype == torch.float16: - return eps16 - elif dtype == torch.float32: - return eps32 - elif dtype == torch.float64: - return eps64 - elif dtype == torch.bfloat16: - return epsbf16 +def get_dtype_min(dtype: torch.dtype, ): + if dtype in [torch.float32, torch.bfloat16, torch.float16]: + return -1e+10 else: raise RuntimeError(f"expected x to be floating-point, got {dtype}")