Skip to content
665 changes: 665 additions & 0 deletions csrc/includes/memory_access_utils.h

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#include "custom_cuda_layers.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include "inference_cuda_layers.h"

#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
Expand Down
6 changes: 5 additions & 1 deletion csrc/transformer/inference/csrc/dequantize.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#include "custom_cuda_layers.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include "inference_cuda_layers.h"

#define MAX_QUANTIZE_GROUPING 1024

Expand Down
92 changes: 42 additions & 50 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
#include "custom_cuda_layers.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include "inference_cuda_layers.h"
#include "memory_access_utils.h"

namespace cg = cooperative_groups;
#define MAX_CAP 4
Expand All @@ -16,25 +21,21 @@ __global__ void fused_bias_gelu(float* input,
int total_count,
int intermediate_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
// Input restriction: intermediate_size % vals_per_access == 0
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(float);
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access;

if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];

data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
float data[vals_per_access];
float data_bias[vals_per_access];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(data_bias, bias + (offset % intermediate_size));

data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
#pragma unroll
for (int i = 0; i < vals_per_access; i++) { data[i] = gelu(data[i] + data_bias[i]); }

input_cast[offset] = data;
mem_access::store_global<granularity>(input + offset, data);
}
}

Expand All @@ -43,40 +44,28 @@ __global__ void fused_bias_gelu(__half* input,
int total_count,
int intermediate_size)
{
// Input restriction: intermediate_size % vals_per_access == 0
// This kernel doubles the per-thread ALU workload as compared to the float implementation
#ifdef HALF_PRECISION_AVAILABLE

float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);

int offset = blockIdx.x * blockDim.x + threadIdx.x;
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(__half);
int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access;

if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];

__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);

float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);

float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);

low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;

low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);

vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
// Divide by 2 since we store two values per __half2
__half2 data[vals_per_access / 2];
__half2 bias_data[vals_per_access / 2];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(bias_data, bias + (offset % intermediate_size));

#pragma unroll
for (int i = 0; i < vals_per_access / 2; i++) {
float2 data_f = __half22float2(data[i]);
float2 bias_f = __half22float2(bias_data[i]);
data[i] = __floats2half2_rn(gelu(data_f.x + bias_f.x), gelu(data_f.y + bias_f.y));
}

input_cast[offset] = vals_vec;
mem_access::store_global<granularity>(input + offset, data);
}
#endif
}
Expand All @@ -88,13 +77,16 @@ void launch_bias_gelu(T* input,
int batch_size,
cudaStream_t stream)
{
int total_count = batch_size * (intermediate_size / 4);
int threads = 1024; // intermediate_size / iterations / 4;
constexpr int threads = 1024;
constexpr int granularity = 16;

const int total_count = batch_size * intermediate_size;
const int elems_per_block = threads * (granularity / sizeof(T));
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size);
dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block);

fused_bias_gelu<<<grid_dims, block_dims, 0, stream>>>(
input, bias, total_count, intermediate_size / 4);
input, bias, total_count, intermediate_size);
}

template void launch_bias_gelu<float>(float*, const float*, int, int, cudaStream_t);
Expand Down
6 changes: 5 additions & 1 deletion csrc/transformer/inference/csrc/normalize.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include <limits>
#include "custom_cuda_layers.h"
#include "inference_cuda_layers.h"

#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
Expand Down
9 changes: 6 additions & 3 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <vector>
#include "context.h"
#include "cublas_wrappers.h"
#include "custom_cuda_layers.h"
#include "inference_context.h"
#include "inference_cublas_wrappers.h"
#include "inference_cuda_layers.h"

std::array<int, 3> gemm_algos = std::array<int, 3>({99, 99, 99});

Expand Down
6 changes: 5 additions & 1 deletion csrc/transformer/inference/csrc/relu.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#include "custom_cuda_layers.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include "inference_cuda_layers.h"

#define MAX_CAP 4
#define MAX_SEQ 2048
Expand Down
6 changes: 5 additions & 1 deletion csrc/transformer/inference/csrc/softmax.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include <limits>
#include "custom_cuda_layers.h"
#include "inference_cuda_layers.h"

#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
Expand Down
6 changes: 5 additions & 1 deletion csrc/transformer/inference/csrc/transform.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include "custom_cuda_layers.h"
#include "inference_cuda_layers.h"
namespace cg = cooperative_groups;

// Bias add
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#pragma once

#include <c10/cuda/CUDAStream.h>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#pragma once

#include <assert.h>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#pragma once

#ifdef __HIP_PLATFORM_HCC__
Expand Down
2 changes: 1 addition & 1 deletion op_builder/transformer_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ def extra_ldflags(self):
return []

def include_paths(self):
return ['csrc/transformer/inference/includes']
return ['csrc/transformer/inference/includes', 'csrc/includes']
4 changes: 4 additions & 0 deletions tests/unit/ops/transformer/inference/test_bias_gelu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""

import pytest
import torch
import deepspeed
Expand Down