Skip to content

Commit 59e3793

Browse files
zixi-qiradeksm
authored andcommitted
[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE3 (vllm-project#17504)
Signed-off-by: qizixi <[email protected]>
1 parent 462c39d commit 59e3793

File tree

2 files changed

+36
-31
lines changed

2 files changed

+36
-31
lines changed

vllm/model_executor/models/llama_eagle3.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import torch.nn as nn
77
from transformers import LlamaConfig
88

9-
from vllm.config import ModelConfig, VllmConfig
9+
from vllm.compilation.decorators import support_torch_compile
10+
from vllm.config import VllmConfig
1011
from vllm.logger import init_logger
1112
from vllm.model_executor.layers.layernorm import RMSNorm
1213
from vllm.model_executor.layers.linear import QKVParallelLinear
@@ -76,17 +77,19 @@ def forward(
7677
return hidden_states, residual
7778

7879

80+
@support_torch_compile
7981
class LlamaModel(nn.Module):
8082

8183
def __init__(
8284
self,
8385
*,
84-
model_config: ModelConfig,
86+
vllm_config: VllmConfig,
8587
start_layer_id: int = 0,
8688
prefix: str = "",
8789
) -> None:
8890
super().__init__()
89-
self.config = model_config.hf_config
91+
self.config = vllm_config. \
92+
speculative_config.draft_model_config.hf_config
9093
self.vocab_size = self.config.vocab_size
9194
self.embed_tokens = VocabParallelEmbedding(
9295
self.config.vocab_size,
@@ -119,8 +122,7 @@ def forward(
119122
hidden_states: torch.Tensor,
120123
) -> tuple[torch.Tensor, torch.Tensor]:
121124
input_embeds = self.embed_tokens(input_ids)
122-
if (hidden_states.shape[-1] != input_embeds.shape[-1]):
123-
hidden_states = self.fc(hidden_states)
125+
assert hidden_states.shape[-1] == input_embeds.shape[-1]
124126

125127
residual = None
126128
hidden_states, residual = self.layers[0](
@@ -169,9 +171,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
169171

170172
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
171173
nn.Module.__init__(self)
172-
model_config = vllm_config.speculative_config.draft_model_config
173-
self.config = model_config.hf_config
174-
self.model = LlamaModel(model_config=model_config,
174+
self.config = vllm_config. \
175+
speculative_config.draft_model_config.hf_config
176+
self.model = LlamaModel(vllm_config=vllm_config,
175177
start_layer_id=start_layer_id,
176178
prefix="model")
177179

@@ -214,6 +216,13 @@ def compute_logits(
214216
logits_new[:, targets] = logits
215217
return logits_new
216218

219+
def combine_hidden_states(
220+
self,
221+
hidden_states: torch.Tensor,
222+
) -> torch.Tensor:
223+
# combine multiple auxiliary hidden states returned by eagle3
224+
return self.model.fc(hidden_states)
225+
217226
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
218227
loader = AutoWeightsLoader(
219228
self,

vllm/v1/spec_decode/eagle.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.model_executor.model_loader.loader import get_model_loader
1111
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
1212
from vllm.model_executor.models import ModelRegistry
13+
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
1314
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
1415
from vllm.v1.sample.metadata import SamplingMetadata
1516

@@ -39,11 +40,9 @@ def __init__(
3940

4041
self.hidden_size = vllm_config.model_config.get_hidden_size()
4142

42-
# TODO: make eagle3 compatible with cudagraph
43-
self.use_cuda_graph = self.method != 'eagle3' and \
44-
(self.vllm_config.compilation_config.level
45-
== CompilationLevel.PIECEWISE and
46-
not self.vllm_config.model_config.enforce_eager)
43+
self.use_cuda_graph = (self.vllm_config.compilation_config.level
44+
== CompilationLevel.PIECEWISE and
45+
not self.vllm_config.model_config.enforce_eager)
4746

4847
self.cudagraph_batch_sizes = list(
4948
reversed(
@@ -90,6 +89,12 @@ def propose(
9089
batch_size = next_token_ids.shape[0]
9190
last_token_indices = cu_num_tokens[1:] - 1
9291

92+
if self.method == "eagle3":
93+
assert isinstance(self.model, Eagle3LlamaForCausalLM)
94+
target_hidden_states = self.model.combine_hidden_states(
95+
target_hidden_states)
96+
assert target_hidden_states.shape[-1] == self.hidden_size
97+
9398
# Shift the input ids by one token.
9499
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
95100
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
@@ -126,20 +131,15 @@ def propose(
126131
# copy inputs to buffer for cudagraph
127132
self.positions[:num_tokens] = target_positions
128133

129-
if self.method == 'eagle':
130-
self.hidden_states[:num_tokens] = target_hidden_states
131-
hidden_states = self.hidden_states
132-
else:
133-
# TODO: make eagle3 compatible with cuda graph
134-
hidden_states = target_hidden_states
134+
self.hidden_states[:num_tokens] = target_hidden_states
135135

136136
with set_forward_context(attn_metadata,
137137
self.vllm_config,
138138
num_tokens=num_input_tokens):
139139
last_hidden_states, hidden_states = self.model(
140140
input_ids=self.input_ids[:num_input_tokens],
141141
positions=self.positions[:num_input_tokens],
142-
hidden_states=hidden_states[:num_input_tokens],
142+
hidden_states=self.hidden_states[:num_input_tokens],
143143
)
144144
sample_hidden_states = last_hidden_states[last_token_indices]
145145
logits = self.model.compute_logits(sample_hidden_states, None)
@@ -209,10 +209,7 @@ def propose(
209209
self.input_ids[:batch_size] = input_ids
210210
self.positions[:batch_size] = clamped_positions
211211

212-
if self.method == 'eagle':
213-
# TODO: make eagle3 compatible with cudagraph.
214-
self.hidden_states[:batch_size] = hidden_states
215-
hidden_states = self.hidden_states
212+
self.hidden_states[:batch_size] = hidden_states
216213

217214
# Run the model.
218215
with set_forward_context(attn_metadata,
@@ -221,7 +218,7 @@ def propose(
221218
last_hidden_states, hidden_states = self.model(
222219
input_ids=self.input_ids[:input_batch_size],
223220
positions=self.positions[:input_batch_size],
224-
hidden_states=hidden_states[:input_batch_size],
221+
hidden_states=self.hidden_states[:input_batch_size],
225222
)
226223
hidden_states = hidden_states[:batch_size]
227224
logits = self.model.compute_logits(last_hidden_states[:batch_size],
@@ -314,12 +311,11 @@ def dummy_run(
314311
) -> None:
315312
with set_forward_context(None, self.vllm_config,
316313
num_tokens=num_tokens):
317-
if self.method == 'eagle':
318-
self.model(
319-
input_ids=self.input_ids[:num_tokens],
320-
positions=self.positions[:num_tokens],
321-
hidden_states=self.hidden_states[:num_tokens],
322-
)
314+
self.model(
315+
input_ids=self.input_ids[:num_tokens],
316+
positions=self.positions[:num_tokens],
317+
hidden_states=self.hidden_states[:num_tokens],
318+
)
323319

324320

325321
# NOTE(woosuk): Currently, the below code is not used and we always use argmax

0 commit comments

Comments
 (0)