diff --git a/tester/base.py b/tester/base.py index 59caac64..55557ea4 100644 --- a/tester/base.py +++ b/tester/base.py @@ -594,6 +594,7 @@ def get_arg(api_config, arg_pos, arg_name, default=None): # some accuracy error can be considered tolerable special_accuracy_atol_rtol = { # "API": (atol, rtol), + "paddle.incubate.nn.functional.fused_bias_act": (1, 1e-2) } torch_error_skip = frozenset( diff --git a/tester/paddle_to_torch/rules.py b/tester/paddle_to_torch/rules.py index 7d579ae9..d97bc1af 100644 --- a/tester/paddle_to_torch/rules.py +++ b/tester/paddle_to_torch/rules.py @@ -1866,6 +1866,17 @@ def fused_bias_act( ) -> torch.Tensor: import torch.nn.functional as F + def swiglu(x): + x, gate = x.chunk(2, dim=-1) + return x * torch.sigmoid(x) * gate + + def geglu(x): + x, gate = x.chunk(2, dim=-1) + return F.gelu(x) * gate + + if dequant_scales is not None: + x = x * dequant_scales + if compute_dtype != 'default': if compute_dtype == 'fp16': compute_dtype = 'float16' @@ -1877,30 +1888,10 @@ def fused_bias_act( x = x.to(getattr(torch, compute_dtype)) else: x = x.float() if not x.is_floating_point() else x - if dequant_scales is not None: - dequant_scales = dequant_scales.to(x.dtype) - x = x * dequant_scales + if bias is not None: bias = bias.to(x.dtype) x = x + bias - if shift is not None: - repeat_factor = x.shape[-1] // shift.shape[-1] - shift = shift.repeat(repeat_factor) - shift = shift.to(x.dtype) - x = x + shift - if smooth is not None: - repeat_factor = x.shape[-1] // smooth.shape[-1] - smooth = smooth.repeat(repeat_factor) - smooth = smooth.to(x.dtype) - x = x * smooth - - def swiglu(x): - x, gate = x.chunk(2, dim=-1) - return x * torch.sigmoid(x) * gate - - def geglu(x): - x, gate = x.chunk(2, dim=-1) - return F.gelu(x) * gate act_method = act_method.lower() if act_method == 'gelu': @@ -1917,17 +1908,31 @@ def geglu(x): x = geglu(x) else: raise ValueError(f"Unsupported activation method: {act_method}") + + if shift is not None: + repeat_factor = x.shape[-1] // shift.shape[-1] + shift = shift.repeat(repeat_factor) + shift = shift.to(x.dtype) + x = x + shift + + if smooth is not None: + repeat_factor = x.shape[-1] // smooth.shape[-1] + smooth = smooth.repeat(repeat_factor) + smooth = smooth.to(x.dtype) + x = x * smooth if quant_scale > 0: - x = x / quant_scale + x = quant_max_bound * quant_scale * x if quant_round_type == 0: - x = torch.round(x) # Round to nearest, ties to even + x = torch.round(x) elif quant_round_type == 1: x = torch.where(x >= 0, torch.ceil(x - 0.5), torch.floor(x + 0.5)) else: raise ValueError(f"Unsupported quant_round_type: {quant_round_type}") - x = x * quant_scale x = torch.clamp(x, min=quant_min_bound, max=quant_max_bound) + + x = x.to(torch.int8) + return x """ core = "result = fused_bias_act(**kwargs)"