Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge LoCo with Zero++ #6730

Merged
merged 22 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions csrc/includes/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,36 @@ void launch_swizzled_quant(int8_t* q_data,
int devices_per_node,
cudaStream_t stream);

void launch_loco_swizzled_quant(int8_t* quantized_data,
float* quantized_scales,
const __half* uncompressed_data,
__half* error_feedback,
const float err_beta,
int num_bits,
quantize::Type quant_type,
int groups,
int elems_per_group,
int pipelining,
int nodes,
int devices_per_node,
cudaStream_t stream);

void launch_loco_dequant_reduce(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,
const float* input_scales,
int num_gpus,
int num_bits,
quantize::Type quant_type,
int out_groups,
int elems_per_out_group,
int elems_per_in_tensor,
int groups_per_in_tensor,
int elems_per_in_group,
__half* error_feedback,
const float err_beta,
cudaStream_t stream);

void launch_dequant_reduce(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,
Expand Down
128 changes: 51 additions & 77 deletions csrc/includes/quantization_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ constexpr int max_threads = 1024;
Class to hold the quantization parameters for a given tensor.
Holds the implementation of the quantization operation.
*/

template <Type qType, int numBits>
class Params {
public:
Expand Down Expand Up @@ -145,112 +146,85 @@ class Params<Type::Asymmetric, numBits> {
Group stats tracks the necessary statistics about the quantized group
to abstract the particulars for the main loop.
*/
template <Type qType>
class GroupStats {
public:
DS_D_INLINE void update(__half2 val);
// Helper functions
DS_D_INLINE __half h_abs(const __half& val) {
return __habs(val);
}

DS_D_INLINE void reduce(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp);
};
DS_D_INLINE __half2 h_abs(const __half2& val) {
return __habs2(val);
}

DS_D_INLINE float to_max_float(const __half& val) {
return __half2float(val);
}

DS_D_INLINE float to_min_float(const __half& val) {
return __half2float(val);
}

DS_D_INLINE float to_max_float(const __half2& val) {
const float2 partial_max = conversion::to<float2>(val);
return reduce::element<rop::Max>(partial_max.x, partial_max.y);
}

template <>
class GroupStats<Type::Symmetric> {
DS_D_INLINE float to_min_float(const __half2& val) {
const float2 partial_min = conversion::to<float2>(val);
return reduce::element<rop::Min>(partial_min.x, partial_min.y);
}

// GroupStats class template
template <Type qType, typename DataType = __half2>
class GroupStats;

// Symmetric Quantization
template <typename DataType>
class GroupStats<Type::Symmetric, DataType> {
public:
// Symmetric quantization only tracks the maximum absolute value
__half2 cur_max;
float max;
DataType cur_max;

/*
Technically, this would give bad results if there
are 0 values to process since the reduction would
give -inf instead of 0. We do not consider this
to be a reasonable edge case.
*/
DS_D_INLINE GroupStats() { cur_max = reduce::init<rop::Max, __half2>(); }
DS_D_INLINE GroupStats() { cur_max = reduce::init<rop::Max, DataType>(); }

/*
Updated the running absmax used to calculate params.
Function Arguments :
val : The __half2 value to update the running min and max with.
*/
DS_D_INLINE void update(__half2 val)
{
cur_max = reduce::element<rop::Max>(cur_max, __habs2(val));
DS_D_INLINE void update(DataType val) {
cur_max = reduce::element<rop::Max>(cur_max, h_abs(val));
}

/*
Function to return calculated quantization params.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
Function Arguments :
tb - Threadblock object. cg::thread_block
warp - Warp object. cg::thread_block_tile<hw_warp_size>
*/
template <int numBits, int threads_per_group>
DS_D_INLINE Params<Type::Symmetric, numBits> get_params(
cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp)
{
const float2 partial_max = conversion::to<float2>(cur_max);
float max = reduce::element<rop::Max>(partial_max.x, partial_max.y);

cg::thread_block_tile<hw_warp_size>& warp) {
float max = to_max_float(cur_max);
reduce::partitioned_block<rop::Max, threads_per_group>(tb, warp, max);
Params<Type::Symmetric, numBits> params(max);

return params;
}
};

template <>
class GroupStats<Type::Asymmetric> {
// Asymmetric Quantization
template <typename DataType>
class GroupStats<Type::Asymmetric, DataType> {
public:
__half2 cur_max;
__half2 cur_min;
DataType cur_max;
hwchen2017 marked this conversation as resolved.
Show resolved Hide resolved
DataType cur_min;

/*
Initialize cur_max to -inf, cur_min to inf since
we are doing a true range analysis.
*/
DS_D_INLINE GroupStats()
{
cur_max = reduce::init<rop::Max, __half2>();
cur_min = reduce::init<rop::Min, __half2>();
DS_D_INLINE GroupStats() {
cur_max = reduce::init<rop::Max, DataType>();
cur_min = reduce::init<rop::Min, DataType>();
}

/*
Updated the running min and max used to calculate params.
Function Arguments :
val : The __half2 value to update the running min and max with.
*/
DS_D_INLINE void update(__half2 val)
{
DS_D_INLINE void update(DataType val) {
cur_max = reduce::element<rop::Max>(cur_max, val);
cur_min = reduce::element<rop::Min>(cur_min, val);
}

/*
Function to return calculated quantization params.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
Function Arguments :
tb - Threadblock object. cg::thread_block
warp - Warp object. cg::thread_block_tile<hw_warp_size>
*/
template <int numBits, int threads_per_group>
DS_D_INLINE Params<Type::Asymmetric, numBits> get_params(
cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp)
{
const float2 partial_max = conversion::to<float2>(cur_max);
float max = reduce::element<rop::Max>(partial_max.x, partial_max.y);

const float2 partial_min = conversion::to<float2>(cur_min);
float min = reduce::element<rop::Min>(partial_min.x, partial_min.y);

cg::thread_block_tile<hw_warp_size>& warp) {
float max = to_max_float(cur_max);
float min = to_min_float(cur_min);
reduce::partitioned_block<rop::Max, rop::Min, threads_per_group>(tb, warp, max, min);

Params<Type::Asymmetric, numBits> params(max, min);

return params;
}
};
Expand Down
108 changes: 108 additions & 0 deletions csrc/quantization/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,55 @@ at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in,
return output;
}

std::vector<at::Tensor> ds_loco_swizzle_quant(at::Tensor& input_vals,
at::Tensor& error_feedback,
float err_beta,
int groups,
int num_bits,
quantize::Type quant_type,
int pipeline_size,
int nodes,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
auto scales = torch::empty({groups, scales_elems}, scales_options);

auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

const int quantization_scalar = 8 / num_bits;
const int compressed_vals = at::numel(input_vals) / quantization_scalar;

auto output = torch::empty({compressed_vals}, output_options);
const int elems_per_group = at::numel(input_vals) / groups;

launch_loco_swizzled_quant(
reinterpret_cast<int8_t*>(output.data_ptr()),
reinterpret_cast<float*>(scales.data_ptr()),
reinterpret_cast<const __half*>(input_vals.data_ptr()),
reinterpret_cast<__half*>(error_feedback.data_ptr()),
err_beta,
num_bits,
quant_type,
groups,
elems_per_group,
pipeline_size,
nodes,
devices_per_node,
at::cuda::getCurrentCUDAStream()
);

return {output, scales};
}

std::vector<at::Tensor> ds_swizzle_quant(at::Tensor& input_vals,
int groups,
int num_bits,
Expand Down Expand Up @@ -265,6 +314,63 @@ std::vector<at::Tensor> quantized_reduction(at::Tensor& input_vals,
return {output, scales};
}

std::vector<at::Tensor> loco_quantized_reduction(at::Tensor& input_vals,
at::Tensor& input_scales,
at::Tensor& error_feedback,
float err_beta,
int in_groups,
int out_groups,
int num_bits,
quantize::Type quant_type,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;

auto scales = torch::empty({out_groups, scales_elems}, scales_options);

auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

std::vector<int64_t> sz(input_vals.sizes().begin(), input_vals.sizes().end());
sz[sz.size() - 1] = sz.back() / devices_per_node;

const int elems_per_in_tensor = at::numel(input_vals) / devices_per_node;

auto output = torch::empty(sz, output_options);

const int elems_per_in_group = elems_per_in_tensor / (in_groups / devices_per_node);
const int elems_per_out_group = elems_per_in_tensor / out_groups;

launch_loco_dequant_reduce(
(int8_t*)output.data_ptr(),
(float*)scales.data_ptr(),
(const int8_t*)input_vals.data_ptr(),
(const float*)input_scales.data_ptr(),
devices_per_node,
num_bits,
quant_type,
out_groups,
elems_per_out_group,
elems_per_in_tensor,
in_groups / devices_per_node,
elems_per_in_group,
(half*)error_feedback.data_ptr(),
err_beta,
at::cuda::getCurrentCUDAStream()
);

return {output, scales};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
Expand Down Expand Up @@ -295,4 +401,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"Dequantize int8 to half (experimental)");
m.def("swizzle_quant", &ds_swizzle_quant);
m.def("quantized_reduction", &quantized_reduction);
m.def("loco_swizzle_quant", &ds_loco_swizzle_quant, "LoCo Swizzled Quantization Kernel");
m.def("loco_quantized_reduction", &loco_quantized_reduction, "LoCo Quantization and Reduction Kernel");
}
Loading