Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fp8 linear test code #9903

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 197 additions & 1 deletion paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def swiglu(x, y=None):
from paddlenlp.utils.tools import get_env_device

from .. import linear_utils
from ..linear_utils import Linear
#from ..linear_utils import Linear
from ..segment_parallel_utils import ReshardLayer
from ..utils import caculate_llm_flops
from .configuration import (
Expand Down Expand Up @@ -107,6 +107,202 @@ def swiglu(x, y=None):
]


def create_fp8_meta(num_gemms=1, amax_history_len=10):
"""
Create and initialize FP8TensorMeta
"""
fp8_meta = FP8TensorMeta(is_forward=True)
fp8_meta.prepare(num_gemms, amax_history_len)
return fp8_meta

def fp8_gemm(
A: paddle.Tensor,
A_scale_inv: paddle.Tensor,
A_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
A_dtype: tex.DType,
B: paddle.Tensor,
B_scale_inv: paddle.Tensor,
B_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
B_dtype: tex.DType,
out_dtype: paddle.dtype,
workspace: paddle.Tensor,
transa: bool = True,
gelu: bool = False,
accumulate: bool = False,
out: Optional[paddle.Tensor] = None,
out_index=None,
fp8_meta_tensor: FP8TensorMeta = None,
bias: Optional[paddle.Tensor] = None,
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
) -> paddle.Tensor:
"""TN layout GEMM with fp8 inputs."""

if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_index is not None

if out is None:
if accumulate:
out = paddle.zeros(
shape=[
B.shape[0],
A.shape[0],
],
dtype=out_dtype,
)
else:
out = paddle.empty(
shape=[
B.shape[0],
A.shape[0],
],
dtype=out_dtype,
)

# Use bfloat16 as default bias_dtype
bias_dtype = paddle.bfloat16 if bias is None else bias.dtype
if gelu:
gelu_input = paddle.empty_like(out, dtype=bias_dtype)
else:
gelu_input = None
bias_dtype = TE_DType[bias_dtype]

out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype

tex.te_gemm(
A,
A_scale_inv,
B,
B_scale_inv,
bias if use_bias else None,
out,
None if out_index is None else fp8_meta_tensor.scale,
None if out_index is None else fp8_meta_tensor.amax_history,
gelu_input, # this is pre_gelu_out
workspace,
A_fp8_tensor.value,
B_fp8_tensor.value,
0 if out_index is None else out_index,
int(A_dtype),
int(B_dtype),
int(out_dtype),
int(bias_dtype),
True, # transa
False, # transb
False, # grad
workspace.shape[0],
accumulate,
use_split_accumulator,
0, # math_sm_count
)

return out, gelu_input



fp8_dtype = tex.DType.kFloat8E4M3
out_dtype = paddle.bfloat16
fp8_meta = create_fp8_meta(num_gemms=1)



def quant_fn(x):
scale = paddle.max( paddle.abs(x.cast("float32")) ).reshape([1])
#print("sacle ", scale)
out_fp32 = x.cast("float32") / scale

out_fp32 = paddle.clip( out_fp32, min=-448, max = 448 )

return out_fp32.cast("float8_e4m3fn"), scale

def fp8_matmul(x, y):
x_fp8, scale_x = quant_fn(x)
y_fp8, scale_y = quant_fn(y)
all_scale = paddle.concat( [scale_x, scale_y])

workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")

actual_out, _ = fp8_gemm(
y_fp8,
all_scale,
FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype,
x_fp8,
all_scale,
FP8FwdTensors.GEMM1_INPUT,
fp8_dtype,
out_dtype,
workspace,
False,
)

return actual_out



class LinearFunc(PyLayer):
@staticmethod
def forward(ctx, x, y):
# with paddle.no_grad():
# out = paddle.matmul(x, y)

x_orig_shape = x.shape
out = fp8_matmul(x.reshape( [-1, x_orig_shape[-1]]), y.T)

out = out.reshape( x_orig_shape[:-1] + [-1] )
#print("fp8 matmul", x.shape, y.shape, out.shape)

ctx.save_for_backward(x, y)
return out


@staticmethod
def backward(ctx, dout):
x, weight = ctx.saved_tensor()

# dx
#dx = paddle.matmul( dout, weight, False, True)
x_orig_shape = x.shape
dx = fp8_matmul( dout.reshape( [-1, weight.shape[-1]]), weight)
dx = dx.reshape( x_orig_shape )


# dw
# dw = paddle.matmul( dout.reshape([-1, dout.shape[-1]] ), x.reshape( [-1, x.shape[-1]] ), True, False).transpose([1,0])
# dw = paddle.matmul( x.reshape( [-1, x.shape[-1]]), dout.reshape([-1, dout.shape[-1]] ) , True, False)
x_2d = x.reshape( [-1, x.shape[-1]]).T
dout_2d = dout.reshape([-1, dout.shape[-1]] ).T
dw = fp8_matmul( x_2d, dout_2d)


return dx, dw


class Linear( paddle.nn.Layer):
def __init__(
self,
in_features: int,
out_features: int,
bias_attr: bool = False
) -> None:
super().__init__()
self._dtype = self._helper.get_default_dtype()

self.weight = self.create_parameter(
shape=[in_features, out_features],
dtype=self._dtype,
is_bias=False,
)

print("self define linear !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!", self.weight.dtype )


def forward(self, x):
#print( self.weight.dtype )
return LinearFunc.apply(x, self.weight)


def _get_interleave(n):
def _get_interleave_power_of_2(n):
start = 2 ** (-(2 ** -(np.log2(n) - 3)))
Expand Down
Loading