Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions requirements/requirements-inf.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
google
lm-eval>=0.2.0
protobuf
transformers
transformers[sentencepiece]
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

import sys
import pytest
import os
from os.path import abspath, dirname, join
import torch
import warnings

# Set this environment variable for the T5 inference unittest(s) (e.g. google/t5-v1_1-small)
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'

# allow having multiple repository checkouts and not needing to remember to rerun
# 'pip install -e .[dev]' when switching between checkouts and running tests.
git_repo_path = abspath(join(dirname(dirname(__file__)), "src"))
Expand Down
72 changes: 72 additions & 0 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from packaging import version as pkg_version
from deepspeed.ops.op_builder import OpBuilder
from transformers import pipeline
from transformers.models.t5.modeling_t5 import T5Block
from transformers.models.roberta.modeling_roberta import RobertaLayer
from huggingface_hub import HfApi

rocm_version = OpBuilder.installed_rocm_version()
Expand Down Expand Up @@ -55,6 +57,7 @@
"text-classification",
"token-classification",
"text-generation",
"text2text-generation",
]
pytest.all_models = {
task: [m.modelId for m in _all_models if m.pipeline_tag == task]
Expand Down Expand Up @@ -150,6 +153,8 @@ def query(model_w_task):
return "My name is jean-baptiste and I live in montreal."
elif task == "text-generation":
return "DeepSpeed is the greatest"
elif task == "text2text-generation":
return "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
else:
NotImplementedError(f'query for task "{task}" is not implemented')

Expand Down Expand Up @@ -187,6 +192,11 @@ def text_generation_assert(x, y):
for res in y)


def text2text_generation_assert(x, y):
return set(res["generated_text"] for res in x) == set(res["generated_text"]
for res in y)


@pytest.fixture
def assert_fn(model_w_task):
model, task = model_w_task
Expand All @@ -196,6 +206,7 @@ def assert_fn(model_w_task):
"text-classification": text_classification_assert,
"token-classification": token_classification_assert,
"text-generation": text_generation_assert,
"text2text-generation": text2text_generation_assert,
}
assert_fn = assert_fn_dict.get(task, None)
if assert_fn is None:
Expand Down Expand Up @@ -323,6 +334,67 @@ def test(
assert assert_fn(bs_output, ds_output)


@pytest.mark.seq_inference
@pytest.mark.parametrize(
"model_w_task, injection_policy",
[
(("google/t5-v1_1-small",
"text2text-generation"),
{
T5Block: ('SelfAttention.o',
'EncDecAttention.o',
'DenseReluDense.wo')
}),
(("roberta-large",
"fill-mask"),
{
RobertaLayer: ('output.dense')
}),
],
ids=["t5",
"roberta"],
)
@pytest.mark.parametrize("dtype", [torch.float], ids=["fp32"])
@pytest.mark.parametrize("enable_cuda_graph", [False], ids=["noCG"])
class TestInjectionPolicy(DistributedTest):
world_size = [1, 2]

def test(
self,
model_w_task,
injection_policy,
query,
inf_kwargs,
assert_fn,
invalid_model_task_config,
dtype,
enable_cuda_graph,
):
if invalid_model_task_config:
pytest.skip(invalid_model_task_config)

model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "2"))

# We have to load these large models on CPU with pipeline because not
# enough GPU memory
pipe = pipeline(task, model=model, device=-1, framework="pt")
bs_output = pipe(query, **inf_kwargs)

pipe.model = deepspeed.init_inference(pipe.model,
mp_size=world_size,
dtype=dtype,
injection_policy=injection_policy)
# Switch device to GPU so that input tensors are not on CPU
pipe.device = torch.device(f"cuda:{local_rank}")
ds_output = pipe(query, **inf_kwargs)

print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)


@pytest.mark.nightly
@pytest.mark.parametrize(
"model_family, model_name",
Expand Down