diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index 4318bd52b9..f71712338b 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -63,6 +63,7 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module: # would cause JAX init failure when using multi hosts with Ray. from tpu_inference.models.jax.deepseek_v3 import DeepseekV3ForCausalLM + from tpu_inference.models.jax.dflash import DFlashForCausalLM from tpu_inference.models.jax.gpt_oss import GptOss from tpu_inference.models.jax.llama3 import LlamaForCausalLM from tpu_inference.models.jax.llama4 import Llama4ForCausalLM @@ -73,6 +74,7 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module: Qwen2_5_VLForConditionalGeneration from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM from tpu_inference.models.jax.qwen3_moe import Qwen3MoeForCausalLM + _MODEL_REGISTRY["DFlashDraftModel"] = DFlashForCausalLM _MODEL_REGISTRY["Llama4ForCausalLM"] = Llama4ForCausalLM _MODEL_REGISTRY["DeepseekV3ForCausalLM"] = DeepseekV3ForCausalLM _MODEL_REGISTRY["LlamaForCausalLM"] = LlamaForCausalLM diff --git a/tpu_inference/models/jax/qwen3.py b/tpu_inference/models/jax/qwen3.py index ba6d1038a0..f56f9aba6d 100644 --- a/tpu_inference/models/jax/qwen3.py +++ b/tpu_inference/models/jax/qwen3.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from itertools import islice from typing import List, Optional, Tuple import jax @@ -246,6 +247,27 @@ def __init__(self, ) +def _build_target_layer_ids(num_target_layers: int, + num_draft_layers: int) -> list[int]: + if num_draft_layers == 1: + return [num_target_layers // 2] + return [ + round(1 + i * (num_target_layers - 4) / (num_draft_layers - 1)) + for i in range(num_draft_layers) + ] + + +def _get_dflash_target_layer_ids(target_num_layers: int, + draft_hf_config) -> list[int]: + dflash_cfg = getattr(draft_hf_config, "dflash_config", None) + if dflash_cfg is not None: + target_layer_ids = dflash_cfg.get("target_layer_ids", None) + if target_layer_ids is not None: + return list(target_layer_ids) + num_draft_layers = draft_hf_config.num_hidden_layers + return _build_target_layer_ids(target_num_layers, num_draft_layers) + + class Qwen3Model(Qwen2Model): def __init__(self, @@ -302,6 +324,65 @@ def __init__(self, else: self.norm = PPMissingLayer() + self._init_aux_hidden_state_layers(vllm_config) + + def _init_aux_hidden_state_layers(self, vllm_config): + self.aux_hidden_state_layers = [] + self.capture_aux_after_layer = False + if vllm_config.speculative_config: + method = getattr(vllm_config.speculative_config, "method", None) + if method == "eagle3": + num_layers = len(self.layers) + self.aux_hidden_state_layers = (2, num_layers // 2, + num_layers - 3) + elif method == "dflash": + draft_config = ( + vllm_config.speculative_config.draft_model_config) + dflash_cfg = getattr(draft_config.hf_config, "dflash_config", + {}) + target_layer_ids = dflash_cfg.get("target_layer_ids", None) + if target_layer_ids is not None: + self.aux_hidden_state_layers = tuple(target_layer_ids) + else: + num_target = getattr(draft_config.hf_config, + "num_target_layers", 5) + num_layers = len(self.layers) + step = max(1, (num_layers - 4) // (num_target - 1)) + self.aux_hidden_state_layers = tuple( + range(1, num_layers - 2, step))[:num_target] + self.capture_aux_after_layer = True + + def __call__( + self, + kv_caches: List[jax.Array], + input_ids: Optional[jax.Array], + attention_metadata, + inputs_embeds: Optional[jax.Array] = None, + ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]: + if inputs_embeds is not None: + x = inputs_embeds + else: + x = self.embed_tokens(input_ids) + + aux_hidden_states = [] + for i, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer)): + if (not self.capture_aux_after_layer + and i in self.aux_hidden_state_layers): + aux_hidden_states.append(x) + kv_cache = kv_caches[i] + kv_cache, x = layer( + kv_cache, + x, + attention_metadata, + ) + kv_caches[i] = kv_cache + if (self.capture_aux_after_layer + and i in self.aux_hidden_state_layers): + aux_hidden_states.append(x) + x = self.norm(x) + return kv_caches, x, aux_hidden_states + class Qwen3ForCausalLM(JaxModule, LoadableWithIterator): @@ -362,7 +443,7 @@ def __call__( if not is_first_rank: assert intermediate_tensors is not None inputs_embeds = intermediate_tensors["hidden_states"] - kv_caches, x = self.model( + kv_caches, x, aux_hidden_states = self.model( kv_caches, input_ids, attention_metadata, @@ -370,7 +451,7 @@ def __call__( ) if not is_last_rank: x = JaxIntermediateTensors(tensors={"hidden_states": x}, ) - return kv_caches, x, [] + return kv_caches, x, aux_hidden_states def compute_logits(self, hidden_states: jax.Array) -> jax.Array: if hasattr(self, 'lm_head'): diff --git a/tpu_inference/runner/kv_cache_manager.py b/tpu_inference/runner/kv_cache_manager.py index b0941a6f3d..8edb79ff6b 100644 --- a/tpu_inference/runner/kv_cache_manager.py +++ b/tpu_inference/runner/kv_cache_manager.py @@ -156,7 +156,8 @@ def get_kv_cache_spec(self): head_size, sliding_window=sliding_window) - if self.runner.speculative_config and self.runner.speculative_config.method == "eagle3": + if self.runner.speculative_config and self.runner.speculative_config.method in ( + "eagle3", "dflash"): draft_model_config = self.runner.speculative_config.draft_model_config hf_config = draft_model_config.hf_config num_kv_heads = common_utils.get_padded_num_heads( @@ -164,7 +165,8 @@ def get_kv_cache_spec(self): head_size = common_utils.get_padded_head_dim( hf_config.hidden_size // hf_config.num_attention_heads) # Eagle3 has only 1 layer - for i in range(1): + draft_num_layers = getattr(hf_config, 'num_hidden_layers', 1) + for i in range(draft_num_layers): if self.use_mla: kv_cache_spec[ f"draft_layer.{i}"] = self._create_attention_spec( diff --git a/tpu_inference/runner/speculative_decoding_manager.py b/tpu_inference/runner/speculative_decoding_manager.py index 08ee78e9f9..0939919ae3 100644 --- a/tpu_inference/runner/speculative_decoding_manager.py +++ b/tpu_inference/runner/speculative_decoding_manager.py @@ -14,7 +14,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, replace from typing import TYPE_CHECKING, Optional import jax.numpy as jnp @@ -83,6 +83,15 @@ def propose_draft_token_ids( scheduler_output, input_ids, ) + elif self.runner.speculative_config.method == "dflash": + self._draft_token_ids = self.propose_dflash_draft_token_ids( + sampled_token_ids, + aux_hidden_states, + attn_metadata, + spec_decode_metadata, + scheduler_output, + input_ids, + ) else: raise NotImplementedError( f"Speculative decoding method " @@ -159,6 +168,88 @@ def propose_eagle3_draft_token_ids( draft_token_ids = np.expand_dims(draft_token_ids, axis=-1) return draft_token_ids.tolist() + def propose_dflash_draft_token_ids( + self, + sampled_token_ids: list[list[int]], + aux_hidden_states: Optional[tuple[jnp.ndarray, ...]], + attn_metadata: AttentionMetadata, + spec_decode_metadata: Optional[SpecDecodeMetadata], + scheduler_output: VllmSchedulerOutput, + input_ids: jnp.ndarray, + ) -> list[list[int]]: + # TODO(woosuk): Refactor the loop. + req_ids = self.runner.input_batch.req_ids + next_token_ids: list[int] = [] + for i, token_ids in enumerate(sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = req_ids[i] + req_state = self.runner.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + + # Pad the batch size to match with existing padding for target model + pad_len = attn_metadata.seq_lens.shape[0] - len(next_token_ids) + assert pad_len >= 0 + next_token_ids += [0] * pad_len + + next_token_ids = device_array( + self.runner.mesh, np.array(next_token_ids, dtype=jnp.int32)) + + if spec_decode_metadata is None: + num_rejected_tokens = None + else: + num_draft_tokens = spec_decode_metadata.draft_lengths_cpu + num_rejected_tokens = [ + int(n) + 1 - len(sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + + pad_len = self.runner.max_num_reqs - len(num_rejected_tokens) + num_rejected_tokens += [0] * pad_len + num_rejected_tokens = device_array( + self.runner.mesh, np.array(num_rejected_tokens, + dtype=jnp.int32)) + + accepted_seq_lens = self.runner.input_batch.num_tokens_no_spec[: + attn_metadata + . + seq_lens + .shape[ + 0]].copy( + ) + accepted_attn_metadata = replace( + attn_metadata, + seq_lens=device_array(self.runner.mesh, + accepted_seq_lens.astype(np.int32)), + ) + + target_hidden_states, input_ids, last_token_indices, attn_metadata = self.runner.drafter.prepare_inputs( + accepted_attn_metadata, + input_ids, + aux_hidden_states, + next_token_ids, + num_rejected_tokens, + ) + + self.runner.kv_caches, draft_token_ids = self.runner.drafter.propose( + kv_caches=self.runner.kv_caches, + input_ids=input_ids, + attn_metadata=attn_metadata, + last_token_indices=last_token_indices, + target_hidden_states=target_hidden_states, + ) + draft_token_ids = np.array(draft_token_ids) + if draft_token_ids.ndim == 1: + draft_token_ids = np.expand_dims(draft_token_ids, axis=-1) + return draft_token_ids.tolist() + def get_spec_decode_metadata( self, num_draft_tokens: np.ndarray, diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 2f3a5893ed..83e01f9b13 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -411,6 +411,9 @@ def _init_speculative_decoding(self) -> None: self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.method == "eagle3": self.drafter = Eagle3Proposer(self.vllm_config, self) + elif self.speculative_config.method == "dflash": + from tpu_inference.spec_decode.jax.dflash import DFlashProposer + self.drafter = DFlashProposer(self.vllm_config, self) else: raise NotImplementedError( "Unsupported speculative decoding method: "