diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 7b332b6f0c..2f107e18bd 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -395,12 +395,15 @@ def warmup(self): with self.all_context(): max_batches = self.cache_config.max_batches num_tokens = max_batches - + dist_ctx = get_dist_manager().current_context() + dp = dist_ctx.dp # warmup prefill inputs = self.inputs_strategy.make_dummy(max_batches, is_decoding=False, device='cuda', vocab_size=self.model_config.vocab_size) + if dp > 1: + inputs.build_dp_meta() self._forward_impl(inputs) # warmup decoding(with cuda graph) @@ -411,6 +414,8 @@ def warmup(self): is_decoding=True, device='cuda', vocab_size=self.model_config.vocab_size) + if dp > 1: + inputs.build_dp_meta() self._forward_impl(inputs) def _slice_outs(self, inputs: torch.Tensor, seq_length: torch.LongTensor):