Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ __global__ void fused_bias_residual(float* input,
float* attnbias,
int total_count,
int intermediate_size,
int mp_size,
float mp_size,
bool preln)
{
float4* input_cast = reinterpret_cast<float4*>(input);
Expand All @@ -191,10 +191,10 @@ __global__ void fused_bias_residual(float* input,
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
if (preln) {
data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x);
data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y);
data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z);
data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w);
data.x = (data.x + res_vec.x + bias_data.x + attn_bias.x) * mp_size + (out.x);
data.y = (data.y + res_vec.y + bias_data.y + attn_bias.y) * mp_size + (out.y);
data.z = (data.z + res_vec.z + bias_data.z + attn_bias.z) * mp_size + (out.z);
data.w = (data.w + res_vec.w + bias_data.w + attn_bias.w) * mp_size + (out.w);
} else {
data.x = data.x + out.x + bias_data.x;
data.y = data.y + out.y + bias_data.y;
Expand All @@ -212,7 +212,7 @@ __global__ void fused_bias_residual(__half* input,
__half* attn_bias,
int total_count,
int intermediate_size,
int mp_size,
float mp_size,
bool preln)
{
#ifdef HALF_PRECISION_AVAILABLE
Expand Down Expand Up @@ -257,13 +257,13 @@ __global__ void fused_bias_residual(__half* input,

if (preln) {
low_data.x =
(low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x));
(low_data.x + low_res.x + (low_bias.x + attn_low_bias.x)) * mp_size + low_out.x;
low_data.y =
(low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y));
high_data.x = (high_data.x + high_res.x) * mp_size +
(high_out.x + (high_bias.x + attn_high_bias.x));
high_data.y = (high_data.y + high_res.y) * mp_size +
(high_out.y + (high_bias.y + attn_high_bias.y));
(low_data.y + low_res.y + (low_bias.y + attn_low_bias.y)) * mp_size + low_out.y;
high_data.x = (high_data.x + high_res.x + (high_bias.x + attn_high_bias.x)) * mp_size +
high_out.x;
high_data.y = (high_data.y + high_res.y + (high_bias.y + attn_high_bias.y)) * mp_size +
high_out.y;
} else {
low_data.x = (low_data.x + low_out.x + low_bias.x);
low_data.y = (low_data.y + low_out.y + low_bias.y);
Expand Down Expand Up @@ -332,10 +332,10 @@ __global__ void gptj_residual_add(float* input,
data.z += attn_bias.z;
data.w += attn_bias.w;
}
data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x);
data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y);
data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z);
data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w);
data.x = (data.x + res_vec.x + bias_data.x) * mp_size + (out.x);
data.y = (data.y + res_vec.y + bias_data.y) * mp_size + (out.y);
data.z = (data.z + res_vec.z + bias_data.z) * mp_size + (out.z);
data.w = (data.w + res_vec.w + bias_data.w) * mp_size + (out.w);

output_cast[offset] = data;
}
Expand Down Expand Up @@ -395,10 +395,10 @@ __global__ void gptj_residual_add(__half* input,
high_data.y += attn_high_bias.y;
}

low_data.x = low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x));
low_data.y = low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y));
high_data.x = high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x));
high_data.y = high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y));
low_data.x = (low_data.x + low_res.x + (low_bias.x)) * mp_size + low_out.x;
low_data.y = (low_data.y + low_res.y + (low_bias.y)) * mp_size + low_out.y;
high_data.x = (high_data.x + high_res.x + (high_bias.x)) * mp_size + high_out.x;
high_data.y = (high_data.y + high_res.y + (high_bias.y)) * mp_size + high_out.y;

vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
Expand Down
14 changes: 6 additions & 8 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def qkv_copy(self, dst, src):
torch.cat([qkv_s[i] for qkv_s in qkv_split],
axis=1) for i in range(len(qkv_split[0]))
]
dst.data.copy_(weight_split[self.gpu_index].to(
torch.cuda.current_device()).contiguous())
dst = weight_split[self.gpu_index].contiguous().to(
torch.cuda.current_device())
else:
if src_shape[0] == dst_shape[0]:
return torch.nn.Parameter(src)
Expand All @@ -84,8 +84,7 @@ def qkv_copy(self, dst, src):
torch.cat([qkv_s[i] for qkv_s in qkv_split],
axis=0) for i in range(len(qkv_split[0]))
]
dst.data.copy_(bias_split[self.gpu_index].to(
torch.cuda.current_device()).contiguous())
dst = bias_split[self.gpu_index].contiguous().to(torch.cuda.current_device())

return torch.nn.Parameter(dst)

Expand All @@ -108,15 +107,14 @@ def copy(self, dst, src):
self.merge_assert(src_shape[1], dst_shape[1])
weight_split = torch.split(src.data, dst_shape[1], dim=1)

dst.data.copy_(weight_split[self.gpu_index].to(
torch.cuda.current_device()).contiguous())
dst = weight_split[self.gpu_index].contiguous().to(
torch.cuda.current_device())
else:
if src_shape[0] == dst_shape[0]:
return torch.nn.Parameter(src)

bias_split = torch.split(src.data, dst_shape[-1])
dst.data.copy_(bias_split[self.gpu_index].to(
torch.cuda.current_device()).contiguous())
dst = bias_split[self.gpu_index].contiguous().to(torch.cuda.current_device())

return torch.nn.Parameter(dst)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def selfAttention_fp():
else:
qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \
inference_cuda_module.qkv_gemm_fp32

qkv_out = qkv_func(input,
attn_qkvw,
(attn_qkvb if attn_qkvb is not None else norm_b),
Expand Down Expand Up @@ -312,11 +311,13 @@ def selfAttention_int8():
(q_groups * (3 if qkv_merging else 1) * (2**merge_count)),
(attn_qkvb is not None))
context_layer, key_layer, value_layer = compute_attention(qkv_out)

output = inference_cuda_module.vector_matmul_int8(context_layer,
attn_ow,
q_scales[1],
q_groups,
(merge_count))

return output, key_layer, value_layer, context_layer

if config.q_int8:
Expand Down
51 changes: 48 additions & 3 deletions tests/unit/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import itertools
import deepspeed
from torch import distributed as dist
from deepspeed.git_version_info import torch_info
from collections import defaultdict
from .common import distributed_test
Expand Down Expand Up @@ -230,7 +231,8 @@ def _go():
_ = pipe(query, **inf_kwargs)
torch.cuda.synchronize()
start = time.time()
bs_output = pipe(query, **inf_kwargs)
for i in range(10):
bs_output = pipe(query, **inf_kwargs)
torch.cuda.synchronize()
bs_time = time.time() - start

Expand All @@ -247,14 +249,15 @@ def _go():
_ = pipe(query, **inf_kwargs)
torch.cuda.synchronize()
start = time.time()
ds_output = pipe(query, **inf_kwargs)
for i in range(10):
ds_output = pipe(query, **inf_kwargs)
torch.cuda.synchronize()
ds_time = time.time() - start

if task == "text-generation":
bs_output = pipe(query, **inf_kwargs)

# These performance tests are only measuring the time for a single
# These performance tests are only measuring the time for a few
# inference request, we just want to check that performance isn't terrible
assert ds_time <= (bs_time * 1.1)
assert assert_fn(bs_output, ds_output)
Expand Down Expand Up @@ -323,3 +326,45 @@ def _go():
assert ppl_diff < 0.01

_go()


@pytest.mark.inference
@pytest.mark.parametrize("model_w_task",
[("EleutherAI/gpt-neo-2.7B",
"text-generation")])
@pytest.mark.parametrize("dtype",
[torch.half]) # FP32 tests are failing due to OOM on 16GB V100
def test_multi_gpu(model_w_task,
dtype,
enable_cuda_graph,
query,
inf_kwargs,
assert_fn,
invalid_model_task_config):
if invalid_model_task_config:
pytest.skip(invalid_model_task_config)

world_size = 2
model, task = model_w_task

@distributed_test(world_size=[world_size])
def _go():
local_rank = int(os.getenv("LOCAL_RANK", "0"))

pipe = pipeline(task, model=model, device=local_rank, framework="pt")
pipe.model = deepspeed.init_inference(
pipe.model,
mp_size=world_size,
dtype=dtype,
replace_method="auto",
replace_with_kernel_inject=True,
enable_cuda_graph=enable_cuda_graph,
)
response = pipe(query, **inf_kwargs)

outputs = [None] * world_size
dist.all_gather_object(outputs, response)
for out_1, out_2 in itertools.combinations(outputs, 2):
assert assert_fn(out_1, out_2)

_go()