|
33 | 33 | #include <ATen/ops/_addmm_activation_native.h>
|
34 | 34 | #include <ATen/ops/_compute_linear_combination_native.h>
|
35 | 35 | #include <ATen/ops/_convert_weight_to_int4pack_for_cpu_native.h>
|
| 36 | +#include <ATen/ops/_dyn_quant_matmul_4bit_native.h> |
| 37 | +#include <ATen/ops/_dyn_quant_pack_4bit_weight_native.h> |
36 | 38 | #include <ATen/ops/_int_mm_native.h>
|
37 | 39 | #include <ATen/ops/_linalg_check_errors.h>
|
38 | 40 | #include <ATen/ops/_linalg_det.h>
|
@@ -3429,6 +3431,8 @@ Tensor kron(const Tensor& self, const Tensor& other) {
|
3429 | 3431 | DEFINE_DISPATCH(weight_to_int4pack_stub);
|
3430 | 3432 | DEFINE_DISPATCH(int4pack_mm_stub);
|
3431 | 3433 | DEFINE_DISPATCH(int8pack_mm_stub);
|
| 3434 | +DEFINE_DISPATCH(dyn_quant_pack_4bit_weight_stub); |
| 3435 | +DEFINE_DISPATCH(dyn_quant_matmul_4bit_stub); |
3432 | 3436 |
|
3433 | 3437 | Tensor _convert_weight_to_int4pack_cpu(
|
3434 | 3438 | const Tensor& in,
|
@@ -3492,6 +3496,69 @@ Tensor _weight_int4pack_mm_cpu(
|
3492 | 3496 | return C;
|
3493 | 3497 | }
|
3494 | 3498 |
|
| 3499 | +Tensor _dyn_quant_pack_4bit_weight_cpu( |
| 3500 | + const Tensor& weights, |
| 3501 | + const Tensor& scales_zeros, |
| 3502 | + const std::optional<Tensor>& bias, |
| 3503 | + const int64_t block_size, |
| 3504 | + const int64_t in_features, |
| 3505 | + const int64_t out_features) { |
| 3506 | + TORCH_CHECK( |
| 3507 | + weights.dtype() == at::kByte, __func__, " : expect weight to be kByte."); |
| 3508 | + TORCH_CHECK( |
| 3509 | + block_size == in_features || |
| 3510 | + (!(block_size % 32) && !(in_features % block_size)), |
| 3511 | + __func__, |
| 3512 | + ": Group size should be multiple of 32, in_features [", |
| 3513 | + in_features, |
| 3514 | + "]. Provided ", |
| 3515 | + block_size); |
| 3516 | + Tensor packed_weights = |
| 3517 | + at::empty(weights.sizes(), weights.options().dtype(at::kByte)); |
| 3518 | + dyn_quant_pack_4bit_weight_stub( |
| 3519 | + kCPU, |
| 3520 | + packed_weights, |
| 3521 | + weights, |
| 3522 | + scales_zeros, |
| 3523 | + bias, |
| 3524 | + out_features, |
| 3525 | + in_features, |
| 3526 | + block_size); |
| 3527 | + return packed_weights; |
| 3528 | +} |
| 3529 | + |
| 3530 | +Tensor _dyn_quant_matmul_4bit_cpu( |
| 3531 | + const Tensor& inp, |
| 3532 | + const Tensor& packed_weights, |
| 3533 | + const int64_t block_size, |
| 3534 | + const int64_t in_features, |
| 3535 | + const int64_t out_features) { |
| 3536 | + auto M = inp.size(0); |
| 3537 | + TORCH_CHECK( |
| 3538 | + inp.dtype() == kFloat, |
| 3539 | + __func__, |
| 3540 | + " : expect input to be 32-bit float tensor."); |
| 3541 | + TORCH_CHECK( |
| 3542 | + block_size == in_features || |
| 3543 | + (!(block_size % 32) && !(in_features % block_size)), |
| 3544 | + __func__, |
| 3545 | + ": Group size should be multiple of 32, in_features [", |
| 3546 | + in_features, |
| 3547 | + "]. Provided ", |
| 3548 | + block_size); |
| 3549 | + auto output = at::empty({M, out_features}, inp.options()); |
| 3550 | + dyn_quant_matmul_4bit_stub( |
| 3551 | + kCPU, |
| 3552 | + output, |
| 3553 | + inp, |
| 3554 | + packed_weights, |
| 3555 | + M, |
| 3556 | + out_features, |
| 3557 | + in_features, |
| 3558 | + block_size); |
| 3559 | + return output; |
| 3560 | +} |
| 3561 | + |
3495 | 3562 | Tensor _weight_int8pack_mm_cpu(
|
3496 | 3563 | const Tensor& A,
|
3497 | 3564 | const Tensor& B,
|
|
0 commit comments