Skip to content

Commit 4022983

Browse files
authored
[xpu] add weight_only_linear (#72602)
1 parent 876366b commit 4022983

File tree

4 files changed

+158
-13
lines changed

4 files changed

+158
-13
lines changed

paddle/fluid/framework/ir/xpu/weight_only_linear_xpu_pass.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,10 @@ void PermuteINT8WeightOnlyPass::ApplyImpl(ir::Graph* graph) const {
222222
PADDLE_ENFORCE_NOT_NULL(
223223
graph, common::errors::PreconditionNotMet("graph should not be null."));
224224
Init(name_scope_, graph);
225-
226-
ApplyPermuteINT8WeightOnly(graph);
225+
auto version = phi::backends::xpu::get_xpu_version(-1);
226+
if (version == phi::backends::xpu::XPUVersion::XPU2) {
227+
ApplyPermuteINT8WeightOnly(graph);
228+
}
227229
}
228230

229231
} // namespace ir

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,6 +1693,8 @@ XPUOpMap& get_kl3_ops() {
16931693
phi::DataType::BFLOAT16})},
16941694
{"warpctc_grad", XPUKernelSet({phi::DataType::FLOAT32})},
16951695
{"warpctc", XPUKernelSet({phi::DataType::FLOAT32})},
1696+
{"weight_only_linear",
1697+
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::BFLOAT16})},
16961698
{"where_index",
16971699
XPUKernelSet({phi::DataType::INT32,
16981700
phi::DataType::BOOL,

paddle/phi/kernels/fusion/xpu/weight_only_linear_kernel_xpu.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,20 @@
1717

1818
namespace phi {
1919
template <typename T, typename Context>
20-
void WeightOnlyLinearKernel(const Context& dev_ctx,
21-
const DenseTensor& x,
22-
const DenseTensor& weight,
23-
const paddle::optional<DenseTensor>& bias,
24-
const DenseTensor& weight_scale,
25-
const std::string& weight_dtype,
26-
const int32_t arch,
27-
const int32_t group_size,
28-
DenseTensor* out) {
20+
void WeightOnlyLinearXpuKernel(const Context& dev_ctx,
21+
const DenseTensor& x,
22+
const DenseTensor& weight,
23+
const paddle::optional<DenseTensor>& bias,
24+
const DenseTensor& weight_scale,
25+
const std::string& weight_dtype,
26+
const int32_t arch,
27+
const int32_t group_size,
28+
DenseTensor* out) {
2929
PADDLE_ENFORCE_EQ(
3030
weight_dtype,
3131
"int8",
3232
common::errors::Fatal(
33-
"WeightOnlyLinearKernel xpu just support int8 weight only"));
33+
"WeightOnlyLinearXpuKernel xpu just support int8 weight only"));
3434
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
3535
auto xpu_ctx = static_cast<const phi::XPUContext*>(&dev_ctx);
3636
dev_ctx.template Alloc<T>(out);
@@ -134,6 +134,6 @@ void WeightOnlyLinearKernel(const Context& dev_ctx,
134134
PD_REGISTER_KERNEL(weight_only_linear_xpu,
135135
XPU,
136136
ALL_LAYOUT,
137-
phi::WeightOnlyLinearKernel,
137+
phi::WeightOnlyLinearXpuKernel,
138138
phi::dtype::float16,
139139
phi::dtype::bfloat16) {}
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <xft/xdnn_plugin.h>
16+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
#ifdef PADDLE_WITH_XPU_XRE5
19+
#include "xblas/xblas_legacy_api.h"
20+
#endif
21+
22+
namespace phi {
23+
template <typename T, typename Context>
24+
void WeightOnlyLinearKernel(const Context& dev_ctx,
25+
const DenseTensor& x,
26+
const DenseTensor& weight,
27+
const paddle::optional<DenseTensor>& bias,
28+
const DenseTensor& weight_scale,
29+
const std::string& weight_dtype,
30+
const int32_t arch,
31+
const int32_t group_size,
32+
DenseTensor* out) {
33+
using XPUType = typename XPUTypeTrait<T>::Type;
34+
int64_t n = weight.dims()[0];
35+
int64_t k = weight.dims()[1];
36+
int64_t m = x.numel() / k;
37+
if (weight_dtype == "int4") {
38+
n = n * 2;
39+
}
40+
out->Resize({static_cast<int64_t>(m), static_cast<int64_t>(n)});
41+
dev_ctx.template Alloc<T>(out);
42+
43+
DenseTensor bias_fp32;
44+
if (bias.is_initialized() && bias.get().dtype() == phi::DataType::FLOAT16) {
45+
bias_fp32.Resize(bias.get().dims());
46+
dev_ctx.template Alloc<float>(&bias_fp32);
47+
int r = baidu::xpu::api::cast<XPUType, float>(
48+
dev_ctx.x_context(),
49+
reinterpret_cast<const XPUType*>(
50+
bias.get().data<phi::dtype::float16>()),
51+
bias_fp32.data<float>(),
52+
n);
53+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
54+
}
55+
auto input_x = reinterpret_cast<const XPUType*>(x.data<T>());
56+
auto input_y = reinterpret_cast<XPUType*>(out->data<T>());
57+
58+
baidu::xpu::xblas::FcFusionTensor<const XPUType> tensor_x{
59+
input_x, nullptr, m, k, k, false};
60+
baidu::xpu::xblas::FcFusionTensor<const XPUType> tensor_y_const{
61+
input_y, nullptr, m, n, n, false};
62+
baidu::xpu::xblas::FcFusionTensor<XPUType> tensor_y{
63+
input_y, nullptr, m, n, n, false};
64+
baidu::xpu::xblas::FcFusionEpilogue<float, float> epilogue{
65+
api::Activation_t::LINEAR,
66+
bias.is_initialized() ? (bias.get().dtype() == phi::DataType::FLOAT16
67+
? bias_fp32.data<float>()
68+
: bias.get().data<float>())
69+
: nullptr,
70+
nullptr,
71+
weight_scale.dims().size() != 0 ? weight_scale.data<float>() : nullptr,
72+
0,
73+
1,
74+
nullptr};
75+
76+
if (weight_dtype == "int8") {
77+
// using TGEMM=int8_wo_t;
78+
using TGEMM = float;
79+
baidu::xpu::xblas::FcFusionDesc<TGEMM, float, float> desc{1.0f, 0.0f};
80+
baidu::xpu::xblas::FcFusionTensor<const int8_t> tensor_w{
81+
reinterpret_cast<const int8_t*>(weight.data<int8_t>()),
82+
nullptr,
83+
n,
84+
k,
85+
k,
86+
true};
87+
int r1 = baidu::xpu::xblas::fc_fusion<XPUType,
88+
int8_t,
89+
XPUType,
90+
XPUType,
91+
TGEMM,
92+
float,
93+
float,
94+
float,
95+
float>(dev_ctx.x_context(),
96+
tensor_x,
97+
tensor_w,
98+
tensor_y_const,
99+
tensor_y,
100+
desc,
101+
epilogue);
102+
PD_CHECK(r1 == 0, "xblas::fc_fusion failed");
103+
} else if (weight_dtype == "int4") {
104+
// baidu::xpu::xblas::FcFusionDesc<int4_wo_int15, float, XPUType>
105+
// desc{1.0f, 0.0f};
106+
// baidu::xpu::xblas::FcFusionTensor<const int4_t> tensor_w{
107+
// reinterpret_cast<const int4_t*>(weight.data<int8_t>()),
108+
// nullptr,
109+
// n,
110+
// k,
111+
// k,
112+
// true};
113+
// int r1 = baidu::xpu::xblas::fc_fusion<XPUType,
114+
// int4_t,
115+
// XPUType,
116+
// XPUType,
117+
// int4_wo_int15, // int8_wo_t
118+
// float,
119+
// XPUType,
120+
// float,
121+
// float>(dev_ctx.x_context(),
122+
// tensor_x,
123+
// tensor_w,
124+
// tensor_y_const,
125+
// tensor_y,
126+
// desc,
127+
// epilogue);
128+
// PD_CHECK(r1 == 0, "xblas::fc_fusion failed");
129+
PD_THROW("unsupported weight_dtype=int4");
130+
} else {
131+
PD_THROW("unsupported weight_dtype: ", weight_dtype);
132+
}
133+
}
134+
} // namespace phi
135+
136+
PD_REGISTER_KERNEL(weight_only_linear,
137+
XPU,
138+
ALL_LAYOUT,
139+
phi::WeightOnlyLinearKernel,
140+
phi::dtype::float16,
141+
phi::dtype::bfloat16) {}

0 commit comments

Comments
 (0)