File tree Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -1049,6 +1049,7 @@ def forward(
10491049 position_ids : Optional [torch .IntTensor ] = None ,
10501050 inputs_embeds : Optional [torch .FloatTensor ] = None ,
10511051 spec_metadata : Optional [SpecMetadata ] = None ,
1052+ ** kwargs ,
10521053 ) -> torch .Tensor :
10531054 if (input_ids is None ) ^ (inputs_embeds is not None ):
10541055 raise ValueError (
@@ -1145,6 +1146,7 @@ def forward(
11451146 )
11461147
11471148 if spec_metadata and spec_metadata .spec_dec_mode .is_mtp ():
1149+ # TODO Merge API with EagleWorker in modeling_speculative.py
11481150 # get logits
11491151 logits = self .logits_processor .forward (
11501152 hidden_states [spec_metadata .gather_ids ],
Original file line number Diff line number Diff line change @@ -186,6 +186,8 @@ def get_hidden_states(self):
186186class Eagle3OneModelSpecMetadata (SpecMetadata ):
187187 # The hidden states
188188 hidden_states : Optional [torch .Tensor ] = None
189+ # The number of layers to be captured
190+ num_capture_layers : int = 3
189191 # The layers to be captured
190192 layers_to_capture : Tuple [int , ...] = field (init = False )
191193 # The hidden size of the hidden states
@@ -198,8 +200,8 @@ class Eagle3OneModelSpecMetadata(SpecMetadata):
198200 batch_indices_cuda : Optional [torch .Tensor ] = None
199201
200202 def __post_init__ (self ):
201- if self .num_layers == 1 :
202- self .layers_to_capture = (1 , )
203+ if self .num_layers == 1 or self . num_capture_layers == 1 :
204+ self .layers_to_capture = (self . num_layers - 1 , )
203205 else :
204206 if self .num_layers <= 5 :
205207 raise ValueError ("Not enough hidden layers for EAGLE" )
You can’t perform that action at this time.
0 commit comments