Kimi K2.x online cross-node datagen over RDMA (retain in-process hidden-states generator)#1
Open
k-l-lambda wants to merge 100 commits into
Open
Kimi K2.x online cross-node datagen over RDMA (retain in-process hidden-states generator)#1k-l-lambda wants to merge 100 commits into
k-l-lambda wants to merge 100 commits into
Conversation
- mtp_eval_generate.py: Phase 1 dataset generation + hidden state extraction - mtp_eval_acceptance.py: Phase 2 standalone MTP forward pass + acceptance rate - Fix trust_remote_code and vLLM 0.17 Request API compatibility Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add scripts/k2_mtp_config/ with configuration_deepseek.py, modeling_deepseek.py, and config.json so the eval script works without external model directory. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- K2.5 mode: --mtp-weights for separate mtp.safetensors (INT4 GPTQ dequant) - V3 mode: --model-dir to extract MTP layer from model shards (FP8 block dequant) - Backward compatible: K2.5 results verified identical (1/202700) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… 0.17 compat - mtp_eval_acceptance.py: add V3 mode (--model-dir, FP8 block dequant from shards) - vllm_hidden_states_generator.py: fix DeviceConfig(device=cuda) and explicit worker_cls for vLLM 0.17 compatibility Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…eight from weight_scale_inv weight_scale_inv was being matched as weight due to no end anchor in alternation. Reorder to (weight_scale_inv|weight) to fix FP8 expert dequantization. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- src/speculators/models/mtp/: MTPDraftModel + MTPSpeculatorConfig - Architecture: embed_tokens + enorm + hnorm + eh_proj + frozen DeepSeek decoder + shared_head - Supports CE and KL-div loss - Phase A: freeze decoder (K2.5 layer 60), train only ~2.45B MTP-exclusive params - Registered as "mtp" alongside eagle3 - src/speculators/train/data.py: add standardize_data_mtp for single-layer hidden states - examples/mtp_k25_train.py: end-to-end training script for K2.5 MTP Phase A - src/speculators/models/__init__.py: export MTPDraftModel, MTPSpeculatorConfig Data pipeline verified: [1, 2048, 7168] BF16 tensors, shift_batch alignment correct
…h fixes - core.py: build 4D causal attention mask for DeepSeekV3DecoderLayer - core.py: freeze decoder in eval() mode (MoE gate asserts not training) - core.py: add train() override to maintain frozen layers in eval mode - core.py: fix k2_mtp_config path (src/speculators/models/mtp -> 5 levels up) - core.py: add trust_remote_code=True to AutoConfig.from_pretrained - examples/mtp_k25_train.py: fix maybe_setup_distributed return value (4-tuple) - examples/mtp_k25_train.py: use local K2.5 snapshot path - examples/mtp_k25_train.py: add trust_remote_code=True to AutoConfig - examples/mtp_k25_train.py: fix setup_metric_logger signature - examples/mtp_k25_train.py: cast model to bfloat16 before training Smoke test result (200 samples, 3 epochs, random decoder init): Epoch 1: val_loss=5.876, val_acc=23.2% Epoch 2: val_loss=4.648, val_acc=34.8% Epoch 3: val_loss=4.558, val_acc=37.1% ← training works end-to-end
- Add _load_k25_layer_weights() helper: finds all layer keys from index.json, dequantizes 1152 INT4 compressed-tensors expert weights (weight_packed/scale/shape) to BF16, loads 14 direct BF16 keys (attention, norms, gate) - Replace glob-based load_model_layers(["model.layers.60.*"]) with proper enumeration - loads all 1166 layer 60 keys in ~20s - Add --skip N argument to mtp_eval_generate.py for independent test set generation (skip first N samples after shuffle to avoid training set overlap)
core.py: - mtp_loss_ce: use logits.transpose(1,2) instead of squeeze(0) for batch>1 compat - forward(): batch_size-aware target_ids padding (new_zeros(B,1) not new_zeros(1,1)) - forward(): batch_size-aware default adjusted_mask (ones(B,T) not ones(1,T)) - forward(): batch_size-aware KL target padding (new_zeros(B,1,V)) - forward(): expand position_ids to [B,T] for decoder compat - _init_decoder_from_verifier: raise RuntimeError (not warn) when frozen decoder load fails - _init_decoder_from_verifier: add layer_idx bounds check after negative normalization config.py: - architectures default: "MTPSpeculator" -> "MTPDraftModel" (correct class name) - serialize_decoder_config: to_diff_dict() -> to_dict() for checkpoint portability - validate_decoder_config: add explicit isinstance(PretrainedConfig) branch + TypeError for invalid types Note: batch-size fixes are forward-compatible; current training uses batch=1 via speculators multipack collation and is unaffected.
…e3 training data - Changed --layer-id (single int) to --layer-ids (nargs=+) - Multi-layer: saves hidden_states as list of tensors (Eagle3 format) - Single-layer: saves hidden_states as single tensor (MTP format, backward-compat) - Default: [-1] (last layer, MTP mode) - Example: --layer-ids 2 30 58 60 (Eagle3 for K2.5 with 61 layers)
…ining
base_components.py:
- Import DeepseekV3DecoderLayer, DeepseekV3RMSNorm, DeepseekV3Config from bundled
k2_mtp_config (uses importlib to avoid sys.path pollution)
- Register "kimi_k2" in model_classes with MoE decoder + lambda rotary (MLA internal)
model_definitions.py:
- Add DeepseekV3DecoderEagle3FirstLayer: Eagle3 first layer adapted for MLA attention
- Patches q_a_proj (H -> q_lora_rank) and kv_a_proj_with_mqa (H -> kv_lora_rank+rope_dim)
to accept 2*H input, instead of standard q/k/v patching
- Sets _attn_implementation="eager" (required for DeepSeek attention)
- forward() inherited from Eagle3FirstLayerMixin (embeds/hidden split + norm)
- Register kimi_k2 Eagle3 first layer via override_components()
base_components.py: - Load modeling_deepseek/configuration_deepseek via importlib, register in sys.modules - Add _DeepseekRotaryProxy: returns (None,None) so MLA attention computes own rotary - Add _BlockMaskCompatDecoder: wraps DeepseekV3DecoderLayer to convert Eagle3 BlockMask to standard 4D causal tensor (DeepseekV3Attention expects tensor, not BlockMask) model_definitions.py: - Use _DsBlockMaskDecoder as base for Eagle3 first layer - Add custom forward: BlockMask conversion + reimplemented Eagle3 first-layer logic that handles 3-tuple return from DeepseekV3Attention (output, weights, past_kv) - Register both kimi_k2 and deepseek_v3 model types eagle3/core.py: add trust_remote_code=True to AutoConfig.from_pretrained scripts/train.py: - Add kimi_k2 to DRAFT_ARCH_CONFIGS using DeepseekV3Config - Add trust_remote_code=True to both AutoConfig calls Training status: runs for 40+ steps before triton kernel assertion fails (vocab indexing issue during compute_accuracy, under investigation)
…t, trust_remote_code)
Eagle3 TTT loop uses cache_position=arange(step*S, (step+1)*S). For ttt_step>=1, K2.5 MLA interprets large cache_position values as kv_seq_len offset, causing index out of bounds in triton kernels. The _BlockMaskCompatDecoder already resets cache_position for subsequent layers, but the Eagle3 first layer was missing this reset. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Eagle3 core.py creates position_ids = 1 + arange(total_seq_len), making them 1-indexed with max value = total_seq_len. K2.5 DeepseekV3Attention rotary_emb.forward() returns cos[:seq_len] (0-indexed, size=seq_len). When a packed batch is exactly total_seq_len=4096 tokens, cos[4096] triggers device-side assert (index 4096 into tensor of size 4096). This crashed deterministically at step 596 — the first batch that packed exactly to total_seq_len=4096. Fix: shift position_ids to 0-indexed in both DeepseekV3DecoderEagle3FirstLayer and _BlockMaskCompatDecoder before passing to self_attn. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…atches CRITICAL: _BlockMaskCompatDecoder was converting Eagle3 BlockMask to a simple causal mask, dropping document boundaries. This allowed cross-sequence attention in multipack batches, inflating training val accuracy (71.2%) vs true single-sequence accuracy (40.6%). Changes: - core.py: add build_packed_attention_mask() that creates block-diagonal causal mask from lengths tensor. Used for K2.5/DeepSeek-V3 models; Llama/Qwen3 still use BlockMask via flex_attention. - base_components.py: _BlockMaskCompatDecoder now raises TypeError on BlockMask instead of silently converting to simple causal. - model_definitions.py: Eagle3 first layer also rejects BlockMask. - mtp/core.py: use document-boundary mask when lengths has multiple docs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- eval_eagle3_checkpoint.py: add --apply-shift flag to apply shift_batch
(same alignment as training collate_fn: input_ids[j]=x_{j+1}, hs[j]=g_j)
- core.py build_packed_attention_mask: add lengths.flatten() to handle
2D batch tensors passed from eval script
Without shift_batch, eval was testing a different task than training
(off-by-one token alignment), causing 37% vs 67% apparent discrepancy.
With shift: independent test set gives 69.5% matching training val 67.5%.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…4-layer data Layer 0 (layer 2) of K2.5 has std≈0.01 (near-zero signal after embedding). Layer -1 (layer 60) has std≈2.8 (rich semantic representation). Using h[0] for MTP training caused model to learn from zero signal. Also fix mtp_eval_acceptance.py to consistently use h[-1]. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… collate_fn) The eval script used manual h[:-2]/ids[1:-1]/mask[2:] alignment which had an off-by-one in the loss_mask vs training shift_batch alignment. This caused 45.3% eval vs 68.0% training val on the same data. With shift_batch: independent test gives 68.4%, matching training val. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…mask arg) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
In speculators mode, the eval script used build_mtp_model which created a random decoder layer (1166 missing keys from checkpoint). Now uses MTPDraftModel.from_training_args which properly loads K2.5 layer 60 frozen decoder weights before loading the checkpoint trainable weights. Added --verifier-name-or-path arg (defaults to K2.5 local snapshot). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… large-scale data gen - Pre-filter rows with estimated input > seq_length (char/token ratio ~3.5) - Sample preferring has-response rows to minimize vLLM generation calls - Fix multimodal content (list instead of str) in text extraction Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…_hf_dataset_from_csv Left-truncating the prefix corrupts the generation context: K2.5's hidden states for response tokens differ from what they were at generation time. Right-truncating the response preserves the full prefix so the context seen by vLLM during dynamic training matches the original generation context. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Training command for Eagle3 dynamic hidden states training on export_3c6b3075_16k_hf (276K samples, 16K context). Pretrain weights: eagle3_v2_apilog/7. Loss: CE, ttt_steps=1. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…5K, max-checkpoints 3
…atch 163840 vs 32000)
Resolves 'NameError: name torch is not defined' when reloading Pydantic schemas
in vLLM environments where torch type annotations are not in global scope.
The fix:
- Import torch at top of config.py and pydantic_utils.py
- Pass _types_namespace={'torch': torch} to model_rebuild in reload_schema
This allows Pydantic to resolve torch type hints during schema validation,
fixing incompatibility with vLLM 0.20 and similar environments.
Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
Resolves AttributeError when PretrainedConfig tries to setattr ClassVar fields during initialization. Pydantic blocks setattr on ClassVar names, so they must be removed from kwargs before passing to PretrainedConfig.__init__. This fixes Eagle3SpeculatorConfig initialization failure: AttributeError: object has no attribute '__pydantic_fields_set__' Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
…init__ - Add explicit __pydantic_fields_set__ initialization before calling PretrainedConfig.__init__ - Prevents AttributeError when PretrainedConfig.__setattr__ triggers Pydantic's __setattr__ - Use object.__setattr__ for transformers_version to avoid Pydantic interception - Harden Eagle3 transformer config validator to mirror MTP pattern - Return early for PretrainedConfig objects - Raise TypeError for unsupported types instead of silently accepting
…retrainedConfig - Use BaseModel.__init__ directly instead of PydanticClassRegistryMixin.__init__ - This ensures __pydantic_fields_set__ is created before PretrainedConfig can trigger __setattr__ - Simplifies the approach: directly initialize Pydantic state that PretrainedConfig needs
…etrainedConfig.__init__ - Set __pydantic_fields_set__, __pydantic_extra__, __pydantic_private__ manually - This prevents AttributeError when PretrainedConfig's setattr triggers Pydantic's __setattr__ - The attributes must exist before BaseModel.__init__ attempts to use them
…l attribute setting - Bypass PretrainedConfig.__init__ which calls setattr() and triggers Pydantic's __setattr__ - Use object.__setattr__ to directly set attributes without Pydantic interception - This avoids the MRO conflict entirely by not calling PretrainedConfig.__init__
…nit__ calls - Override __new__ to create instance and set Pydantic internal state first - This ensures __pydantic_fields_set__ exists before any __init__ chain runs - PretrainedConfig.__init__ will still be called by MRO but __pydantic_fields_set__ is already ready
…ers > 0 - These DataLoader options are only valid with multiprocessing (num_workers > 0) - Conditionally add them to avoid ValueError when running single-threaded
…idate classmethod
Author
|
@codex review |
…ization Addresses review findings on the Pydantic/PretrainedConfig default-materialization path (the __new__ fast-path that pre-seeds __pydantic_fields_set__ skips Pydantic's own default application, so unset fields leak their class-level FieldInfo into to_dict/to_diff_dict and crash save_pretrained): - to_dict no longer mutates instance state. The old _materialize_field_defaults added resolved names to __pydantic_fields_set__; that set is Pydantic semantic state (fields explicitly supplied), and mutating it during serialization corrupts later model_dump(exclude_unset=True)/diff behavior. A field resolved to its default is by definition NOT explicitly set, so it is no longer added. - Use FieldInfo.get_default(call_default_factory=True) instead of hand-resolving default/default_factory. This gives Pydantic's proper per-instance (deep-)copy of mutable/factory defaults, so two configs no longer share the same default transformer_layer_config object. - A leaked field with no default (PydanticUndefined) now raises instead of being silently set to None. That state means a required field was never set (the fast-path skipped its validation); fabricating None would serialize an invalid config. Surface it instead. - Collapse the triplicated materialization logic (model_post_init + _materialize_field_defaults + to_dict) into the single _materialize_field_defaults helper; model_post_init and to_dict both call it. - validate(): drop the no-op __class_validators__ loop (a Pydantic v1-era pattern that does nothing under v2) and instead walk the MRO for a real PretrainedConfig.validate INSTANCE method, dispatching to it if present so HF's own config validation is not silently shadowed. Still neutralizes the MRO collision with Pydantic's deprecated BaseModel.validate classmethod (no such HF instance method exists in transformers 5.x, so it stays a safe no-op there). Tests: add coverage that to_dict is side-effect-free wrt __pydantic_fields_set__ and idempotent, that factory defaults are per-instance (not shared), and that materialization raises on an unset required field.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR brings the
novitalabsfork line (Kimi K2.x speculative-decoding work) ontomain. The central, deliberate divergence from upstream is that we retain the in-processVllmHiddenStatesGeneratorinstead of moving to the disk-only offline.safetensorspipeline — because our training topology streams hidden states online, cross-node, over RDMA (Mooncake) rather than through a shared disk.99 commits, 83 files (+11.5k / −1.4k). It is intended to be merged as the fork's
develop-equivalent baseline, not as an upstream contribution.Why we keep
vllm_hidden_states_generator.py(RDMA path)Upstream removed the in-process generator (PR vllm-project#433 /
8fdee2d, inv0.5.0) in favor of a standalone vLLM server that writes per-sample.safetensorsto disk, with the trainer reading them back offline. That is a sound choice for offline batch generation, but it has no in-process hook on the hidden states, which is exactly what our online cross-node transport needs.Our topology (e.g. Kimi K2.6, ~1T MoE verifier + Eagle3-MLA draft):
VllmHiddenStatesGenerator.safetensorsis never the primary transportThe cost is that the generator couples to private
vllm.v1.*internals and is hand-ported per vLLM release (e.g.eos_token_iddropped for 0.17.1; the current form targets vLLM 0.20 withblock_hasher). This trade-off and the upstream-vs-fork comparison are documented indocs/online_datagen_rdma.mdand the README "Fork Notes" section.What's included
Online / dynamic training data generation
data_generation/vllm_hidden_states_generator.py(vLLM 0.20 compat:VLLM_WORKER_MULTIPROC_METHOD=spawn, v1 scheduler/executor loop,block_hasher); Kimi tokenizer + vLLMRequestAPI support.gen_hf_dataset_from_csv, apilog/ZClawBench dataset prep, eval/acceptance scripts.Kimi K2 / MTP + Eagle3-MLA model support
models/mtp/(config + core),models/eagle3/MLA attention/core/model_definitions, K2 MTP reference config underscripts/k2_mtp_config/.recover_d2t_and_convert_checkpoint), d2t/vocab-mapping fixes for lightseek weights.Config / Pydantic robustness (the fixes that make
Eagle3SpeculatorConfigload)__pydantic_fields_set__/__new__pre-init chain resolving the Pydantic ⇄transformers.PretrainedConfigMRO conflict.model_rebuild; ClassVar stripping; deferred-ref schema rebuild;save_pretrained/serialization hardening;num_workers-gated prefetch settings; resume that skips incomplete checkpoint dirs.Docs
docs/online_datagen_rdma.md— full rationale for the in-process generator + RDMA transport, with a when-to-use-which table.Notes for reviewers
main; it is not meant to be upstreamed as-is.__new__/__pydantic_fields_set__chain intact when rebasing.