Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion custom_ops/gpu_ops/machete/machete_mm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
std::optional<paddle::Tensor> const& maybe_token_scales,
std::string maybe_schedule) {
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
std::optional<int64_t> maybe_group_size_opt;
std::optional<int64_t> maybe_group_size_opt = std::optional<int64_t>(maybe_group_size);
std::optional<std::string> maybe_schedule_opt;
if (maybe_schedule == "") {
maybe_schedule_opt = std::nullopt;
} else {
maybe_schedule_opt = std::optional<std::string>(maybe_schedule);
}
return machete::mm_dispatch({.A = A,
.B = B,
Expand Down Expand Up @@ -63,6 +65,8 @@ std::vector<paddle::Tensor> MacheteMMKernel(
paddle::DataType maybe_out_type;
if (b_type_str == "uint4b8") {
b_type_id = machete::kU4B8.id();
} else if (b_type_str == "uint8b128") {
b_type_id = machete::kU8B128.id();
} else {
PADDLE_ENFORCE(false, "b_type_str not supported!");
}
Expand Down
2 changes: 2 additions & 0 deletions custom_ops/gpu_ops/machete/machete_prepack_B.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ std::vector<paddle::Tensor> MachetePrepackBKernel(

if (b_type_str == "uint4b8") {
b_type_id = machete::kU4B8.id();
} else if (b_type_str == "uint8b128") {
b_type_id = machete::kU8B128.id();
} else {
PADDLE_ENFORCE(false, "b_type_str not supported!");
}
Expand Down
19 changes: 13 additions & 6 deletions fastdeploy/model_executor/layers/quantization/ops/machete_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def quantize_weights(
w_s: Scales (None if `group_size` is None).
"""
assert paddle.is_floating_point(w), "w must be float type"
assert quant_type in ["uint4", "uint4b8"], "only support quant_type = uint4, uint4b8"
assert quant_type in ["uint4b8", "uint8b128"], "only support quant_type = uint4b8, uint8b128"

orig_device = w.place
size_k, size_n = w.shape
Expand All @@ -103,8 +103,12 @@ def quantize_weights(
max_val = paddle.max(w, axis=0, keepdim=True)
min_val = paddle.min(w, axis=0, keepdim=True)

max_q_val = float(7.0)
min_q_val = float(-8.0)
if quant_type == "uint4b8":
max_q_val = float(7.0)
min_q_val = float(-8.0)
else:
max_q_val = float(127.0)
min_q_val = float(-128.0)

w_s = paddle.ones([1], dtype=paddle.float32) # unscaled case

Expand All @@ -124,18 +128,20 @@ def quantize_weights(
# w_q += quant_type.bias
if quant_type == "uint4b8":
w_q += 8
else:
w_q += 128

# Restore original shapes
if group_size is not None and group_size < size_k:

def reshape_w(w_tensor):
w_tensor = w_tensor.reshape([group_size, -1, size_n])
w_tensor = w_tensor.transpose([1, 0, 2])
w_tensor = w_tensor.reshape([size_k, size_n])
w_tensor = w_tensor.reshape([size_k, size_n]).contiguous()
return w_tensor

w_q = reshape_w(w_q)
w_s = w_s.reshape([-1, size_n])
w_s = w_s.reshape([-1, size_n]).contiguous()

# Move tensors back to original device
w_q = w_q.to(orig_device)
Expand All @@ -153,7 +159,8 @@ def machete_quantize_and_pack(
group_size: int = -1,
):
w_q, w_s = quantize_weights(w, group_size, quant_type=quant_type)
w_q = pack_rows(w_q, 4, *w_q.shape)
num_bits = 4 if quant_type == "uint4b8" else 8
w_q = pack_rows(w_q, num_bits, *w_q.shape)
w_q_col = w_q.transpose([1, 0]).contiguous() # convert to col major
w_q_prepack = machete_prepack_B(
w_q_col,
Expand Down
12 changes: 7 additions & 5 deletions fastdeploy/model_executor/layers/quantization/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
)

if (
self.name() == "wint4"
and _ENABLE_MACHETE
_ENABLE_MACHETE
and envs.FD_USE_MACHETE == "1"
and layer.weight_shape[1]
and layer.weight_shape[1] % 128 == 0
and not layer.add_bias
):
return MacheteWeightOnlyLinearMethod(self)
return GPUWeightOnlyLinearMethod(self)
Expand Down Expand Up @@ -230,6 +230,8 @@ def create_weights(self, layer, **extra_weight_attrs):
weight_scale_shape = [1, layer.weight_shape[1]]
if self.quant_config.name() == "wint4":
layer.weight_shape[0] //= 8
else:
layer.weight_shape[0] //= 4
layer.weight_dtype = "int32"
else:
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
Expand Down Expand Up @@ -282,7 +284,7 @@ def process_weights_after_loading(self, layer) -> None:
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
w=layer.weight,
atype=layer._dtype,
quant_type="uint4b8",
quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
)
else:
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
Expand Down Expand Up @@ -387,7 +389,7 @@ def process_loaded_weights(self, layer, weight) -> None:
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
w=weight,
atype=layer._dtype,
quant_type="uint4b8",
quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
)
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
Expand All @@ -400,7 +402,7 @@ def apply(self, layer, x):
x,
w_prepack=layer.weight,
w_g_s=layer.weight_scale,
weight_dtype="uint4b8",
weight_dtype="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128",
)

return linear_out
128 changes: 98 additions & 30 deletions tests/operators/test_machete_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def convert_uint16_to_float(in_list):
not core.is_compiled_with_cuda() or get_sm_version() < 90,
"machete only support sm90.",
)
class WeightOnlyLinearTestCase(unittest.TestCase):
class WeightOnlyInt4LinearTestCase(unittest.TestCase):
def config(self):
self.dtype = "float16"
self.rtol = 1e-5
self.atol = 1e-2
self.atol = 1.3e-1
self.bias = False
self.batch = 1
self.token = 512
Expand All @@ -77,11 +77,10 @@ def config(self):
self.weight_dtype = "int4"
self.static = False
self.group_size = -1
self.machete_group_size = -1

def setUp(self):
self.config()
if self.dtype == "bfloat16" or self.weight_dtype == "int4":
self.atol = 1.3e-1
x = np.random.random((self.token, self.in_features))
self.x = paddle.to_tensor(x, dtype=self.dtype)
if self.bias:
Expand Down Expand Up @@ -111,29 +110,30 @@ def get_linear_out(self):
return out.numpy()

def get_weight_only_linear_out(self):
for i in range(10):
out = Q.weight_only_linear(
self.x,
self.weight,
bias=self.bias,
weight_scale=self.weight_scale,
weight_dtype=self.weight_dtype,
group_size=self.group_size,
)
out = Q.weight_only_linear(
self.x,
self.weight,
bias=self.bias,
weight_scale=self.weight_scale,
weight_dtype=self.weight_dtype,
group_size=self.group_size,
)
return out.numpy()

def get_machete_weight_only_linear_out(self):
w_q, w_s = machete_quantize_and_pack(
w=self.float_weight.cuda(),
atype=self.dtype,
quant_type="uint4b8",
quant_type="uint4b8" if self.weight_dtype == "int4" else "uint8b128",
group_size=self.machete_group_size,
)

out = machete_wint_mm(
self.x,
w_prepack=w_q,
w_g_s=w_s, # group scales
weight_dtype="uint4b8", # weight_dtype
weight_dtype="uint4b8" if self.weight_dtype == "int4" else "uint8b128", # weight_dtype
group_size=self.machete_group_size,
)
return out.numpy()

Expand All @@ -149,26 +149,94 @@ def test_weight_only_linear(self):
np.testing.assert_allclose(out_paddle, out_machete, rtol=self.rtol, atol=self.atol)


M = [32, 128]
K_N = [[2048, 4096]]
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_sm_version() < 90,
"machete only support sm90.",
)
class WeightOnlyInt8LinearTestCase(unittest.TestCase):
def config(self):
self.dtype = "float16"
self.rtol = 1e-5
self.atol = 1e-1
self.bias = False
self.batch = 1
self.token = 512
self.in_features = 7168
self.out_features = 1024
self.weight_dtype = "int8"
self.static = False
self.group_size = -1
self.machete_group_size = 128

def setUp(self):
self.config()
x = np.random.random((self.token, self.in_features))
self.x = paddle.to_tensor(x, dtype=self.dtype)
if self.bias:
bias_attr = base.ParamAttr(
trainable=False,
regularizer=None,
initializer=paddle.nn.initializer.Constant(value=1.0),
)
else:
bias_attr = None
set_default_dtype(self.dtype)
self.linear = paddle.nn.Linear(self.in_features, self.out_features, bias_attr=bias_attr)

def make_case(m, k, n):
class Case(WeightOnlyLinearTestCase):
def config(self, _m=m, _k=k, _n=n):
super().config()
self.token = m
self.in_features = k
self.out_features = n
self.bias = self.linear.bias
self.weight = self.linear.weight
self.float_weight = self.linear.weight
self.weight_scale = None

self.weight, self.weight_scale = Q.weight_quantize(
(self.float_weight.cuda() if self.weight_dtype == "int8" else self.weight.cpu()),
algo=("weight_only_int8" if self.weight_dtype == "int8" else "weight_only_int4"),
group_size=self.group_size,
)

def get_linear_out(self):
out = self.linear(self.x)
return out.numpy()

def get_weight_only_linear_out(self):
out = Q.weight_only_linear(
self.x,
self.weight,
bias=self.bias,
weight_scale=self.weight_scale,
weight_dtype=self.weight_dtype,
group_size=self.group_size,
)
return out.numpy()

Case.name = f"WeightOnlyLinearTestCase{m}{k}{n}"
return Case
def get_machete_weight_only_linear_out(self):
w_q, w_s = machete_quantize_and_pack(
w=self.float_weight.cuda(),
atype=self.dtype,
quant_type="uint4b8" if self.weight_dtype == "int4" else "uint8b128",
group_size=self.machete_group_size,
)

out = machete_wint_mm(
self.x,
w_prepack=w_q,
w_g_s=w_s, # group scales
weight_dtype="uint4b8" if self.weight_dtype == "int4" else "uint8b128", # weight_dtype
group_size=self.machete_group_size,
)
return out.numpy()

def test_weight_only_linear(self):
out_expect = self.get_linear_out()
# out_paddle = self.get_weight_only_linear_out()
out_machete = self.get_machete_weight_only_linear_out()

if self.dtype == "bfloat16":
# out_paddle = convert_uint16_to_float(out_paddle)
out_expect = convert_uint16_to_float(out_expect)
out_machete = convert_uint16_to_float(out_machete)
np.testing.assert_allclose(out_expect, out_machete, rtol=self.rtol, atol=self.atol)

for k, n in K_N:
for m in M:
cls = make_case(m, k, n)
globals()[cls.name] = cls

if __name__ == "__main__":
unittest.main()
Loading