From 65adc00539c30059324aa0dbe07cd050a33742d6 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 2 Sep 2022 14:13:09 -0700 Subject: [PATCH 1/7] only override forward if using cuda-graph --- deepspeed/inference/engine.py | 46 +++++++---------------------------- 1 file changed, 9 insertions(+), 37 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index a4b57a05f37b..81566e7165c5 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -162,10 +162,7 @@ def __init__(self, torch.cuda.set_rng_state(_rng_state.cpu()) if self.mp_world_size > 1: - self.model_orig_fwd = self.module.forward - self.module.forward = self.forward - else: - self.module.register_forward_pre_hook(self._pre_forward_hook) + assert not self.enable_cuda_graph, "Cuda graph is not supported for model parallelism" def _get_model_config_generate(self, config): self.config = getattr(self.module, 'config', None) if config is None else config @@ -475,14 +472,6 @@ def _convert_to_dtype(self): elif self.dtype == torch.float: self.module.float() - def _pre_forward_hook(self, module, *inputs, **kwargs): - for input in inputs: - if torch.is_tensor(input): - input = input.to(torch.cuda.current_device()) - for k in kwargs: - if torch.is_tensor(kwargs[k]): - kwargs[k] = kwargs[k].to(torch.cuda.current_device()) - def _create_cuda_graph(self, *inputs, **kwargs): # warmup to create the workspace and cublas handle cuda_stream = torch.cuda.Stream() @@ -519,30 +508,13 @@ def forward(self, *inputs, **kwargs): *inputs: Variable length input list **kwargs: variable length keyword arguments """ - - if self.mp_world_size > 1: - if self.mpu is None: - for input in inputs: - if torch.is_tensor(input): - input = input.to(torch.cuda.current_device()) - if not input.is_contiguous(): - input = input.contiguous() - dist.broadcast(input, 0) - for k in kwargs: - if torch.is_tensor(kwargs[k]): - kwargs[k] = kwargs[k].to(torch.cuda.current_device()) - if not kwargs[k].is_contiguous(): - kwargs[k] = kwargs[k].contiguous() - dist.broadcast(kwargs[k], 0) - outputs = self.model_orig_fwd(*inputs, **kwargs) - else: - if self.enable_cuda_graph: - if self.cuda_graph_created: - outputs = self._graph_replay(*inputs, **kwargs) - else: - self._create_cuda_graph(*inputs, **kwargs) - outputs = self._graph_replay(*inputs, **kwargs) + if self.enable_cuda_graph: + if self.cuda_graph_created: + outputs = self._graph_replay(*inputs, **kwargs) else: - outputs = self.module(*inputs, **kwargs) - #outputs = self.module(*inputs, **kwargs) + self._create_cuda_graph(*inputs, **kwargs) + outputs = self._graph_replay(*inputs, **kwargs) + else: + outputs = self.module(*inputs, **kwargs) + return outputs From 94bbe8b30e4ff0b592b31ed4bf3cb52990efc0a1 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 7 Sep 2022 21:56:06 -0700 Subject: [PATCH 2/7] override forward if cuda graph enabled --- deepspeed/inference/engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 81566e7165c5..1e729f832fba 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -163,6 +163,9 @@ def __init__(self, if self.mp_world_size > 1: assert not self.enable_cuda_graph, "Cuda graph is not supported for model parallelism" + elif self.enable_cuda_graph: + self.model_orig_fwd = self.module.forward + self.module.forward = self.forward def _get_model_config_generate(self, config): self.config = getattr(self.module, 'config', None) if config is None else config From 8f2db03b93af1e22b8be706085c89cc9a09c26cb Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 8 Sep 2022 13:43:08 -0700 Subject: [PATCH 3/7] remove fwd override --- deepspeed/inference/engine.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 1e729f832fba..81566e7165c5 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -163,9 +163,6 @@ def __init__(self, if self.mp_world_size > 1: assert not self.enable_cuda_graph, "Cuda graph is not supported for model parallelism" - elif self.enable_cuda_graph: - self.model_orig_fwd = self.module.forward - self.module.forward = self.forward def _get_model_config_generate(self, config): self.config = getattr(self.module, 'config', None) if config is None else config From eef3fd16543f9e8c298be7bf60cbcbc90144a508 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 9 Sep 2022 14:18:23 -0700 Subject: [PATCH 4/7] use latest HF in inf test --- .github/workflows/nv-inference.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/nv-inference.yml b/.github/workflows/nv-inference.yml index dc00682edd46..d2551b0f79d2 100644 --- a/.github/workflows/nv-inference.yml +++ b/.github/workflows/nv-inference.yml @@ -40,8 +40,6 @@ jobs: run: | git clone https://github.com/huggingface/transformers cd transformers - # if needed switch to the last known good SHA until transformers@master is fixed - git checkout v4.21.2 git rev-parse --short HEAD pip uninstall --yes transformers pip install . From f787b340859de7c2e361558becc8b6429b6b0e1e Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Mon, 12 Sep 2022 15:13:19 -0700 Subject: [PATCH 5/7] turn off CG for MP tests --- tests/unit/inference/test_inference.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 1b1efdc595fe..198dc62add97 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -308,7 +308,6 @@ def test( self, model_w_task, dtype, - enable_cuda_graph, query, inf_kwargs, assert_fn, @@ -325,14 +324,11 @@ def test( pipe = pipeline(task, model=model, device=-1, framework="pt") bs_output = pipe(query, **inf_kwargs) - pipe.model = deepspeed.init_inference( - pipe.model, - mp_size=self.world_size, - dtype=dtype, - replace_method="auto", - replace_with_kernel_inject=True, - enable_cuda_graph=enable_cuda_graph, - ) + pipe.model = deepspeed.init_inference(pipe.model, + mp_size=self.world_size, + dtype=dtype, + replace_method="auto", + replace_with_kernel_inject=True) # Switch device to GPU so that input tensors are not on CPU pipe.device = torch.device(f"cuda:{local_rank}") ds_output = pipe(query, **inf_kwargs) From 6da9f1d2135295afb763a46d0d75b9e9d3292af8 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 13 Sep 2022 15:17:53 -0700 Subject: [PATCH 6/7] swap gpt2 for gpt_neo in mp tests --- tests/unit/inference/test_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 198dc62add97..d9dbad45ef06 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -292,7 +292,7 @@ def test( @pytest.mark.seq_inference @pytest.mark.parametrize("model_w_task", - [("gpt2", + [("EleutherAI/gpt-neo-1.3B", "text-generation"), ("EleutherAI/gpt-neox-20b", "text-generation"), From 40d4d0f9a98ca9f8af48a49983ae4c2595fc477c Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 14 Sep 2022 10:37:16 -0700 Subject: [PATCH 7/7] rename gpt-neo id --- tests/unit/inference/test_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index d9dbad45ef06..04e9320accc8 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -298,7 +298,7 @@ def test( "text-generation"), ("bigscience/bloom-3b", "text-generation")], - ids=["gpt2", + ids=["gpt-neo", "gpt-neox", "bloom"]) class TestMPSize(DistributedTest):