diff --git a/csrc/includes/context.h b/csrc/includes/context.h old mode 100644 new mode 100755 index 1e4820177c5d..e05c41dc1d0a --- a/csrc/includes/context.h +++ b/csrc/includes/context.h @@ -29,7 +29,7 @@ for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y) #define DS_CUDA_NUM_THREADS 512 -#define DS_MAXIMUM_NUM_BLOCKS 4096 +#define DS_MAXIMUM_NUM_BLOCKS 262144 inline int DS_GET_BLOCKS(const int N) { @@ -69,7 +69,7 @@ class Context { if (!_workspace) { assert(_workspace == nullptr); cudaMalloc(&_workspace, size); - } else if (_workSpaceSize != size) { + } else if (_workSpaceSize < size) { cudaFree(_workspace); cudaMalloc(&_workspace, size); } diff --git a/csrc/includes/custom_cuda_layers.h b/csrc/includes/custom_cuda_layers.h index 2e72a35292c6..764687da4ff4 100644 --- a/csrc/includes/custom_cuda_layers.h +++ b/csrc/includes/custom_cuda_layers.h @@ -29,7 +29,6 @@ void launch_bias_gelu(const T* input, T* output, int intermediate_size, int batch_size, - int sequence_length, cudaStream_t stream); template @@ -37,7 +36,6 @@ void launch_gelu(const T* input, T* output, int intermediate_size, int batch_size, - int sequence_length, cudaStream_t stream); template @@ -46,7 +44,6 @@ void launch_d_gelu(T* d_output, const T* bias, int intermediate_size, int batch_size, - int sequence_length, cudaStream_t stream); // Custom fused bias add with layer normalization @@ -57,14 +54,12 @@ void launch_bias_residual_layer_norm(T* vals, const T* beta, float epsilon, int batch_size, - int sequence_length, int hidden_dim, cudaStream_t stream, bool preLayerNorm, - bool training = false, - T* vars = nullptr, - T* means = nullptr, - T* vals_hat = nullptr); + bool training, + T* vars, + T* means); template void launch_bias_residual_layer_norm(T* vals, @@ -73,14 +68,11 @@ void launch_bias_residual_layer_norm(T* vals, const T* beta, float epsilon, int batch_size, - int sequence_length, int hidden_dim, cudaStream_t stream, bool preLayerNorm, - bool training = false, - T* vars = nullptr, - T* vals_hat = nullptr, - bool save_vals = false); + bool training, + T* vars); template void launch_layerNorm_backward_fused_add(const T* out_grad1, @@ -93,7 +85,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1, T* betta_grad, T* inp_grad, int batch_size, - int sequence_length, int hidden_dim, cudaStream_t stream[2]); template @@ -106,7 +97,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1, T* betta_grad, T* inp_grad, int batch_size, - int sequence_length, int hidden_dim, cudaStream_t stream[2], bool invertible = false, @@ -122,7 +112,6 @@ void launch_layerNorm_backward(const T* out_grad, T* betta_grad, T* inp_grad, int batch_size, - int sequence_length, int hidden_dim, cudaStream_t stream[2]); @@ -135,7 +124,6 @@ void launch_layerNorm_backward(const T* out_grad, T* betta_grad, T* inp_grad, int batch_size, - int sequence_length, int hidden_dim, cudaStream_t stream[2], bool invertible = false, @@ -153,7 +141,6 @@ void launch_layerNorm_backward_nreversible(const T* out_grad, T* betta_grad, T* inp_grad, int batch_size, - int sequence_length, int hidden_dim, cudaStream_t stream[2]); diff --git a/csrc/includes/dropout.h b/csrc/includes/dropout.h index 090df3a0abf8..f6e32af5608d 100644 --- a/csrc/includes/dropout.h +++ b/csrc/includes/dropout.h @@ -9,15 +9,13 @@ class Dropout { public: struct Config { float ratio; - uint32_t batch, dim; + uint32_t dim; bool training; - Config(float r, uint32_t batch, uint32_t dim) - : ratio(r), batch(batch), dim(dim), training(true) - { - } + Config(float r, uint32_t d) : ratio(r), dim(d), training(true) {} float RATIO() const { return training ? ratio : 0.0; } + inline void SetDim(uint32_t d) { dim = d; } }; Dropout(const Config& config) : _config(config), _mask(nullptr) {} @@ -70,6 +68,8 @@ class Dropout { Config GetConfig() const { return _config; } + inline void SetDimension(uint32_t dim) { _config.SetDim(dim); } + private: uint8_t* _mask; Config _config; diff --git a/csrc/includes/ds_transformer_cuda.h b/csrc/includes/ds_transformer_cuda.h old mode 100755 new mode 100644 index 896dce8c26db..3fac43e4c6a5 --- a/csrc/includes/ds_transformer_cuda.h +++ b/csrc/includes/ds_transformer_cuda.h @@ -121,11 +121,17 @@ class BertTransformerLayer { void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr, uint8_t* attn_output_dropout_mask_ptr, - uint8_t* layer_output_dropout_mask_ptr); + uint8_t* layer_output_dropout_mask_ptr, + T* layer_norm_var, + T* layer_norm_mean, + T* attn_layer_norm_var, + T* attn_layer_norm_mean); inline int GetBatchSize() const { return _batch_size; } inline int GetNumHeads() const { return _heads; } inline int GetSeqLength() const { return _seq_length; } + + void SetSeqLength(int seq_len, int bsz); inline int GetHiddenSize() const { return _hidden_size; } void SetTrainingMode(bool training); @@ -150,8 +156,8 @@ class BertTransformerLayer { // layers FeedForward _qkv_linear; FeedForward _attn_out_linear; - Normalize_Layer _norm_layer2; - Normalize_Layer _norm_layer3; + Normalize_Layer _attn_layer_norm; + Normalize_Layer _layer_norm; Normalize_Layer* _last_normalize; FeedForward _ff1, _ff2; Softmax _softmax; diff --git a/csrc/includes/gelu.h b/csrc/includes/gelu.h index 247bfb273de0..41cf6f2a68a7 100644 --- a/csrc/includes/gelu.h +++ b/csrc/includes/gelu.h @@ -9,13 +9,8 @@ template class Gelu { public: struct Config { - uint32_t batch_size; - uint32_t seq_length; uint32_t intermediate_size; - Config(uint32_t batch, uint32_t seq, uint32_t inter_size) - : batch_size(batch), seq_length(seq), intermediate_size(inter_size) - { - } + Config(uint32_t inter_size) : intermediate_size(inter_size) {} }; Gelu(const Config& config) : _config(config) {} @@ -28,14 +23,12 @@ class Gelu { T* output, cudaStream_t stream) { - launch_bias_gelu( - input_buf, bias, output, _config.intermediate_size, bsz, _config.seq_length, stream); + launch_bias_gelu(input_buf, bias, output, _config.intermediate_size, bsz, stream); } void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, cudaStream_t stream) { - launch_d_gelu( - d_output, input_buf, bias, _config.intermediate_size, bsz, _config.seq_length, stream); + launch_d_gelu(d_output, input_buf, bias, _config.intermediate_size, bsz, stream); } private: diff --git a/csrc/includes/normalize_layer.h b/csrc/includes/normalize_layer.h index 37ee752c88b5..bfe84636ddb9 100644 --- a/csrc/includes/normalize_layer.h +++ b/csrc/includes/normalize_layer.h @@ -16,57 +16,27 @@ class Normalize_Layer { uint32_t seqLength; uint32_t hiddenDim; float epsilon; - bool training, save_vals; - bool allocateGrad; + bool training; bool useMean; - Config(uint32_t batch, - uint32_t seq, - uint32_t h, - bool training, - bool save_vals = true, - bool allocateGrad = true, - bool useMean = true) + Config(uint32_t batch, uint32_t seq, uint32_t h, bool training, bool useMean = true) : batchSize(batch), seqLength(seq), hiddenDim(h), epsilon(1e-12), training(training), - save_vals(save_vals), - allocateGrad(allocateGrad), useMean(useMean) { } }; - Normalize_Layer(Config config) : config_(config), vars(nullptr), vals_hat(nullptr) + Normalize_Layer(Config config) + : config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr) { - if (config_.training) { - cudaMalloc((void**)&vars, config_.batchSize * config_.seqLength * sizeof(T)); - - if (config_.useMean) - cudaMalloc((void**)&means, config_.batchSize * config_.seqLength * sizeof(T)); - - if (config_.save_vals) - cudaMalloc((void**)&vals_hat, - config_.batchSize * config_.seqLength * config_.hiddenDim * sizeof(T)); - - if (config_.allocateGrad) - cudaMalloc((void**)&inp_grad, - config_.batchSize * config_.seqLength * config_.hiddenDim * sizeof(T)); - } } - ~Normalize_Layer() - { - if (config_.training) { - cudaFree(vars); - if (config_.useMean) cudaFree(means); - if (config_.save_vals) cudaFree(vals_hat); - if (config_.allocateGrad) cudaFree(inp_grad); - } - } + ~Normalize_Layer() {} - void ForwardCheckpoint(int bsz, + void ForwardCheckpoint(int bsz, // batch * seq T* vals, const T* residual, const T* gamma, @@ -80,14 +50,12 @@ class Normalize_Layer { betta, config_.epsilon, bsz, - config_.seqLength, config_.hiddenDim, stream, preLayerNorm, config_.training, vars, - means, - vals_hat); + means); } void Forward(int bsz, @@ -104,14 +72,11 @@ class Normalize_Layer { betta, config_.epsilon, bsz, - config_.seqLength, config_.hiddenDim, stream, preLayerNorm, config_.training, - vars, - vals_hat, - config_.save_vals); + vars); } void Backward(int bsz, @@ -120,7 +85,7 @@ class Normalize_Layer { T* gamma_grad, T* betta_grad, cudaStream_t stream[2], - T* inp_grad_out = nullptr, + T* inp_grad_out, const T* norm_in = nullptr) { launch_layerNorm_backward(out_grad, @@ -130,9 +95,8 @@ class Normalize_Layer { gamma, gamma_grad, betta_grad, - (config_.allocateGrad ? inp_grad : inp_grad_out), + inp_grad_out, bsz, - config_.seqLength, config_.hiddenDim, stream); } @@ -144,21 +108,20 @@ class Normalize_Layer { T* gamma_grad, T* betta_grad, cudaStream_t stream[2], - T* inp_grad_out = nullptr, - const T* norm_out = nullptr) + T* inp_grad_out, + const T* norm_out) { launch_layerNorm_backward(out_grad, - (config_.save_vals ? vals_hat : norm_out), + norm_out, vars, gamma, gamma_grad, betta_grad, - (config_.allocateGrad ? inp_grad : inp_grad_out), + inp_grad_out, bsz, - config_.seqLength, config_.hiddenDim, stream, - config_.save_vals, + !config_.useMean, betta); } @@ -169,7 +132,7 @@ class Normalize_Layer { T* gamma_grad, T* betta_grad, cudaStream_t stream[2], - T* inp_grad_out = nullptr, + T* inp_grad_out, const T* norm_in = nullptr) { launch_layerNorm_backward_fused_add(out_grad1, @@ -180,9 +143,8 @@ class Normalize_Layer { gamma, gamma_grad, betta_grad, - (config_.allocateGrad ? inp_grad : inp_grad_out), + inp_grad_out, bsz, - config_.seqLength, config_.hiddenDim, stream); } @@ -195,33 +157,41 @@ class Normalize_Layer { T* gamma_grad, T* betta_grad, cudaStream_t stream[2], - T* inp_grad_out = nullptr, - const T* norm_out = nullptr) + T* inp_grad_out, + const T* norm_out) { launch_layerNorm_backward_fused_add(out_grad1, out_grad2, - (config_.save_vals ? vals_hat : norm_out), + norm_out, vars, gamma, gamma_grad, betta_grad, - (config_.allocateGrad ? inp_grad : inp_grad_out), + inp_grad_out, bsz, - config_.seqLength, config_.hiddenDim, stream, - config_.save_vals, + !config_.useMean, betta); } - inline T* GetInputGrad() const { return inp_grad; } - inline bool UseMean() const { return config_.useMean; } + inline void SetVar(T* variance) + { + if (!variance) { throw std::runtime_error("Normalize variance is null."); } + vars = variance; + } + + inline void SetMean(T* mean) + { + if (!mean) { throw std::runtime_error("Normalize mean is null."); } + means = mean; + } + private: Config config_; T* vars; T* means; T* vals_hat; - T* inp_grad; }; diff --git a/csrc/includes/softmax.h b/csrc/includes/softmax.h old mode 100644 new mode 100755 index 2a18daee0b78..2bc2f67059cf --- a/csrc/includes/softmax.h +++ b/csrc/includes/softmax.h @@ -45,13 +45,15 @@ class Softmax { out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream); } - inline int GetProbDepth() const { return config_.prob_depth; } + inline size_t GetProbDepth() const { return config_.prob_depth; } - inline int GetBatchSize() const { return config_.batchSize; } + inline size_t GetBatchSize() const { return config_.batchSize; } - inline int GetNumHeads() const { return config_.heads; } + inline size_t GetNumHeads() const { return config_.heads; } - inline int GetSeqLength() const { return config_.seq_length; } + inline size_t GetSeqLength() const { return config_.seq_length; } + + inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; } private: Config config_; diff --git a/csrc/includes/strided_batch_gemm.h b/csrc/includes/strided_batch_gemm.h index 8c43608e2ecf..44a1b313b986 100644 --- a/csrc/includes/strided_batch_gemm.h +++ b/csrc/includes/strided_batch_gemm.h @@ -3,6 +3,7 @@ #include #include #include +#include "context.h" template class StridedBatchGemm { @@ -38,6 +39,12 @@ class StridedBatchGemm { gemm_algos(algos) { } + void SetConfig(int mm, int nn, int kk) + { + m = mm; + n = nn; + k = kk; + } }; StridedBatchGemm(const Config& config) : _config(config) {} @@ -163,6 +170,8 @@ class StridedBatchGemm { inline const T* GetBufferB() const { return q_buf; } + inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); } + private: Config _config; const T* q_buf; diff --git a/csrc/transformer/cublas_wrappers.cu b/csrc/transformer/cublas_wrappers.cu index 7b0016bcae5e..3128e6c0fd64 100644 --- a/csrc/transformer/cublas_wrappers.cu +++ b/csrc/transformer/cublas_wrappers.cu @@ -34,7 +34,12 @@ int cublas_gemm_ex(cublasHandle_t handle, algo); if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, "!!!! kernel execution error.\n"); + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); return EXIT_FAILURE; } return 0; @@ -74,7 +79,12 @@ int cublas_gemm_ex(cublasHandle_t handle, algo); if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, "!!!! kernel execution error.\n"); + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); return EXIT_FAILURE; } return 0; @@ -122,7 +132,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, algo); if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, "!!!! kernel execution error.\n"); + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); return EXIT_FAILURE; } return 0; @@ -170,7 +185,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, algo); if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, "!!!! kernel execution error.\n"); + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); return EXIT_FAILURE; } diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp index e36c3786944f..9d275916c68e 100644 --- a/csrc/transformer/ds_transformer_cuda.cpp +++ b/csrc/transformer/ds_transformer_cuda.cpp @@ -78,20 +78,16 @@ BertTransformerLayer::BertTransformerLayer(int layer_id, hidden_size, hidden_size, gemm_algos[0])), - _norm_layer2(typename Normalize_Layer::Config(batch_size, - seq_length, - hidden_size, - true, - false, - false, - !normalize_invertible)), - _norm_layer3(typename Normalize_Layer::Config(batch_size, - seq_length, - hidden_size, - true, - false, - false, - !normalize_invertible)), + _attn_layer_norm(typename Normalize_Layer::Config(batch_size, + seq_length, + hidden_size, + true, + !normalize_invertible)), + _layer_norm(typename Normalize_Layer::Config(batch_size, + seq_length, + hidden_size, + true, + !normalize_invertible)), _ff1(typename FeedForward::Config(batch_size * seq_length, intermediate_size, hidden_size, @@ -101,16 +97,10 @@ BertTransformerLayer::BertTransformerLayer(int layer_id, intermediate_size, gemm_algos[2])), _softmax(typename Softmax::Config(batch_size, num_heads, seq_length)), - _gelu(typename Gelu::Config(_batch_size, _seq_length, intermediate_size)), - _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio, - _batch_size * _heads * _seq_length, - _seq_length)), - _attn_output_dropout(typename Dropout::Config(hidden_output_dropout_ratio, - _batch_size * _seq_length, - _hidden_size)), - _layer_output_dropout(typename Dropout::Config(hidden_output_dropout_ratio, - _batch_size * _seq_length, - _hidden_size)), + _gelu(typename Gelu::Config(_intermediate_size)), + _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio, _seq_length)), + _attn_output_dropout(typename Dropout::Config(hidden_output_dropout_ratio, _hidden_size)), + _layer_output_dropout(typename Dropout::Config(hidden_output_dropout_ratio, _hidden_size)), _attn_scores(typename StridedBatchGemm::Config(_batch_size * _heads, _seq_length, _seq_length, @@ -196,18 +186,18 @@ void BertTransformerLayer::Forward(int bsz, if (_normalize_invertible) add_res_ptr = buf_1 + 3 * small_buf_size; if (_attn_dropout_checkpoint) ctx_bufB_ptr = buf_1 + 4 * small_buf_size; + int bsz_seq = bsz * _seq_length; + if (_pre_or_postLayerNorm) { - if (_norm_layer3.UseMean()) - _norm_layer3.ForwardCheckpoint( - bsz, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true); + if (_layer_norm.UseMean()) + _layer_norm.ForwardCheckpoint( + bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true); else - _norm_layer3.Forward( - bsz, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true); + _layer_norm.Forward( + bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true); } - int bsz_seq = bsz * _seq_length; - if (_pre_or_postLayerNorm) _qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle); else @@ -247,19 +237,19 @@ void BertTransformerLayer::Forward(int bsz, bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream); if (_pre_or_postLayerNorm) { - if (_norm_layer2.UseMean()) - _norm_layer2.ForwardCheckpoint( - bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); + if (_attn_layer_norm.UseMean()) + _attn_layer_norm.ForwardCheckpoint( + bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); else - _norm_layer2.Forward( - bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); + _attn_layer_norm.Forward( + bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); } else { - if (_norm_layer2.UseMean()) - _norm_layer2.ForwardCheckpoint( - bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); + if (_attn_layer_norm.UseMean()) + _attn_layer_norm.ForwardCheckpoint( + bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); else - _norm_layer2.Forward( - bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); + _attn_layer_norm.Forward( + bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); } _ff1.Forward(bsz_seq, @@ -268,7 +258,7 @@ void BertTransformerLayer::Forward(int bsz, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), _cublasHandle); - _gelu.ForwardWithBiasAdd(bsz, + _gelu.ForwardWithBiasAdd(bsz_seq, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, (_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr), @@ -289,11 +279,12 @@ void BertTransformerLayer::Forward(int bsz, bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream); if (!_pre_or_postLayerNorm) { - if (_norm_layer3.UseMean()) - _norm_layer3.ForwardCheckpoint( - bsz, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true); + if (_layer_norm.UseMean()) + _layer_norm.ForwardCheckpoint( + bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true); else - _norm_layer3.Forward(bsz, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true); + _layer_norm.Forward( + bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true); } } @@ -358,26 +349,26 @@ void BertTransformerLayer::Backward(int bsz, int bsz_heads = bsz * _heads; if (!_pre_or_postLayerNorm) { - if (_norm_layer3.UseMean()) - _norm_layer3.Backward(bsz, - grad_output_ptr, - norm_w_ptr, - grad_norm_w_ptr, - grad_norm_b_ptr, - streams, - buf_1, - inp_norm_ptr); + if (_layer_norm.UseMean()) + _layer_norm.Backward(bsz_seq, + grad_output_ptr, + norm_w_ptr, + grad_norm_w_ptr, + grad_norm_b_ptr, + streams, + buf_1, + inp_norm_ptr); else - _norm_layer3.Backward(bsz, - grad_output_ptr, - norm_w_ptr, - norm_b_ptr, - grad_norm_w_ptr, - grad_norm_b_ptr, - streams, - buf_1, - output_ptr); + _layer_norm.Backward(bsz_seq, + grad_output_ptr, + norm_w_ptr, + norm_b_ptr, + grad_norm_w_ptr, + grad_norm_b_ptr, + streams, + buf_1, + output_ptr); } if (_pre_or_postLayerNorm) @@ -389,7 +380,8 @@ void BertTransformerLayer::Backward(int bsz, ? buf_0 : (_pre_or_postLayerNorm ? grad_output_ptr : buf_1); - if (_gelu_checkpoint) _gelu.ForwardWithBiasAdd(bsz, ff2_inp_ptr, inter_b_ptr, buf_2, _stream); + if (_gelu_checkpoint) + _gelu.ForwardWithBiasAdd(bsz_seq, ff2_inp_ptr, inter_b_ptr, buf_2, _stream); _ff2.Backward(bsz_seq, layer_dropout_buf, (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), @@ -401,7 +393,7 @@ void BertTransformerLayer::Backward(int bsz, ff2_buf); _gelu.Backward( - bsz, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream); + bsz_seq, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream); _ff1.Backward(bsz_seq, ff2_buf, @@ -417,49 +409,49 @@ void BertTransformerLayer::Backward(int bsz, launch_fused_add2(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream); if (_pre_or_postLayerNorm) { - if (_norm_layer2.UseMean()) - _norm_layer2.BackwardFusedAdd(bsz, - buf_3, - grad_output_ptr, - attn_nw_ptr, - grad_attn_nw_ptr, - grad_attn_nb_ptr, - streams, - buf_0, - add_res_ptr); + if (_attn_layer_norm.UseMean()) + _attn_layer_norm.BackwardFusedAdd(bsz_seq, + buf_3, + grad_output_ptr, + attn_nw_ptr, + grad_attn_nw_ptr, + grad_attn_nb_ptr, + streams, + buf_0, + add_res_ptr); else - _norm_layer2.BackwardFusedAdd(bsz, - buf_3, - grad_output_ptr, - attn_nw_ptr, - attn_nb_ptr, - grad_attn_nw_ptr, - grad_attn_nb_ptr, - streams, - buf_0, - ff1_inp_ptr); + _attn_layer_norm.BackwardFusedAdd(bsz_seq, + buf_3, + grad_output_ptr, + attn_nw_ptr, + attn_nb_ptr, + grad_attn_nw_ptr, + grad_attn_nb_ptr, + streams, + buf_0, + ff1_inp_ptr); } else { - if (_norm_layer2.UseMean()) - _norm_layer2.Backward(bsz, - buf_2, - attn_nw_ptr, - grad_attn_nw_ptr, - grad_attn_nb_ptr, - streams, - buf_0, - add_res_ptr); + if (_attn_layer_norm.UseMean()) + _attn_layer_norm.Backward(bsz_seq, + buf_2, + attn_nw_ptr, + grad_attn_nw_ptr, + grad_attn_nb_ptr, + streams, + buf_0, + add_res_ptr); else - _norm_layer2.Backward(bsz, - buf_2, - attn_nw_ptr, - attn_nb_ptr, - grad_attn_nw_ptr, - grad_attn_nb_ptr, - streams, - buf_0, - ff1_inp_ptr); + _attn_layer_norm.Backward(bsz_seq, + buf_2, + attn_nw_ptr, + attn_nb_ptr, + grad_attn_nw_ptr, + grad_attn_nb_ptr, + streams, + buf_0, + ff1_inp_ptr); } _attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream); @@ -524,28 +516,28 @@ void BertTransformerLayer::Backward(int bsz, buf_2); if (_pre_or_postLayerNorm) { - if (_norm_layer3.UseMean()) - _norm_layer3.BackwardFusedAdd(bsz, - buf_2, - buf_0, - norm_w_ptr, - grad_norm_w_ptr, - grad_norm_b_ptr, - streams, - grad_input_ptr, - input_ptr); + if (_layer_norm.UseMean()) + _layer_norm.BackwardFusedAdd(bsz_seq, + buf_2, + buf_0, + norm_w_ptr, + grad_norm_w_ptr, + grad_norm_b_ptr, + streams, + grad_input_ptr, + input_ptr); else - _norm_layer3.BackwardFusedAdd(bsz, - buf_2, - buf_0, - norm_w_ptr, - norm_b_ptr, - grad_norm_w_ptr, - grad_norm_b_ptr, - streams, - grad_input_ptr, - inp_norm_ptr); + _layer_norm.BackwardFusedAdd(bsz_seq, + buf_2, + buf_0, + norm_w_ptr, + norm_b_ptr, + grad_norm_w_ptr, + grad_norm_b_ptr, + streams, + grad_input_ptr, + inp_norm_ptr); } else launch_fused_add2(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream); } @@ -562,11 +554,34 @@ void BertTransformerLayer::SetTrainingMode(bool training) template void BertTransformerLayer::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr, uint8_t* attn_output_dropout_mask_ptr, - uint8_t* layer_output_dropout_mask_ptr) + uint8_t* layer_output_dropout_mask_ptr, + T* attn_layer_norm_var, + T* attn_layer_norm_mean, + T* layer_norm_var, + T* layer_norm_mean) { _attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr); _attn_output_dropout.SetMask(attn_output_dropout_mask_ptr); _layer_output_dropout.SetMask(layer_output_dropout_mask_ptr); + + _attn_layer_norm.SetVar(attn_layer_norm_var); + _attn_layer_norm.SetMean(attn_layer_norm_mean); + _layer_norm.SetVar(layer_norm_var); + _layer_norm.SetMean(layer_norm_mean); +} + +template +void BertTransformerLayer::SetSeqLength(int seq_len, int bsz) +{ + _seq_length = seq_len; + + _softmax.SetSeqLength(_seq_length); + _attn_prob_dropout.SetDimension(_seq_length); + _attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads); + _attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length); + + Context::Instance().GenWorkSpace( + get_workspace_size(bsz, _seq_length, _hidden_size, _heads, _training, _gelu_checkpoint)); } template @@ -687,54 +702,61 @@ std::vector ds_transformer_forward(int layer_id, std::shared_ptr> layer = std::static_pointer_cast>(s_transformer_layers[layer_id]); + int seq_len = layer->GetSeqLength(); + if (input.size(1) != seq_len) { + seq_len = input.size(1); + layer->SetSeqLength(seq_len, bsz); + } + auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output); auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input)); auto attn_o_inp = torch::empty_like(input); - auto qkv_tf = torch::empty({(bsz * layer->GetSeqLength()), output_w.size(0) * 3}, options); + auto qkv_tf = torch::empty({(bsz * seq_len), output_w.size(0) * 3}, options); auto attn_prob_dropout_mask = - torch::empty({(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()}, - uint8_options); + torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, uint8_options); auto attn_output_dropout_mask = - torch::empty({(bsz * layer->GetSeqLength()), layer->GetHiddenSize()}, uint8_options); + torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options); auto layer_output_dropout_mask = - torch::empty({(bsz * layer->GetSeqLength()), layer->GetHiddenSize()}, uint8_options); + torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options); + + auto attn_layer_norm_var = torch::empty({(bsz * seq_len)}, options); + auto attn_layer_norm_mean = torch::empty({(bsz * seq_len)}, options); + auto layer_norm_var = torch::empty({(bsz * seq_len)}, options); + auto layer_norm_mean = torch::empty({(bsz * seq_len)}, options); T* inp_norm_ptr = (T*)inp_norm.data_ptr(); T* add_res_ptr = (T*)add_res.data_ptr(); T* q_tf_ptr = (T*)qkv_tf.data_ptr(); - T* k_tf_ptr = - q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(T*)k_tf.data_ptr(); - T* v_tf_ptr = - k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(T*)v_tf.data_ptr(); + T* k_tf_ptr = q_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)k_tf.data_ptr(); + T* v_tf_ptr = k_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)v_tf.data_ptr(); T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr(); - torch::Tensor ff2_inp = - torch::empty({(bsz * layer->GetSeqLength()), output_w.size(1)}, options); + torch::Tensor ff2_inp = torch::empty({(bsz * seq_len), output_w.size(1)}, options); torch::Tensor gelu_inp = - (gelu_checkpoint - ? ff2_inp - : torch::empty({(bsz * layer->GetSeqLength()), output_w.size(1)}, options)); + (gelu_checkpoint ? ff2_inp : torch::empty({(bsz * seq_len), output_w.size(1)}, options)); auto ff1_inp = torch::empty_like(input); T* ff2_inp_ptr = (T*)ff2_inp.data_ptr(); T* gelu_inp_ptr = (T*)gelu_inp.data_ptr(); T* ff1_inp_ptr = (T*)ff1_inp.data_ptr(); - torch::Tensor soft_out = torch::empty( - {(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()}, options); + torch::Tensor soft_out = + torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options); torch::Tensor ctx_bufB = (attn_dropout_checkpoint ? soft_out - : torch::empty( - {(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()}, - options)); + : torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options)); T* soft_out_ptr = (T*)soft_out.data_ptr(); T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr(); layer->SetTrainingMode(training_mode); layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(), (uint8_t*)attn_output_dropout_mask.data_ptr(), - (uint8_t*)layer_output_dropout_mask.data_ptr()); + (uint8_t*)layer_output_dropout_mask.data_ptr(), + (T*)attn_layer_norm_var.data_ptr(), + (T*)attn_layer_norm_mean.data_ptr(), + (T*)layer_norm_var.data_ptr(), + (T*)layer_norm_mean.data_ptr()); layer->Forward(bsz, input_ptr, @@ -776,7 +798,11 @@ std::vector ds_transformer_forward(int layer_id, ff2_inp, attn_prob_dropout_mask, attn_output_dropout_mask, - layer_output_dropout_mask}; + layer_output_dropout_mask, + attn_layer_norm_var, + attn_layer_norm_mean, + layer_norm_var, + layer_norm_mean}; } template @@ -795,6 +821,10 @@ std::vector ds_transformer_backward(int layer_id, const torch::Tensor& attn_prob_dropout_mask, const torch::Tensor& attn_output_dropout_mask, const torch::Tensor& layer_output_dropout_mask, + const torch::Tensor& attn_layer_norm_var, + const torch::Tensor& attn_layer_norm_mean, + const torch::Tensor& layer_norm_var, + const torch::Tensor& layer_norm_mean, const torch::Tensor& input, const torch::Tensor& input_mask, const torch::Tensor& attn_qkvw, @@ -838,6 +868,7 @@ std::vector ds_transformer_backward(int layer_id, CHECK_INPUT(norm_b); int bsz = g_output.size(0); + std::shared_ptr> layer = std::static_pointer_cast>(s_transformer_layers[layer_id]); @@ -900,7 +931,11 @@ std::vector ds_transformer_backward(int layer_id, layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(), (uint8_t*)attn_output_dropout_mask.data_ptr(), - (uint8_t*)layer_output_dropout_mask.data_ptr()); + (uint8_t*)layer_output_dropout_mask.data_ptr(), + (T*)attn_layer_norm_var.data_ptr(), + (T*)attn_layer_norm_mean.data_ptr(), + (T*)layer_norm_var.data_ptr(), + (T*)layer_norm_mean.data_ptr()); layer->Backward(bsz, grad_output_ptr, diff --git a/csrc/transformer/gelu_kernels.cu b/csrc/transformer/gelu_kernels.cu old mode 100755 new mode 100644 index f0e65e3829b5..209b64a90902 --- a/csrc/transformer/gelu_kernels.cu +++ b/csrc/transformer/gelu_kernels.cu @@ -279,13 +279,12 @@ void launch_bias_gelu(const T* input, T* output, int intermediate_size, int batch_size, - int sequence_length, cudaStream_t stream) { int iterations = (intermediate_size + 1023) / 1024; int threads = intermediate_size / iterations / 4; dim3 block_dims(threads); - dim3 grid_dims(sequence_length * batch_size); + dim3 grid_dims(batch_size); fused_bias_gelu<<>>(input, bias, output, intermediate_size); } @@ -295,24 +294,26 @@ void launch_gelu(const T* input, T* output, int intermediate_size, int batch_size, - int sequence_length, cudaStream_t stream) { int iterations = (intermediate_size + 1023) / 1024; int threads = intermediate_size / iterations / 4; dim3 block_dims(threads); - dim3 grid_dims(sequence_length * batch_size); + dim3 grid_dims(batch_size); gelu_kernel<<>>(input, output, intermediate_size); } -template void -launch_bias_gelu(const float*, const float*, float*, int, int, int, cudaStream_t); -template void -launch_bias_gelu<__half>(const __half*, const __half*, __half*, int, int, int, cudaStream_t); +template void launch_bias_gelu(const float*, const float*, float*, int, int, cudaStream_t); +template void launch_bias_gelu<__half>(const __half*, + const __half*, + __half*, + int, + int, + cudaStream_t); -template void launch_gelu(const float*, float*, int, int, int, cudaStream_t); -template void launch_gelu<__half>(const __half*, __half*, int, int, int, cudaStream_t); +template void launch_gelu(const float*, float*, int, int, cudaStream_t); +template void launch_gelu<__half>(const __half*, __half*, int, int, cudaStream_t); template void launch_d_gelu(T* d_output, @@ -320,17 +321,15 @@ void launch_d_gelu(T* d_output, const T* bias, int intermediate_size, int batch_size, - int sequence_length, cudaStream_t stream) { int iterations = (intermediate_size + 1023) / 1024; int threads = intermediate_size / iterations / 4; dim3 block_dims(threads); - dim3 grid_dims(sequence_length * batch_size); + dim3 grid_dims(batch_size); d_gelu_func<<>>(d_output, input, bias, intermediate_size); } -template void launch_d_gelu(float*, const float*, const float*, int, int, int, cudaStream_t); -template void -launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, int, cudaStream_t); +template void launch_d_gelu(float*, const float*, const float*, int, int, cudaStream_t); +template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, cudaStream_t); diff --git a/csrc/transformer/general_kernels.cu b/csrc/transformer/general_kernels.cu index 0ce280a702ab..ad97a9a4793c 100644 --- a/csrc/transformer/general_kernels.cu +++ b/csrc/transformer/general_kernels.cu @@ -14,15 +14,18 @@ __global__ void column_sum_reduce(const T* __restrict__ inp, cg::thread_block_tile g = cg::tiled_partition(b); int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; float localSum = 0; // Loop across matrix height - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - localSum += (float)inp[offset]; - offset += y_stride; + if (idx < width) { + int offset = threadIdx.y * width + idx; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + localSum += (float)inp[offset]; + offset += y_stride; + } } tile[threadIdx.x][threadIdx.y] = localSum; @@ -40,7 +43,7 @@ __global__ void column_sum_reduce(const T* __restrict__ inp, if (threadIdx.x == 0) { int pos = blockIdx.x * TILE_DIM + threadIdx.y; - out[pos] = sum; + if (pos < (rows * width)) out[pos] = sum; } } @@ -58,10 +61,10 @@ void launch_fuse_transpose_bias_kernel(const float* inp, int cols, cudaStream_t stream) { - assert(rows % TILE_DIM == 0); - assert(cols % TILE_DIM == 0); + // assert(rows % TILE_DIM == 0); + // assert(cols % TILE_DIM == 0); - dim3 grid_dim(cols / TILE_DIM); + dim3 grid_dim((cols - 1) / TILE_DIM + 1); dim3 block_dim(TILE_DIM, TILE_DIM); column_sum_reduce<<>>(inp, out, rows, cols); @@ -74,10 +77,10 @@ void launch_fuse_transpose_bias_kernel<__half>(const __half* inp, int cols, cudaStream_t stream) { - assert(rows % TILE_DIM == 0); - assert(cols % TILE_DIM == 0); + // assert(rows % TILE_DIM == 0); + // assert(cols % TILE_DIM == 0); - dim3 grid_dim(cols / TILE_DIM); + dim3 grid_dim((cols - 1) / TILE_DIM + 1); dim3 block_dim(TILE_DIM, TILE_DIM); column_sum_reduce<__half><<>>(inp, out, rows, cols); diff --git a/csrc/transformer/normalize_kernels.cu b/csrc/transformer/normalize_kernels.cu old mode 100755 new mode 100644 index 7345175694bf..4431aeb3d8e5 --- a/csrc/transformer/normalize_kernels.cu +++ b/csrc/transformer/normalize_kernels.cu @@ -27,10 +27,9 @@ __global__ void fused_bias_residual_layer_norm(float* vals, const float* beta, float epsilon, bool preLayerNorm, - bool training = false, - float* vars = nullptr, - float* means = nullptr, - float* vals_hat = nullptr) + bool training, + float* vars, + float* means) { constexpr int iteration_stride = row_stride / iterations; @@ -108,10 +107,9 @@ __global__ void fused_bias_residual_layer_norm(__half* vals, const __half* beta, float epsilon, bool preLayerNorm, - bool training = false, - __half* vars = nullptr, - __half* means = nullptr, - __half* vals_hat = nullptr) + bool training, + __half* vars, + __half* means) { #if __CUDA_ARCH__ >= 700 constexpr int iteration_stride = row_stride / iterations; @@ -204,14 +202,12 @@ void launch_bias_residual_layer_norm(T* vals, const T* beta, float epsilon, int batch_size, - int sequence_length, int hidden_dim, cudaStream_t stream, bool preLayerNorm, bool training, T* vars, - T* means, - T* vals_hat); + T* means); template <> void launch_bias_residual_layer_norm(float* vals, @@ -220,40 +216,38 @@ void launch_bias_residual_layer_norm(float* vals, const float* beta, float epsilon, int batch_size, - int sequence_length, int hidden_dim, cudaStream_t stream, bool preLayerNorm, bool training, float* vars, - float* means, - float* vals_hat) + float* means) { constexpr int threads = THREADS; - dim3 grid_dim(batch_size * sequence_length); + dim3 grid_dim(batch_size); dim3 block_dim(threads); // There are some limitations to call below functions, now just enumerate the situations. if (hidden_dim == 768) fused_bias_residual_layer_norm<768, 3><<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means); else if (hidden_dim == 512) fused_bias_residual_layer_norm<512, 2><<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means); else if (hidden_dim == 1024) fused_bias_residual_layer_norm<1024, 4><<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means); else if (hidden_dim == 1536) fused_bias_residual_layer_norm<1536, 6><<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means); else if (hidden_dim == 2048) fused_bias_residual_layer_norm<2048, 8><<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means); else if (hidden_dim == 2560) fused_bias_residual_layer_norm<2560, 10><<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means); else throw std::runtime_error("Unsupport hidden_dim."); } @@ -265,39 +259,37 @@ void launch_bias_residual_layer_norm<__half>(__half* vals, const __half* beta, float epsilon, int batch_size, - int sequence_length, int hidden_dim, cudaStream_t stream, bool preLayerNorm, bool training, __half* vars, - __half* means, - __half* vals_hat) + __half* means) { constexpr int threads = 128; - dim3 grid_dim(batch_size * sequence_length); + dim3 grid_dim(batch_size); dim3 block_dim(threads); // There are some limitations to call below functions, now just enumerate the situations. if (hidden_dim == 768) fused_bias_residual_layer_norm<384, 3><<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means); else if (hidden_dim == 512) fused_bias_residual_layer_norm<256, 2><<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means); else if (hidden_dim == 1024) fused_bias_residual_layer_norm<512, 4><<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means); else if (hidden_dim == 1536) fused_bias_residual_layer_norm<768, 6><<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means); else if (hidden_dim == 2048) fused_bias_residual_layer_norm<1024, 8><<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means); else if (hidden_dim == 2560) fused_bias_residual_layer_norm<1280, 10><<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means); else throw std::runtime_error("Unsupport hidden_dim."); } @@ -309,10 +301,8 @@ __global__ void fused_bias_residual_layer_norm(float* vals, const float* beta, float epsilon, bool preLayerNorm, - bool training = false, - float* vars = nullptr, - float* vals_hat = nullptr, - bool save_vals = false) + bool training, + float* vars) { constexpr int iteration_stride = row_stride / iterations; @@ -388,10 +378,8 @@ __global__ void fused_bias_residual_layer_norm(__half* vals, const __half* beta, float epsilon, bool preLayerNorm, - bool training = false, - __half* vars = nullptr, - __half* vals_hat = nullptr, - bool save_vals = false) + bool training, + __half* vars) { #if __CUDA_ARCH__ >= 700 constexpr int iteration_stride = row_stride / iterations; @@ -481,14 +469,11 @@ void launch_bias_residual_layer_norm(T* vals, const T* beta, float epsilon, int batch_size, - int sequence_length, int hidden_dim, cudaStream_t stream, bool preLayerNorm, bool training, - T* vars, - T* vals_hat, - bool save_vals); + T* vars); /* To tune this launch the following restrictions must be met: @@ -512,88 +497,37 @@ void launch_bias_residual_layer_norm(float* vals, const float* beta, float epsilon, int batch_size, - int sequence_length, int hidden_dim, cudaStream_t stream, bool preLayerNorm, bool training, - float* vars, - float* vals_hat, - bool save_vals) + float* vars) { constexpr int threads = THREADS; - dim3 grid_dim(batch_size * sequence_length); + dim3 grid_dim(batch_size); dim3 block_dim(threads); // There are some limitations to call below functions, now just enumerate the situations. if (hidden_dim == 768) - fused_bias_residual_layer_norm<768, 3><<>>(vals, - residual, - gamma, - beta, - epsilon, - preLayerNorm, - training, - vars, - vals_hat, - save_vals); + fused_bias_residual_layer_norm<768, 3><<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); else if (hidden_dim == 512) - fused_bias_residual_layer_norm<512, 2><<>>(vals, - residual, - gamma, - beta, - epsilon, - preLayerNorm, - training, - vars, - vals_hat, - save_vals); + fused_bias_residual_layer_norm<512, 2><<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); else if (hidden_dim == 1024) - fused_bias_residual_layer_norm<1024, 4><<>>(vals, - residual, - gamma, - beta, - epsilon, - preLayerNorm, - training, - vars, - vals_hat, - save_vals); + fused_bias_residual_layer_norm<1024, 4><<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); else if (hidden_dim == 1536) - fused_bias_residual_layer_norm<1536, 6><<>>(vals, - residual, - gamma, - beta, - epsilon, - preLayerNorm, - training, - vars, - vals_hat, - save_vals); + fused_bias_residual_layer_norm<1536, 6><<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); else if (hidden_dim == 2048) - fused_bias_residual_layer_norm<2048, 8><<>>(vals, - residual, - gamma, - beta, - epsilon, - preLayerNorm, - training, - vars, - vals_hat, - save_vals); + fused_bias_residual_layer_norm<2048, 8><<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); else if (hidden_dim == 2560) - fused_bias_residual_layer_norm<2560, 10><<>>(vals, - residual, - gamma, - beta, - epsilon, - preLayerNorm, - training, - vars, - vals_hat, - save_vals); + fused_bias_residual_layer_norm<2560, 10><<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); else throw std::runtime_error("Unsupport hidden_dim."); } @@ -605,87 +539,36 @@ void launch_bias_residual_layer_norm<__half>(__half* vals, const __half* beta, float epsilon, int batch_size, - int sequence_length, int hidden_dim, cudaStream_t stream, bool preLayerNorm, bool training, - __half* vars, - __half* vals_hat, - bool save_vals) + __half* vars) { constexpr int threads = 128; - dim3 grid_dim(batch_size * sequence_length); + dim3 grid_dim(batch_size); dim3 block_dim(threads); // There are some limitations to call below functions, now just enumerate the situations. if (hidden_dim == 768) - fused_bias_residual_layer_norm<384, 3><<>>(vals, - residual, - gamma, - beta, - epsilon, - preLayerNorm, - training, - vars, - vals_hat, - save_vals); + fused_bias_residual_layer_norm<384, 3><<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); else if (hidden_dim == 512) - fused_bias_residual_layer_norm<256, 2><<>>(vals, - residual, - gamma, - beta, - epsilon, - preLayerNorm, - training, - vars, - vals_hat, - save_vals); + fused_bias_residual_layer_norm<256, 2><<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); else if (hidden_dim == 1024) - fused_bias_residual_layer_norm<512, 4><<>>(vals, - residual, - gamma, - beta, - epsilon, - preLayerNorm, - training, - vars, - vals_hat, - save_vals); + fused_bias_residual_layer_norm<512, 4><<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); else if (hidden_dim == 1536) - fused_bias_residual_layer_norm<768, 6><<>>(vals, - residual, - gamma, - beta, - epsilon, - preLayerNorm, - training, - vars, - vals_hat, - save_vals); + fused_bias_residual_layer_norm<768, 6><<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); else if (hidden_dim == 2048) - fused_bias_residual_layer_norm<1024, 8><<>>(vals, - residual, - gamma, - beta, - epsilon, - preLayerNorm, - training, - vars, - vals_hat, - save_vals); + fused_bias_residual_layer_norm<1024, 8><<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); else if (hidden_dim == 2560) - fused_bias_residual_layer_norm<1280, 10><<>>(vals, - residual, - gamma, - beta, - epsilon, - preLayerNorm, - training, - vars, - vals_hat, - save_vals); + fused_bias_residual_layer_norm<1280, 10><<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); else throw std::runtime_error("Unsupport hidden_dim."); } @@ -1037,15 +920,13 @@ void launch_layerNorm_backward(const float* out_grad, float* gamma_grad, float* betta_grad, float* inp_grad, - int batch_size, - int sequence_length, + int batch, int hidden_dim, cudaStream_t stream[2], bool invertible, const float* betta) { constexpr int threads = THREADS; - int batch = batch_size * sequence_length; dim3 grid_dim(hidden_dim / TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM); @@ -1086,15 +967,13 @@ void launch_layerNorm_backward<__half>(const __half* out_grad, __half* gamma_grad, __half* betta_grad, __half* inp_grad, - int batch_size, - int sequence_length, + int batch, int hidden_dim, cudaStream_t stream[2], bool invertible, const __half* betta) { constexpr int threads = THREADS; - int batch = batch_size * sequence_length; dim3 grid_dim(hidden_dim / TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM); @@ -1336,13 +1215,11 @@ void launch_layerNorm_backward(const float* out_grad, float* gamma_grad, float* betta_grad, float* inp_grad, - int batch_size, - int sequence_length, + int batch, int hidden_dim, cudaStream_t stream[2]) { constexpr int threads = THREADS; - int batch = batch_size * sequence_length; dim3 grid_dim(hidden_dim / TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM); @@ -1384,13 +1261,11 @@ void launch_layerNorm_backward<__half>(const __half* out_grad, __half* gamma_grad, __half* betta_grad, __half* inp_grad, - int batch_size, - int sequence_length, + int batch, int hidden_dim, cudaStream_t stream[2]) { constexpr int threads = THREADS; - int batch = batch_size * sequence_length; dim3 grid_dim(hidden_dim / TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM); @@ -1759,15 +1634,13 @@ void launch_layerNorm_backward_fused_add(const float* out_grad1, float* gamma_grad, float* betta_grad, float* inp_grad, - int batch_size, - int sequence_length, + int batch, int hidden_dim, cudaStream_t stream[2], bool invertible, const float* betta) { constexpr int threads = THREADS; - int batch = batch_size * sequence_length; dim3 grid_dim(hidden_dim / TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM); @@ -1808,15 +1681,13 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, __half* gamma_grad, __half* betta_grad, __half* inp_grad, - int batch_size, - int sequence_length, + int batch, int hidden_dim, cudaStream_t stream[2], bool invertible, const __half* betta) { constexpr int threads = THREADS; - int batch = batch_size * sequence_length; dim3 grid_dim(hidden_dim / TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM); @@ -2070,13 +1941,11 @@ void launch_layerNorm_backward_fused_add(const float* out_grad1, float* gamma_grad, float* betta_grad, float* inp_grad, - int batch_size, - int sequence_length, + int batch, int hidden_dim, cudaStream_t stream[2]) { constexpr int threads = THREADS; - int batch = batch_size * sequence_length; dim3 grid_dim(hidden_dim / TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM); @@ -2119,13 +1988,11 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, __half* gamma_grad, __half* betta_grad, __half* inp_grad, - int batch_size, - int sequence_length, + int batch, int hidden_dim, cudaStream_t stream[2]) { constexpr int threads = THREADS; - int batch = batch_size * sequence_length; dim3 grid_dim(hidden_dim / TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM); diff --git a/csrc/transformer/softmax_kernels.cu b/csrc/transformer/softmax_kernels.cu index 8e2b86901609..582da4829f47 100644 --- a/csrc/transformer/softmax_kernels.cu +++ b/csrc/transformer/softmax_kernels.cu @@ -1,3 +1,4 @@ +#include #include "custom_cuda_layers.h" #include "general_kernels.h" @@ -282,7 +283,7 @@ __global__ void attn_softmax(__half* vals, } template -void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t, bool); +void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t); template <> void launch_attn_softmax(float* vals, @@ -294,11 +295,10 @@ void launch_attn_softmax(float* vals, { const int threads = 128; int seq_length4 = sequence_length / 4; - int seq2 = sequence_length * seq_length4; int block_compute_size = - (seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4); - dim3 grid_dim(batch_size, heads * seq2 / block_compute_size); + (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1); + dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size); int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; @@ -330,8 +330,9 @@ void launch_attn_softmax(float* vals, else { const int threads = 256; block_compute_size = - (seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4); - dim3 grid_dim(batch_size, heads * seq2 / block_compute_size); + (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) + : 1); + dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size); int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; @@ -362,11 +363,10 @@ void launch_attn_softmax<__half>(__half* vals, { const int threads = 128; int seq_length4 = sequence_length / 4; - int seq2 = sequence_length * seq_length4; int block_compute_size = - (seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4); - dim3 grid_dim(batch_size, heads * seq2 / block_compute_size); + (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1); + dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size); int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; @@ -399,8 +399,9 @@ void launch_attn_softmax<__half>(__half* vals, else { const int threads = 256; block_compute_size = - (seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4); - dim3 grid_dim(batch_size, heads * seq2 / block_compute_size); + (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) + : 1); + dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size); int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; @@ -531,55 +532,41 @@ void launch_attn_softmax_backward_v2(T* out_grad, int seq_length, cudaStream_t stream) { - if ((seq_length % WARP_SIZE) != 0 || seq_length > 2048) - throw std::runtime_error("Invalid sequence length found in softmax backward."); - const int warps_per_block = 4; dim3 grid_dim(batch_size * heads * seq_length / warps_per_block); dim3 block_dim(WARP_SIZE, warps_per_block); - switch (seq_length) { - case 32: - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - break; - case 64: - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - break; - case 128: - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - break; - case 256: - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - break; - case 384: - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - break; - case 512: - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - break; - case 768: - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - break; - case 1024: - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - break; - case 2048: - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - break; - default: - throw std::runtime_error( - std::string("Special sequence length found in softmax backward, seq_length: ") + - std::to_string(seq_length)); - } + if (seq_length <= 32) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 64) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 128) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 256) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 384) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 512) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 768) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 1024) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 2048) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else + throw std::runtime_error( + std::string("Special sequence length found in softmax backward, seq_length: ") + + std::to_string(seq_length)); } template void launch_attn_softmax_backward_v2<__half>(__half* out_grad, diff --git a/deepspeed/ops/transformer/transformer.py b/deepspeed/ops/transformer/transformer.py index 97a0beefc305..7dc66c562e51 100755 --- a/deepspeed/ops/transformer/transformer.py +++ b/deepspeed/ops/transformer/transformer.py @@ -184,26 +184,30 @@ def forward(ctx, ff2_inp, attn_prob_dropout_mask, attn_output_dropout_mask, - layer_output_dropout_mask) = forward_func(config.layer_id, - input, - input_mask, - attn_qkvw, - attn_qkvb, - attn_ow, - attn_ob, - attn_nw, - attn_nb, - inter_w, - inter_b, - output_w, - output_b, - norm_w, - norm_b, - config.training, - config.pre_layer_norm, - config.attn_dropout_checkpoint, - config.normalize_invertible, - config.gelu_checkpoint) + layer_output_dropout_mask, + attn_layer_norm_var, + attn_layer_norm_mean, + layer_norm_var, + layer_norm_mean) = forward_func(config.layer_id, + input, + input_mask, + attn_qkvw, + attn_qkvb, + attn_ow, + attn_ob, + attn_nw, + attn_nb, + inter_w, + inter_b, + output_w, + output_b, + norm_w, + norm_b, + config.training, + config.pre_layer_norm, + config.attn_dropout_checkpoint, + config.normalize_invertible, + config.gelu_checkpoint) # For testing only. if grads is not None: @@ -280,6 +284,9 @@ def forward(ctx, if not config.normalize_invertible: ctx.add_res = add_res + ctx.attn_layer_norm_mean = attn_layer_norm_mean + ctx.layer_norm_mean = layer_norm_mean + ctx.ff1_inp = ff1_inp if not config.gelu_checkpoint: ctx.gelu_inp = gelu_inp @@ -288,6 +295,8 @@ def forward(ctx, ctx.attn_prob_dropout_mask = attn_prob_dropout_mask ctx.attn_output_dropout_mask = attn_output_dropout_mask ctx.layer_output_dropout_mask = layer_output_dropout_mask + ctx.attn_layer_norm_var = attn_layer_norm_var + ctx.layer_norm_var = layer_norm_var return output @@ -364,6 +373,10 @@ def backward(ctx, grad_output): ctx.attn_prob_dropout_mask, ctx.attn_output_dropout_mask, ctx.layer_output_dropout_mask, + ctx.attn_layer_norm_var, + ctx.attn_layer_norm_mean, + ctx.layer_norm_var, + ctx.layer_norm_mean, (ctx.inp_norm if (ctx.config.pre_layer_norm and ctx.config.normalize_invertible) else input), input_mask, diff --git a/tests/unit/test_cuda_backward.py b/tests/unit/test_cuda_backward.py index 8e678334ab91..3d114fa0c8dc 100755 --- a/tests/unit/test_cuda_backward.py +++ b/tests/unit/test_cuda_backward.py @@ -252,10 +252,10 @@ def run_backward(ds_config, atol=1e-2, verbose=False): @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol', [ - (3,1024,128,16,24,True,False, 0.05), - (3,1024,128,16,24,True,True, 0.05), - (3,1024,128,16,24,False,False, 0.1), - (3,1024,128,16,24,False,True, 0.2), + (3,1024,120,16,24,True,False, 0.05), + (3,1024,120,16,24,True,True, 0.05), + (3,1024,56,16,24,False,False, 0.1), + (3,1024,56,16,24,False,True, 0.2), ]) # yapf: disable def test_backward(batch_size, hidden_size, diff --git a/tests/unit/test_cuda_forward.py b/tests/unit/test_cuda_forward.py index fc8b8cc7e210..5c21a73dbe91 100755 --- a/tests/unit/test_cuda_forward.py +++ b/tests/unit/test_cuda_forward.py @@ -109,7 +109,7 @@ def create_models(ds_config): num_hidden_layers=ds_config.num_hidden_layers, num_attention_heads=ds_config.heads, batch_size=ds_config.batch_size, - intermediate_size=ds_config.intermediate_size, + intermediate_size=4 * ds_config.hidden_size, hidden_act="gelu", hidden_dropout_prob=ds_config.hidden_dropout_ratio, attention_probs_dropout_prob=ds_config.attn_dropout_ratio, @@ -130,12 +130,12 @@ def create_models(ds_config): weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size))) weights[4].data.fill_(1.0) weights.append( - nn.Parameter(torch.Tensor(ds_config.intermediate_size, + nn.Parameter(torch.Tensor(4 * ds_config.hidden_size, ds_config.hidden_size))) weights[5].data.normal_(mean=0.0, std=ds_config.initializer_range) weights.append( nn.Parameter(torch.Tensor(ds_config.hidden_size, - ds_config.intermediate_size))) + 4 * ds_config.hidden_size))) weights[6].data.normal_(mean=0.0, std=ds_config.initializer_range) weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size))) weights[7].data.fill_(1.0) @@ -145,7 +145,7 @@ def create_models(ds_config): for i in range(4): biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size))) biases[i + 1].data.zero_() - biases.append(nn.Parameter(torch.Tensor(ds_config.intermediate_size))) + biases.append(nn.Parameter(torch.Tensor(4 * ds_config.hidden_size))) biases[5].data.zero_() biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size))) biases[6].data.zero_() @@ -174,7 +174,7 @@ def set_seed(seed): torch.manual_seed(seed) -def run_forward(ds_config, atol=1e-2, verbose=False, test_bsz=None): +def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): set_seed(123) bert_encoder, ds_encoder = create_models(ds_config) @@ -183,10 +183,12 @@ def run_forward(ds_config, atol=1e-2, verbose=False, test_bsz=None): # prepare test data kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32 hidden_states = torch.randn(bsz, - ds_config.max_seq_length, + seq_len, #ds_config.max_seq_length, ds_config.hidden_size, **kwargs) - input_mask = torch.randn(bsz, 1, 1, ds_config.max_seq_length, **kwargs) + input_mask = torch.randn(bsz, 1, 1, + seq_len, #ds_config.max_seq_length, + **kwargs) # run baseline base_results = bert_encoder(hidden_states, @@ -211,10 +213,15 @@ def run_forward(ds_config, atol=1e-2, verbose=False, test_bsz=None): (64,1024,128,16,3,True,True), (8,1024,384,16,3,True,False), (8,1024,384,16,3,True,True), + (8,1024,384,16,3,True,True), + (8,1024,120,16,3,True,False), + (8,1024,120,16,3,True,True), (8,1024,512,16,3,True,False), (8,1024,512,16,3,True,True), - (64,1024,128,16,3,False,False), - (64,1024,128,16,3,False,True), + (64,1024,56,16,3,False,False), + (64,1024,56,16,3,False,True), + (64,1024,24,16,3,False,False), + (64,1024,24,16,3,False,True), (8,1024,384,16,3,False,False), (8,1024,384,16,3,False,True), (8,1024,512,16,3,False,False), @@ -242,8 +249,7 @@ def test_forward(batch_size, ds_config.layer_id = None ds_config.batch_size = batch_size ds_config.hidden_size = hidden_size - ds_config.intermediate_size = 4 * hidden_size - ds_config.max_seq_length = seq_len + ds_config.max_seq_length = 128 #seq_len ds_config.heads = heads ds_config.attn_dropout_ratio = 0.0 ds_config.hidden_dropout_ratio = 0.0 @@ -252,7 +258,7 @@ def test_forward(batch_size, ds_config.initializer_range = 0.02 ds_config.fp16 = use_fp16 - run_forward(ds_config, atol=2e-2) + run_forward(ds_config, seq_len, atol=2e-2) @pytest.mark.parametrize('batch_size, small_bsz, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16', @@ -279,7 +285,6 @@ def test_forward_with_small_bsz(batch_size, ds_config.layer_id = None ds_config.batch_size = batch_size ds_config.hidden_size = hidden_size - ds_config.intermediate_size = 4 * hidden_size ds_config.max_seq_length = seq_len ds_config.heads = heads ds_config.attn_dropout_ratio = 0.0 @@ -289,7 +294,7 @@ def test_forward_with_small_bsz(batch_size, ds_config.initializer_range = 0.02 ds_config.fp16 = use_fp16 - run_forward(ds_config, atol=2e-2, test_bsz=small_bsz) + run_forward(ds_config, seq_len, atol=2e-2, test_bsz=small_bsz) @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16', [ @@ -314,7 +319,6 @@ def test_forward_stochastic(batch_size, ds_config.layer_id = None ds_config.batch_size = batch_size ds_config.hidden_size = hidden_size - ds_config.intermediate_size = hidden_size ds_config.max_seq_length = seq_len ds_config.heads = heads ds_config.attn_dropout_ratio = 0.0 @@ -325,4 +329,4 @@ def test_forward_stochastic(batch_size, ds_config.fp16 = use_fp16 ds_config.stochastic_mode = True - run_forward(ds_config, atol=7e-2) + run_forward(ds_config, seq_len, atol=7e-2)