Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DeepSpeedExamples
13 changes: 10 additions & 3 deletions csrc/transformer/ds_transformer_cuda.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

static std::unordered_map<int, std::shared_ptr<void>> s_transformer_layers;

const int init_seq_length = 128;

// C++ interface

template <typename T>
Expand Down Expand Up @@ -591,7 +593,6 @@ int create_transformer_layer(int layer_id,
int hidden_dim,
int num_heads,
int intermediate_size,
int seq_length,
float attn_dropout_ratio,
float hidden_dropout_ratio,
int seed,
Expand All @@ -604,14 +605,14 @@ int create_transformer_layer(int layer_id,
{
Context::Instance().SetSeed(seed);
Context::Instance().TestGemmFP16(
test_gemm, batch_size, seq_length, num_heads, hidden_dim / num_heads);
test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);

auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id,
batch_size,
hidden_dim,
num_heads,
intermediate_size,
seq_length,
init_seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
pre_or_postLayerNorm,
Expand Down Expand Up @@ -873,6 +874,12 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);

int seq_len = layer->GetSeqLength();
if (g_output.size(1) != seq_len) {
seq_len = g_output.size(1);
layer->SetSeqLength(seq_len, bsz);
}

auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
Expand Down
20 changes: 14 additions & 6 deletions csrc/transformer/softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ __global__ void attn_softmax(float* vals,
#endif

int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);

for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
Expand Down Expand Up @@ -113,7 +114,8 @@ __global__ void attn_softmax(float* vals,
#endif

int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);

for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }

Expand Down Expand Up @@ -216,7 +218,8 @@ __global__ void attn_softmax(__half* vals,
#endif

int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);

for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
Expand Down Expand Up @@ -252,7 +255,8 @@ __global__ void attn_softmax(__half* vals,
#endif

int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);

for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }

Expand Down Expand Up @@ -339,7 +343,9 @@ void launch_attn_softmax<float>(float* vals,
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);

iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
Expand Down Expand Up @@ -408,7 +414,9 @@ void launch_attn_softmax<__half>(__half* vals,
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);

iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
Expand Down
28 changes: 22 additions & 6 deletions deepspeed/ops/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
class TransformerConfig():
def __init__(self,
batch_size,
max_seq_length,
hidden_size,
intermediate_size,
heads,
Expand All @@ -30,7 +29,6 @@ def __init__(self,
self.batch_size = batch_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_seq_length = max_seq_length
self.heads = heads
self.attn_dropout_ratio = attn_dropout_ratio
self.hidden_dropout_ratio = hidden_dropout_ratio
Expand Down Expand Up @@ -92,7 +90,6 @@ class DeepSpeedTransformerConfig(TransformerConfig):
"""
def __init__(self,
batch_size=-1,
max_seq_length=-1,
hidden_size=-1,
intermediate_size=-1,
heads=-1,
Expand All @@ -112,7 +109,6 @@ def __init__(self,
super(DeepSpeedTransformerConfig,
self).__init__(
batch_size,
max_seq_length,
hidden_size,
(intermediate_size if intermediate_size > 0 else 4 * hidden_size),
heads,
Expand Down Expand Up @@ -142,7 +138,7 @@ def from_dict(cls, json_object):

@classmethod
def from_json_file(cls, json_file):
with open(json_file, "r", encoding='utf-8') as reader:
with open(json_file, "r", encoding='utf-16') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))

Expand Down Expand Up @@ -177,6 +173,18 @@ def forward(ctx,
cuda_module = stochastic_transformer_cuda_module if config.stochastic_mode else transformer_cuda_module
forward_func = cuda_module.forward_fp16 if config.fp16 else cuda_module.forward_fp32

inp_size = input.size()
if inp_size[1] % 16 != 0:
input = torch.cat((input,
torch.randn((inp_size[0],
(16 - (inp_size[1] % 16)),
inp_size[2]),
device=input.device,
dtype=input.dtype)),
1)
input_mask = torch.cat((input_mask, torch.ones((inp_size[0], input_mask.shape[1], input_mask.shape[2], \
(16 - (inp_size[1] % 16))), device=input_mask.device, dtype=input_mask.dtype) * -10000), 3)

(output,
inp_norm,
qkv_tf,
Expand Down Expand Up @@ -303,11 +311,17 @@ def forward(ctx,
ctx.attn_layer_norm_var = attn_layer_norm_var
ctx.layer_norm_var = layer_norm_var

if inp_size[1] % 16 != 0:
output = torch.narrow(output, 1, 0, inp_size[1])
return output

@staticmethod
def backward(ctx, grad_output):
bsz = grad_output.shape[0]
grad_output_shape = grad_output.size()
if grad_output_shape[1] % 16 != 0:
grad_output = torch.cat((grad_output, torch.zeros((bsz, (16 - (grad_output_shape[1] % 16)), \
grad_output_shape[2]), device=grad_output.device, dtype=grad_output.dtype)), 1)

if bsz > ctx.config.batch_size:
raise ValueError('grad_output batch size exceeds the limit.')
Expand Down Expand Up @@ -398,6 +412,9 @@ def backward(ctx, grad_output):
norm_w,
norm_b)

if grad_output_shape[1] % 16 != 0:
grad_input = torch.narrow(grad_input, 1, 0, grad_output_shape[1])

return (grad_input,
None,
None,
Expand Down Expand Up @@ -501,7 +518,6 @@ def __init__(self, layer_id, config, initial_weights=None, initial_biases=None):
self.config.hidden_size,
self.config.heads,
self.config.intermediate_size,
self.config.max_seq_length,
self.config.attn_dropout_ratio,
self.config.hidden_dropout_ratio,
self.config.seed,
Expand Down
32 changes: 12 additions & 20 deletions tests/unit/test_cuda_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def create_models(ds_config):
hidden_act="gelu",
hidden_dropout_prob=ds_config.hidden_dropout_ratio,
attention_probs_dropout_prob=ds_config.attn_dropout_ratio,
max_position_embeddings=ds_config.max_seq_length,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=ds_config.initializer_range)

Expand Down Expand Up @@ -210,25 +210,18 @@ def set_seed(seed):
torch.manual_seed(seed)


def run_backward(ds_config, atol=1e-2, verbose=False):
def run_backward(ds_config, seq_len, atol=1e-2, verbose=False):
set_seed(123)
bert_encoder, ds_encoder = create_models(ds_config)

# prepare test data
kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32
hidden_states = torch.randn(ds_config.batch_size,
ds_config.max_seq_length,
seq_len,
ds_config.hidden_size,
**kwargs)
input_mask = torch.randn(ds_config.batch_size,
1,
1,
ds_config.max_seq_length,
**kwargs)
Y = torch.randn(ds_config.batch_size,
ds_config.max_seq_length,
ds_config.hidden_size,
**kwargs)
input_mask = torch.randn(ds_config.batch_size, 1, 1, seq_len, **kwargs)
Y = torch.randn(ds_config.batch_size, seq_len, ds_config.hidden_size, **kwargs)

# run baseline
base_results = bert_encoder(hidden_states,
Expand Down Expand Up @@ -257,12 +250,12 @@ def run_backward(ds_config, atol=1e-2, verbose=False):
#test_backward[3-1024-120-16-24-True-True-0.05]
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
[
(3,1024,120,16,24,True,False, 0.05),
(3,1024,120,16,24,True,True, 0.05),
(3,1024,56,16,24,False,False, 0.1),
(3,1024,56,16,24,False,True, 0.2),
(3,128,56,2,24,False,False, 0.1),
(3,128,56,2,24,False,True, 0.2),
(3,1024,119,16,24,True,False, 0.05),
(3,1024,115,16,24,True,True, 0.05),
(1024,128,10,2,2,False,False, 0.1),
(3,1024,52,16,24,False,True, 0.2),
(3,128,51,2,24,False,False, 0.1),
(3,128,54,2,24,False,True, 0.2),
]) # yapf: disable
def test_backward(batch_size,
hidden_size,
Expand All @@ -282,7 +275,6 @@ def test_backward(batch_size,
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.intermediate_size = hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
ds_config.hidden_dropout_ratio = 0.0
Expand All @@ -291,7 +283,7 @@ def test_backward(batch_size,
ds_config.initializer_range = 0.02
ds_config.fp16 = use_fp16

run_backward(ds_config, atol=atol)
run_backward(ds_config, seq_len, atol=atol)


#@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
Expand Down
34 changes: 13 additions & 21 deletions tests/unit/test_cuda_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def create_models(ds_config):
hidden_act="gelu",
hidden_dropout_prob=ds_config.hidden_dropout_ratio,
attention_probs_dropout_prob=ds_config.attn_dropout_ratio,
max_position_embeddings=ds_config.max_seq_length,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=ds_config.initializer_range,
fp16=ds_config.fp16)
Expand Down Expand Up @@ -186,13 +186,8 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):

# prepare test data
kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32
hidden_states = torch.randn(bsz,
seq_len, #ds_config.max_seq_length,
ds_config.hidden_size,
**kwargs)
input_mask = torch.randn(bsz, 1, 1,
seq_len, #ds_config.max_seq_length,
**kwargs)
hidden_states = torch.randn(bsz, seq_len, ds_config.hidden_size, **kwargs)
input_mask = torch.randn(bsz, 1, 1, seq_len, **kwargs)

# run baseline
base_results = bert_encoder(hidden_states,
Expand All @@ -213,25 +208,25 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
# FP16 test cases can only run on the devices support FP16.
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[
(8,256,128,4,3,True,False),
(8,256,128,4,3,True,True),
(64,1024,128,16,3,True,False),
(64,1024,128,16,3,True,True),
(8,1024,384,16,3,True,False),
(8,256,53,4,3,True,False),
(8,256,52,4,3,True,True),
(3,1024,51,16,3,True,False),
(3,1024,54,16,3,True,True),
(8,1024,381,16,3,True,False),
(8,1024,384,16,3,True,True),
(8,1024,384,16,3,True,True),
(8,1024,120,16,3,True,False),
(8,1024,119,16,3,True,False),
(8,1024,120,16,3,True,True),
(8,1024,512,16,3,True,False),
(8,1024,509,16,3,True,False),
(8,1024,512,16,3,True,True),
(64,1024,56,16,3,False,False),
(64,1024,56,16,3,False,True),
(64,1024,53,16,3,False,True),
(64,1024,24,16,3,False,False),
(64,1024,24,16,3,False,True),
(64,1024,21,16,3,False,True),
(8,1024,384,16,3,False,False),
(8,1024,384,16,3,False,True),
(8,1024,512,16,3,False,False),
(8,1024,512,16,3,False,True),
(8,1024,511,16,3,False,True),
(8,1536,128,24,3,False,False),
(8,1536,128,24,3,False,True),
(8,2048,128,32,3,False,False),
Expand Down Expand Up @@ -259,7 +254,6 @@ def test_forward(batch_size,
ds_config.layer_id = None
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.max_seq_length = 128 #seq_len
ds_config.intermediate_size = 4 * hidden_size
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
Expand Down Expand Up @@ -297,7 +291,6 @@ def test_forward_with_small_bsz(batch_size,
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.intermediate_size = 4 * hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
ds_config.hidden_dropout_ratio = 0.0
Expand Down Expand Up @@ -332,7 +325,6 @@ def test_forward_stochastic(batch_size,
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.intermediate_size = 4 * hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
ds_config.hidden_dropout_ratio = 0.0
Expand Down