@@ -245,6 +245,16 @@ static void QuantFP32ToIntX(const float* src_ptr,
245
245
LOG (FATAL) << " Not support." ;
246
246
}
247
247
248
+ template <>
249
+ void QuantFP32ToIntX<float >(const float * src_ptr,
250
+ float * dst_ptr,
251
+ float max_val,
252
+ int numel) {
253
+ for (int i = 0 ; i < numel; i++) {
254
+ dst_ptr[i] = static_cast <float >(src_ptr[i]);
255
+ }
256
+ }
257
+
248
258
template <>
249
259
void QuantFP32ToIntX<int16_t >(const float * src_ptr,
250
260
int16_t * dst_ptr,
@@ -364,16 +374,16 @@ void ConvertWithoutQuant(phi::DenseTensor* weight,
364
374
phi::DenseTensor* scale_max,
365
375
bool transpose,
366
376
const std::vector<float >& weight_scales) {
367
- PADDLE_ENFORCE_EQ (
368
- weight_scales.empty (),
369
- false ,
370
- platform::errors::InvalidArgument (
371
- " ConvertWithoutQuant is not allowed weight scales is empty!" ));
372
377
if (transpose) {
373
378
Transpose2D (weight);
374
379
}
375
380
bool per_tensor_quant = weight_scales.size () == 1 ;
376
381
if (std::is_same<T, int8_t >::value || std::is_same<T, int16_t >::value) {
382
+ PADDLE_ENFORCE_EQ (
383
+ weight_scales.empty (),
384
+ false ,
385
+ platform::errors::InvalidArgument (
386
+ " ConvertWithoutQuant is not allowed weight scales is empty!" ));
377
387
auto * cpu_ctx = static_cast <phi::CPUContext*>(
378
388
platform::DeviceContextPool::Instance ().Get (phi::CPUPlace ()));
379
389
if (per_tensor_quant) {
@@ -400,8 +410,32 @@ void ConvertWithoutQuant(phi::DenseTensor* weight,
400
410
weight_scales.data (),
401
411
weight_scales.size () * sizeof (float ));
402
412
}
413
+ } else if (std::is_same<T, float >::value) {
414
+ // Convert fp16 to fp32
415
+ phi::DenseTensor weight_fp32;
416
+ CastToFp32 (weight, &weight_fp32);
417
+ // Find max
418
+ int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size (-1 );
419
+ int size = weight_fp32.numel ();
420
+ auto * weight_data = weight_fp32.data <float >();
421
+ float max_val = FindMaxAbs (weight_data, size);
422
+ std::vector<float > max_vec (max_ptr_size, max_val);
423
+ weight_max->set_type (phi::DataType::FLOAT32);
424
+ weight_max->Resize ({max_ptr_size});
425
+ auto * cpu_ctx = static_cast <phi::CPUContext*>(
426
+ platform::DeviceContextPool::Instance ().Get (phi::CPUPlace ()));
427
+ memcpy (cpu_ctx->Alloc <float >(weight_max),
428
+ max_vec.data (),
429
+ max_ptr_size * sizeof (float ));
430
+
431
+ // Quant
432
+ weight->set_type (phi::DataType::FLOAT32);
433
+ weight->Resize (weight_fp32.dims ());
434
+ QuantFP32ToIntX<float >(
435
+ weight_data, cpu_ctx->Alloc <float >(weight), max_val, size);
403
436
} else {
404
- LOG (FATAL) << " Only support int8<->int8 and int16<->int16 convert." ;
437
+ LOG (FATAL)
438
+ << " Only support float<->int31, int8<->int8 and int16<->int16 convert." ;
405
439
}
406
440
}
407
441
@@ -424,6 +458,13 @@ template void ConvertWithoutQuant<int8_t>(
424
458
bool transpose,
425
459
const std::vector<float >& weight_scales);
426
460
461
+ template void ConvertWithoutQuant<float >(
462
+ phi::DenseTensor* weight,
463
+ phi::DenseTensor* weight_max,
464
+ phi::DenseTensor* scale_max,
465
+ bool transpose,
466
+ const std::vector<float >& weight_scales);
467
+
427
468
bool IsPerTensorQuant (const std::vector<float >& weight_max) {
428
469
bool per_tensor = true ;
429
470
PADDLE_ENFORCE_GT (
0 commit comments