diff --git a/.github/workflows/amd.yml b/.github/workflows/amd.yml index 6847180e97b0..6f9eae41cd45 100644 --- a/.github/workflows/amd.yml +++ b/.github/workflows/amd.yml @@ -63,5 +63,5 @@ jobs: run: | if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi cd tests - TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -x -n 4 -m 'not sequential' unit/ + TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -x -n 4 unit/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -x -m 'sequential' unit/ diff --git a/.github/workflows/nv-inference.yml b/.github/workflows/nv-inference.yml new file mode 100644 index 000000000000..d985f882254b --- /dev/null +++ b/.github/workflows/nv-inference.yml @@ -0,0 +1,63 @@ +name: nv-inference + +on: + push: + branches: + - 'master' + - 'staging**' + paths-ignore: + - 'docs/**' + pull_request: + paths-ignore: + - 'docs/**' + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + unit-tests: + runs-on: [self-hosted, nvidia, cu111, v100] + + steps: + - uses: actions/checkout@v2 + + - name: environment + run: | + nvidia-smi + which python + python --version + which nvcc + nvcc --version + pip install --upgrade pip + pip uninstall --yes torch torchvision + pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html + python -c "import torch; print('torch:', torch.__version__, torch)" + python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + + - name: Install transformers + run: | + git clone https://github.com/huggingface/transformers + cd transformers + # if needed switch to the last known good SHA until transformers@master is fixed + # git checkout 1cc453d33 + git rev-parse --short HEAD + pip uninstall --yes transformers + pip install . + + - name: Python environment + run: | + pip list + + - name: Install deepspeed + run: | + pip uninstall --yes deepspeed + pip install .[dev,1bit,autotuning,sparse_attn,inf] + ds_report + + - name: Unit tests + run: | + unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch + if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi + cd tests + TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 -m 'inference' unit/ diff --git a/.github/workflows/nv-nightly.yml b/.github/workflows/nv-nightly.yml new file mode 100644 index 000000000000..bb3c1a9430e2 --- /dev/null +++ b/.github/workflows/nv-nightly.yml @@ -0,0 +1,52 @@ +name: nv-nightly + +on: + schedule: + - cron: "0 0 * * *" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + unit-tests: + runs-on: [self-hosted, nvidia, cu111, v100] + + steps: + - uses: actions/checkout@v2 + + - name: environment + run: | + nvidia-smi + which python + python --version + which nvcc + nvcc --version + pip install --upgrade pip + pip uninstall --yes torch torchvision + pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html + python -c "import torch; print('torch:', torch.__version__, torch)" + python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + + - name: Install transformers + run: | + git clone https://github.com/huggingface/transformers + cd transformers + # if needed switch to the last known good SHA until transformers@master is fixed + # git checkout 1cc453d33 + git rev-parse --short HEAD + pip uninstall --yes transformers + pip install . + + - name: Install deepspeed + run: | + pip uninstall --yes deepspeed + pip install .[dev,1bit,autotuning,sparse_attn] + ds_report + + - name: Unit tests + run: | + unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch + if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi + cd tests + TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'nightly' unit/ diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml index d9c14ca440c4..fd6859c7fb45 100644 --- a/.github/workflows/nv-torch-latest-v100.yml +++ b/.github/workflows/nv-torch-latest-v100.yml @@ -60,5 +60,5 @@ jobs: unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi cd tests - TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 -m 'not sequential' unit/ + TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 unit/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'sequential' unit/ diff --git a/.github/workflows/nv-torch-nightly-v100.yml b/.github/workflows/nv-torch-nightly-v100.yml index bb53bef5e240..e1c916afba2d 100644 --- a/.github/workflows/nv-torch-nightly-v100.yml +++ b/.github/workflows/nv-torch-nightly-v100.yml @@ -53,5 +53,5 @@ jobs: unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi cd tests - TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 -m 'not sequential' unit/ + TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 unit/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'sequential' unit/ diff --git a/.github/workflows/nv-torch18-v100.yml b/.github/workflows/nv-torch18-v100.yml index abe941e527b7..a596393e890f 100644 --- a/.github/workflows/nv-torch18-v100.yml +++ b/.github/workflows/nv-torch18-v100.yml @@ -60,5 +60,5 @@ jobs: unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi cd tests - TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 -m 'not sequential' unit/ + TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 unit/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'sequential' unit/ diff --git a/.github/workflows/nv-transformers-v100.yml b/.github/workflows/nv-transformers-v100.yml index 97e8eea35aa0..efbd015ce1b0 100644 --- a/.github/workflows/nv-transformers-v100.yml +++ b/.github/workflows/nv-transformers-v100.yml @@ -57,6 +57,8 @@ jobs: pip install .[testing] # find reqs used in ds integration tests find examples/pytorch -regextype posix-egrep -regex '.*(language-modeling|question-answering|summarization|image-classification|text-classification|translation).*/requirements.txt' -exec grep -v 'torch' {} \; | xargs -I {} pip install --upgrade {} + # force datasets version due to issues + pip install datasets==2.2.2 # force protobuf version due to issues pip install "protobuf<4.21.0" pip list diff --git a/bin/dsr b/bin/dsr new file mode 120000 index 000000000000..747bf4722c42 --- /dev/null +++ b/bin/dsr @@ -0,0 +1 @@ +ds_report \ No newline at end of file diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index 00c70eea22b5..acb35325dd9e 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -174,7 +174,8 @@ __global__ void fused_bias_residual(float* input, float* attnbias, int total_count, int intermediate_size, - int mp_size) + int mp_size, + bool preln) { float4* input_cast = reinterpret_cast(input); float4* output_cast = reinterpret_cast(output); @@ -189,12 +190,17 @@ __global__ void fused_bias_residual(float* input, float4 res_vec = attn_cast[offset]; float4 bias_data = bias_cast[offset % intermediate_size]; float4 attn_bias = attnbias_cast[offset % intermediate_size]; - - data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x); - data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y); - data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z); - data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w); - + if (preln) { + data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x); + data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y); + data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z); + data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w); + } else { + data.x = data.x + out.x + bias_data.x; + data.y = data.y + out.y + bias_data.y; + data.z = data.z + out.z + bias_data.z; + data.w = data.w + out.w + bias_data.w; + } output_cast[offset] = data; } } @@ -206,7 +212,8 @@ __global__ void fused_bias_residual(__half* input, __half* attn_bias, int total_count, int intermediate_size, - int mp_size) + int mp_size, + bool preln) { #ifdef HALF_PRECISION_AVAILABLE @@ -248,15 +255,21 @@ __global__ void fused_bias_residual(__half* input, float2 attn_low_bias = __half22float2(attnbias_half[0]); float2 attn_high_bias = __half22float2(attnbias_half[1]); - low_data.x = - (low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x)); - low_data.y = - (low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y)); - high_data.x = - (high_data.x + high_res.x) * mp_size + (high_out.x + (high_bias.x + attn_high_bias.x)); - high_data.y = - (high_data.y + high_res.y) * mp_size + (high_out.y + (high_bias.y + attn_high_bias.y)); - + if (preln) { + low_data.x = + (low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x)); + low_data.y = + (low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y)); + high_data.x = (high_data.x + high_res.x) * mp_size + + (high_out.x + (high_bias.x + attn_high_bias.x)); + high_data.y = (high_data.y + high_res.y) * mp_size + + (high_out.y + (high_bias.y + attn_high_bias.y)); + } else { + low_data.x = (low_data.x + low_out.x + low_bias.x); + low_data.y = (low_data.y + low_out.y + low_bias.y); + high_data.x = (high_data.x + high_out.x + high_bias.x); + high_data.y = (high_data.y + high_out.y + high_bias.y); + } vals_half[0] = __float22half2_rn(low_data); vals_half[1] = __float22half2_rn(high_data); @@ -274,6 +287,7 @@ void launch_bias_residual(T* input, int batch, int hidden_dim, int mp_size, + bool preln, cudaStream_t stream) { int total_count = batch * hidden_dim / 4; @@ -281,20 +295,13 @@ void launch_bias_residual(T* input, dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size); fused_bias_residual<<>>( - input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size); + input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size, preln); } -template void -launch_bias_residual(float*, float*, float*, float*, float*, int, int, int, cudaStream_t); -template void launch_bias_residual<__half>(__half*, - __half*, - __half*, - __half*, - __half*, - int, - int, - int, - cudaStream_t); +template void launch_bias_residual< + float>(float*, float*, float*, float*, float*, int, int, int, bool, cudaStream_t); +template void launch_bias_residual< + __half>(__half*, __half*, __half*, __half*, __half*, int, int, int, bool, cudaStream_t); __global__ void gptj_residual_add(float* input, float* output, diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index ef4f0597d985..2c300c0b6c92 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -787,17 +787,17 @@ at::Tensor ds_vector_matmul_int8(at::Tensor& input, } template -void mlp_unfused_cublas(at::Tensor& output, - at::Tensor& input, - at::Tensor& residual, - at::Tensor& input_bias, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool preLayerNorm, - bool mlp_after_attn) +at::Tensor mlp_unfused_cublas(at::Tensor& output, + at::Tensor& input, + at::Tensor& residual, + at::Tensor& input_bias, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool preLayerNorm, + bool mlp_after_attn) { int bsz = input.size(0) * input.size(1); auto inp_norm = at::empty_like(input); @@ -840,18 +840,19 @@ void mlp_unfused_cublas(at::Tensor& output, weight.size(1), bsz, Context::Instance().GetCurrentStream()); + return inp_norm; } template -at::Tensor ds_mlp_gemm(at::Tensor& input, - at::Tensor& residual, - at::Tensor& input_bias, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool preLayerNorm, - bool mlp_after_attn) +std::vector ds_mlp_gemm(at::Tensor& input, + at::Tensor& residual, + at::Tensor& input_bias, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool preLayerNorm, + bool mlp_after_attn) { auto input_cont = input.contiguous(); auto options = at::TensorOptions() @@ -863,19 +864,19 @@ at::Tensor ds_mlp_gemm(at::Tensor& input, auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); int bsz = input_cont.size(0) * input_cont.size(1); - mlp_unfused_cublas(output, - mlp_after_attn ? input : residual, - residual, - input_bias, - weight, - bias, - gamma, - beta, - epsilon, - preLayerNorm, - mlp_after_attn); - - return output; + auto res_add = mlp_unfused_cublas(output, + mlp_after_attn ? input : residual, + residual, + input_bias, + weight, + bias, + gamma, + beta, + epsilon, + preLayerNorm, + mlp_after_attn); + + return {output, res_add}; } template @@ -1001,7 +1002,8 @@ void residual_add_bias(at::Tensor& output, at::Tensor& attention_b, int mp_size, bool mlp_after_attn, - bool add_bias) + bool add_bias, + bool preln) { int bsz = input.size(0) * input.size(1); int hidden_size = input.size(2); @@ -1017,6 +1019,7 @@ void residual_add_bias(at::Tensor& output, bsz, hidden_size, mp_size, + preln, Context::Instance().GetCurrentStream()); else launch_gptj_residual_add((float*)input.data_ptr(), @@ -1037,6 +1040,7 @@ void residual_add_bias(at::Tensor& output, bsz, hidden_size, mp_size, + preln, Context::Instance().GetCurrentStream()); else launch_gptj_residual_add<__half>((__half*)input.data_ptr(), diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/custom_cuda_layers.h index c8a0b79a111b..a48b2d7f06cc 100644 --- a/csrc/transformer/inference/includes/custom_cuda_layers.h +++ b/csrc/transformer/inference/includes/custom_cuda_layers.h @@ -58,6 +58,7 @@ void launch_bias_residual(T* input, int batch, int hidden_dim, int mp_size, + bool preln, cudaStream_t stream); template diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 17999a9b3a38..ec2cba4f9f14 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -3,21 +3,22 @@ ''' import torch import os -from torch.nn.modules import Module + import deepspeed.comm as dist +import deepspeed.utils.groups as groups + +from torch.nn.modules import Module +from packaging import version as pkg_version + from ..runtime.state_dict_factory import SDLoaderFactory from ..runtime.weight_quantizer import WeightQuantization from ..module_inject.replace_module import replace_transformer_layer from ..utils import logger from ..comm.comm import init_distributed - from ..pipe import PipelineModule from ..moe.utils import has_moe_layers from ..moe.layer import MoE -import deepspeed.comm as dist -import deepspeed.utils.groups as groups - DS_INFERENCE_ENABLED = False @@ -88,9 +89,13 @@ def __init__(self, self.ep_group = ep_group self.expert_mp_group = expert_mp_group self.enable_cuda_graph = enable_cuda_graph - self.cuda_grah_created = False + self.cuda_graph_created = False self._init_quantization_setting(quantization_setting) + if enable_cuda_graph: + assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \ + "If you want to use cuda graph, please upgrade torch to at least v1.10" + if self.checkpoint: self._load_checkpoint(self.checkpoint) @@ -372,7 +377,7 @@ def _create_cuda_graph(self, *inputs, **kwargs): with torch.cuda.graph(self._cuda_graphs): self.static_output = self.module(*self.static_inputs, **self.static_kwargs) - self.cuda_grah_created = True + self.cuda_graph_created = True def _graph_replay(self, *inputs, **kwargs): for i in range(len(inputs)): @@ -409,7 +414,7 @@ def forward(self, *inputs, **kwargs): outputs = self.model_orig_fwd(*inputs, **kwargs) else: if self.enable_cuda_graph: - if self.cuda_grah_created: + if self.cuda_graph_created: outputs = self._graph_replay(*inputs, **kwargs) else: self._create_cuda_graph(*inputs, **kwargs) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 54e7998acecc..72599e9e43a1 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -351,8 +351,11 @@ def replace_with_policy(child, # linear layer is created with [input, output] shape # transpose it here to reduce inference cost! def transpose(data): + # temp move to cpu to avoid requiring extra GPU memory during the reshape + data = data.to('cpu') data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1)) data = data.reshape(data.shape[-1], data.shape[-2]) + data.to(torch.cuda.current_device()) return data if attn_linear_layer: @@ -460,7 +463,7 @@ def _transpose(x): new_module.norm_b.data = input_nb.to(torch.cuda.current_device()) else: transformer_config = deepspeed.DeepSpeedTransformerConfig( - batch_size=micro_batch_size, + batch_size=micro_batch_size if micro_batch_size > 0 else 1, hidden_size=config.hidden_size, heads=config.num_attention_heads, attn_dropout_ratio=config.attention_probs_dropout_prob, diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index 12de114e55bd..cac9429c15ca 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -471,7 +471,7 @@ def forward(ctx, config.pre_layer_norm, False) else: - intermediate = mlp_gemm_func(input, + intermediate, residual_add = mlp_gemm_func(input, residual, bias, inter_w, @@ -482,14 +482,16 @@ def forward(ctx, config.pre_layer_norm, config.mlp_after_attn) output = vector_matmul_func(intermediate, output_w, False) - inference_cuda_module.residual_add(output, - residual, - input, - output_b, - bias if bias is not None else output_b, - config.mp_size, - config.mlp_after_attn, - bias is not None) + inference_cuda_module.residual_add( + output, + residual if config.pre_layer_norm else residual_add, + input, + output_b, + bias if bias is not None else output_b, + config.mp_size, + config.mlp_after_attn, + bias is not None, + config.pre_layer_norm) if mp_group is not None and dist.get_world_size(group=mp_group) > 1: dist.all_reduce(output, group=mp_group) return output diff --git a/requirements/requirements-inf.txt b/requirements/requirements-inf.txt new file mode 100644 index 000000000000..e9f45a392e1d --- /dev/null +++ b/requirements/requirements-inf.txt @@ -0,0 +1,2 @@ +lm-eval>=0.2.0 +transformers diff --git a/setup.py b/setup.py index 4a0105f17fe8..3d484f8edc88 100755 --- a/setup.py +++ b/setup.py @@ -61,7 +61,8 @@ def fetch_requirements(path): 'dev': fetch_requirements('requirements/requirements-dev.txt'), 'autotuning': fetch_requirements('requirements/requirements-autotuning.txt'), 'autotuning_ml': fetch_requirements('requirements/requirements-autotuning-ml.txt'), - 'sparse_attn': fetch_requirements('requirements/requirements-sparse_attn.txt') + 'sparse_attn': fetch_requirements('requirements/requirements-sparse_attn.txt'), + 'inf': fetch_requirements('requirements/requirements-inf.txt') } # Add specific cupy version to both onebit extension variants @@ -291,6 +292,7 @@ def create_dir_symlink(src, dest): 'bin/ds', 'bin/ds_ssh', 'bin/ds_report', + 'bin/dsr', 'bin/ds_elastic' ], classifiers=[ diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 000000000000..a52a49e5bbc3 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +addopts = -m "not sequential and not nightly and not inference" +markers = + sequential:Tests that need to be run sequentially + inference:Inference model tests + nightly:Tests that should be run nightly diff --git a/tests/unit/test_inference.py b/tests/unit/test_inference.py index 5b3a0cc681bc..ce176b9268b5 100644 --- a/tests/unit/test_inference.py +++ b/tests/unit/test_inference.py @@ -1,123 +1,323 @@ import os +import time import torch import pytest +import itertools import deepspeed +from deepspeed.git_version_info import torch_info from collections import defaultdict -from transformers import pipeline from .common import distributed_test from packaging import version as pkg_version +from deepspeed.ops.op_builder import OpBuilder -pytest.task_query_dict = { - "fill-mask": - defaultdict( - lambda: "Hello I'm a [MASK] model.", - {"roberta-base": "Hello I'm a model."}, - ), - "question-answering": - defaultdict(lambda: { - "question": "What is the greatest?", - "context": "DeepSpeed is the greatest", - }), - "text-classification": - defaultdict(lambda: "DeepSpeed is the greatest"), - "token-classification": - defaultdict(lambda: "My name is jean-baptiste and I live in montreal."), - "text-generation": - defaultdict(lambda: "DeepSpeed is the greatest"), -} -pytest.task_model_dict = { - "fill-mask": { - "bert": "bert-base-cased", - "roberta": "roberta-base" - }, - "question-answering": { - "bert": "deepset/minilm-uncased-squad2", - "roberta": "deepset/roberta-base-squad2", - }, - "text-classification": { - "bert": "cross-encoder/ms-marco-MiniLM-L-12-v2", - "roberta": "j-hartmann/emotion-english-distilroberta-base", - }, - "token-classification": { - "bert": "dslim/bert-base-NER", - "roberta": "Jean-Baptiste/roberta-large-ner-english", - }, - "text-generation": { - "gpt2": "distilgpt2", - "gpt_neo": "Norod78/hebrew-bad_wiki-gpt_neo-tiny", - "gptj": "EleutherAI/gpt-j-6B", - }, +try: + import lm_eval + import lm_eval.models + import lm_eval.tasks + from lm_eval.evaluator import evaluate + from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer + from huggingface_hub import HfApi +except ImportError: + pytest.skip("please install w. [inf] extra to run this test", + allow_module_level=True) + +rocm_version = OpBuilder.installed_rocm_version() +if rocm_version != (0, 0): + pytest.skip("skip inference tests on rocm for now", allow_module_level=True) + +_bert_models = [ + "bert-base-cased", + "bert-base-uncased", + "bert-large-cased", + "bert-large-uncased", + "bert-base-multilingual-cased", + "bert-base-multilingual-uncased", + "deepset/minilm-uncased-squad2", + "cross-encoder/ms-marco-MiniLM-L-12-v2", + "dslim/bert-base-NER", + "bert-large-uncased-whole-word-masking-finetuned-squad", + "distilbert-base-cased-distilled-squad", +] +_roberta_models = [ + "roberta-large", + "roberta-base", + "deepset/roberta-base-squad2", + "j-hartmann/emotion-english-distilroberta-base", + "Jean-Baptiste/roberta-large-ner-english", +] +_gpt_models = [ + "gpt2", + "distilgpt2", + "Norod78/hebrew-bad_wiki-gpt_neo-tiny", + "EleutherAI/gpt-j-6B", +] +_all_models = HfApi().list_models() + +test_models = set(_bert_models + _roberta_models + _gpt_models) +test_tasks = [ + "fill-mask", + "question-answering", + "text-classification", + "token-classification", + "text-generation", +] +pytest.all_models = { + task: [m.modelId for m in _all_models if m.pipeline_tag == task] + for task in test_tasks } +_model_w_tasks = itertools.product(*[test_models, test_tasks]) + + +def _valid_model_task(model_task): + m, t = model_task + return m in pytest.all_models[t] + + +pytest.models_w_tasks = list(filter(_valid_model_task, _model_w_tasks)) +pytest.mt_names = [f"{m}-{t}" for m, t in pytest.models_w_tasks] +""" +These fixtures iterate all combinations of tasks and models, dtype, & cuda_graph +""" + + +@pytest.fixture(params=pytest.models_w_tasks, ids=pytest.mt_names) +def model_w_task(request): + return request.param + + +@pytest.fixture(params=[torch.float, torch.half], ids=["fp32", "fp16"]) +def dtype(request): + return request.param + + +@pytest.fixture(params=[True, False], ids=["CG", "noCG"]) +def enable_cuda_graph(request): + return request.param + + +""" +This fixture will validate the configuration +""" + + +@pytest.fixture() +def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph): + model, task = model_w_task + if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"): + msg = "DS inference injection doesn't work well on older torch versions" + elif model not in pytest.all_models[task]: + msg = f"Not a valid model / task combination: {model} / {task}" + elif enable_cuda_graph and (torch_info["cuda_version"] == "0.0"): + msg = "CUDA not detected, cannot use CUDA Graph" + elif enable_cuda_graph and pkg_version.parse( + torch.__version__) < pkg_version.parse("1.10"): + msg = "CUDA Graph is only available in torch versions >= 1.10" + elif ("gpt-j-6B" in model) and (dtype == torch.float): + msg = f"Not enough GPU memory to run {model} with dtype {dtype}" + else: + msg = "" + return msg + + +""" +These fixtures can be used to customize the query, inference args, and assert +statement for each combination of model /task +""" + @pytest.fixture -def model(task, model_family): - if model_family not in pytest.task_model_dict[task]: - pytest.skip(f"No models in family {model_family} for task {task}") - return pytest.task_model_dict[task][model_family] +def query(model_w_task): + model, task = model_w_task + if task == "fill-mask": + if "roberta" in model: + return "Hello I'm a model." + else: + return "Hell I'm a [MASK] model." + elif task == "question-answering": + return { + "question": "What's my name?", + "context": "My name is Clara and I live in Berkeley", + } + elif task == "text-classification": + return "DeepSpeed is the greatest" + elif task == "token-classification": + return "My name is jean-baptiste and I live in montreal." + elif task == "text-generation": + return "DeepSpeed is the greatest" + else: + NotImplementedError(f'query for task "{task}" is not implemented') @pytest.fixture -def query(task, model): - return pytest.task_query_dict[task][model] +def inf_kwargs(model_w_task): + model, task = model_w_task + if task == "text-generation": + return {"do_sample": False} + else: + return {} -@pytest.mark.parametrize( - "task", - ( - "fill-mask", - "question-answering", - "text-classification", - "token-classification", - "text-generation", - ), -) -@pytest.mark.parametrize("model_family", ("bert", "roberta", "gpt2", "gpt_neo")) -def test_model_task_inject(task, model, query, dtype=torch.float): - if pkg_version.parse(torch.__version__) <= pkg_version.parse('1.2'): - pytest.skip("DS inference injection doesn't work well on older torch versions") +@pytest.fixture +def assert_fn(model_w_task): + model, task = model_w_task + if task == "fill-mask": + return lambda x, y: set(res["token_str"] for res in x) == set( + res["token_str"] for res in y + ) + elif task == "question-answering": + return lambda x, y: x["answer"] == y["answer"] + elif task == "text-classification": + return lambda x, y: set(res["label"] for res in x) == set( + res["label"] for res in y + ) + elif task == "token-classification": + return lambda x, y: set(ent["word"] for ent in x) == set( + ent["word"] for ent in y + ) + elif task == "text-generation": + return lambda x, y: set(res["generated_text"] for res in x) == set( + res["generated_text"] for res in y + ) + else: + NotImplementedError(f'assert_fn for task "{task}" is not implemented') + + +""" +Tests +""" + + +@pytest.mark.inference +def test_model_task( + model_w_task, + dtype, + enable_cuda_graph, + query, + inf_kwargs, + assert_fn, + invalid_model_task_config, +): + if invalid_model_task_config: + pytest.skip(invalid_model_task_config) + + model, task = model_w_task @distributed_test(world_size=[1]) def _go(): local_rank = int(os.getenv("LOCAL_RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - generator = pipeline(task, model=model, device=local_rank) - generator.model = deepspeed.init_inference( - generator.model, - mp_size=world_size, + if "gpt-j-6B" in model and dtype == torch.half: + _model = AutoModelForCausalLM.from_pretrained(model) + tokenizer = AutoTokenizer.from_pretrained(model) + _model.half() + pipe = pipeline( + task, + model=_model, + tokenizer=tokenizer, + device=local_rank, + framework="pt", + ) + else: + pipe = pipeline(task, model=model, device=local_rank, framework="pt") + if dtype == torch.half: + pipe.model.half() + + # Warm-up queries for perf measurement + for i in range(10): + _ = pipe(query, **inf_kwargs) + torch.cuda.synchronize() + start = time.time() + bs_output = pipe(query, **inf_kwargs) + torch.cuda.synchronize() + bs_time = time.time() - start + + pipe.model = deepspeed.init_inference( + pipe.model, + mp_size=1, dtype=dtype, replace_method="auto", replace_with_kernel_inject=True, + enable_cuda_graph=enable_cuda_graph, ) + # Warm-up queries for perf measurement + for i in range(10): + _ = pipe(query, **inf_kwargs) + torch.cuda.synchronize() + start = time.time() + ds_output = pipe(query, **inf_kwargs) + torch.cuda.synchronize() + ds_time = time.time() - start - response = generator(query) + if task == "text-generation": + bs_output = pipe(query, **inf_kwargs) - _go() + # These performance tests are only measuring the time for a single + # inference request, we just want to check that performance isn't terrible + assert ds_time <= (bs_time * 1.1) + assert assert_fn(bs_output, ds_output) + _go() -@pytest.mark.parametrize("dtype", [(torch.float), (torch.half)]) -def test_gpt2_inject(dtype): - if pkg_version.parse(torch.__version__) <= pkg_version.parse('1.2'): - pytest.skip("DS inference injection doesn't work well on older torch versions") +@pytest.mark.nightly +@pytest.mark.parametrize( + "model_family, model_name", + ( + ["gpt2", + "EleutherAI/gpt-neo-2.7B"], + ["gpt2", + "EleutherAI/gpt-j-6B"], + ["gpt2", + "gpt2-xl"], + ), +) +@pytest.mark.parametrize("task", ["lambada"]) +def test_lm_correctness(model_family, model_name, task): @distributed_test(world_size=[1]) def _go(): - local_rank = int(os.getenv("LOCAL_RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - generator = pipeline("text-generation", model="gpt2", device=local_rank) + local_rank = os.getenv("LOCAL_RANK", "0") + device = torch.device(f"cuda:{local_rank}") + dtype = torch.float + task_dict = lm_eval.tasks.get_task_dict([task]) + + if 'gpt-j-6B' in model_name: + dtype = torch.half + lm = lm_eval.models.get_model(model_family).create_from_arg_string( + f"pretrained={model_name}", + {"device": "cpu"}) + setattr(lm, model_family, getattr(lm, model_family).half().to(device)) + lm._device = device + else: + lm = lm_eval.models.get_model(model_family).create_from_arg_string( + f"pretrained={model_name}", + {"device": f"cuda:{local_rank}"}) + + torch.cuda.synchronize() + start = time.time() + bs_output = evaluate(lm=lm, task_dict=task_dict) + torch.cuda.synchronize() + bs_time = time.time() - start - generator.model = deepspeed.init_inference( - generator.model, - mp_size=world_size, + ds_model = deepspeed.init_inference( + getattr(lm, + model_family), + mp_size=1, dtype=dtype, replace_method="auto", replace_with_kernel_inject=True, + enable_cuda_graph=False, ) + setattr(lm, model_family, ds_model) + torch.cuda.synchronize() + start = time.time() + ds_output = evaluate(lm=lm, task_dict=task_dict) + torch.cuda.synchronize() + ds_time = time.time() - start - prompt = "DeepSpeed is" - string_1 = generator(prompt, do_sample=False, max_length=128) - string_2 = generator(prompt, do_sample=False, max_length=128) - assert string_1 == string_2 + ppl_diff = abs(bs_output["results"][task]["ppl"] - + ds_output["results"][task]["ppl"]) + assert ds_time <= bs_time + assert ppl_diff < 0.01 _go()