Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions deepseek_dequantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor.modeling.moe.linearize import linearize_moe_model

model = AutoModelForCausalLM.from_pretrained(
"deepseek-ai/DeepSeek-V4-Flash",
torch_dtype="auto",
device_map="cpu",
)
delattr(model, "_weight_conversions")
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V4-Flash")

save_dir = "DeepSeek-V4-Flash-bf16"
#model.dequantize(torch.bfloat16)
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
145 changes: 145 additions & 0 deletions examples/quantizing_moe/deepseek_v4_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from compressed_tensors.offload import load_offloaded_model
from compressed_tensors.quantization.quant_scheme import (
FP8_BLOCK,
NVFP4,
QuantizationScheme,
)
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modeling.moe.linearize import linearize_moe_model
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils.dev import skip_weights_download

# Select model and load it.
MODEL_ID = "RedHatAI/DeepSeek-V4-Flash-BF16"

with load_offloaded_model():
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
# device_map="auto_offload",
device_map="cpu",
# max_memory={"cpu": 3e10},
# offload_folder="offload_folder",
)
# from transformers.core_model_loading import revert_weight_conversion
# from compressed_tensors.offload import disable_onloading
# with disable_onloading():
# new_state_dict = revert_weight_conversion(model, model.state_dict())
# print(new_state_dict.keys())
# exit(0)

linearize_moe_model(model)

# kluge for the way I saved the decompressed checkpoint
# mds = model.model.layers[-1].self_attn.wq_a._hf_hook.weights_map.dataset.index
# mds["model.hc_head.base"] = mds['model.hc_head.hc_base']
# mds["model.hc_head.fn"] = mds['model.hc_head.hc_fn']
# mds["model.hc_head.scale"] = mds['model.hc_head.hc_scale']

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 64 # 1024
MAX_SEQUENCE_LENGTH = 512

# Load dataset and preprocess.
ds = load_dataset(
DATASET_ID,
split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]", # get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES)
)
ds = ds.shuffle(seed=42)


def preprocess(example):
# DeepSeek-V4 does not have a traditional chat template.
# Encode manually per https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash/tree/main/encoding
BOS = "<|begin▁of▁sentence|>"
EOS = "<|end▁of▁sentence|>"
text = BOS
for message in example["messages"]:
role = message["role"]
content = message["content"]
if role == "system":
text += content
elif role == "user":
text += f"<|User|>{content}"
elif role == "assistant":
text += f"<|Assistant|></think>{content}{EOS}"

return {"text": text}


ds = ds.map(preprocess)


def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)

# Configure the quantization algorithm to run.
# * quantize mlp/expert weights to NVFP4
# * quantize attention projection weights to FP8_BLOCK
# model.model.layers.0.self_attn.q_a_proj
#
# wq_a | q_a_proj
# wq_b | q_b_proj
# wkv | kv_proj
# wo_a | o_a_proj
# wo_b | o_b_proj

recipe = QuantizationModifier(
config_groups={
"attention": QuantizationScheme(
targets=[
r"re:.*attn\.(q_a_proj|q_b_proj|kv_proj|o_a_proj|o_b_proj)$",
r"re:.*attn\.compressor\.indexer\.q_b_proj$",
],
**FP8_BLOCK,
),
"experts": QuantizationScheme(
targets=[
r"re:.*mlp\.experts.*(gate|up|down)_proj$",
r"re:.*mlp\.shared_experts.*(gate|up|down)_proj$",
],
**NVFP4,
),
},
ignore=[],
)
# model.layers.4.self_attn.compressor.indexer.weights_proj
# model.layers.3.ffn_hc

# Apply algorithms.
# due to the large size of DeepSeek-V4, we specify sequential targets such that
# only one block is loaded into GPU memory at a time
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
sequential_targets=["DeepseekV4DecoderLayer"],
batch_size=1,
shuffle_calibration_samples=True,
)

# Save to disk compressed.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4-FP8-BLOCK"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
183 changes: 183 additions & 0 deletions fix_checkpoint_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""Resave a DeepSeek-V4-Flash NVFP4 checkpoint with key names matching the BF16
checkpoint structure. Quantization parameter suffixes (weight_packed, weight_scale,
input_global_scale, weight_global_scale) are preserved; only prefixes and module
names are changed."""

import argparse
import json
import re
import shutil
from pathlib import Path

from safetensors import safe_open
from safetensors.torch import save_file


def rename_key(key: str) -> str:
if key == "head.weight":
return key

if key.startswith("model."):
key = key[len("model."):]

top_level = {
"embed_tokens.weight": "embed.weight",
"norm.weight": "norm.weight",
"hc_head.hc_base": "hc_head_base",
"hc_head.hc_fn": "hc_head_fn",
"hc_head.hc_scale": "hc_head_scale",
}
if key in top_level:
return top_level[key]

m = re.match(r"(layers\.\d+\.)(.*)", key)
if not m:
raise ValueError(f"Unrecognized key: {key}")

prefix = m.group(1)
rest = m.group(2)

# --- layer norms ---
if rest == "input_layernorm.weight":
return prefix + "attn_norm.weight"
if rest == "post_attention_layernorm.weight":
return prefix + "ffn_norm.weight"

# --- hardware counters ---
hc_map = {
"attn_hc.base": "hc_attn_base",
"attn_hc.fn": "hc_attn_fn",
"attn_hc.scale": "hc_attn_scale",
"ffn_hc.base": "hc_ffn_base",
"ffn_hc.fn": "hc_ffn_fn",
"ffn_hc.scale": "hc_ffn_scale",
}
if rest in hc_map:
return prefix + hc_map[rest]

# --- compressor.indexer (most specific first) ---
ci_exact = {
"self_attn.compressor.indexer.gate_proj.weight": "attn.indexer.compressor.wgate.weight",
"self_attn.compressor.indexer.kv_norm.weight": "attn.indexer.compressor.norm.weight",
"self_attn.compressor.indexer.kv_proj.weight": "attn.indexer.compressor.wkv.weight",
"self_attn.compressor.indexer.position_bias": "attn.indexer.compressor.ape",
"self_attn.compressor.indexer.weights_proj.weight": "attn.indexer.weights_proj.weight",
}
if rest in ci_exact:
return prefix + ci_exact[rest]
m2 = re.match(r"self_attn\.compressor\.indexer\.q_b_proj\.(.*)", rest)
if m2:
return prefix + "attn.indexer.wq_b." + m2.group(1)

# --- compressor (without indexer) ---
c_exact = {
"self_attn.compressor.gate_proj.weight": "attn.compressor.wgate.weight",
"self_attn.compressor.kv_norm.weight": "attn.compressor.norm.weight",
"self_attn.compressor.kv_proj.weight": "attn.compressor.wkv.weight",
"self_attn.compressor.position_bias": "attn.compressor.ape",
}
if rest in c_exact:
return prefix + c_exact[rest]

# --- self-attention (exact matches) ---
attn_exact = {
"self_attn.sinks": "attn.attn_sink",
"self_attn.kv_norm.weight": "attn.kv_norm.weight",
"self_attn.q_a_norm.weight": "attn.q_norm.weight",
}
if rest in attn_exact:
return prefix + attn_exact[rest]

# --- self-attention projections (with possible quant suffixes) ---
attn_proj_map = {
"self_attn.kv_proj": "attn.wkv",
"self_attn.o_a_proj": "attn.wo_a",
"self_attn.o_b_proj": "attn.wo_b",
"self_attn.q_a_proj": "attn.wq_a",
"self_attn.q_b_proj": "attn.wq_b",
}
for old, new in attn_proj_map.items():
m2 = re.match(rf"{re.escape(old)}\.(.*)", rest)
if m2:
return prefix + new + "." + m2.group(1)

# --- MLP gate ---
gate_map = {
"mlp.gate.weight": "ffn.gate.weight",
"mlp.gate.tid2eid": "ffn.gate.tid2eid",
"mlp.gate.e_score_correction_bias": "ffn.gate.bias",
}
if rest in gate_map:
return prefix + gate_map[rest]

# --- MLP experts ---
proj_map = {"gate_proj": "w1", "down_proj": "w2", "up_proj": "w3"}
m2 = re.match(r"mlp\.experts\.(\d+)\.(gate_proj|down_proj|up_proj)\.(.*)", rest)
if m2:
eid, proj, suffix = m2.group(1), m2.group(2), m2.group(3)
return prefix + f"ffn.experts.{eid}.{proj_map[proj]}.{suffix}"

# --- MLP shared experts ---
m2 = re.match(r"mlp\.shared_experts\.(gate_proj|down_proj|up_proj)\.(.*)", rest)
if m2:
proj, suffix = m2.group(1), m2.group(2)
return prefix + f"ffn.shared_experts.{proj_map[proj]}.{suffix}"

raise ValueError(f"Unrecognized key: layers.*.{rest}")


def main():
parser = argparse.ArgumentParser(
description="Resave NVFP4 checkpoint with BF16-style key names"
)
parser.add_argument("input_dir", type=Path)
parser.add_argument("output_dir", type=Path)
args = parser.parse_args()

args.output_dir.mkdir(parents=True, exist_ok=True)

index_path = args.input_dir / "model.safetensors.index.json"
with open(index_path) as f:
index = json.load(f)

shard_names = sorted(set(index["weight_map"].values()))

new_weight_map = {}
for old_key, shard_name in index["weight_map"].items():
new_weight_map[rename_key(old_key)] = shard_name

for i, shard_name in enumerate(shard_names):
src = args.input_dir / shard_name
dst = args.output_dir / shard_name
print(f"[{i + 1}/{len(shard_names)}] Processing {shard_name} ...")

tensors = {}
with safe_open(str(src), framework="pt") as f:
for key in f.keys():
tensors[rename_key(key)] = f.get_tensor(key)

save_file(tensors, str(dst))
del tensors
print(f" Saved {dst}")

new_index = {
"metadata": index.get("metadata", {}),
"weight_map": new_weight_map,
}
out_index = args.output_dir / "model.safetensors.index.json"
with open(out_index, "w") as f:
json.dump(new_index, f, indent=2, sort_keys=False)
print(f"Saved {out_index}")

for name in ("config.json", "generation_config.json",
"tokenizer.json", "tokenizer_config.json"):
src = args.input_dir / name
if src.exists():
shutil.copy2(src, args.output_dir / name)
print(f"Copied {name}")

print("Done.")


if __name__ == "__main__":
main()
6 changes: 1 addition & 5 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from llmcompressor.core.session_functions import active_session
from llmcompressor.datasets import get_calibration_dataloader
from llmcompressor.entrypoints.utils import post_process, pre_process
from llmcompressor.modeling.moe_context import moe_calibration_context
from llmcompressor.modeling.offset_norm import norm_calibration_context
from llmcompressor.pipelines import CalibrationPipeline

Expand Down Expand Up @@ -219,10 +218,7 @@ def apply_recipe_modifiers(

# (Helen INFERENG-661): validate recipe modifiers before initialization
# Apply calibration contexts for the entire calibration process
with norm_calibration_context(self.model), moe_calibration_context(
self.model,
calibrate_all_experts=self.dataset_args.moe_calibrate_all_experts,
):
with norm_calibration_context(self.model):
session.initialize(
model=self.model,
start=-1,
Expand Down
11 changes: 0 additions & 11 deletions src/llmcompressor/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,7 @@
"""

# trigger registration
from .afmoe import CalibrationAfmoeMoE # noqa: F401
from .deepseek_v3 import CalibrationDeepseekV3MoE # noqa: F401
from .glm4_moe import CalibrationGlm4MoeMoE # noqa: F401
from .glm4_moe_lite import CalibrationGlm4MoeLiteMoE # noqa: F401
from .glm_moe_dsa import CalibrationGlmMoeDsaMoE # noqa: F401
from .llama4 import SequentialLlama4TextMoe # noqa: F401
from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401
from .qwen3_5_moe import CalibrationQwen3_5MoeSparseMoeBlock
from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401
from .qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock # noqa: F401
from .offset_norm import CalibrationOffsetNorm # noqa: F401
from .gemma4 import SequentialGemma4TextExperts # noqa: F401
# TODO: add granite4

from .fuse import *
Loading
Loading