Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 18 additions & 26 deletions fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
self._init_model_inputs()

# CUDA Graph
self.use_cudagraph = False # self.graph_opt_config.use_cudagraph
self.use_cudagraph = False # TODO(gongshaotian): Use Target Model flag
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes

Expand Down Expand Up @@ -117,6 +117,9 @@ def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode
self.parallel_config.max_model_len - max_dec_len,
)

if self.fd_config.parallel_config.enable_expert_parallel:
input_length = min(input_length, 32)

block_num = (
input_length + self.cache_config.block_size - 1
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
Expand Down Expand Up @@ -541,7 +544,7 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
self.model_inputs["not_need_stop"][0] = True
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer

def _initialize_forward_meta(self):
def _initialize_forward_meta(self, step_use_cudagraph: bool = False):
"""
Initialize forward meta and attention meta data
"""
Expand Down Expand Up @@ -569,23 +572,8 @@ def _initialize_forward_meta(self):
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)

# Update Batch type for cuda graph
only_decode_batch = True
prefill_exists = None

# Mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
only_decode_batch_list = []
prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
only_decode_batch = all(only_decode_batch_list)
self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"

self.forward_meta.step_use_cudagraph = (
self.use_cudagraph
and only_decode_batch
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
)
# TODO(gongshaotian): Use CUDAGraph with Draft Model
self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.use_cudagraph

def exist_prefill(self):
"""
Expand Down Expand Up @@ -671,9 +659,12 @@ def _post_process(self, sampled_token_ids):
self.parallel_config.use_ep,
)

def _propose(self):
def _propose(self, step_use_cudagraph: bool = False):
"""
Main process for MTP inference
Main process for MTP inference.
Args:
step_use_cudagraph: bool
Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP.
"""
for substep in range(self.num_model_steps):
if self.model_inputs["not_need_stop"]:
Expand All @@ -697,15 +688,16 @@ def _propose(self):

# Initialize forward meta data
self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
self.model_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
self.model_inputs["batch_id_per_token"][:] = -1
self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
# for speculative decoding
self.model_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False)

# Initialize forward meta data
self._initialize_forward_meta()
self._initialize_forward_meta(step_use_cudagraph=step_use_cudagraph)
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)

# Padding inputs for cuda graph
self.padding_cudagraph_inputs()
Expand Down Expand Up @@ -861,10 +853,10 @@ def _extend_draft_token_with_ngram_match(self):
self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()

def _run_impl(self, full_hidden_states):
""""""
def _run_impl(self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False):
"""Execute Draft Model"""
self._prepare_inputs(full_hidden_states)
self._propose()
self._propose(step_use_cudagraph=step_use_cudagraph)
self._update_status()
if self.hybrid_mode:
self._extend_draft_token_with_ngram_match()
Expand Down
8 changes: 6 additions & 2 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,9 @@ def _dummy_run(

if self.speculative_decoding:
if self.speculative_method == "mtp":
self.proposer.run(full_hidden_states=model_output)
self.proposer.run(
full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph
)
else:
self.proposer.run(share_inputs=self.share_inputs)

Expand Down Expand Up @@ -1600,7 +1602,9 @@ class at the server level, which is too granular for ModelRunner.
# 6. Speculative decode
if self.speculative_decoding:
if self.speculative_method == "mtp":
self.proposer.run(full_hidden_states=model_output)
self.proposer.run(
full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph
)
else:
self.proposer.run(share_inputs=self.share_inputs)

Expand Down
Loading