Skip to content

Commit 6267a2b

Browse files
newwayTravis-Lee
andauthored
[XPU] Support int31 weight dynamic quantization for fc and conv2d (#59981) (#67058)
Co-authored-by: Travis-Lee <[email protected]>
1 parent 736a253 commit 6267a2b

File tree

6 files changed

+89
-6
lines changed

6 files changed

+89
-6
lines changed

paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc

+13
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,19 @@ void Conv2dXPUFusePass::CreateFusionWeightsAndBias(
763763
false,
764764
weight_scale,
765765
true);
766+
} else if (quant_post_type.find("conv2d") != quant_post_type.end() &&
767+
quant_post_type.find("conv2d")->second == 4) {
768+
VLOG(5) << "Use int31 per-tensor weight";
769+
PrepareWeight<float, float>(graph,
770+
scope,
771+
block,
772+
conv_filter_replicated_node,
773+
&filter_intx,
774+
&filter_max,
775+
&scale_max,
776+
false,
777+
weight_scale,
778+
false);
766779
} else if (quant_post_type.find("conv2d") != quant_post_type.end() &&
767780
quant_post_type.find("conv2d")->second == 0 ||
768781
quant_post_type.find("conv2d") != quant_post_type.end() &&

paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc

+13
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,19 @@ void FcXPUFusePass::CreateFusionWeightsAndBias(
572572
!transpose_w,
573573
weight_scale,
574574
true);
575+
} else if (quant_post_type.find("fc") != quant_post_type.end() &&
576+
quant_post_type.find("fc")->second == 4) {
577+
VLOG(5) << "Use int31 per-tensor weight";
578+
PrepareWeight<float, float>(graph,
579+
scope,
580+
block,
581+
mul_w_replicated_node,
582+
&filter_intx,
583+
&filter_max,
584+
&scale_max,
585+
!transpose_w,
586+
weight_scale,
587+
false);
575588
} else if (quant_post_type.find("fc") != quant_post_type.end() &&
576589
quant_post_type.find("fc")->second == 0 ||
577590
quant_post_type.find("fc") != quant_post_type.end() &&

paddle/fluid/framework/ir/xpu/pass_utils.cc

+12
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,18 @@ void PrepareWeight(Graph* graph,
256256
}
257257
}
258258

259+
template void PrepareWeight<float, float>(
260+
Graph* graph,
261+
Scope* scope,
262+
BlockDesc* block,
263+
Node* weight,
264+
Node** dst_weight,
265+
Node** dst_weight_max,
266+
Node** dst_scale_max,
267+
bool transpose,
268+
const std::vector<float>& weight_scales,
269+
bool per_channel_quant = false);
270+
259271
template void PrepareWeight<float, int16_t>(
260272
Graph* graph,
261273
Scope* scope,

paddle/fluid/framework/ir/xpu/quant_utils.cc

+47-6
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,16 @@ static void QuantFP32ToIntX(const float* src_ptr,
245245
LOG(FATAL) << "Not support.";
246246
}
247247

248+
template <>
249+
void QuantFP32ToIntX<float>(const float* src_ptr,
250+
float* dst_ptr,
251+
float max_val,
252+
int numel) {
253+
for (int i = 0; i < numel; i++) {
254+
dst_ptr[i] = static_cast<float>(src_ptr[i]);
255+
}
256+
}
257+
248258
template <>
249259
void QuantFP32ToIntX<int16_t>(const float* src_ptr,
250260
int16_t* dst_ptr,
@@ -364,16 +374,16 @@ void ConvertWithoutQuant(phi::DenseTensor* weight,
364374
phi::DenseTensor* scale_max,
365375
bool transpose,
366376
const std::vector<float>& weight_scales) {
367-
PADDLE_ENFORCE_EQ(
368-
weight_scales.empty(),
369-
false,
370-
platform::errors::InvalidArgument(
371-
"ConvertWithoutQuant is not allowed weight scales is empty!"));
372377
if (transpose) {
373378
Transpose2D(weight);
374379
}
375380
bool per_tensor_quant = weight_scales.size() == 1;
376381
if (std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value) {
382+
PADDLE_ENFORCE_EQ(
383+
weight_scales.empty(),
384+
false,
385+
platform::errors::InvalidArgument(
386+
"ConvertWithoutQuant is not allowed weight scales is empty!"));
377387
auto* cpu_ctx = static_cast<phi::CPUContext*>(
378388
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
379389
if (per_tensor_quant) {
@@ -400,8 +410,32 @@ void ConvertWithoutQuant(phi::DenseTensor* weight,
400410
weight_scales.data(),
401411
weight_scales.size() * sizeof(float));
402412
}
413+
} else if (std::is_same<T, float>::value) {
414+
// Convert fp16 to fp32
415+
phi::DenseTensor weight_fp32;
416+
CastToFp32(weight, &weight_fp32);
417+
// Find max
418+
int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1);
419+
int size = weight_fp32.numel();
420+
auto* weight_data = weight_fp32.data<float>();
421+
float max_val = FindMaxAbs(weight_data, size);
422+
std::vector<float> max_vec(max_ptr_size, max_val);
423+
weight_max->set_type(phi::DataType::FLOAT32);
424+
weight_max->Resize({max_ptr_size});
425+
auto* cpu_ctx = static_cast<phi::CPUContext*>(
426+
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
427+
memcpy(cpu_ctx->Alloc<float>(weight_max),
428+
max_vec.data(),
429+
max_ptr_size * sizeof(float));
430+
431+
// Quant
432+
weight->set_type(phi::DataType::FLOAT32);
433+
weight->Resize(weight_fp32.dims());
434+
QuantFP32ToIntX<float>(
435+
weight_data, cpu_ctx->Alloc<float>(weight), max_val, size);
403436
} else {
404-
LOG(FATAL) << "Only support int8<->int8 and int16<->int16 convert.";
437+
LOG(FATAL)
438+
<< "Only support float<->int31, int8<->int8 and int16<->int16 convert.";
405439
}
406440
}
407441

@@ -424,6 +458,13 @@ template void ConvertWithoutQuant<int8_t>(
424458
bool transpose,
425459
const std::vector<float>& weight_scales);
426460

461+
template void ConvertWithoutQuant<float>(
462+
phi::DenseTensor* weight,
463+
phi::DenseTensor* weight_max,
464+
phi::DenseTensor* scale_max,
465+
bool transpose,
466+
const std::vector<float>& weight_scales);
467+
427468
bool IsPerTensorQuant(const std::vector<float>& weight_max) {
428469
bool per_tensor = true;
429470
PADDLE_ENFORCE_GT(

paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc

+2
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ void Conv2dXPUKernel(const Context& ctx,
221221
DataTypeToString(filter.dtype()),
222222
DataTypeToString(out_dtype)));
223223
}
224+
} else if (filter.dtype() == DataType::FLOAT32) {
225+
CONV2D_XPU_KERNEL_IMPL(float, float, float, int32_t);
224226
} else {
225227
PADDLE_THROW(phi::errors::Unimplemented(
226228
"Not support x_dtype is %s, filter_dtype is %s and out_dtype is %s.",

paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc

+2
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ void FcXPUKernel(const Context& ctx,
165165
DataTypeToString(w.dtype()),
166166
DataTypeToString(out_dtype)));
167167
}
168+
} else if (w.dtype() == DataType::FLOAT32) {
169+
FC_XPU_KERNEL_IMPL(float, float, float, int32_t);
168170
} else {
169171
PADDLE_THROW(phi::errors::Unimplemented(
170172
"Not support x_dtype is %s, w_dtype is %s and out_dtype is %s.",

0 commit comments

Comments
 (0)