Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 20 additions & 28 deletions fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def __init__(
self._init_model_inputs()

# CUDA Graph
self.use_cudagraph = self.graph_opt_config.use_cudagraph
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes

Expand Down Expand Up @@ -119,6 +118,11 @@ def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode
num_tokens // batch_size,
self.model_config.max_model_len - max_dec_len,
)

# TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP.
if self.fd_config.parallel_config.enable_expert_parallel:
input_length = min(input_length, 32)

block_num = (
input_length + self.cache_config.block_size - 1
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
Expand Down Expand Up @@ -551,7 +555,7 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
self.model_inputs["not_need_stop"][0] = True
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer

def _initialize_forward_meta(self):
def _initialize_forward_meta(self, step_use_cudagraph: bool = False):
"""
Initialize forward meta and attention meta data
"""
Expand Down Expand Up @@ -587,23 +591,7 @@ def _initialize_forward_meta(self):
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)

# Update Batch type for cuda graph
only_decode_batch = True
prefill_exists = None

# Mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
only_decode_batch_list = []
prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
only_decode_batch = all(only_decode_batch_list)
self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"

self.forward_meta.step_use_cudagraph = (
self.use_cudagraph
and only_decode_batch
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
)
self.forward_meta.step_use_cudagraph = step_use_cudagraph

def exist_prefill(self):
"""
Expand Down Expand Up @@ -689,9 +677,12 @@ def _post_process(self, sampled_token_ids):
self.parallel_config.use_ep,
)

def _propose(self):
def _propose(self, step_use_cudagraph: bool = False):
"""
Main process for MTP inference
Main process for MTP inference.
Args:
step_use_cudagraph: bool
Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP.
"""
for substep in range(self.num_model_steps):
if self.model_inputs["not_need_stop"]:
Expand All @@ -715,7 +706,7 @@ def _propose(self):

# Initialize forward meta data
self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
self.model_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
self.model_inputs["batch_id_per_token"][:] = -1
self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)

Expand All @@ -724,7 +715,8 @@ def _propose(self):
self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False)

# Initialize forward meta data
self._initialize_forward_meta()
self._initialize_forward_meta(step_use_cudagraph=step_use_cudagraph)
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)

# Padding inputs for cuda graph
self.padding_cudagraph_inputs()
Expand Down Expand Up @@ -752,7 +744,7 @@ def _propose(self):
previous_hidden_states=self.model_inputs["target_hidden_states"],
forward_meta=self.forward_meta,
)
if self.use_cudagraph:
if self.forward_meta.step_use_cudagraph:
model_output = model_output[: self.real_token_num]
hidden_states = rebuild_padding(
model_output,
Expand Down Expand Up @@ -871,10 +863,10 @@ def _extend_draft_token_with_ngram_match(self):
self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()

def _run_impl(self, full_hidden_states):
""""""
def _run_impl(self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False):
"""Execute Draft Model"""
self._prepare_inputs(full_hidden_states)
self._propose()
self._propose(step_use_cudagraph=step_use_cudagraph)
self._update_status()
if self.hybrid_mode:
self._extend_draft_token_with_ngram_match()
Expand All @@ -891,7 +883,7 @@ def padding_cudagraph_inputs(self) -> None:
# In init_attention_metadata, the decode buffer has already been cleared

# To adapt to CUDA Graph, keep the forward pass at the maximum batch size.
if self.use_cudagraph:
if self.forward_meta.step_use_cudagraph:
self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
return
9 changes: 6 additions & 3 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,10 +1501,11 @@ def _dummy_sampler_run(
skip_save_output=True,
async_output_queue=self.async_output_queue,
)

if self.speculative_decoding:
if self.speculative_method == "mtp":
self.proposer.run(full_hidden_states=model_output)
self.proposer.run(
full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph
)
else:
self.proposer.run(share_inputs=self.share_inputs)

Expand Down Expand Up @@ -1948,7 +1949,9 @@ class at the server level, which is too granular for ModelRunner.
# 6. Speculative decode
if self.speculative_decoding:
if self.speculative_method == "mtp":
self.proposer.run(full_hidden_states=model_output)
self.proposer.run(
full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph
)
else:
self.proposer.run(share_inputs=self.share_inputs)

Expand Down
Loading