diff --git a/tests/models/jax/test_qwen3_dflash.py b/tests/models/jax/test_qwen3_dflash.py new file mode 100644 index 0000000000..fa819faeba --- /dev/null +++ b/tests/models/jax/test_qwen3_dflash.py @@ -0,0 +1,51 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace + +from tpu_inference.models.jax.qwen3 import \ + _get_dflash_target_layer_ids as get_target_layer_ids_for_qwen3 +from tpu_inference.models.jax.qwen3_dflash import _build_target_layer_ids +from tpu_inference.models.jax.qwen3_dflash import \ + _get_dflash_target_layer_ids as get_target_layer_ids_for_qwen3_dflash + + +def test_build_target_layer_ids_default_layout(): + assert _build_target_layer_ids(32, 1) == [16] + assert _build_target_layer_ids(32, 4) == [1, 10, 20, 29] + + +def test_get_target_layer_ids_prefers_explicit_config(): + cfg = SimpleNamespace( + dflash_config={"target_layer_ids": [2, 6, 10]}, + num_target_layers=32, + num_hidden_layers=3, + ) + + assert get_target_layer_ids_for_qwen3_dflash(cfg, 32) == [2, 6, 10] + assert get_target_layer_ids_for_qwen3(32, cfg) == [2, 6, 10] + + +def test_get_target_layer_ids_fallback_matches_between_modules(): + cfg = SimpleNamespace( + dflash_config=None, + num_target_layers=32, + num_hidden_layers=3, + ) + + dflash_ids = get_target_layer_ids_for_qwen3_dflash(cfg, 32) + qwen3_ids = get_target_layer_ids_for_qwen3(32, cfg) + + assert dflash_ids == [1, 15, 29] + assert qwen3_ids == dflash_ids diff --git a/tests/models/jax/test_qwen3_dflash_attention.py b/tests/models/jax/test_qwen3_dflash_attention.py new file mode 100644 index 0000000000..5b9cb8a3d5 --- /dev/null +++ b/tests/models/jax/test_qwen3_dflash_attention.py @@ -0,0 +1,289 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from tpu_inference.layers.common.attention_metadata import AttentionMetadata +from tpu_inference.layers.common.dflash_attention_interface import \ + dflash_concat_attention +from tpu_inference.models.jax.qwen3_dflash import Qwen3DFlashAttention + + +def _make_attention_metadata(query_start_loc: list[int]) -> AttentionMetadata: + query_start_loc = np.asarray(query_start_loc, dtype=np.int32) + seq_lens = np.diff(query_start_loc) + total_tokens = int(query_start_loc[-1]) + return AttentionMetadata( + input_positions=jnp.arange(total_tokens, dtype=jnp.int32), + block_tables=jnp.zeros((max(1, total_tokens), ), dtype=jnp.int32), + seq_lens=jnp.asarray(seq_lens, dtype=jnp.int32), + query_start_loc=jnp.asarray(query_start_loc, dtype=jnp.int32), + request_distribution=jnp.asarray([0, 0, len(seq_lens)], + dtype=jnp.int32), + ) + + +def _dense_reference_attention( + q: jax.Array, + k: jax.Array, + v: jax.Array, + *, + sm_scale: float, +) -> jax.Array: + logits = jnp.einsum("qnh,knh->nqk", q.astype(jnp.float32), + k.astype(jnp.float32)) + logits = logits * sm_scale + probs = jax.nn.softmax(logits, axis=-1).astype(v.dtype) + return jnp.einsum("nqk,knh->qnh", probs, v) + + +def _build_stub_attention(impl: str) -> Qwen3DFlashAttention: + attention = object.__new__(Qwen3DFlashAttention) + attention.q_proj = lambda x: x + attention.q_norm = lambda x: x + attention.k_proj = lambda x: x + attention.k_norm = lambda x: x + attention.v_proj = lambda x: x + attention.o_proj = lambda x: x + attention.head_dim_original = 1 + attention.rope_theta = 10000.0 + attention.rope_scaling = None + attention.mesh = object() + attention.dflash_attention_impl = impl + attention.max_query_len = 2 + attention.kv_cache_quantized_dtype = None + attention._k_scale = 1.0 + attention._v_scale = 1.0 + return attention + + +def test_dflash_concat_attention_matches_concat_reference(): + q = jnp.array([[[1.0]], [[2.0]]], dtype=jnp.float32) + k_ctx = jnp.array([[[1.0]], [[0.0]]], dtype=jnp.float32) + k_noise = jnp.array([[[0.0]], [[1.0]]], dtype=jnp.float32) + v_ctx = jnp.array([[[10.0]], [[20.0]]], dtype=jnp.float32) + v_noise = jnp.array([[[100.0]], [[200.0]]], dtype=jnp.float32) + md = _make_attention_metadata([0, 2]) + sm_scale = 1.0 + + output = dflash_concat_attention( + q, + k_ctx, + k_noise, + v_ctx, + v_noise, + md, + max_query_len=2, + sm_scale=sm_scale, + ) + + k_cat = jnp.concatenate([k_ctx, k_noise], axis=0) + v_cat = jnp.concatenate([v_ctx, v_noise], axis=0) + expected = _dense_reference_attention(q, k_cat, v_cat, sm_scale=sm_scale) + np.testing.assert_allclose(np.asarray(output), + np.asarray(expected), + rtol=1e-5, + atol=1e-5) + + additive_expected = _dense_reference_attention(q, + k_ctx + k_noise, + v_ctx + v_noise, + sm_scale=sm_scale) + assert not np.allclose(np.asarray(output), np.asarray(additive_expected)) + + +def test_dflash_concat_attention_repeats_kv_heads_for_gqa(): + q = jnp.array( + [ + [[1.0], [2.0]], + [[3.0], [4.0]], + ], + dtype=jnp.float32, + ) + k_ctx = jnp.array([[[0.5]], [[1.5]]], dtype=jnp.float32) + k_noise = jnp.array([[[2.0]], [[3.0]]], dtype=jnp.float32) + v_ctx = jnp.array([[[11.0]], [[13.0]]], dtype=jnp.float32) + v_noise = jnp.array([[[17.0]], [[19.0]]], dtype=jnp.float32) + md = _make_attention_metadata([0, 2]) + + output = dflash_concat_attention( + q, + k_ctx, + k_noise, + v_ctx, + v_noise, + md, + max_query_len=2, + sm_scale=1.0, + ) + + k_cat = jnp.concatenate( + [jnp.repeat(k_ctx, 2, axis=1), + jnp.repeat(k_noise, 2, axis=1)], + axis=0, + ) + v_cat = jnp.concatenate( + [jnp.repeat(v_ctx, 2, axis=1), + jnp.repeat(v_noise, 2, axis=1)], + axis=0, + ) + expected = _dense_reference_attention(q, k_cat, v_cat, sm_scale=1.0) + np.testing.assert_allclose(np.asarray(output), + np.asarray(expected), + rtol=1e-5, + atol=1e-5) + + +def test_qwen3_dflash_attention_concat_impl(monkeypatch): + md = _make_attention_metadata([0, 2]) + hidden_states = jnp.array([[[1.0]], [[2.0]]], dtype=jnp.float32) + target_hidden_states = jnp.array([[[3.0]], [[4.0]]], dtype=jnp.float32) + kv_cache = jnp.array([0.0], dtype=jnp.float32) + + concat_calls = {} + cache_update_calls = {} + + def fake_concat_attention( + q, + k_ctx, + k_noise, + v_ctx, + v_noise, + _md, + *, + max_query_len, + sm_scale, + ): + concat_calls["q"] = np.asarray(q) + concat_calls["k_ctx"] = np.asarray(k_ctx) + concat_calls["k_noise"] = np.asarray(k_noise) + concat_calls["v_ctx"] = np.asarray(v_ctx) + concat_calls["v_noise"] = np.asarray(v_noise) + concat_calls["max_query_len"] = max_query_len + concat_calls["sm_scale"] = sm_scale + return jnp.full_like(q, 7.0) + + def fake_attention( + kv_cache, + q, + k, + v, + _md, + _mesh, + _head_dim_original, + **_kwargs, + ): + cache_update_calls["q"] = np.asarray(q) + cache_update_calls["k"] = np.asarray(k) + cache_update_calls["v"] = np.asarray(v) + return kv_cache + 1.0, jnp.full_like(q, -5.0) + + monkeypatch.setattr("tpu_inference.models.jax.qwen3_dflash.apply_rope", + lambda x, *_args, **_kwargs: x) + monkeypatch.setattr( + "tpu_inference.models.jax.qwen3_dflash.dflash_concat_attention", + fake_concat_attention) + monkeypatch.setattr("tpu_inference.models.jax.qwen3_dflash.attention", + fake_attention) + + attention = _build_stub_attention("concat_dense") + new_kv_cache, output = attention( + kv_cache=kv_cache, + hidden_states=hidden_states, + target_hidden_states=target_hidden_states, + attention_metadata=md, + ) + + np.testing.assert_allclose(np.asarray(output), np.full((2, 1, 1), 7.0)) + np.testing.assert_allclose(np.asarray(new_kv_cache), np.array([1.0])) + np.testing.assert_allclose(concat_calls["k_ctx"], + np.asarray(target_hidden_states)) + np.testing.assert_allclose(concat_calls["k_noise"], + np.asarray(hidden_states)) + np.testing.assert_allclose(cache_update_calls["k"], + np.asarray(hidden_states)) + np.testing.assert_allclose(cache_update_calls["v"], + np.asarray(hidden_states)) + assert concat_calls["max_query_len"] == 2 + + +def test_qwen3_dflash_attention_additive_legacy_impl(monkeypatch): + md = _make_attention_metadata([0, 2]) + hidden_states = jnp.array([[[1.0]], [[2.0]]], dtype=jnp.float32) + target_hidden_states = jnp.array([[[3.0]], [[4.0]]], dtype=jnp.float32) + kv_cache = jnp.array([0.0], dtype=jnp.float32) + + calls = {} + + def fake_attention( + kv_cache, + q, + k, + v, + _md, + _mesh, + _head_dim_original, + **_kwargs, + ): + calls["q"] = np.asarray(q) + calls["k"] = np.asarray(k) + calls["v"] = np.asarray(v) + return kv_cache + 2.0, jnp.full_like(q, 3.0) + + def fail_concat(*_args, **_kwargs): + raise AssertionError("concat path should not run for additive_legacy") + + monkeypatch.setattr("tpu_inference.models.jax.qwen3_dflash.apply_rope", + lambda x, *_args, **_kwargs: x) + monkeypatch.setattr( + "tpu_inference.models.jax.qwen3_dflash.dflash_concat_attention", + fail_concat) + monkeypatch.setattr("tpu_inference.models.jax.qwen3_dflash.attention", + fake_attention) + + attention = _build_stub_attention("additive_legacy") + new_kv_cache, output = attention( + kv_cache=kv_cache, + hidden_states=hidden_states, + target_hidden_states=target_hidden_states, + attention_metadata=md, + ) + + expected_k = np.asarray(target_hidden_states + hidden_states) + np.testing.assert_allclose(calls["k"], expected_k) + np.testing.assert_allclose(calls["v"], expected_k) + np.testing.assert_allclose(np.asarray(output), np.full((2, 1, 1), 3.0)) + np.testing.assert_allclose(np.asarray(new_kv_cache), np.array([2.0])) + + +def test_qwen3_dflash_attention_unknown_impl_raises(monkeypatch): + md = _make_attention_metadata([0, 2]) + hidden_states = jnp.array([[[1.0]], [[2.0]]], dtype=jnp.float32) + target_hidden_states = jnp.array([[[3.0]], [[4.0]]], dtype=jnp.float32) + kv_cache = jnp.array([0.0], dtype=jnp.float32) + + monkeypatch.setattr("tpu_inference.models.jax.qwen3_dflash.apply_rope", + lambda x, *_args, **_kwargs: x) + + attention = _build_stub_attention("bad_impl") + with pytest.raises(ValueError, match="Unsupported"): + attention( + kv_cache=kv_cache, + hidden_states=hidden_states, + target_hidden_states=target_hidden_states, + attention_metadata=md, + ) diff --git a/tests/spec_decode/test_dflash.py b/tests/spec_decode/test_dflash.py new file mode 100644 index 0000000000..c59ffadc62 --- /dev/null +++ b/tests/spec_decode/test_dflash.py @@ -0,0 +1,68 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +import numpy as np + +from tpu_inference.spec_decode.jax.dflash import DFlashProposer + + +def _make_single_device_mesh() -> jax.sharding.Mesh: + devices = np.array(jax.devices()[:1]) + return jax.sharding.Mesh(devices, axis_names=("model", )) + + +def test_sample_block_draft_tokens_uses_target_model_logits(): + proposer = object.__new__(DFlashProposer) + proposer.mesh = _make_single_device_mesh() + proposer.num_speculative_tokens = 2 + + call_record = {} + # Use a JAX array as dummy state (JAX tracing requires array-like args) + target_state = jnp.array(0) + + def fake_compute_logits_fn(state, hidden_states, lora_metadata): + call_record["shape"] = hidden_states.shape + return jnp.array([[0.0, 2.0, 1.0], [4.0, 1.0, 0.0]], dtype=jnp.float32) + + proposer.compute_logits_fn = fake_compute_logits_fn + + # hidden_states layout: [context_token, draft_token_0, draft_token_1, ...] + # _sample_block_draft_tokens slices [1:1+num_speculative_tokens] + hidden_states = jnp.ones((3, 8), dtype=jnp.bfloat16) + draft_token_ids = proposer._sample_block_draft_tokens( + target_state, hidden_states) + + np.testing.assert_array_equal(np.asarray(draft_token_ids), + np.array([1, 0], dtype=np.int32)) + assert call_record["shape"] == (2, 8) + + +def test_sample_block_draft_tokens_returns_1d_int_ids(): + proposer = object.__new__(DFlashProposer) + proposer.mesh = _make_single_device_mesh() + proposer.num_speculative_tokens = 2 + + proposer.compute_logits_fn = lambda _state, _hidden, _lora: jnp.array( + [[1.0, 0.0], [0.0, 1.0]], dtype=jnp.float32) + + # 1 context + 2 draft positions + hidden_states = jnp.ones((3, 4), dtype=jnp.bfloat16) + draft_token_ids = proposer._sample_block_draft_tokens( + jnp.array(0), hidden_states) + + assert draft_token_ids.ndim == 1 + assert draft_token_ids.shape == (2, ) + assert jnp.issubdtype(draft_token_ids.dtype, jnp.integer) diff --git a/tpu_inference/layers/common/dflash_attention_interface.py b/tpu_inference/layers/common/dflash_attention_interface.py new file mode 100644 index 0000000000..def7784e9d --- /dev/null +++ b/tpu_inference/layers/common/dflash_attention_interface.py @@ -0,0 +1,137 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DFlash-specific attention helpers.""" + +from __future__ import annotations + +import functools + +import jax +import jax.numpy as jnp +from jax import lax + +from tpu_inference.layers.common.attention_metadata import AttentionMetadata + + +@functools.partial(jax.jit, static_argnames=("max_query_len", )) +def dflash_concat_attention( + q: jax.Array, # [T, N, H] + k_ctx: jax.Array, # [T, K, H] + k_noise: jax.Array, # [T, K, H] + v_ctx: jax.Array, # [T, K, H] + v_noise: jax.Array, # [T, K, H] + attention_metadata: AttentionMetadata, + *, + max_query_len: int, + sm_scale: float, +) -> jax.Array: + """Computes DFlash concat attention outputs for query tokens. + + This path follows DFlash semantics by concatenating context/noise keys and + values along token axis, while keeping query tokens as the noise stream. + """ + if max_query_len <= 0: + raise ValueError(f"{max_query_len=} must be positive.") + if not (q.shape[0] == k_ctx.shape[0] == k_noise.shape[0] == v_ctx.shape[0] + == v_noise.shape[0]): + raise ValueError( + "All DFlash attention streams must share the same token count.") + + num_tokens, num_heads, _ = q.shape + num_kv_heads = k_ctx.shape[1] + if num_heads % num_kv_heads != 0: + raise ValueError( + f"Expected num_heads divisible by num_kv_heads, got {num_heads=} {num_kv_heads=}" + ) + + # Expand KV heads to match query head count for GQA/MQA. + kv_repeat = num_heads // num_kv_heads + if kv_repeat > 1: + k_ctx = jnp.repeat(k_ctx, kv_repeat, axis=1) + k_noise = jnp.repeat(k_noise, kv_repeat, axis=1) + v_ctx = jnp.repeat(v_ctx, kv_repeat, axis=1) + v_noise = jnp.repeat(v_noise, kv_repeat, axis=1) + + # Pad so dynamic_slice_in_dim always has static size inside fori_loop. + pad_len = max_query_len + q = jnp.pad(q, ((0, pad_len), (0, 0), (0, 0))) + k_ctx = jnp.pad(k_ctx, ((0, pad_len), (0, 0), (0, 0))) + k_noise = jnp.pad(k_noise, ((0, pad_len), (0, 0), (0, 0))) + v_ctx = jnp.pad(v_ctx, ((0, pad_len), (0, 0), (0, 0))) + v_noise = jnp.pad(v_noise, ((0, pad_len), (0, 0), (0, 0))) + + # Per-request token offsets and lengths. + query_start_loc = attention_metadata.query_start_loc + req_lens = query_start_loc[1:] - query_start_loc[:-1] + if attention_metadata.request_distribution is not None: + num_reqs = jnp.minimum(attention_metadata.request_distribution[2], + req_lens.shape[0]) + else: + num_reqs = req_lens.shape[0] + + # KV range is 2x because context and noise are concatenated. + arange_q = jnp.arange(max_query_len) + arange_kv = jnp.arange(2 * max_query_len) + + # Large negative for masking out padding positions before softmax. + mask_value = -0.7 * float(jnp.finfo(jnp.float32).max) + outputs = jnp.zeros_like(q) + + def _body(i: int, current: jax.Array) -> jax.Array: + # Process one request: slice, concat ctx+noise KV, attend, write back. + start = query_start_loc[i] + req_len = req_lens[i] + req_len = jnp.clip(req_len, 0, max_query_len) + + q_blk = lax.dynamic_slice_in_dim(q, start, max_query_len, axis=0) + k_ctx_blk = lax.dynamic_slice_in_dim(k_ctx, + start, + max_query_len, + axis=0) + k_noise_blk = lax.dynamic_slice_in_dim(k_noise, + start, + max_query_len, + axis=0) + v_ctx_blk = lax.dynamic_slice_in_dim(v_ctx, + start, + max_query_len, + axis=0) + v_noise_blk = lax.dynamic_slice_in_dim(v_noise, + start, + max_query_len, + axis=0) + + # Concat context and noise KV along token axis. [2*max_query_len, N, H] + k_blk = jnp.concatenate([k_ctx_blk, k_noise_blk], axis=0) + v_blk = jnp.concatenate([v_ctx_blk, v_noise_blk], axis=0) + + # Mask out padding positions for both Q and KV. + q_valid = arange_q < req_len + kv_valid_len = jnp.maximum(2 * req_len, 1) + kv_valid = arange_kv < kv_valid_len + + logits = jnp.einsum("qnh,knh->nqk", q_blk.astype(jnp.float32), + k_blk.astype(jnp.float32)) + logits = logits * sm_scale + logits = jnp.where(kv_valid[None, None, :], logits, mask_value) + + probs = jax.nn.softmax(logits, axis=-1).astype(v_blk.dtype) + out_blk = jnp.einsum("nqk,knh->qnh", probs, v_blk) + out_blk = jnp.where(q_valid[:, None, None], out_blk, + jnp.zeros_like(out_blk)) + + return lax.dynamic_update_slice_in_dim(current, out_blk, start, axis=0) + + outputs = lax.fori_loop(0, num_reqs, _body, outputs) + return outputs[:num_tokens] diff --git a/tpu_inference/models/jax/dflash.py b/tpu_inference/models/jax/dflash.py new file mode 100644 index 0000000000..d99c7145ad --- /dev/null +++ b/tpu_inference/models/jax/dflash.py @@ -0,0 +1,599 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DFlash draft model for speculative decoding on JAX/TPU.""" + +from typing import List, Tuple + +import jax +import jax.numpy as jnp +from flax import nnx +from jax import lax +from jax.sharding import Mesh +from transformers import Qwen3Config +from vllm.config import VllmConfig + +from tpu_inference import utils +from tpu_inference.kernels.flash_attention.kernel import (BlockSizes, + SegmentIds, + flash_attention) +from tpu_inference.layers.common.sharding import ShardingAxisName +from tpu_inference.layers.jax.rope_interface import apply_rope +from tpu_inference.logger import init_logger +from tpu_inference.models.jax.utils.weight_utils import (BaseWeightLoader, + get_default_maps, + load_hf_weights) +from tpu_inference.utils import get_mesh_shape_product + +logger = init_logger(__name__) + +init_fn = nnx.initializers.uniform() + +# vmem budget for the flash_attention Pallas kernel (128 MiB). +_FA_VMEM_LIMIT = 128 * 1024 * 1024 + + +class DFlashAttention(nnx.Module): + """DFlash cross+self attention with on-device KV cache. + + Each call: + 1. Projects Q from noise embeddings, K/V from [context, noise]. + 2. Applies RoPE to Q and K. + 3. Expands K/V for GQA. + 4. Writes NEW K/V into the pre-allocated cache via dynamic_update_slice. + 5. Runs non-causal flash_attention over the full cache up to the valid + length, using segment_ids to mask padding. + """ + + def __init__( + self, + config: Qwen3Config, + dtype: jnp.dtype, + rng: nnx.Rngs, + mesh: Mesh, + ): + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.rope_theta = config.rope_theta + self.rope_scaling = getattr(config, "rope_scaling", None) + self.rms_norm_eps = config.rms_norm_eps + + self.head_dim_original = getattr(config, "head_dim", + self.hidden_size // self.num_heads) + self.head_dim = utils.get_padded_head_dim(self.head_dim_original) + + sharding_size = get_mesh_shape_product(mesh, + ShardingAxisName.MLP_TENSOR) + self.num_heads = utils.get_padded_num_heads(self.num_heads, + sharding_size) + self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads, + sharding_size) + self.num_kv_groups = self.num_heads // self.num_kv_heads + + self.mesh = mesh + + self.q_proj = nnx.Einsum( + "TD,DNH->TNH", + (self.hidden_size, self.num_heads, self.head_dim), + param_dtype=dtype, + kernel_init=nnx.with_partitioning( + init_fn, (None, ShardingAxisName.ATTN_HEAD, None)), + rngs=rng, + ) + self.k_proj = nnx.Einsum( + "TD,DKH->TKH", + (self.hidden_size, self.num_kv_heads, self.head_dim), + param_dtype=dtype, + kernel_init=nnx.with_partitioning( + init_fn, (None, ShardingAxisName.ATTN_HEAD, None)), + rngs=rng, + ) + self.v_proj = nnx.Einsum( + "TD,DKH->TKH", + (self.hidden_size, self.num_kv_heads, self.head_dim), + param_dtype=dtype, + kernel_init=nnx.with_partitioning( + init_fn, (None, ShardingAxisName.ATTN_HEAD, None)), + rngs=rng, + ) + self.o_proj = nnx.Einsum( + "TNH,NHD->TD", + (self.num_heads, self.head_dim, self.hidden_size), + param_dtype=dtype, + kernel_init=nnx.with_partitioning( + init_fn, (ShardingAxisName.ATTN_HEAD, None, None)), + rngs=rng, + ) + + self.q_norm = nnx.RMSNorm( + self.head_dim, + epsilon=self.rms_norm_eps, + param_dtype=dtype, + scale_init=nnx.with_partitioning(init_fn, (None, )), + rngs=rng, + ) + self.k_norm = nnx.RMSNorm( + self.head_dim, + epsilon=self.rms_norm_eps, + param_dtype=dtype, + scale_init=nnx.with_partitioning(init_fn, (None, )), + rngs=rng, + ) + + def __call__( + self, + x_noise: jax.Array, + target_hidden: jax.Array, + noise_positions: jax.Array, + ctx_positions: jax.Array, + kv_cache_k: jax.Array, + kv_cache_v: jax.Array, + cache_len: jax.Array, + actual_ctx_count: jax.Array, + ) -> Tuple[jax.Array, jax.Array, jax.Array]: + """Non-causal attention with on-device KV cache. + + Uses a two-phase cache write to handle padded context correctly: + Phase A: write context K/V (with padding zeroed) at ``cache_len``. + Phase B: write noise K/V at ``cache_len + actual_ctx_count``, + overwriting any padding zeros from Phase A. + + Args: + x_noise: (T_noise, D) noise hidden states. + target_hidden: (T_padded, D) padded context features. + noise_positions: (T_noise,) position ids for noise tokens. + ctx_positions: (T_padded,) position ids for context tokens. + kv_cache_k: (1, N_heads, max_kv_len, H) pre-allocated K cache. + kv_cache_v: (1, N_heads, max_kv_len, H) pre-allocated V cache. + cache_len: scalar int, valid entries already in cache. + actual_ctx_count: scalar int, real (non-padding) context tokens. + + Returns: + (output, new_kv_cache_k, new_kv_cache_v) + """ + T_noise = x_noise.shape[0] + T_padded = target_hidden.shape[0] + + q = self.q_proj(x_noise) + q = self.q_norm(q) + q = apply_rope( + q, + noise_positions, + self.head_dim_original, + self.rope_theta, + self.rope_scaling, + ) + + x_new = jnp.concatenate([target_hidden, x_noise], axis=0) + k_new = self.k_proj(x_new) + v_new = self.v_proj(x_new) + k_new = self.k_norm(k_new) + + new_positions = jnp.concatenate([ctx_positions, noise_positions], + axis=0) + k_new = apply_rope( + k_new, + new_positions, + self.head_dim_original, + self.rope_theta, + self.rope_scaling, + ) + + if self.num_kv_groups > 1: + k_new = jnp.repeat(k_new, self.num_kv_groups, axis=1) + v_new = jnp.repeat(v_new, self.num_kv_groups, axis=1) + + k_ctx = k_new[:T_padded] + v_ctx = v_new[:T_padded] + k_noise = k_new[T_padded:] + v_noise = v_new[T_padded:] + + ctx_mask = (jnp.arange(T_padded) < actual_ctx_count) # (T_padded,) + ctx_mask_kv = ctx_mask[:, jnp.newaxis, jnp.newaxis] # (T_padded, 1, 1) + k_ctx = jnp.where(ctx_mask_kv, k_ctx, 0.0) + v_ctx = jnp.where(ctx_mask_kv, v_ctx, 0.0) + + k_ctx_4d = k_ctx.transpose(1, 0, 2)[jnp.newaxis, :, :, :] + v_ctx_4d = v_ctx.transpose(1, 0, 2)[jnp.newaxis, :, :, :] + kv_cache_k = lax.dynamic_update_slice(kv_cache_k, k_ctx_4d, + (0, 0, cache_len, 0)) + kv_cache_v = lax.dynamic_update_slice(kv_cache_v, v_ctx_4d, + (0, 0, cache_len, 0)) + + noise_start = cache_len + actual_ctx_count + k_noise_4d = k_noise.transpose(1, 0, 2)[jnp.newaxis, :, :, :] + v_noise_4d = v_noise.transpose(1, 0, 2)[jnp.newaxis, :, :, :] + kv_cache_k = lax.dynamic_update_slice(kv_cache_k, k_noise_4d, + (0, 0, noise_start, 0)) + kv_cache_v = lax.dynamic_update_slice(kv_cache_v, v_noise_4d, + (0, 0, noise_start, 0)) + + new_cache_len = cache_len + actual_ctx_count + T_noise + max_kv_len = kv_cache_k.shape[2] + + q_4d = q.transpose(1, 0, 2)[jnp.newaxis, :, :, :] + kv_ids = (jnp.arange(max_kv_len) < new_cache_len).astype(jnp.int32) + q_ids = jnp.ones(T_noise, dtype=jnp.int32) + seg_ids = SegmentIds( + q=q_ids[jnp.newaxis, :], + kv=kv_ids[jnp.newaxis, :], + ) + + sm_scale = self.head_dim_original**-0.5 + block_sizes = BlockSizes( + block_q=T_noise, + block_k_major=max_kv_len, + block_k=max_kv_len, + block_b=1, + ) + attn_out = flash_attention( + q_4d, + kv_cache_k, + kv_cache_v, + segment_ids=seg_ids, + causal=False, + sm_scale=sm_scale, + block_sizes=block_sizes, + vmem_limit_bytes=_FA_VMEM_LIMIT, + ) + + attn_out = attn_out[0].transpose(1, 0, 2) + output = self.o_proj(attn_out) + + return output, kv_cache_k, kv_cache_v + + +class DFlashMLP(nnx.Module): + + def __init__(self, config: Qwen3Config, dtype: jnp.dtype, rng: nnx.Rngs): + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + self.gate_proj = nnx.Linear( + hidden_size, + intermediate_size, + use_bias=False, + param_dtype=dtype, + kernel_init=nnx.with_partitioning( + init_fn, (None, ShardingAxisName.MLP_TENSOR)), + rngs=rng, + ) + self.up_proj = nnx.Linear( + hidden_size, + intermediate_size, + use_bias=False, + param_dtype=dtype, + kernel_init=nnx.with_partitioning( + init_fn, (None, ShardingAxisName.MLP_TENSOR)), + rngs=rng, + ) + self.down_proj = nnx.Linear( + intermediate_size, + hidden_size, + use_bias=False, + param_dtype=dtype, + kernel_init=nnx.with_partitioning( + init_fn, (ShardingAxisName.MLP_TENSOR, None)), + rngs=rng, + ) + + def __call__(self, x: jax.Array) -> jax.Array: + return self.down_proj(jax.nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class DFlashDecoderLayer(nnx.Module): + + def __init__( + self, + config: Qwen3Config, + dtype: jnp.dtype, + rng: nnx.Rngs, + mesh: Mesh, + ): + hidden_size = config.hidden_size + rms_norm_eps = config.rms_norm_eps + + self.input_layernorm = nnx.RMSNorm( + hidden_size, + epsilon=rms_norm_eps, + param_dtype=dtype, + scale_init=nnx.with_partitioning(init_fn, (None, )), + rngs=rng, + ) + self.self_attn = DFlashAttention( + config=config, + dtype=dtype, + rng=rng, + mesh=mesh, + ) + self.post_attention_layernorm = nnx.RMSNorm( + hidden_size, + epsilon=rms_norm_eps, + param_dtype=dtype, + scale_init=nnx.with_partitioning(init_fn, (None, )), + rngs=rng, + ) + self.mlp = DFlashMLP(config=config, dtype=dtype, rng=rng) + + def __call__( + self, + x: jax.Array, + target_hidden: jax.Array, + noise_positions: jax.Array, + ctx_positions: jax.Array, + kv_cache_k: jax.Array, + kv_cache_v: jax.Array, + cache_len: jax.Array, + actual_ctx_count: jax.Array, + ) -> Tuple[jax.Array, jax.Array, jax.Array]: + """Returns (hidden_states, new_kv_cache_k, new_kv_cache_v).""" + residual = x + x = self.input_layernorm(x) + x, kv_cache_k, kv_cache_v = self.self_attn( + x, + target_hidden, + noise_positions, + ctx_positions, + kv_cache_k, + kv_cache_v, + cache_len, + actual_ctx_count, + ) + x = residual + x + + residual = x + x = self.post_attention_layernorm(x) + x = self.mlp(x) + x = residual + x + return x, kv_cache_k, kv_cache_v + + +class DFlashModel(nnx.Module): + + def __init__( + self, + vllm_config: VllmConfig, + rng: nnx.Rngs, + mesh: Mesh, + ) -> None: + spec_config = vllm_config.speculative_config + assert spec_config is not None + hf_config = spec_config.draft_model_config.hf_config + dtype = jnp.bfloat16 + hidden_size = hf_config.hidden_size + rms_norm_eps = hf_config.rms_norm_eps + + self.embed_tokens = nnx.Embed( + num_embeddings=hf_config.vocab_size, + features=hidden_size, + param_dtype=dtype, + embedding_init=nnx.with_partitioning( + init_fn, (ShardingAxisName.VOCAB, None)), + rngs=rng, + ) + + self.layers = nnx.List([ + DFlashDecoderLayer( + config=hf_config, + dtype=dtype, + rng=rng, + mesh=mesh, + ) for _ in range(hf_config.num_hidden_layers) + ]) + + dflash_config = getattr(hf_config, "dflash_config", {}) + target_layer_ids = dflash_config.get("target_layer_ids", None) + num_target_layers = getattr(hf_config, "num_target_layers", None) + if target_layer_ids is not None: + num_context_features = len(target_layer_ids) + elif num_target_layers is not None: + num_context_features = num_target_layers + else: + num_context_features = hf_config.num_hidden_layers + + target_hidden_size = getattr(hf_config, "target_hidden_size", + hidden_size) + fc_in_features = num_context_features * target_hidden_size + + self.fc = nnx.Linear( + fc_in_features, + hidden_size, + use_bias=False, + param_dtype=dtype, + kernel_init=nnx.with_partitioning(init_fn, (None, None)), + rngs=rng, + ) + + self.hidden_norm = nnx.RMSNorm( + hidden_size, + epsilon=rms_norm_eps, + param_dtype=dtype, + scale_init=nnx.with_partitioning(init_fn, (None, )), + rngs=rng, + ) + self.norm = nnx.RMSNorm( + hidden_size, + epsilon=rms_norm_eps, + param_dtype=dtype, + scale_init=nnx.with_partitioning(init_fn, (None, )), + rngs=rng, + ) + + +class DFlashWeightLoader(BaseWeightLoader): + + def __init__(self, vllm_config: VllmConfig, mesh: Mesh): + super().__init__(vllm_config, framework="pt") + self.vllm_config = vllm_config + self.mesh = mesh + + def load_weights(self, model: "DFlashForCausalLM", mappings: dict): + metadata_map = get_default_maps( + self.vllm_config.speculative_config.draft_model_config, + self.mesh, + mappings, + ) + load_hf_weights( + vllm_config=self.vllm_config, + model=model, + metadata_map=metadata_map, + mesh=self.mesh, + is_draft_model=True, + ) + + # If the embedding is not initialized, initialize it with a dummy + # array here to pass jit compilation. The real weights will be shared + # from the target model. + if isinstance(model.model.embed_tokens.embedding.value, + jax.ShapeDtypeStruct): + model.model.embed_tokens.embedding.value = jnp.zeros( + model.model.embed_tokens.embedding.shape, + dtype=model.model.embed_tokens.embedding.dtype, + ) + + +class DFlashForCausalLM(nnx.Module): + """DFlash draft model for speculative decoding on TPU.""" + + WeightLoader = DFlashWeightLoader + + def __init__( + self, + vllm_config: VllmConfig, + rng_key: jax.Array, + mesh: Mesh, + ) -> None: + nnx.Module.__init__(self) + self.vllm_config = vllm_config + self.rng = nnx.Rngs(rng_key) + self.mesh = mesh + + spec_config = vllm_config.speculative_config + assert spec_config is not None + hf_config = spec_config.draft_model_config.hf_config + self.hf_config = hf_config + self.block_size = getattr(hf_config, "block_size", 8) + dflash_config = getattr(hf_config, "dflash_config", {}) + self.mask_token_id = dflash_config.get("mask_token_id", 0) + + self._position_scheme = dflash_config.get("position_scheme", + "incremental") + + self.model = DFlashModel( + vllm_config=vllm_config, + rng=self.rng, + mesh=mesh, + ) + + def __call__( + self, + kv_caches: List[jax.Array], + input_ids: jax.Array, + target_hidden_states: jax.Array, + attention_metadata, + ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]: + """Forward pass for the DFlash draft model. + + ``target_hidden_states`` is a 3-tuple: + (ctx_hidden, cache_len_arr, actual_ctx_count_arr) + where: + ctx_hidden: (T_padded, D) — padded context features. + cache_len_arr: (1,) int32 — valid entries already in KV cache. + actual_ctx_count_arr: (1,) int32 — real (non-padding) context count. + + ``kv_caches`` is a flat list of length ``2 * num_layers``: + [k_cache_0, v_cache_0, k_cache_1, v_cache_1, ...] + Each cache has shape ``(1, num_heads, max_kv_len, head_dim)``. + + Returns: + (kv_caches, hidden_states, [target_hidden_states]) + """ + ctx_hidden, cache_len_arr, actual_ctx_count_arr = target_hidden_states + cache_len = cache_len_arr[0] # scalar + actual_ctx_count = actual_ctx_count_arr[0] # scalar + + noise_emb = self.model.embed_tokens(input_ids) + pos_offset = cache_len if self._position_scheme == "incremental" else 0 + T_padded = ctx_hidden.shape[0] + T_noise = input_ids.shape[0] + ctx_positions = jnp.arange(T_padded, dtype=jnp.int32) + pos_offset + noise_positions = (jnp.arange(T_noise, dtype=jnp.int32) + pos_offset + + actual_ctx_count) + + x = noise_emb + for i, layer in enumerate(self.model.layers): + kv_k = kv_caches[2 * i] + kv_v = kv_caches[2 * i + 1] + x, kv_k, kv_v = layer( + x, + ctx_hidden, + noise_positions, + ctx_positions, + kv_k, + kv_v, + cache_len, + actual_ctx_count, + ) + kv_caches[2 * i] = kv_k + kv_caches[2 * i + 1] = kv_v + + x = self.model.norm(x) + + return kv_caches, x, [] + + def compute_logits(self, hidden_states: jax.Array) -> jax.Array: + """Compute logits using tied embedding weights.""" + return jnp.dot(hidden_states, + self.model.embed_tokens.embedding.value.T) + + def combine_hidden_states(self, hidden_states: jax.Array) -> jax.Array: + """Project concatenated target auxiliary hidden states. + + Args: + hidden_states: (T, num_target_layers * target_hidden_size) + + Returns: + (T, hidden_size) projected + normalised context features. + """ + return self.model.hidden_norm(self.model.fc(hidden_states)) + + def load_weights(self, rng_key: jax.Array): + self.rng = jax.random.key(self.vllm_config.model_config.seed) + + mappings = { + "layers.*.input_layernorm": "model.layers.*.input_layernorm.scale", + "layers.*.self_attn.q_proj": + "model.layers.*.self_attn.q_proj.kernel", + "layers.*.self_attn.k_proj": + "model.layers.*.self_attn.k_proj.kernel", + "layers.*.self_attn.v_proj": + "model.layers.*.self_attn.v_proj.kernel", + "layers.*.self_attn.o_proj": + "model.layers.*.self_attn.o_proj.kernel", + "layers.*.self_attn.q_norm": + "model.layers.*.self_attn.q_norm.scale", + "layers.*.self_attn.k_norm": + "model.layers.*.self_attn.k_norm.scale", + "layers.*.post_attention_layernorm": + "model.layers.*.post_attention_layernorm.scale", + "layers.*.mlp.gate_proj": "model.layers.*.mlp.gate_proj.kernel", + "layers.*.mlp.up_proj": "model.layers.*.mlp.up_proj.kernel", + "layers.*.mlp.down_proj": "model.layers.*.mlp.down_proj.kernel", + "fc": "model.fc.kernel", + "hidden_norm": "model.hidden_norm.scale", + "norm": "model.norm.scale", + "embed_tokens": "model.embed_tokens.embedding", + } + + loader = self.WeightLoader(self.vllm_config, self.mesh) + loader.load_weights(self, mappings) diff --git a/tpu_inference/models/jax/qwen3_dflash.py b/tpu_inference/models/jax/qwen3_dflash.py new file mode 100644 index 0000000000..06bfd108c7 --- /dev/null +++ b/tpu_inference/models/jax/qwen3_dflash.py @@ -0,0 +1,529 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +import jax +import jax.numpy as jnp +from flax import nnx +from jax.sharding import Mesh +from transformers import Qwen3Config +from vllm.config import VllmConfig + +from tpu_inference import utils +from tpu_inference.layers.common.attention_interface import attention +from tpu_inference.layers.common.attention_metadata import AttentionMetadata +from tpu_inference.layers.common.dflash_attention_interface import \ + dflash_concat_attention +from tpu_inference.layers.common.quantization import quantize_kv +from tpu_inference.layers.jax.embed import JaxEmbed +from tpu_inference.layers.jax.linear import JaxEinsum, JaxLinear +from tpu_inference.layers.jax.norm import JaxRmsNorm +from tpu_inference.layers.jax.rope_interface import apply_rope +from tpu_inference.layers.vllm.quantization.configs import VllmQuantConfig +from tpu_inference.logger import init_logger +from tpu_inference.models.jax.qwen2 import Qwen2MLP as Qwen3MLP +from tpu_inference.models.jax.qwen3 import \ + _build_target_layer_ids as _build_target_layer_ids_shared +from tpu_inference.models.jax.qwen3 import \ + _get_dflash_target_layer_ids as _get_dflash_target_layer_ids_shared +from tpu_inference.models.jax.utils.weight_utils import (BaseWeightLoader, + get_default_maps, + load_hf_weights) + +logger = init_logger(__name__) + +init_fn = nnx.initializers.uniform() + + +def _build_target_layer_ids(num_target_layers: int, + num_draft_layers: int) -> list[int]: + return _build_target_layer_ids_shared(num_target_layers, num_draft_layers) + + +def _get_dflash_target_layer_ids( + draft_hf_config: Qwen3Config, + target_num_layers: int, +) -> list[int]: + return _get_dflash_target_layer_ids_shared(target_num_layers, + draft_hf_config) + + +class Qwen3DFlashAttention(nnx.Module): + + def __init__(self, config: Qwen3Config, dtype: jnp.dtype, rng: nnx.Rngs, + mesh: Mesh, kv_cache_dtype: str, + quant_config: VllmQuantConfig, dflash_attention_impl: str, + max_query_len: int): + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.rope_theta = config.rope_theta + self.rope_scaling = getattr(config, "rope_scaling", None) + self.rms_norm_eps = config.rms_norm_eps + + self.head_dim_original = getattr(config, "head_dim", + self.hidden_size // self.num_heads) + self.head_dim = utils.get_padded_head_dim(self.head_dim_original) + + sharding_size = mesh.shape["model"] + self.num_heads = utils.get_padded_num_heads(self.num_heads, + sharding_size) + self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads, + sharding_size) + + self.mesh = mesh + self.dflash_attention_impl = dflash_attention_impl + if max_query_len <= 0: + raise ValueError(f"{max_query_len=} must be positive.") + self.max_query_len = int(max_query_len) + + self.q_proj = JaxEinsum( + "TD,DNH->TNH", + (self.hidden_size, self.num_heads, self.head_dim), + param_dtype=dtype, + kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)), + rngs=rng, + quant_config=quant_config, + ) + self.q_norm = JaxRmsNorm( + self.head_dim, + epsilon=self.rms_norm_eps, + param_dtype=dtype, + scale_init=nnx.with_partitioning(init_fn, (None, )), + rngs=rng, + quant_config=quant_config, + ) + self.k_proj = JaxEinsum( + "TD,DKH->TKH", + (self.hidden_size, self.num_kv_heads, self.head_dim), + param_dtype=dtype, + kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)), + rngs=rng, + quant_config=quant_config, + ) + self.k_norm = JaxRmsNorm( + self.head_dim, + epsilon=self.rms_norm_eps, + param_dtype=dtype, + scale_init=nnx.with_partitioning(init_fn, (None, )), + rngs=rng, + quant_config=quant_config, + ) + self.v_proj = JaxEinsum( + "TD,DKH->TKH", + (self.hidden_size, self.num_kv_heads, self.head_dim), + param_dtype=dtype, + kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)), + rngs=rng, + quant_config=quant_config, + ) + self.o_proj = JaxEinsum( + "TNH,NHD->TD", + (self.num_heads, self.head_dim, self.hidden_size), + param_dtype=dtype, + kernel_init=nnx.with_partitioning(init_fn, ("model", None, None)), + rngs=rng, + quant_config=quant_config, + ) + + self._q_scale = 1.0 + self._k_scale = 1.0 + self._v_scale = 1.0 + self.kv_cache_quantized_dtype = None + if kv_cache_dtype != "auto": + self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype( + kv_cache_dtype) + + def __call__( + self, + kv_cache: jax.Array, + hidden_states: jax.Array, + target_hidden_states: jax.Array, + attention_metadata: AttentionMetadata, + ) -> Tuple[jax.Array, jax.Array]: + md = attention_metadata + q = self.q_proj(hidden_states) + q = self.q_norm(q) + q = apply_rope(q, md.input_positions, self.head_dim_original, + self.rope_theta, self.rope_scaling) + + if target_hidden_states.shape[0] != hidden_states.shape[0]: + raise ValueError( + "DFlash currently expects target/noise token counts to match, " + f"got {target_hidden_states.shape[0]=} {hidden_states.shape[0]=}" + ) + + k_ctx = self.k_proj(target_hidden_states) + k_ctx = self.k_norm(k_ctx) + k_ctx = apply_rope(k_ctx, md.input_positions, self.head_dim_original, + self.rope_theta, self.rope_scaling) + v_ctx = self.v_proj(target_hidden_states) + + k_noise = self.k_proj(hidden_states) + k_noise = self.k_norm(k_noise) + k_noise = apply_rope(k_noise, md.input_positions, + self.head_dim_original, self.rope_theta, + self.rope_scaling) + v_noise = self.v_proj(hidden_states) + + if self.dflash_attention_impl == "additive_legacy": + k = k_ctx + k_noise + v = v_ctx + v_noise + + q_scale = k_scale = v_scale = None + if self.kv_cache_quantized_dtype: + k_scale = self._k_scale + v_scale = self._v_scale + k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, + k_scale, v_scale) + + new_kv_cache, outputs = attention( + kv_cache, + q, + k, + v, + attention_metadata, + self.mesh, + self.head_dim_original, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, + ) + elif self.dflash_attention_impl == "concat_dense": + outputs = dflash_concat_attention( + q, + k_ctx, + k_noise, + v_ctx, + v_noise, + attention_metadata, + max_query_len=self.max_query_len, + sm_scale=self.head_dim_original**-0.5, + ) + + q_scale = k_scale = v_scale = None + k_for_cache = k_noise + v_for_cache = v_noise + if self.kv_cache_quantized_dtype: + k_scale = self._k_scale + v_scale = self._v_scale + k_for_cache, v_for_cache = quantize_kv( + self.kv_cache_quantized_dtype, k_for_cache, v_for_cache, + k_scale, v_scale) + + # Keep existing draft KV-cache update behavior using the noise + # stream while computing DFlash outputs from concat K/V semantics. + new_kv_cache, _ = attention( + kv_cache, + q, + k_for_cache, + v_for_cache, + attention_metadata, + self.mesh, + self.head_dim_original, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, + ) + else: + raise ValueError( + f"Unsupported {self.dflash_attention_impl=}. " + "Expected one of {'concat_dense', 'additive_legacy'}.") + + o = self.o_proj(outputs) + return new_kv_cache, o + + +class Qwen3DFlashDecoderLayer(nnx.Module): + + def __init__(self, config: Qwen3Config, dtype: jnp.dtype, rng: nnx.Rngs, + mesh: Mesh, kv_cache_dtype: str, + quant_config: VllmQuantConfig, dflash_attention_impl: str, + max_query_len: int): + self.input_layernorm = JaxRmsNorm( + config.hidden_size, + epsilon=config.rms_norm_eps, + param_dtype=dtype, + scale_init=nnx.with_partitioning(init_fn, (None, )), + rngs=rng, + quant_config=quant_config, + ) + self.self_attn = Qwen3DFlashAttention( + config=config, + dtype=dtype, + rng=rng, + mesh=mesh, + kv_cache_dtype=kv_cache_dtype, + quant_config=quant_config, + dflash_attention_impl=dflash_attention_impl, + max_query_len=max_query_len) + self.post_attention_layernorm = JaxRmsNorm( + config.hidden_size, + epsilon=config.rms_norm_eps, + param_dtype=dtype, + scale_init=nnx.with_partitioning(init_fn, (None, )), + rngs=rng, + quant_config=quant_config, + ) + self.mlp = Qwen3MLP( + config=config, + dtype=dtype, + rng=rng, + quant_config=quant_config, + ) + + def __call__( + self, + kv_cache: jax.Array, + hidden_states: jax.Array, + target_hidden_states: jax.Array, + attention_metadata: AttentionMetadata, + ) -> Tuple[jax.Array, jax.Array]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + kv_cache, attn_out = self.self_attn( + kv_cache, + hidden_states, + target_hidden_states, + attention_metadata, + ) + hidden_states = residual + attn_out + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return kv_cache, hidden_states + + +class Qwen3DFlashModel(nnx.Module): + + def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs, + mesh: Mesh) -> None: + draft_model_config = vllm_config.speculative_config.draft_model_config + hf_config = draft_model_config.hf_config + target_model_config = vllm_config.model_config + dtype = target_model_config.dtype + additional_config = getattr(vllm_config, "additional_config", + None) or {} + + self.embed_tokens = JaxEmbed( + num_embeddings=target_model_config.get_vocab_size(), + features=hf_config.hidden_size, + param_dtype=dtype, + embedding_init=nnx.with_partitioning(init_fn, ("model", None)), + rngs=rng, + quant_config=vllm_config.quant_config, + ) + + self.layers = [ + # DFlash proposes up to `num_speculative_tokens + 1` query tokens + # per request per model invocation. + Qwen3DFlashDecoderLayer( + config=hf_config, + dtype=dtype, + rng=rng, + mesh=mesh, + kv_cache_dtype=vllm_config.cache_config.cache_dtype, + quant_config=vllm_config.quant_config, + dflash_attention_impl=additional_config.get( + "dflash_attention_impl", "concat_dense"), + max_query_len=int( + vllm_config.speculative_config.num_speculative_tokens) + 1, + ) for _ in range(hf_config.num_hidden_layers) + ] + + target_layer_ids = _get_dflash_target_layer_ids( + hf_config, target_model_config.hf_config.num_hidden_layers) + self.target_layer_ids = tuple(target_layer_ids) + target_hidden_size = target_model_config.get_hidden_size() + combined_hidden_size = target_hidden_size * len(target_layer_ids) + + self.fc = JaxLinear( + combined_hidden_size, + hf_config.hidden_size, + use_bias=False, + param_dtype=dtype, + kernel_init=nnx.with_partitioning(init_fn, (None, "model")), + rngs=rng, + quant_config=vllm_config.quant_config, + ) + self.hidden_norm = JaxRmsNorm( + hf_config.hidden_size, + epsilon=hf_config.rms_norm_eps, + param_dtype=dtype, + scale_init=nnx.with_partitioning(init_fn, (None, )), + rngs=rng, + quant_config=vllm_config.quant_config, + ) + self.norm = JaxRmsNorm( + hf_config.hidden_size, + epsilon=hf_config.rms_norm_eps, + param_dtype=dtype, + scale_init=nnx.with_partitioning(init_fn, (None, )), + rngs=rng, + quant_config=vllm_config.quant_config, + ) + + def __call__( + self, + kv_caches: List[jax.Array], + input_ids: jax.Array, + target_hidden_states: jax.Array, + attention_metadata: AttentionMetadata, + ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]: + hidden_states = self.embed_tokens(input_ids) + + num_draft_layers = len(self.layers) + draft_kv_start = max(0, len(kv_caches) - num_draft_layers) + for i, layer in enumerate(self.layers): + kv_idx = draft_kv_start + i + kv_cache, hidden_states = layer( + kv_caches[kv_idx], + hidden_states, + target_hidden_states, + attention_metadata, + ) + kv_caches[kv_idx] = kv_cache + + residual = hidden_states + hidden_states = self.norm(hidden_states) + return kv_caches, hidden_states, [residual] + + def combine_hidden_states(self, hidden_states: jax.Array) -> jax.Array: + hidden_states = self.fc(hidden_states) + hidden_states = self.hidden_norm(hidden_states) + return hidden_states + + +class Qwen3DFlashWeightLoader(BaseWeightLoader): + + def __init__(self, vllm_config: VllmConfig, mesh: Mesh): + super().__init__(vllm_config, framework="pt") + self.vllm_config = vllm_config + self.mesh = mesh + + def load_weights(self, model: "Qwen3DFlashForCausalLM", mappings: dict): + metadata_map = get_default_maps( + self.vllm_config.speculative_config.draft_model_config, self.mesh, + mappings) + + # We only load the subset needed by the DFlash draft model path. + # Support both raw DFlash checkpoints (e.g., `layers.*`) and + # `model.*`-prefixed variants. + filter_regex = ( + r"^((model\.)?embed_tokens\.weight|" + r"(model\.)?layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight|" + r"(model\.)?layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj|q_norm|k_norm)\.weight|" + r"(model\.)?layers\.\d+\.mlp\.(gate_proj|up_proj|down_proj)\.weight|" + r"(model\.)?(fc|hidden_norm|norm)\.weight)$") + + load_hf_weights( + vllm_config=self.vllm_config, + model=model, + metadata_map=metadata_map, + mesh=self.mesh, + filter_regex=filter_regex, + is_draft_model=True, + ) + + +class Qwen3DFlashForCausalLM(nnx.Module): + WeightLoader = Qwen3DFlashWeightLoader + + def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, + mesh: Mesh): + self.vllm_config = vllm_config + self.rng = nnx.Rngs(rng_key) + self.mesh = mesh + self.model = Qwen3DFlashModel( + vllm_config=vllm_config, + rng=self.rng, + mesh=mesh, + ) + + def __call__( + self, + kv_caches: List[jax.Array], + input_ids: jax.Array, + target_hidden_states: jax.Array, + attention_metadata: AttentionMetadata, + ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]: + return self.model( + kv_caches, + input_ids, + target_hidden_states, + attention_metadata, + ) + + def compute_logits(self, hidden_states: jax.Array) -> jax.Array: + return self.model.embed_tokens.decode(hidden_states) + + def combine_hidden_states(self, hidden_states: jax.Array) -> jax.Array: + return self.model.combine_hidden_states(hidden_states) + + def load_weights(self, _rng_key: jax.Array): + mappings = { + "embed_tokens": "model.embed_tokens.weight", + "model.embed_tokens": "model.embed_tokens.weight", + "layers.*.input_layernorm": + "model.layers.*.input_layernorm.weight", + "model.layers.*.input_layernorm": + "model.layers.*.input_layernorm.weight", + "layers.*.post_attention_layernorm": + "model.layers.*.post_attention_layernorm.weight", + "model.layers.*.post_attention_layernorm": + "model.layers.*.post_attention_layernorm.weight", + "layers.*.self_attn.q_proj": + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.q_proj": + "model.layers.*.self_attn.q_proj.weight", + "layers.*.self_attn.k_proj": + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.k_proj": + "model.layers.*.self_attn.k_proj.weight", + "layers.*.self_attn.v_proj": + "model.layers.*.self_attn.v_proj.weight", + "model.layers.*.self_attn.v_proj": + "model.layers.*.self_attn.v_proj.weight", + "layers.*.self_attn.o_proj": + "model.layers.*.self_attn.o_proj.weight", + "model.layers.*.self_attn.o_proj": + "model.layers.*.self_attn.o_proj.weight", + "layers.*.self_attn.q_norm": + "model.layers.*.self_attn.q_norm.weight", + "model.layers.*.self_attn.q_norm": + "model.layers.*.self_attn.q_norm.weight", + "layers.*.self_attn.k_norm": + "model.layers.*.self_attn.k_norm.weight", + "model.layers.*.self_attn.k_norm": + "model.layers.*.self_attn.k_norm.weight", + "layers.*.mlp.gate_proj": "model.layers.*.mlp.gate_proj.weight", + "model.layers.*.mlp.gate_proj": + "model.layers.*.mlp.gate_proj.weight", + "layers.*.mlp.up_proj": "model.layers.*.mlp.up_proj.weight", + "model.layers.*.mlp.up_proj": "model.layers.*.mlp.up_proj.weight", + "layers.*.mlp.down_proj": "model.layers.*.mlp.down_proj.weight", + "model.layers.*.mlp.down_proj": + "model.layers.*.mlp.down_proj.weight", + "fc": "model.fc.weight", + "model.fc": "model.fc.weight", + "hidden_norm": "model.hidden_norm.weight", + "model.hidden_norm": "model.hidden_norm.weight", + "norm": "model.norm.weight", + "model.norm": "model.norm.weight", + } + + loader = self.WeightLoader(self.vllm_config, self.mesh) + loader.load_weights(self, mappings) diff --git a/tpu_inference/spec_decode/jax/dflash.py b/tpu_inference/spec_decode/jax/dflash.py new file mode 100644 index 0000000000..2bc6c70545 --- /dev/null +++ b/tpu_inference/spec_decode/jax/dflash.py @@ -0,0 +1,339 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DFlash proposer for speculative decoding on JAX/TPU.""" + +import functools +from dataclasses import replace +from typing import Any, Optional + +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx +from jax import lax +from jax.sharding import NamedSharding, PartitionSpec +from vllm.config import VllmConfig + +from tpu_inference import utils +from tpu_inference.layers.common.attention_metadata import AttentionMetadata +from tpu_inference.layers.common.sharding import ShardingAxisName +from tpu_inference.logger import init_logger +from tpu_inference.models.common.model_loader import get_model +from tpu_inference.utils import device_array, get_mesh_shape_product + +logger = init_logger(__name__) + + +class DFlashProposer: + """Proposer for speculative decoding using DFlash block diffusion.""" + + def __init__( + self, + vllm_config: VllmConfig, + runner: Any, + ): + self.vllm_config = vllm_config + self.speculative_config = vllm_config.speculative_config + assert self.speculative_config is not None + self.draft_model_config = self.speculative_config.draft_model_config + self.method = self.speculative_config.method + + self.runner = runner + self.mesh = runner.mesh + self.num_speculative_tokens = ( + self.speculative_config.num_speculative_tokens) + + hf_config = self.draft_model_config.hf_config + self.block_size = getattr(hf_config, "block_size", + self.num_speculative_tokens + 1) + dflash_config = getattr(hf_config, "dflash_config", {}) + self.mask_token_id = dflash_config.get("mask_token_id", 0) + self.hidden_size = hf_config.hidden_size + self.num_layers = hf_config.num_hidden_layers + + self.rng_key = jax.random.key(self.vllm_config.model_config.seed) + self.max_num_tokens = runner.max_num_tokens + self.max_model_len = runner.max_model_len + + # Context length tracking (host-side counter) + self._ctx_len: int = 0 + + # On-device KV caches (allocated in load_model) + self._draft_kv_caches: Optional[list[jax.Array]] = None + self._cache_len: int = 0 + self._max_kv_len: int = 0 + + # Track previous seq_len for GPU-compatible crop semantics. + # GPU calls past_key_values_draft.crop(start) AFTER each forward + # pass, where start = beginning of the CURRENT iteration's block. + # This equals the seq_len from the PREVIOUS call to prepare_inputs. + # We must match this: cache_len = prev_seq_len, not current seq_len. + self._prev_seq_len: int = 0 + + def load_model(self, target_model: Any) -> None: + """Load the DFlash draft model and share embeddings from target.""" + ( + self.model_fn, + self.compute_logits_fn, + self.combine_hidden_states_fn, + _, + self.state, + _, + _, + ) = get_model(self.vllm_config, + self.rng_key, + self.mesh, + is_draft_model=True) + + # Share the target model's embedding with the draft model. + draft_embed = getattr(self.state.model, "embed_tokens", None) + target_embed = getattr(target_model.model, "embed_tokens", None) + if target_embed is None: + target_embed = getattr(target_model.model, "embed", None) + if target_embed is not None: + if draft_embed is None or not jnp.any(draft_embed.embedding): + logger.info( + "Sharing target model embedding with DFlash draft model.") + self.state.model.embed_tokens = target_embed + elif jnp.array_equal(draft_embed.embedding, + target_embed.embedding): + logger.info("Draft embedding identical to target; sharing.") + self.state.model.embed_tokens = target_embed + + # Allocate on-device KV caches + hf_config = self.draft_model_config.hf_config + + sharding_size = get_mesh_shape_product(self.mesh, + ShardingAxisName.MLP_TENSOR) + num_heads = utils.get_padded_num_heads(hf_config.num_attention_heads, + sharding_size) + head_dim_orig = getattr( + hf_config, "head_dim", + hf_config.hidden_size // hf_config.num_attention_heads) + head_dim = utils.get_padded_head_dim(head_dim_orig) + + self._max_kv_len = self._next_padded_size(self.max_model_len) + cache_shape = (1, num_heads, self._max_kv_len, head_dim) + self._draft_kv_caches = [] + for _ in range(self.num_layers): + k_cache = jnp.zeros(cache_shape, dtype=jnp.bfloat16) + v_cache = jnp.zeros(cache_shape, dtype=jnp.bfloat16) + self._draft_kv_caches.append(k_cache) + self._draft_kv_caches.append(v_cache) + self._cache_len = 0 + + logger.info( + "Allocated DFlash on-device KV caches: %d layers, shape %s", + self.num_layers, + cache_shape, + ) + + @functools.partial(jax.jit, static_argnums=(0,)) + def _project_aux_hidden( + self, state: nnx.State, + aux_hidden_states: tuple[jax.Array, ...]) -> jax.Array: + """Project and normalise auxiliary hidden states.""" + raw = jnp.concatenate(aux_hidden_states, axis=-1) + return self.combine_hidden_states_fn(state, raw) + + @staticmethod + def _next_padded_size(n: int) -> int: + """Round n up to the next power-of-two (min 16).""" + if n <= 16: + return 16 + p = 16 + while p < n: + p *= 2 + return p + + @functools.partial(jax.jit, static_argnums=(0, 3, 4)) + def _build_noise_block( + self, + seq_len_arr: jax.Array, + next_token_ids: jax.Array, + mask_token_id: int, + block_size: int, + ) -> tuple[jax.Array, jax.Array]: + """Build noise block and positions (JIT-compiled).""" + seq_len = seq_len_arr[0] + first_token = next_token_ids[0] + noise_input_ids = jnp.full((block_size, ), + mask_token_id, + dtype=jnp.int32) + noise_input_ids = noise_input_ids.at[0].set(first_token) + noise_positions = jnp.arange(block_size, dtype=jnp.int32) + seq_len + return noise_input_ids, noise_positions + + def prepare_inputs( + self, + attn_metadata: AttentionMetadata, + input_ids: jax.Array, + aux_hidden_states: tuple[jax.Array, ...], + next_token_ids: jax.Array, + num_rejected_tokens: Optional[jax.Array] = None, + ) -> tuple[jax.Array, jax.Array, jax.Array, AttentionMetadata]: + """Prepare DFlash inputs with on-device KV cache.""" + assert aux_hidden_states is not None and len(aux_hidden_states) > 0 + + # 1. Current sequence length + seq_len_jax = attn_metadata.seq_lens[0] + seq_len = int(jax.device_get(seq_len_jax)) + + # 2. Crop cache to match GPU DynamicCache.crop(start) semantics. + # + # GPU reference (zhongyan_dev/dflash/model/dflash.py line 246): + # past_key_values_draft.crop(start) + # where `start` = beginning of the CURRENT block = position of + # the first accepted token from the previous iteration. + # + # After crop, GPU cache_seq_len = start, which equals the seq_len + # from the PREVIOUS prepare_inputs call (not the current one). + # Context + noise are then written starting from this position. + # + # Bug was: self._cache_len = seq_len (CURRENT accepted position), + # which left stale noise K/V entries from the previous iteration + # in positions [prev_seq_len, seq_len) and shifted all subsequent + # RoPE positions, accumulating errors every iteration. + if self._prev_seq_len > 0: + self._cache_len = self._prev_seq_len + + if seq_len < self._ctx_len: + self._ctx_len = seq_len + self._prev_seq_len = seq_len + + # 3. Project new auxiliary hidden states (on-device, JIT'd) + projected = self._project_aux_hidden(self.state, aux_hidden_states) + + # 4. Compute context update — slicing and padding stay on device + # to avoid host<->TPU transfer overhead. + num_new = seq_len - self._ctx_len + if num_new <= 0: + # Full rejection — trim context tracking, use zero placeholder. + # Noise writes at cache_len + 0, completely overwriting padding. + self._ctx_len = seq_len + self._cache_len = min(self._cache_len, seq_len) + actual_new_ctx_count = 0 + new_ctx_jax = device_array( + self.mesh, + jnp.zeros((16, self.hidden_size), dtype=jnp.bfloat16), + ) + else: + end = min(self._ctx_len + num_new, self.max_model_len) + n_copy = end - self._ctx_len + actual_new_ctx_count = n_copy + self._ctx_len = end + + # 5. Slice and pad on device — no host<->TPU transfer. + # Padding to power-of-2 sizes (16/32/64/128) means JIT only + # traces ~4 unique shapes, eliminating per-token retracing. + ctx = projected[:n_copy].astype(jnp.bfloat16) + padded_size = self._next_padded_size(n_copy) + if padded_size > n_copy: + pad = jnp.zeros( + (padded_size - n_copy, self.hidden_size), + dtype=jnp.bfloat16, + ) + ctx = jnp.concatenate([ctx, pad], axis=0) + new_ctx_jax = device_array(self.mesh, ctx) + + # 6. Build noise block + seq_len_arr = device_array(self.mesh, + np.array([seq_len], dtype=np.int32)) + noise_input_ids, noise_positions = self._build_noise_block( + seq_len_arr, + next_token_ids, + self.mask_token_id, + self.block_size, + ) + + # 7. Pack target_hidden_states as 3-tuple (always same pytree shape) + cache_len_arr = device_array( + self.mesh, np.array([self._cache_len], dtype=np.int32)) + actual_ctx_count_arr = device_array( + self.mesh, np.array([actual_new_ctx_count], dtype=np.int32)) + target_hidden = (new_ctx_jax, cache_len_arr, actual_ctx_count_arr) + + # 8. Build draft attention metadata + num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups) + draft_kv_cache_group_id = num_kv_cache_groups - 1 + block_tables = ( + self.runner.input_batch.block_table[draft_kv_cache_group_id]. + get_cpu_tensor().reshape(-1)) + num_reqs = attn_metadata.seq_lens.shape[0] + draft_attn_metadata = replace( + attn_metadata, + input_positions=noise_positions, + query_start_loc=jnp.array([0, self.block_size], dtype=jnp.int32), + block_tables=device_array(self.mesh, block_tables), + ) + + dummy_last_indices = jnp.zeros(num_reqs, dtype=jnp.int32) + return ( + target_hidden, + noise_input_ids, + dummy_last_indices, + draft_attn_metadata, + ) + + @functools.partial(jax.jit, static_argnums=(0, )) + def _sample_block_draft_tokens( + self, + state: nnx.State, + hidden_states: jax.Array, + ) -> jax.Array: + """Greedy-sample draft tokens from the block output.""" + draft_hidden = hidden_states[1:1 + self.num_speculative_tokens] + logits = self.compute_logits_fn(state, draft_hidden, None) + draft_ids = jnp.argmax(logits, axis=-1) + return lax.with_sharding_constraint( + draft_ids, NamedSharding(self.mesh, PartitionSpec())) + + def propose( + self, + kv_caches: list[jax.Array], + input_ids: jax.Array, + attn_metadata: AttentionMetadata, + last_token_indices: jax.Array, + target_hidden_states, + ) -> tuple[list[jax.Array], jnp.ndarray]: + """Generate all draft tokens in one forward pass.""" + # Use our own on-device KV caches + draft_kv_caches, hidden_states, _ = self.model_fn( + self.state, + self._draft_kv_caches, + input_ids, + target_hidden_states, + attn_metadata, + ) + + # Update cached references + self._draft_kv_caches = draft_kv_caches + + # Update cache_len: model wrote actual_ctx_count + T_noise entries. + # This will be corrected at the start of the next prepare_inputs + # to match the actual accepted seq_len. + _, cache_len_arr, actual_ctx_count_arr = target_hidden_states + old_cache_len = int(jax.device_get(cache_len_arr)[0]) + actual_ctx_count = int(jax.device_get(actual_ctx_count_arr)[0]) + T_noise = self.block_size + self._cache_len = old_cache_len + actual_ctx_count + T_noise + + draft_token_ids = self._sample_block_draft_tokens( + self.state, hidden_states) + + if draft_token_ids.ndim == 1: + draft_token_ids = draft_token_ids[jnp.newaxis, :] + + # Pass the FRAMEWORK kv_caches through unchanged + return kv_caches, draft_token_ids