- 
                Notifications
    You must be signed in to change notification settings 
- Fork 271
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
⚙️ Your current environment
The output of python collect_env.py
### Environment Information ###
Operating System: `Linux-6.16.8-arch3-1-x86_64-with-glibc2.42`
Python Version: `3.12.9 (main, Mar 17 2025, 21:01:58) [Clang 20.1.0 ]`
llm-compressor Version: `0.7.2a20250825`
compressed-tensors Version: `0.11.1a20250912`
transformers Version: `4.56.2`
torch Version: `2.8.0+cu129`
CUDA Devices: `['NVIDIA RTX PRO 6000 Blackwell Workstation Edition', 'NVIDIA RTX PRO 6000 Blackwell Workstation Edition']`
AMD Devices: `None`
🐛 Describe the bug
I am trying to quantized the KV cache of ByteDance-Seed/Seed-OSS-36B-Instruct which can have up to a 512K context (128GB in FP16).
However I get the following error
 
Full text error
Traceback (most recent call last):
  File "[...]/.venv/lib/python3.12/site-packages/llmcompressor/pipelines/sequential/helpers.py", line 73, in forward
    outputs = forward_fn(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 19, in forward
  File "[...]/.venv/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/transformers/models/seed_oss/modeling_seed_oss.py", line 259, in forward
    hidden_states, _ = self.self_attn(
                       ^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
           ^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1840, in inner
    hook_result = hook(self, args, result)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/llmcompressor/modifiers/utils/hooks.py", line 93, in wrapped_hook
    return hook(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/llmcompressor/modifiers/quantization/calibration.py", line 260, in calibrate_kv_cache_output_hook
    k_scale = kv_cache.k_scales[module.layer_idx]
              ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
IndexError: list index out of range
Source:
llm-compressor/src/llmcompressor/modifiers/quantization/calibration.py
Lines 255 to 263 in 6304ecf
| def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor): | |
| """ | |
| Hook to update k_scale and v_scale parameters when running kv_cache quantization. | |
| """ | |
| kv_cache = getattr(module, "kv_cache") | |
| k_scale = kv_cache.k_scales[module.layer_idx] | |
| v_scale = kv_cache.v_scales[module.layer_idx] | |
| update_offload_parameter(module, KVCacheScaleType.KEY.value, k_scale) | |
| update_offload_parameter(module, KVCacheScaleType.VALUE.value, v_scale) | 
The error is quite similar to #1295.
This is the quantization script I'm using
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.utils import dispatch_for_generation
from llmcompressor.modifiers.quantization import QuantizationModifier
from compressed_tensors.quantization import (
    QuantizationArgs,
    QuantizationScheme,
    QuantizationStrategy,
    QuantizationType,
)
CALIBRATION_DATASET="HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT="train_sft"
SHUFFLE_SEED=42
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 4096
MODEL_ID = "ByteDance-Seed/Seed-OSS-36B-Instruct"
MODEL_OUT = MODEL_ID.split("/")[1] + "-FP8-KV"
# Load model.
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model.generation_config.do_sample=True
# Dataset processing
ds = load_dataset(CALIBRATION_DATASET, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)
def process_and_tokenize(example):
    text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
    return tokenizer(text, padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
ds = ds.map(process_and_tokenize, remove_columns=ds.column_names)
recipe = [
    QuantizationModifier(
        ignore=["lm_head"],
        # DeepSeek V3 style block quantization + dynamic per token quantization
        config_groups={
            "group_0": QuantizationScheme(
                targets=["Linear"],
                weights=QuantizationArgs(
                    num_bits=8,
                    type=QuantizationType.FLOAT,
                    dynamic=False,
                    symmetric=True,
                    strategy=QuantizationStrategy.BLOCK,
                    block_structure=[128, 128],
                ),
                input_activations=QuantizationArgs(
                    num_bits=8,
                    type=QuantizationType.FLOAT,
                    strategy=QuantizationStrategy.GROUP,
                    symmetric=True,
                    dynamic=True,
                    observer=None,
                    group_size=128,
                ),
            ),
        },
        kv_cache_scheme=QuantizationArgs(
            num_bits=8,
            type=QuantizationType.FLOAT,
            dynamic=False,
            symmetric=True,
            strategy=QuantizationStrategy.TENSOR,
        ),
    )
]
oneshot(
    # pipeline="basic",
    model=model,
    recipe=recipe,
    dataset=ds,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Save to disk in compressed-tensors format.
model.save_pretrained(MODEL_OUT, save_compressed=True)
tokenizer.save_pretrained(MODEL_OUT)
print(f'SUCCESS: files saved in {MODEL_OUT}')
# Testing
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
    model.device
)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")And my pyproject.toml for uv run
[project]
name = "quantizers"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
    "compressed-tensors>=0.11.0",
    "llmcompressor >= 0.7.1",
    "torch >= 2.8.0",
    "transformers >= 4.56.2",
    "zstandard>=0.23.0",
    "mistral-common >= 1.6.2",
    "huggingface_hub",
    "accelerate",
    "fla-core",
    "flash-linear-attention",
    "causal-conv1d",
]
[tool.uv.sources]
torch = [{ index = "pytorch-cu129"}]
torchaudio = [{ index = "pytorch-cu129"}]
torchvision = [{ index = "pytorch-cu129"}]
# uv pip install -U --prerelease allow torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129
[[tool.uv.index]]
name = "pytorch-cu129"
url = "https://download.pytorch.org/whl/cu129"
explicit = true🛠️ Steps to reproduce
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working