Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Diffusion Enhancements #2491

Merged
merged 5 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions csrc/spatial/csrc/opt_bias_add.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include <cassert>
#include "memory_access_utils.h"
#include "spatial_cuda_layers.h"

/*
Fused bias add variants
*/

namespace badd_opt {
constexpr int threads = 256;
constexpr int steps = 2;
constexpr int granularity = 16;
constexpr int vals_per_h = granularity / sizeof(__half);
constexpr int vals_per_h2 = granularity / sizeof(__half2);
constexpr int vals_per_block = threads * steps * vals_per_h;
constexpr int stride = vals_per_h * threads;
} // namespace badd_opt

__global__ void opt_bias_add(__half* result,
const __half* activation,
const __half* bias,
int seq_len,
int channels)
{
const int id = blockIdx.x * badd_opt::vals_per_block + threadIdx.x * badd_opt::vals_per_h;
const int stride = badd_opt::vals_per_h * badd_opt::threads;

for (int i = 0; i < badd_opt::steps; i++) {
if (id + i * badd_opt::stride < seq_len * channels) {
__half2 act_buffer[badd_opt::vals_per_h2];
__half2 bias_buffer[badd_opt::vals_per_h2];

mem_access::load_global<badd_opt::granularity>(act_buffer,
activation + id + i * stride);
mem_access::load_global<badd_opt::granularity>(bias_buffer,
bias + ((id + i * stride) % channels));

for (int j = 0; j < badd_opt::vals_per_h2; j++) { act_buffer[j] += bias_buffer[j]; }

mem_access::store_global<badd_opt::granularity>(result + id + i * stride, act_buffer);
}
}
}

__global__ void opt_bias_add_add(__half* result,
const __half* activation,
const __half* bias,
const __half* other,
int seq_len,
int channels)
{
const int id = blockIdx.x * badd_opt::vals_per_block + threadIdx.x * badd_opt::vals_per_h;
const int stride = badd_opt::vals_per_h * badd_opt::threads;

for (int i = 0; i < badd_opt::steps; i++) {
if (id + i * badd_opt::stride < seq_len * channels) {
__half2 act_buffer[badd_opt::vals_per_h2];
__half2 bias_buffer[badd_opt::vals_per_h2];
__half2 other_buffer[badd_opt::vals_per_h2];

mem_access::load_global<badd_opt::granularity>(act_buffer,
activation + id + i * stride);
mem_access::load_global<badd_opt::granularity>(bias_buffer,
bias + ((id + i * stride) % channels));
mem_access::load_global<badd_opt::granularity>(other_buffer, other + id + i * stride);

for (int j = 0; j < badd_opt::vals_per_h2; j++) {
act_buffer[j] += bias_buffer[j] + other_buffer[j];
}

mem_access::store_global<badd_opt::granularity>(result + id + i * stride, act_buffer);
}
}
}

__global__ void opt_bias_add_bias_add(__half* result,
const __half* activation,
const __half* bias,
const __half* other,
const __half* other_bias,
int seq_len,
int channels)
{
const int id = blockIdx.x * badd_opt::vals_per_block + threadIdx.x * badd_opt::vals_per_h;
const int stride = badd_opt::vals_per_h * badd_opt::threads;

for (int i = 0; i < badd_opt::steps; i++) {
if (id + i * badd_opt::stride < seq_len * channels) {
__half2 act_buffer[badd_opt::vals_per_h2];
__half2 bias_buffer[badd_opt::vals_per_h2];
__half2 other_buffer[badd_opt::vals_per_h2];
__half2 other_bias_buffer[badd_opt::vals_per_h2];

mem_access::load_global<badd_opt::granularity>(act_buffer,
activation + id + i * stride);
mem_access::load_global<badd_opt::granularity>(bias_buffer,
bias + ((id + i * stride) % channels));
mem_access::load_global<badd_opt::granularity>(other_buffer, other + id + i * stride);
mem_access::load_global<badd_opt::granularity>(
other_bias_buffer, other_bias + ((id + i * stride) % channels));

for (int j = 0; j < badd_opt::vals_per_h2; j++) {
act_buffer[j] =
(act_buffer[j] + bias_buffer[j]) + (other_buffer[j] + other_bias_buffer[j]);
}

mem_access::store_global<badd_opt::granularity>(result + id + i * stride, act_buffer);
}
}
}

void launch_opt_bias_add(__half* result,
const __half* activation,
const __half* bias,
const __half* other,
const __half* other_bias,
int batch_size,
int seq_len,
int channels,
cudaStream_t stream)
{
// Should evaluate `true` for reasonable hidden sizes
assert(channels % badd_opt::vals_per_h == 0);

const int effective_seq_len = batch_size * seq_len;
const int vals = effective_seq_len * channels;

dim3 block(badd_opt::threads);
dim3 grid((vals + badd_opt::vals_per_block - 1) / badd_opt::vals_per_block);

if (!other) {
// We shouldn't have a bias if there's no activation
assert(!other_bias);

opt_bias_add<<<grid, block, 0, stream>>>(
result, activation, bias, effective_seq_len, channels);
} else if (!other_bias) {
opt_bias_add_add<<<grid, block, 0, stream>>>(
result, activation, bias, other, effective_seq_len, channels);
} else {
opt_bias_add_bias_add<<<grid, block, 0, stream>>>(
result, activation, bias, other, other_bias, effective_seq_len, channels);
}
}
111 changes: 111 additions & 0 deletions csrc/spatial/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <cstdio>
#include <vector>
#include "spatial_cuda_layers.h"

ChannelsLastProblem dimension_problem(at::Tensor& input)
{
ChannelsLastProblem dims;

if (input.dim() == 4) {
// In some sense this is unsafe (and a reflection of the assumptions made inside
// the C10 options checker). Basically, there's no great way to be sure that
// a tensor is in channels last because a 1x1 image will appear to be in channels
// last even when it isn't.
assert(input.is_contiguous(at::MemoryFormat::ChannelsLast));
dims.batch_size = input.size(0);
dims.seq_len = input.size(2) * input.size(3);
dims.channels = input.size(1);
} else {
assert(input.is_contiguous());
dims.batch_size = input.size(0);
dims.seq_len = input.size(1);
dims.channels = input.size(2);
}

return dims;
}

at::Tensor seq_unroll_bias_add(at::Tensor& input, at::Tensor& bias)
{
assert(input.dtype() == at::kHalf);

// TODO(cmikeh2): Should probably refactor this into a more portable
// description, since it does generalize for channels-last
ChannelsLastProblem problem = dimension_problem(input);

auto output = at::empty_like(input);

launch_opt_bias_add((__half*)output.data_ptr(),
(const __half*)input.data_ptr(),
(const __half*)bias.data_ptr(),
nullptr,
nullptr,
problem.batch_size,
problem.seq_len,
problem.channels,
at::cuda::getCurrentCUDAStream());

return output;
}

at::Tensor seq_bias_add_add(at::Tensor& input, at::Tensor& bias, at::Tensor& other)
{
assert(input.dtype() == at::kHalf);

// TODO(cmikeh2): Should probably refactor this into a more portable
// description, since it does generalize for channels-last
ChannelsLastProblem problem = dimension_problem(input);

auto output = at::empty_like(input);

launch_opt_bias_add((__half*)output.data_ptr(),
(const __half*)input.data_ptr(),
(const __half*)bias.data_ptr(),
(const __half*)other.data_ptr(),
nullptr,
problem.batch_size,
problem.seq_len,
problem.channels,
at::cuda::getCurrentCUDAStream());

return output;
}

at::Tensor seq_bias_add_bias_add(at::Tensor& input,
at::Tensor& bias,
at::Tensor& other,
at::Tensor& other_bias)
{
assert(input.dtype() == at::kHalf);

// TODO(cmikeh2): Should probably refactor this into a more portable
// description, since it does generalize for channels-last
ChannelsLastProblem problem = dimension_problem(input);

auto output = at::empty_like(input);

launch_opt_bias_add((__half*)output.data_ptr(),
(const __half*)input.data_ptr(),
(const __half*)bias.data_ptr(),
(const __half*)other.data_ptr(),
(const __half*)other_bias.data_ptr(),
problem.batch_size,
problem.seq_len,
problem.channels,
at::cuda::getCurrentCUDAStream());

return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("nhwc_bias_add", &seq_unroll_bias_add);
m.def("nhwc_bias_add_add", &seq_bias_add_add);
m.def("nhwc_bias_add_bias_add", &seq_bias_add_bias_add);
}
31 changes: 31 additions & 0 deletions csrc/spatial/includes/spatial_cuda_layers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#pragma once

#if __CUDA_ARCH__ >= 530
#define HALF_PRECISION_AVAILABLE = 1
#endif

#include <cooperative_groups.h>
#include <cuda.h>
#include <cuda_fp16.h>

/*********** Group Norm Kernels, Structs, and Helpers ************/

struct {
int64_t batch_size;
int64_t seq_len;
int64_t channels;
} typedef ChannelsLastProblem;

void launch_opt_bias_add(__half* result,
const __half* activation,
const __half* bias,
const __half* other,
const __half* other_bias,
int batch_size,
int seq_len,
int channels,
cudaStream_t stream);
103 changes: 103 additions & 0 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -593,3 +593,106 @@ template void pad_head_seq(float* padded_output,
int head_size,
int padded_head_size,
cudaStream_t stream);

// TODO(cmikeh2): evaluate different GeLU performance
__device__ __forceinline__ float old_gelu(float val)
{
// 1 / sqrt(2)
constexpr float rsqrt_2 = 0.707106769084930419922;
return val * 0.5f * (1.0f + erff(val * rsqrt_2));
}

namespace fused_geglu {
constexpr int threads = 256;
constexpr int steps = 2;
constexpr int granularity = 16;
} // namespace fused_geglu

template <typename T>
__global__ void fused_bias_geglu(T* output,
const T* activation,
const T* bias,
int base_channels,
int total_elems)
{
constexpr int T_per_access = fused_geglu::granularity / sizeof(T);
constexpr int T_per_step = T_per_access * fused_geglu::threads;
constexpr int T_per_block = T_per_step * fused_geglu::steps;

const int id = blockIdx.x * T_per_block + threadIdx.x * T_per_access;

#pragma unroll
for (int i = 0; i < fused_geglu::steps; i++) {
T activation_buffer_1[T_per_access];
T activation_buffer_2[T_per_access];
T bias_buffer_1[T_per_access];
T bias_buffer_2[T_per_access];

const int iter_id = id + T_per_step * i;
if (iter_id < total_elems) {
const int channel_id = iter_id % base_channels;
const int seq_id = iter_id / base_channels;
const int seq_offset = seq_id * base_channels * 2;

mem_access::load_global<fused_geglu::granularity>(activation_buffer_1,
activation + seq_offset + channel_id);
mem_access::load_global<fused_geglu::granularity>(
activation_buffer_2, activation + seq_offset + channel_id + base_channels);
mem_access::load_global<fused_geglu::granularity>(bias_buffer_1, bias + channel_id);
mem_access::load_global<fused_geglu::granularity>(bias_buffer_2,
bias + channel_id + base_channels);

// Since the GeLU is going to happen at float, might as well
// convert
#pragma unroll
for (int v = 0; v < T_per_access; v++) {
T hidden_state = activation_buffer_1[v] + bias_buffer_1[v];
T pre_gate = activation_buffer_2[v] + bias_buffer_2[v];
float gate_f = old_gelu(conversion::to<float>(pre_gate));
T gate = conversion::to<T>(gate_f);
activation_buffer_1[v] = hidden_state * gate;
}

mem_access::store_global<fused_geglu::granularity>(output + iter_id,
activation_buffer_1);
}
}
}

template <typename T>
void launch_fused_bias_geglu(T* output,
const T* activation,
const T* bias,
int rows,
int elems_per_row,
cudaStream_t stream)
{
/*
Fused bias GEGLU is a variant of the gated activation functions.
The input here is a matrix of [batch, seq_len, 2 * intermediate_dim]
where the second half of the channels act as GeLU gates for the first
half.
*/

// Re-derive the above figures
constexpr int T_per_access = fused_geglu::granularity / sizeof(T);
constexpr int T_per_step = T_per_access * fused_geglu::threads;
constexpr int T_per_block = T_per_step * fused_geglu::steps;

const int base_channels = elems_per_row / 2;
const int total_elems = base_channels * rows;

dim3 block(fused_geglu::threads);
dim3 grid((total_elems + T_per_block - 1) / T_per_block);

fused_bias_geglu<<<grid, block, 0, stream>>>(
output, activation, bias, base_channels, total_elems);
}

template void launch_fused_bias_geglu(__half*,
const __half*,
const __half*,
int,
int,
cudaStream_t);
template void launch_fused_bias_geglu(float*, const float*, const float*, int, int, cudaStream_t);
Loading