Skip to content

Kimi K2.x online cross-node datagen over RDMA (retain in-process hidden-states generator)#1

Open
k-l-lambda wants to merge 100 commits into
mainfrom
feature/kimi-rdma
Open

Kimi K2.x online cross-node datagen over RDMA (retain in-process hidden-states generator)#1
k-l-lambda wants to merge 100 commits into
mainfrom
feature/kimi-rdma

Conversation

@k-l-lambda

Copy link
Copy Markdown

Summary

This PR brings the novitalabs fork line (Kimi K2.x speculative-decoding work) onto main. The central, deliberate divergence from upstream is that we retain the in-process VllmHiddenStatesGenerator instead of moving to the disk-only offline .safetensors pipeline — because our training topology streams hidden states online, cross-node, over RDMA (Mooncake) rather than through a shared disk.

99 commits, 83 files (+11.5k / −1.4k). It is intended to be merged as the fork's develop-equivalent baseline, not as an upstream contribution.

Why we keep vllm_hidden_states_generator.py (RDMA path)

Upstream removed the in-process generator (PR vllm-project#433 / 8fdee2d, in v0.5.0) in favor of a standalone vLLM server that writes per-sample .safetensors to disk, with the trainer reading them back offline. That is a sound choice for offline batch generation, but it has no in-process hook on the hidden states, which is exactly what our online cross-node transport needs.

Our topology (e.g. Kimi K2.6, ~1T MoE verifier + Eagle3-MLA draft):

  • rollout/datagen node (TP=8) produces hidden states in-process via VllmHiddenStatesGenerator
  • trainer node (separate host) consumes them
  • control plane: Redis Streams (sample-ready events + bounded-lag backpressure)
  • data plane: Mooncake over RDMA — tensors move node-to-node from the producer's memory segment; disk .safetensors is never the primary transport

The cost is that the generator couples to private vllm.v1.* internals and is hand-ported per vLLM release (e.g. eos_token_id dropped for 0.17.1; the current form targets vLLM 0.20 with block_hasher). This trade-off and the upstream-vs-fork comparison are documented in docs/online_datagen_rdma.md and the README "Fork Notes" section.

What's included

Online / dynamic training data generation

  • Retain + maintain data_generation/vllm_hidden_states_generator.py (vLLM 0.20 compat: VLLM_WORKER_MULTIPROC_METHOD=spawn, v1 scheduler/executor loop, block_hasher); Kimi tokenizer + vLLM Request API support.
  • Dynamic-mode training (pre-tokenized HF datasets, prefetcher, failure tracking, enforce-eager flag), gen_hf_dataset_from_csv, apilog/ZClawBench dataset prep, eval/acceptance scripts.

Kimi K2 / MTP + Eagle3-MLA model support

  • models/mtp/ (config + core), models/eagle3/ MLA attention/core/model_definitions, K2 MTP reference config under scripts/k2_mtp_config/.
  • Checkpoint conversion (recover_d2t_and_convert_checkpoint), d2t/vocab-mapping fixes for lightseek weights.

Config / Pydantic robustness (the fixes that make Eagle3SpeculatorConfig load)

  • __pydantic_fields_set__/__new__ pre-init chain resolving the Pydantic ⇄ transformers.PretrainedConfig MRO conflict.
  • torch added to the Pydantic type namespace in model_rebuild; ClassVar stripping; deferred-ref schema rebuild; save_pretrained/serialization hardening; num_workers-gated prefetch settings; resume that skips incomplete checkpoint dirs.

Docs

  • docs/online_datagen_rdma.md — full rationale for the in-process generator + RDMA transport, with a when-to-use-which table.
  • README "Fork Notes (novitalabs/develop)".

Notes for reviewers

  • This is the fork baseline, deliberately divergent from upstream main; it is not meant to be upstreamed as-is.
  • The config-init fixes are the validated path for K2.6 eagle3-mla training; please keep the __new__/__pydantic_fields_set__ chain intact when rebasing.
  • vLLM coupling: the generator targets vLLM 0.20; bumping vLLM will require re-porting the v1-internal call sites.

fynnsu and others added 30 commits March 6, 2026 11:48
- mtp_eval_generate.py: Phase 1 dataset generation + hidden state extraction
- mtp_eval_acceptance.py: Phase 2 standalone MTP forward pass + acceptance rate
- Fix trust_remote_code and vLLM 0.17 Request API compatibility

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add scripts/k2_mtp_config/ with configuration_deepseek.py, modeling_deepseek.py,
and config.json so the eval script works without external model directory.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- K2.5 mode: --mtp-weights for separate mtp.safetensors (INT4 GPTQ dequant)
- V3 mode: --model-dir to extract MTP layer from model shards (FP8 block dequant)
- Backward compatible: K2.5 results verified identical (1/202700)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… 0.17 compat

- mtp_eval_acceptance.py: add V3 mode (--model-dir, FP8 block dequant from shards)
- vllm_hidden_states_generator.py: fix DeviceConfig(device=cuda) and explicit worker_cls
  for vLLM 0.17 compatibility

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…eight from weight_scale_inv

weight_scale_inv was being matched as weight due to no end anchor in alternation.
Reorder to (weight_scale_inv|weight) to fix FP8 expert dequantization.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- src/speculators/models/mtp/: MTPDraftModel + MTPSpeculatorConfig
  - Architecture: embed_tokens + enorm + hnorm + eh_proj + frozen DeepSeek decoder + shared_head
  - Supports CE and KL-div loss
  - Phase A: freeze decoder (K2.5 layer 60), train only ~2.45B MTP-exclusive params
  - Registered as "mtp" alongside eagle3
- src/speculators/train/data.py: add standardize_data_mtp for single-layer hidden states
- examples/mtp_k25_train.py: end-to-end training script for K2.5 MTP Phase A
- src/speculators/models/__init__.py: export MTPDraftModel, MTPSpeculatorConfig

Data pipeline verified: [1, 2048, 7168] BF16 tensors, shift_batch alignment correct
…h fixes

- core.py: build 4D causal attention mask for DeepSeekV3DecoderLayer
- core.py: freeze decoder in eval() mode (MoE gate asserts not training)
- core.py: add train() override to maintain frozen layers in eval mode
- core.py: fix k2_mtp_config path (src/speculators/models/mtp -> 5 levels up)
- core.py: add trust_remote_code=True to AutoConfig.from_pretrained
- examples/mtp_k25_train.py: fix maybe_setup_distributed return value (4-tuple)
- examples/mtp_k25_train.py: use local K2.5 snapshot path
- examples/mtp_k25_train.py: add trust_remote_code=True to AutoConfig
- examples/mtp_k25_train.py: fix setup_metric_logger signature
- examples/mtp_k25_train.py: cast model to bfloat16 before training

Smoke test result (200 samples, 3 epochs, random decoder init):
  Epoch 1: val_loss=5.876, val_acc=23.2%
  Epoch 2: val_loss=4.648, val_acc=34.8%
  Epoch 3: val_loss=4.558, val_acc=37.1%  ← training works end-to-end
- Add _load_k25_layer_weights() helper: finds all layer keys from index.json,
  dequantizes 1152 INT4 compressed-tensors expert weights (weight_packed/scale/shape)
  to BF16, loads 14 direct BF16 keys (attention, norms, gate)
- Replace glob-based load_model_layers(["model.layers.60.*"]) with proper
  enumeration - loads all 1166 layer 60 keys in ~20s
- Add --skip N argument to mtp_eval_generate.py for independent test set generation
  (skip first N samples after shuffle to avoid training set overlap)
core.py:
- mtp_loss_ce: use logits.transpose(1,2) instead of squeeze(0) for batch>1 compat
- forward(): batch_size-aware target_ids padding (new_zeros(B,1) not new_zeros(1,1))
- forward(): batch_size-aware default adjusted_mask (ones(B,T) not ones(1,T))
- forward(): batch_size-aware KL target padding (new_zeros(B,1,V))
- forward(): expand position_ids to [B,T] for decoder compat
- _init_decoder_from_verifier: raise RuntimeError (not warn) when frozen decoder load fails
- _init_decoder_from_verifier: add layer_idx bounds check after negative normalization

config.py:
- architectures default: "MTPSpeculator" -> "MTPDraftModel" (correct class name)
- serialize_decoder_config: to_diff_dict() -> to_dict() for checkpoint portability
- validate_decoder_config: add explicit isinstance(PretrainedConfig) branch + TypeError for invalid types

Note: batch-size fixes are forward-compatible; current training uses batch=1 via
speculators multipack collation and is unaffected.
…e3 training data

- Changed --layer-id (single int) to --layer-ids (nargs=+)
- Multi-layer: saves hidden_states as list of tensors (Eagle3 format)
- Single-layer: saves hidden_states as single tensor (MTP format, backward-compat)
- Default: [-1] (last layer, MTP mode)
- Example: --layer-ids 2 30 58 60 (Eagle3 for K2.5 with 61 layers)
…ining

base_components.py:
- Import DeepseekV3DecoderLayer, DeepseekV3RMSNorm, DeepseekV3Config from bundled
  k2_mtp_config (uses importlib to avoid sys.path pollution)
- Register "kimi_k2" in model_classes with MoE decoder + lambda rotary (MLA internal)

model_definitions.py:
- Add DeepseekV3DecoderEagle3FirstLayer: Eagle3 first layer adapted for MLA attention
  - Patches q_a_proj (H -> q_lora_rank) and kv_a_proj_with_mqa (H -> kv_lora_rank+rope_dim)
    to accept 2*H input, instead of standard q/k/v patching
  - Sets _attn_implementation="eager" (required for DeepSeek attention)
  - forward() inherited from Eagle3FirstLayerMixin (embeds/hidden split + norm)
- Register kimi_k2 Eagle3 first layer via override_components()
base_components.py:
- Load modeling_deepseek/configuration_deepseek via importlib, register in sys.modules
- Add _DeepseekRotaryProxy: returns (None,None) so MLA attention computes own rotary
- Add _BlockMaskCompatDecoder: wraps DeepseekV3DecoderLayer to convert Eagle3 BlockMask
  to standard 4D causal tensor (DeepseekV3Attention expects tensor, not BlockMask)

model_definitions.py:
- Use _DsBlockMaskDecoder as base for Eagle3 first layer
- Add custom forward: BlockMask conversion + reimplemented Eagle3 first-layer logic
  that handles 3-tuple return from DeepseekV3Attention (output, weights, past_kv)
- Register both kimi_k2 and deepseek_v3 model types

eagle3/core.py: add trust_remote_code=True to AutoConfig.from_pretrained

scripts/train.py:
- Add kimi_k2 to DRAFT_ARCH_CONFIGS using DeepseekV3Config
- Add trust_remote_code=True to both AutoConfig calls

Training status: runs for 40+ steps before triton kernel assertion fails
(vocab indexing issue during compute_accuracy, under investigation)
Eagle3 TTT loop uses cache_position=arange(step*S, (step+1)*S).
For ttt_step>=1, K2.5 MLA interprets large cache_position values as
kv_seq_len offset, causing index out of bounds in triton kernels.
The _BlockMaskCompatDecoder already resets cache_position for subsequent
layers, but the Eagle3 first layer was missing this reset.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Eagle3 core.py creates position_ids = 1 + arange(total_seq_len), making
them 1-indexed with max value = total_seq_len. K2.5 DeepseekV3Attention
rotary_emb.forward() returns cos[:seq_len] (0-indexed, size=seq_len).
When a packed batch is exactly total_seq_len=4096 tokens, cos[4096]
triggers device-side assert (index 4096 into tensor of size 4096).

This crashed deterministically at step 596 — the first batch that
packed exactly to total_seq_len=4096.

Fix: shift position_ids to 0-indexed in both DeepseekV3DecoderEagle3FirstLayer
and _BlockMaskCompatDecoder before passing to self_attn.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…atches

CRITICAL: _BlockMaskCompatDecoder was converting Eagle3 BlockMask to a
simple causal mask, dropping document boundaries. This allowed cross-sequence
attention in multipack batches, inflating training val accuracy (71.2%)
vs true single-sequence accuracy (40.6%).

Changes:
- core.py: add build_packed_attention_mask() that creates block-diagonal
  causal mask from lengths tensor. Used for K2.5/DeepSeek-V3 models;
  Llama/Qwen3 still use BlockMask via flex_attention.
- base_components.py: _BlockMaskCompatDecoder now raises TypeError on
  BlockMask instead of silently converting to simple causal.
- model_definitions.py: Eagle3 first layer also rejects BlockMask.
- mtp/core.py: use document-boundary mask when lengths has multiple docs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- eval_eagle3_checkpoint.py: add --apply-shift flag to apply shift_batch
  (same alignment as training collate_fn: input_ids[j]=x_{j+1}, hs[j]=g_j)
- core.py build_packed_attention_mask: add lengths.flatten() to handle
  2D batch tensors passed from eval script

Without shift_batch, eval was testing a different task than training
(off-by-one token alignment), causing 37% vs 67% apparent discrepancy.
With shift: independent test set gives 69.5% matching training val 67.5%.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…4-layer data

Layer 0 (layer 2) of K2.5 has std≈0.01 (near-zero signal after embedding).
Layer -1 (layer 60) has std≈2.8 (rich semantic representation).
Using h[0] for MTP training caused model to learn from zero signal.
Also fix mtp_eval_acceptance.py to consistently use h[-1].

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… collate_fn)

The eval script used manual h[:-2]/ids[1:-1]/mask[2:] alignment which
had an off-by-one in the loss_mask vs training shift_batch alignment.
This caused 45.3% eval vs 68.0% training val on the same data.

With shift_batch: independent test gives 68.4%, matching training val.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…mask arg)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
In speculators mode, the eval script used build_mtp_model which created
a random decoder layer (1166 missing keys from checkpoint). Now uses
MTPDraftModel.from_training_args which properly loads K2.5 layer 60
frozen decoder weights before loading the checkpoint trainable weights.

Added --verifier-name-or-path arg (defaults to K2.5 local snapshot).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… large-scale data gen

- Pre-filter rows with estimated input > seq_length (char/token ratio ~3.5)
- Sample preferring has-response rows to minimize vLLM generation calls
- Fix multimodal content (list instead of str) in text extraction

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
k-l-lambda and others added 28 commits April 3, 2026 18:49
…_hf_dataset_from_csv

Left-truncating the prefix corrupts the generation context: K2.5's hidden
states for response tokens differ from what they were at generation time.
Right-truncating the response preserves the full prefix so the context seen
by vLLM during dynamic training matches the original generation context.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Training command for Eagle3 dynamic hidden states training on
export_3c6b3075_16k_hf (276K samples, 16K context).
Pretrain weights: eagle3_v2_apilog/7. Loss: CE, ttt_steps=1.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Resolves 'NameError: name torch is not defined' when reloading Pydantic schemas
in vLLM environments where torch type annotations are not in global scope.

The fix:
- Import torch at top of config.py and pydantic_utils.py
- Pass _types_namespace={'torch': torch} to model_rebuild in reload_schema

This allows Pydantic to resolve torch type hints during schema validation,
fixing incompatibility with vLLM 0.20 and similar environments.

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
Resolves AttributeError when PretrainedConfig tries to setattr ClassVar fields
during initialization. Pydantic blocks setattr on ClassVar names, so they must
be removed from kwargs before passing to PretrainedConfig.__init__.

This fixes Eagle3SpeculatorConfig initialization failure:
  AttributeError: object has no attribute '__pydantic_fields_set__'

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
…init__

- Add explicit __pydantic_fields_set__ initialization before calling PretrainedConfig.__init__
- Prevents AttributeError when PretrainedConfig.__setattr__ triggers Pydantic's __setattr__
- Use object.__setattr__ for transformers_version to avoid Pydantic interception
- Harden Eagle3 transformer config validator to mirror MTP pattern
  - Return early for PretrainedConfig objects
  - Raise TypeError for unsupported types instead of silently accepting
…retrainedConfig

- Use BaseModel.__init__ directly instead of PydanticClassRegistryMixin.__init__
- This ensures __pydantic_fields_set__ is created before PretrainedConfig can trigger __setattr__
- Simplifies the approach: directly initialize Pydantic state that PretrainedConfig needs
…etrainedConfig.__init__

- Set __pydantic_fields_set__, __pydantic_extra__, __pydantic_private__ manually
- This prevents AttributeError when PretrainedConfig's setattr triggers Pydantic's __setattr__
- The attributes must exist before BaseModel.__init__ attempts to use them
…l attribute setting

- Bypass PretrainedConfig.__init__ which calls setattr() and triggers Pydantic's __setattr__
- Use object.__setattr__ to directly set attributes without Pydantic interception
- This avoids the MRO conflict entirely by not calling PretrainedConfig.__init__
…nit__ calls

- Override __new__ to create instance and set Pydantic internal state first
- This ensures __pydantic_fields_set__ exists before any __init__ chain runs
- PretrainedConfig.__init__ will still be called by MRO but __pydantic_fields_set__ is already ready
…ers > 0

- These DataLoader options are only valid with multiprocessing (num_workers > 0)
- Conditionally add them to avoid ValueError when running single-threaded
@k-l-lambda

Copy link
Copy Markdown
Author

@codex review

…ization

Addresses review findings on the Pydantic/PretrainedConfig default-materialization
path (the __new__ fast-path that pre-seeds __pydantic_fields_set__ skips Pydantic's
own default application, so unset fields leak their class-level FieldInfo into
to_dict/to_diff_dict and crash save_pretrained):

- to_dict no longer mutates instance state. The old _materialize_field_defaults
  added resolved names to __pydantic_fields_set__; that set is Pydantic semantic
  state (fields explicitly supplied), and mutating it during serialization
  corrupts later model_dump(exclude_unset=True)/diff behavior. A field resolved
  to its default is by definition NOT explicitly set, so it is no longer added.
- Use FieldInfo.get_default(call_default_factory=True) instead of hand-resolving
  default/default_factory. This gives Pydantic's proper per-instance (deep-)copy
  of mutable/factory defaults, so two configs no longer share the same default
  transformer_layer_config object.
- A leaked field with no default (PydanticUndefined) now raises instead of being
  silently set to None. That state means a required field was never set (the
  fast-path skipped its validation); fabricating None would serialize an invalid
  config. Surface it instead.
- Collapse the triplicated materialization logic (model_post_init +
  _materialize_field_defaults + to_dict) into the single _materialize_field_defaults
  helper; model_post_init and to_dict both call it.
- validate(): drop the no-op __class_validators__ loop (a Pydantic v1-era pattern
  that does nothing under v2) and instead walk the MRO for a real
  PretrainedConfig.validate INSTANCE method, dispatching to it if present so HF's
  own config validation is not silently shadowed. Still neutralizes the MRO
  collision with Pydantic's deprecated BaseModel.validate classmethod (no such HF
  instance method exists in transformers 5.x, so it stays a safe no-op there).

Tests: add coverage that to_dict is side-effect-free wrt __pydantic_fields_set__
and idempotent, that factory defaults are per-instance (not shared), and that
materialization raises on an unset required field.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants