Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tpu_inference/models/common/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
# would cause JAX init failure when using multi hosts with Ray.

from tpu_inference.models.jax.deepseek_v3 import DeepseekV3ForCausalLM
from tpu_inference.models.jax.dflash import DFlashForCausalLM
from tpu_inference.models.jax.gpt_oss import GptOss
from tpu_inference.models.jax.llama3 import LlamaForCausalLM
from tpu_inference.models.jax.llama4 import Llama4ForCausalLM
Expand All @@ -73,6 +74,7 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
Qwen2_5_VLForConditionalGeneration
from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM
from tpu_inference.models.jax.qwen3_moe import Qwen3MoeForCausalLM
_MODEL_REGISTRY["DFlashDraftModel"] = DFlashForCausalLM
_MODEL_REGISTRY["Llama4ForCausalLM"] = Llama4ForCausalLM
_MODEL_REGISTRY["DeepseekV3ForCausalLM"] = DeepseekV3ForCausalLM
_MODEL_REGISTRY["LlamaForCausalLM"] = LlamaForCausalLM
Expand Down
85 changes: 83 additions & 2 deletions tpu_inference/models/jax/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from itertools import islice
from typing import List, Optional, Tuple

import jax
Expand Down Expand Up @@ -246,6 +247,27 @@ def __init__(self,
)


def _build_target_layer_ids(num_target_layers: int,
num_draft_layers: int) -> list[int]:
if num_draft_layers == 1:
return [num_target_layers // 2]
return [
round(1 + i * (num_target_layers - 4) / (num_draft_layers - 1))
for i in range(num_draft_layers)
]


def _get_dflash_target_layer_ids(target_num_layers: int,
draft_hf_config) -> list[int]:
dflash_cfg = getattr(draft_hf_config, "dflash_config", None)
if dflash_cfg is not None:
target_layer_ids = dflash_cfg.get("target_layer_ids", None)
if target_layer_ids is not None:
return list(target_layer_ids)
num_draft_layers = draft_hf_config.num_hidden_layers
return _build_target_layer_ids(target_num_layers, num_draft_layers)


class Qwen3Model(Qwen2Model):

def __init__(self,
Expand Down Expand Up @@ -302,6 +324,65 @@ def __init__(self,
else:
self.norm = PPMissingLayer()

self._init_aux_hidden_state_layers(vllm_config)

def _init_aux_hidden_state_layers(self, vllm_config):
self.aux_hidden_state_layers = []
self.capture_aux_after_layer = False
if vllm_config.speculative_config:
method = getattr(vllm_config.speculative_config, "method", None)
if method == "eagle3":
num_layers = len(self.layers)
self.aux_hidden_state_layers = (2, num_layers // 2,
num_layers - 3)
elif method == "dflash":
draft_config = (
vllm_config.speculative_config.draft_model_config)
dflash_cfg = getattr(draft_config.hf_config, "dflash_config",
{})
target_layer_ids = dflash_cfg.get("target_layer_ids", None)
if target_layer_ids is not None:
self.aux_hidden_state_layers = tuple(target_layer_ids)
else:
num_target = getattr(draft_config.hf_config,
"num_target_layers", 5)
num_layers = len(self.layers)
step = max(1, (num_layers - 4) // (num_target - 1))
self.aux_hidden_state_layers = tuple(
range(1, num_layers - 2, step))[:num_target]
self.capture_aux_after_layer = True

def __call__(
self,
kv_caches: List[jax.Array],
input_ids: Optional[jax.Array],
attention_metadata,
inputs_embeds: Optional[jax.Array] = None,
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
if inputs_embeds is not None:
x = inputs_embeds
else:
x = self.embed_tokens(input_ids)

aux_hidden_states = []
for i, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)):
if (not self.capture_aux_after_layer
and i in self.aux_hidden_state_layers):
aux_hidden_states.append(x)
kv_cache = kv_caches[i]
kv_cache, x = layer(
kv_cache,
x,
attention_metadata,
)
kv_caches[i] = kv_cache
if (self.capture_aux_after_layer
and i in self.aux_hidden_state_layers):
aux_hidden_states.append(x)
x = self.norm(x)
return kv_caches, x, aux_hidden_states


class Qwen3ForCausalLM(JaxModule, LoadableWithIterator):

Expand Down Expand Up @@ -362,15 +443,15 @@ def __call__(
if not is_first_rank:
assert intermediate_tensors is not None
inputs_embeds = intermediate_tensors["hidden_states"]
kv_caches, x = self.model(
kv_caches, x, aux_hidden_states = self.model(
kv_caches,
input_ids,
attention_metadata,
inputs_embeds,
)
if not is_last_rank:
x = JaxIntermediateTensors(tensors={"hidden_states": x}, )
return kv_caches, x, []
return kv_caches, x, aux_hidden_states

def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
if hasattr(self, 'lm_head'):
Expand Down
6 changes: 4 additions & 2 deletions tpu_inference/runner/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,17 @@ def get_kv_cache_spec(self):
head_size,
sliding_window=sliding_window)

if self.runner.speculative_config and self.runner.speculative_config.method == "eagle3":
if self.runner.speculative_config and self.runner.speculative_config.method in (
"eagle3", "dflash"):
draft_model_config = self.runner.speculative_config.draft_model_config
hf_config = draft_model_config.hf_config
num_kv_heads = common_utils.get_padded_num_heads(
hf_config.num_key_value_heads, model_cnt)
head_size = common_utils.get_padded_head_dim(
hf_config.hidden_size // hf_config.num_attention_heads)
# Eagle3 has only 1 layer
for i in range(1):
draft_num_layers = getattr(hf_config, 'num_hidden_layers', 1)
for i in range(draft_num_layers):
if self.use_mla:
kv_cache_spec[
f"draft_layer.{i}"] = self._create_attention_spec(
Expand Down
93 changes: 92 additions & 1 deletion tpu_inference/runner/speculative_decoding_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Optional

import jax.numpy as jnp
Expand Down Expand Up @@ -83,6 +83,15 @@ def propose_draft_token_ids(
scheduler_output,
input_ids,
)
elif self.runner.speculative_config.method == "dflash":
self._draft_token_ids = self.propose_dflash_draft_token_ids(
sampled_token_ids,
aux_hidden_states,
attn_metadata,
spec_decode_metadata,
scheduler_output,
input_ids,
)
else:
raise NotImplementedError(
f"Speculative decoding method "
Expand Down Expand Up @@ -159,6 +168,88 @@ def propose_eagle3_draft_token_ids(
draft_token_ids = np.expand_dims(draft_token_ids, axis=-1)
return draft_token_ids.tolist()

def propose_dflash_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
aux_hidden_states: Optional[tuple[jnp.ndarray, ...]],
attn_metadata: AttentionMetadata,
spec_decode_metadata: Optional[SpecDecodeMetadata],
scheduler_output: VllmSchedulerOutput,
input_ids: jnp.ndarray,
) -> list[list[int]]:
# TODO(woosuk): Refactor the loop.
req_ids = self.runner.input_batch.req_ids
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = req_ids[i]
req_state = self.runner.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)

# Pad the batch size to match with existing padding for target model
pad_len = attn_metadata.seq_lens.shape[0] - len(next_token_ids)
assert pad_len >= 0
next_token_ids += [0] * pad_len

next_token_ids = device_array(
self.runner.mesh, np.array(next_token_ids, dtype=jnp.int32))

if spec_decode_metadata is None:
num_rejected_tokens = None
else:
num_draft_tokens = spec_decode_metadata.draft_lengths_cpu
num_rejected_tokens = [
int(n) + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]

pad_len = self.runner.max_num_reqs - len(num_rejected_tokens)
num_rejected_tokens += [0] * pad_len
num_rejected_tokens = device_array(
self.runner.mesh, np.array(num_rejected_tokens,
dtype=jnp.int32))

accepted_seq_lens = self.runner.input_batch.num_tokens_no_spec[:
attn_metadata
.
seq_lens
.shape[
0]].copy(
)
accepted_attn_metadata = replace(
attn_metadata,
seq_lens=device_array(self.runner.mesh,
accepted_seq_lens.astype(np.int32)),
)

target_hidden_states, input_ids, last_token_indices, attn_metadata = self.runner.drafter.prepare_inputs(
accepted_attn_metadata,
input_ids,
aux_hidden_states,
next_token_ids,
num_rejected_tokens,
)

self.runner.kv_caches, draft_token_ids = self.runner.drafter.propose(
kv_caches=self.runner.kv_caches,
input_ids=input_ids,
attn_metadata=attn_metadata,
last_token_indices=last_token_indices,
target_hidden_states=target_hidden_states,
)
draft_token_ids = np.array(draft_token_ids)
if draft_token_ids.ndim == 1:
draft_token_ids = np.expand_dims(draft_token_ids, axis=-1)
return draft_token_ids.tolist()

def get_spec_decode_metadata(
self,
num_draft_tokens: np.ndarray,
Expand Down
3 changes: 3 additions & 0 deletions tpu_inference/runner/tpu_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,9 @@ def _init_speculative_decoding(self) -> None:
self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.method == "eagle3":
self.drafter = Eagle3Proposer(self.vllm_config, self)
elif self.speculative_config.method == "dflash":
from tpu_inference.spec_decode.jax.dflash import DFlashProposer
self.drafter = DFlashProposer(self.vllm_config, self)
else:
raise NotImplementedError(
"Unsupported speculative decoding method: "
Expand Down