Skip to content

Commit

Permalink
[Spec Decoding] Use target model max length as default for draft model (
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored Aug 21, 2024
1 parent 6925cdb commit 9b73a2f
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
Expand Down Expand Up @@ -210,7 +211,8 @@ def __init__(
hf_config=self.hf_text_config,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window())
sliding_window_len=self.get_hf_config_sliding_window(),
spec_target_max_model_len=spec_target_max_model_len)
self.served_model_name = get_served_model_name(model,
served_model_name)
self.multimodal_config = self._init_multimodal_config(
Expand Down Expand Up @@ -1134,6 +1136,7 @@ def maybe_create_spec_config(
code_revision=draft_code_revision,
tokenizer_revision=target_model_config.tokenizer_revision,
max_model_len=None,
spec_target_max_model_len=target_model_config.max_model_len,
quantization=draft_quantization,
enforce_eager=target_model_config.enforce_eager,
max_seq_len_to_capture=target_model_config.
Expand Down Expand Up @@ -1563,6 +1566,7 @@ def _get_and_verify_max_len(
max_model_len: Optional[int],
disable_sliding_window: bool,
sliding_window_len: Optional[int],
spec_target_max_model_len: Optional[int] = None,
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
Expand Down Expand Up @@ -1605,6 +1609,11 @@ def _get_and_verify_max_len(
# If max_model_len is specified, we use it.
return max_model_len

if spec_target_max_model_len is not None:
# If this is a speculative draft model, we use the max model len
# from the target model.
return spec_target_max_model_len

default_max_len = 2048
logger.warning(
"The model's config.json does not contain any of the following "
Expand Down

0 comments on commit 9b73a2f

Please sign in to comment.