Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 18 additions & 23 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,22 @@ at::Tensor ds_softmax(at::Tensor& attn_scores,
template <typename T>
void allocate_workspace(size_t hidden_dim,
size_t batch_size,
size_t prompt_length,
unsigned num_layers,
unsigned num_heads,
unsigned mp_size = 1,
bool external_cache = false,
unsigned rank = 0)
{
Context::Instance().GenWorkSpace(
num_layers, batch_size, hidden_dim, mp_size, external_cache, sizeof(T), rank);
Context::Instance().GenWorkSpace(num_layers,
num_heads,
batch_size,
prompt_length,
hidden_dim,
mp_size,
external_cache,
sizeof(T),
rank);
}

template <typename T>
Expand Down Expand Up @@ -352,15 +361,7 @@ void attention_unfused(T* prev_key_cont,
float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0;
float alpha = norm_factor * norm_factor / layer_scale;
float gemm_beta = 0.0;
T* workspace;
if (seq_len == 1) {
workspace = (T*)output + bsz * seq_len * heads * k;
} else {
// If we are doing the prompt, switch to the tail workspace
T* scratch = (T*)Context::Instance().GetWorkSpace();
workspace = scratch + ((Context::Instance().get_workspace_size() / sizeof(T)) -
bsz * heads * seq_len * soft_len);
}
T* workspace = (T*)Context::Instance().GetAttentionUnfusedWorkspace();

cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
Expand Down Expand Up @@ -733,11 +734,6 @@ std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
int out_size = q_int8 ? weight.size(0) : weight.size(1);
if (!workspace)
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
allocate_workspace<T>(input.size(2), input.size(0), num_layers, mp_size, external_cache, rank);
workspace = (T*)Context::Instance().GetWorkSpace();

auto options = at::TensorOptions()
.dtype(input.options().dtype())
Expand Down Expand Up @@ -846,13 +842,6 @@ at::Tensor ds_linear_layer(at::Tensor& input,

int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
// Reallocate memory if we received a new prompt
if (!workspace || input.size(1) != 1) {
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
allocate_workspace<T>(input.size(2), input.size(0), num_layers);
workspace = (T*)Context::Instance().GetWorkSpace();
}
auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options);

float alpha = (T)1.0;
Expand Down Expand Up @@ -1425,4 +1414,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
&einsum_sec_sm_ecm<__half>,
"DeepSpeed vector-MM with fp16 (CUDA)");
m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)");
m.def("allocate_workspace_fp32",
&allocate_workspace<float>,
"DeepSpeed memory allocation for GPT inference with fp32 (CUDA)");
m.def("allocate_workspace_fp16",
&allocate_workspace<__half>,
"DeepSpeed memory allocation for GPT inference with fp16 (CUDA)");
}
35 changes: 28 additions & 7 deletions csrc/transformer/inference/includes/inference_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ class Context {
_curr_offset(0),
_stream(0),
_free_memory_size(0),
_num_tokens(1)
_num_tokens(1),
_attention_unfused_workspace_offset(0)
{
if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) {
auto message = std::string("Fail to create cublas handle.");
Expand Down Expand Up @@ -88,7 +89,9 @@ class Context {
}

void GenWorkSpace(const unsigned& num_layers,
const unsigned& num_heads,
const size_t& batch_size,
const size_t& prompt_len,
const size_t& hidden_dim,
const unsigned& mp_size,
const bool& external_cache,
Expand All @@ -99,14 +102,25 @@ class Context {
if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); }

size_t activation_size = 16 * hidden_dim * batch_size;
size_t temp_size = batch_size * num_heads * prompt_len * prompt_len * elem_size / mp_size;
size_t cache_size = num_layers * batch_size * (hidden_dim / mp_size) * 2;
_max_seq_len =
(((_free_memory_size - (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE) /
elem_size)) /
(activation_size + cache_size);
size_t workSpaceSize = (external_cache ? activation_size : (activation_size + cache_size)) *
_max_seq_len * elem_size;
size_t minimal_requirements =
temp_size + (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE;
if (_free_memory_size < minimal_requirements) {
printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n",
minimal_requirements,
_free_memory_size,
total_size);
throw std::runtime_error("Workspace can't be allocated, no enough memory.");
}

_max_seq_len = ((_free_memory_size - minimal_requirements) / elem_size) /
(activation_size + cache_size);
_max_seq_len = std::min((size_t)MAX_OUT_TOKENS, _max_seq_len);
size_t workSpaceSize =
((external_cache ? activation_size : (activation_size + cache_size))) * _max_seq_len *
elem_size +
temp_size;
if (rank == 0 && !_workspace)
printf(
"Free memory : %lu (Bytes) Total memory: %lu (Bytes) Setting maximum total "
Expand All @@ -130,13 +144,18 @@ class Context {
throw std::runtime_error("Workspace is null.");
}
_workSpaceSize = workSpaceSize;
_attention_unfused_workspace_offset = workSpaceSize - temp_size;
}
inline size_t GetMaxTokenLenght() const { return _max_seq_len; }

cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; }

size_t get_workspace_size() const { return _workSpaceSize; }
void* GetWorkSpace() { return _workspace; }
void* GetAttentionUnfusedWorkspace()
{
return _workspace + _attention_unfused_workspace_offset;
}

inline unsigned new_token(unsigned layer_id)
{
Expand Down Expand Up @@ -202,6 +221,8 @@ class Context {
cudaEvent_t _comm_event;

void* _workspace;
// offset from _workspace for attention unfused memory
size_t _attention_unfused_workspace_offset;
uint64_t _seed;
uint64_t _curr_offset;

Expand Down
13 changes: 13 additions & 0 deletions deepspeed/ops/transformer/inference/transformer_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,8 @@ def __init__(self,
device=device),
requires_grad=False)
self.layer_past = None
self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if (not config.fp16) else \
inference_cuda_module.allocate_workspace_fp16

def forward(
self,
Expand All @@ -820,6 +822,17 @@ def forward(
# This needs to be redesigned later!
layer_head_mask=None,
past_key_value=None):
# Allocate memory only on first layer forward
if self.config.layer_id == 0:
self.allocate_workspace(self.config.hidden_size,
input.size()[0],
input.size()[1],
DeepSpeedTransformerInference.layer_id,
self.config.heads,
self.config.mp_size,
self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0)

get_present = (get_present or get_key_value or use_cache)
input_mask = input_mask if attention_mask is None else attention_mask

Expand Down