Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 31 additions & 38 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,67 +95,60 @@ template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStr
// Not called directly from DeepSpeed, but used in ds_qkv_gemm_int8, ds_linear_layer, etc.
__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(float);
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access;

if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % hidden_size];
float data[vals_per_access];
float bias_data[vals_per_access];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(bias_data, bias + (offset % hidden_size));

data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
#pragma unroll
for (int i = 0; i < vals_per_access; i++) { data[i] += bias_data[i]; }

input_cast[offset] = data;
mem_access::store_global<granularity>(input + offset, data);
}
}

__global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size)
{
#ifdef HALF_PRECISION_AVAILABLE

float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);

int offset = blockIdx.x * blockDim.x + threadIdx.x;
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(__half);
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access;

if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % hidden_size];

__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);

float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);

float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);

low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
__half2 data[vals_per_access / 2];
__half2 bias_data[vals_per_access / 2];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(bias_data, bias + (offset % hidden_size));

vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
#pragma unroll
for (int i = 0; i < vals_per_access / 2; i++) {
float2 data_f = __half22float2(data[i]);
float2 bias_f = __half22float2(bias_data[i]);
data[i] = __floats2half2_rn(data_f.x + bias_f.x, data_f.y + bias_f.y);
}

input_cast[offset] = vals_vec;
mem_access::store_global<granularity>(input + offset, data);
}
#endif
}

template <typename T>
void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream)
{
int total_count = batch_size * (hidden_size / 4);
int threads = 1024; // hidden_size / iterations / 4;
constexpr int threads = 1024;
constexpr int granularity = 16;

const int total_count = batch_size * hidden_size;
const int elems_per_block = threads * (granularity / sizeof(T));
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / threads + 1)); // (batch_size);
dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block);

fused_bias_add<<<grid_dims, block_dims, 0, stream>>>(input, bias, total_count, hidden_size / 4);
fused_bias_add<<<grid_dims, block_dims, 0, stream>>>(input, bias, total_count, hidden_size);
}

template void launch_bias_add<float>(float*, const float*, int, int, cudaStream_t);
Expand Down