Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tester/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ def get_arg(api_config, arg_pos, arg_name, default=None):
# some accuracy error can be considered tolerable
special_accuracy_atol_rtol = {
# "API": (atol, rtol),
"paddle.incubate.nn.functional.fused_bias_act": (1, 1e-2)
}

torch_error_skip = frozenset(
Expand Down
71 changes: 41 additions & 30 deletions tester/paddle_to_torch/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1866,6 +1866,29 @@ def fused_bias_act(
) -> torch.Tensor:
import torch.nn.functional as F

def quant_helper_func(input, scale, round_type, max_bound, min_bound):
quant_value = max_bound * scale * input

if round_type == 0:
quant_value = torch.round(quant_value)
else:
quant_value = torch.where(quant_value >= 0, torch.ceil(quant_value - 0.5), torch.floor(quant_value + 0.5))

quant_value = torch.clamp(quant_value, min=min_bound, max=max_bound)

return quant_value

def swiglu(x):
x, gate = x.chunk(2, dim=-1)
return x * torch.sigmoid(x) * gate

def geglu(x):
x, gate = x.chunk(2, dim=-1)
return F.gelu(x) * gate

if dequant_scales is not None:
x = x * dequant_scales

if compute_dtype != 'default':
if compute_dtype == 'fp16':
compute_dtype = 'float16'
Expand All @@ -1877,30 +1900,10 @@ def fused_bias_act(
x = x.to(getattr(torch, compute_dtype))
else:
x = x.float() if not x.is_floating_point() else x
if dequant_scales is not None:
dequant_scales = dequant_scales.to(x.dtype)
x = x * dequant_scales

if bias is not None:
bias = bias.to(x.dtype)
x = x + bias
if shift is not None:
repeat_factor = x.shape[-1] // shift.shape[-1]
shift = shift.repeat(repeat_factor)
shift = shift.to(x.dtype)
x = x + shift
if smooth is not None:
repeat_factor = x.shape[-1] // smooth.shape[-1]
smooth = smooth.repeat(repeat_factor)
smooth = smooth.to(x.dtype)
x = x * smooth

def swiglu(x):
x, gate = x.chunk(2, dim=-1)
return x * torch.sigmoid(x) * gate

def geglu(x):
x, gate = x.chunk(2, dim=-1)
return F.gelu(x) * gate

act_method = act_method.lower()
if act_method == 'gelu':
Expand All @@ -1917,17 +1920,25 @@ def geglu(x):
x = geglu(x)
else:
raise ValueError(f"Unsupported activation method: {act_method}")

if shift is not None:
repeat_factor = x.shape[-1] // shift.shape[-1]
shift = shift.repeat(repeat_factor)
shift = shift.to(x.dtype)
x = x + shift

if smooth is not None:
repeat_factor = x.shape[-1] // smooth.shape[-1]
smooth = smooth.repeat(repeat_factor)
smooth = smooth.to(x.dtype)
x = x * smooth

if quant_scale > 0:
x = x / quant_scale
if quant_round_type == 0:
x = torch.round(x) # Round to nearest, ties to even
elif quant_round_type == 1:
x = torch.where(x >= 0, torch.ceil(x - 0.5), torch.floor(x + 0.5))
else:
raise ValueError(f"Unsupported quant_round_type: {quant_round_type}")
x = x * quant_scale
x = torch.clamp(x, min=quant_min_bound, max=quant_max_bound)
x = quant_helper_func(x, quant_scale, quant_round_type, quant_max_bound, quant_min_bound)
print("after quant", x)

x = x.to(getattr(torch, "int8"))

return x
"""
core = "result = fused_bias_act(**kwargs)"
Expand Down