Skip to content

Conversation

@xuechendi
Copy link

@xuechendi xuechendi commented Apr 6, 2025

Move to the new PR:
#1030
HabanaAI/vllm-hpu-extension#137

==============

Need to use vllm-hpu-extension: https://github.com/HabanaAI/vllm-hpu-extension/tree/dev/chendi/deepseek_r1

Status:

  1. runnable with Deepseek-R1.
  2. Accuracy check: for block fp8 weight => garbage output
  3. accuracy check for BF16 weight => looks good.

TODO:

  1. fix acc for FP8
  2. work with INC team to do new calibration, so we can test with static fp8 performance.

Question:

  1. should we enable dynamic quant fp8 in vllm or INC?

@xuechendi
Copy link
Author

@yiliu30 , please help to review this PR for INC path

  1. I removed both "VLLM_REQUANT_FP8_INC=1" and "VLLM_ENABLE_RUNTIME_DEQUANT=1" and treat as True
    Test script as below
rom vllm import LLM, SamplingParams
import os

os.environ['VLLM_SKIP_WARMUP'] = 'true'
os.environ['PT_HPU_LAZY_MODE'] = '1'
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES']='true'
os.environ['PT_HPU_WEIGHT_SHARING']='0'
#os.environ['HABANA_LOGS']="vllm_inc_debug"
#os.environ["LOG_LEVEL_ALL"]="3"
os.environ['VLLM_MLA_DISABLE_REQUANTIZATION']='1'
#os.environ["QUANT_CONFIG"] = "inc_quant_with_fp8kv_config.json"
#os.environ["LOGLEVEL"] = "DEBUG"

 
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
 
if __name__ == "__main__":
    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.0, max_tokens=16, ignore_eos=True)
 
    # Create an LLM.
    model_path = "/data/models/DeepSeek-R1"
 
    llm = LLM(model=model_path,
            trust_remote_code=True,
            enforce_eager=True,
            dtype="bfloat16",
            use_v2_block_manager=True,
            max_model_len=16384,
            tensor_parallel_size=8,
            distributed_executor_backend='mp',
            gpu_memory_utilization=0.8,
            #kv_cache_dtype="fp8_inc",
            seed=2024)
 
    # Generate texts from the prompts. The output is a list of RequestOutput objects
    # that contain the prompt, generated text, and other information.
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

assert topk_group is None, 'topk_group is not supported on HPU'
if layer is not None:
return layer.hpu_fused_moe(x, router_logits, top_k)
if use_grouped_topk or custom_routing_function is not None:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original bf16 fused_moe doesn't support group_topk and customer_routing_function.
Fix it here.


self.top_k = top_k
self.num_experts = num_experts
self.num_experts = num_experts // self.ep_size
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All above is temporary fix because habana-main is too lagging behind and missing important commits from upstream for Expert parallelism

else self.num_experts // 64
num_expert_per_group = self.num_experts // moe_n_slice
experts_min, experts_max = ep_shift, self.num_experts + ep_shift - 1
if quant_config is not None:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiliu30 , I removed VLLM_REQUANT_FP8_INC check and replaced with if this func is fp8 or unquantized

x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
raise NotImplementedError

# Chendi: Necessary base func added by INC team
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiliu30 , I kept this function here, but I didn't find where it is being used, please check if you want to keep this or not?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
if current_platform.is_hpu():
layer = hpu_ops.fp8_block_moe_prepare_weights(layer)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiliu30 , I wrapped your original set_weight funcs to vllm-hpu-extension and call it here.
BTW, I saw you put it at create_weight instead of process_weights_after_loading, I think we should put it here since we want to set final loaded weight, right? Please check if this change makes sense to you.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, we should put it at process_weights_after_loading, thanks!

use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix a rebase error made in habana_main.

else:
layer.register_parameter("input_scale", None)

def dequant_block_fp8_weight(self, layer) -> torch.Tensor:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is added by INC team, @yiliu30 , please check if it is necessary to keep it in the fp8,py?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
layer.register_parameter("input_scale", None)

def dequant_block_fp8_weight(self, layer) -> torch.Tensor:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
if current_platform.is_hpu():
layer = hpu_ops.fp8_block_moe_prepare_weights(layer)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, we should put it at process_weights_after_loading, thanks!

@yiliu30
Copy link

yiliu30 commented Apr 7, 2025

@yiliu30 , please help to review this PR for INC path

  1. I removed both "VLLM_REQUANT_FP8_INC=1" and "VLLM_ENABLE_RUNTIME_DEQUANT=1" and treat as True
    Test script as below
rom vllm import LLM, SamplingParams
import os

os.environ['VLLM_SKIP_WARMUP'] = 'true'
os.environ['PT_HPU_LAZY_MODE'] = '1'
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES']='true'
os.environ['PT_HPU_WEIGHT_SHARING']='0'
#os.environ['HABANA_LOGS']="vllm_inc_debug"
#os.environ["LOG_LEVEL_ALL"]="3"
os.environ['VLLM_MLA_DISABLE_REQUANTIZATION']='1'
#os.environ["QUANT_CONFIG"] = "inc_quant_with_fp8kv_config.json"
#os.environ["LOGLEVEL"] = "DEBUG"

 
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
 
if __name__ == "__main__":
    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.0, max_tokens=16, ignore_eos=True)
 
    # Create an LLM.
    model_path = "/data/models/DeepSeek-R1"
 
    llm = LLM(model=model_path,
            trust_remote_code=True,
            enforce_eager=True,
            dtype="bfloat16",
            use_v2_block_manager=True,
            max_model_len=16384,
            tensor_parallel_size=8,
            distributed_executor_backend='mp',
            gpu_memory_utilization=0.8,
            #kv_cache_dtype="fp8_inc",
            seed=2024)
 
    # Generate texts from the prompts. The output is a list of RequestOutput objects
    # that contain the prompt, generated text, and other information.
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

We need call shutdown to save the calibration results.

from vllm import LLM, SamplingParams
import os

os.environ['VLLM_SKIP_WARMUP'] = 'true'
os.environ['PT_HPU_LAZY_MODE'] = '1'
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES']='true'
os.environ['PT_HPU_WEIGHT_SHARING']='0'
#os.environ['HABANA_LOGS']="vllm_inc_debug"
#os.environ["LOG_LEVEL_ALL"]="3"
os.environ['VLLM_MLA_DISABLE_REQUANTIZATION']='1'
#os.environ["QUANT_CONFIG"] = "inc_quant_with_fp8kv_config.json"
#os.environ["LOGLEVEL"] = "DEBUG"

 
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
 
if __name__ == "__main__":
    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.0, max_tokens=16, ignore_eos=True)
 
    # Create an LLM.
    model_path = "/data/models/DeepSeek-R1"
 
    llm = LLM(model=model_path,
            trust_remote_code=True,
            enforce_eager=True,
            dtype="bfloat16",
            use_v2_block_manager=True,
            max_model_len=1024,
            max_num_seqs=1,
            tensor_parallel_size=8,
            distributed_executor_backend='mp',
            gpu_memory_utilization=0.8,
            #kv_cache_dtype="fp8_inc",
            seed=2024)
 
    # Generate texts from the prompts. The output is a list of RequestOutput objects
    # that contain the prompt, generated text, and other information.
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    if os.environ.get("QUANT_CONFIG", None) is not None:
        llm.llm_engine.model_executor.shutdown()

self.batch2block_matmul = Matmul()
self.block2batch_matmul = Matmul()
self.latent_cache_k = VLLMKVCache()
self.latent_cache_v = VLLMKVCache()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove latent_cache_v, since it wasn't used?

Switched execution of versioned branches to _next and added logs
redirection to file.
Copy link

@jikunshang jikunshang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

return self.o_proj(x)[0]

def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this since upstream also removed?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We noticed accuracy drop with VLLM_MLA_PERFORM_MATRIX_ABSORPTION=0, let's keep the option to enable it before the accuracy issues is solved.

@staticmethod
def forward_decode(**kwargs) -> torch.Tensor:
if kwargs.get("kv_lora_rank", None):
return ops.flat_pa_mla(**kwargs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have update this in vllm-hpu-extension? if so, no need to update its version?

@kzawora-intel
Copy link

/run-gaudi-tests

Signed-off-by: Chendi.Xue <[email protected]>
@xuechendi xuechendi force-pushed the dev/chendi/deepseek_r1 branch 2 times, most recently from c695238 to b147f2e Compare April 8, 2025 18:43
Signed-off-by: Chendi.Xue <[email protected]>
Signed-off-by: Chendi.Xue <[email protected]>
Signed-off-by: Chendi.Xue <[email protected]>
@xuechendi xuechendi force-pushed the dev/chendi/deepseek_r1 branch from b147f2e to 85f0693 Compare April 8, 2025 22:03
bmyrcha added 2 commits April 9, 2025 11:21
Adjusted method of extracting synapse build id for release branches
Comment on lines +322 to 338
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
# Chendi: This is a cherry-pick of missing commit from upstream
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
.view(-1, self.num_heads, self.qk_head_dim)

# can be W_Q or W_UQ depending q_lora_rank, the former if
# q_lora_rank is None, the latter otherwise. From the Attention
# backend perspective though we call these both W_Q and rely on
# the layer to pass in the correct matrix
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous()

# W_QR is small so for simplicity we dont bother requantizing it
self.W_QR = self.W_QR.to(act_dtype)
# W_QR is small so for simplicity we dont bother requantizing it
self.W_QR = self.W_QR.to(act_dtype)

if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those two if statements should be merged into one.

jmaksymc and others added 9 commits April 10, 2025 14:13
This PR implements HPU support for pipeline parallelism. Tested accuracy
and it's the same as TP accuracy on:
- Llama3.1-70b-Instruct
- Llama3.2-3b-Instruct
- Mixtral-8x7b

To serve with PP:
`VLLM_DECODE_BS_BUCKET_MIN=384 VLLM_DECODE_BLOCK_BUCKET_MAX=896 vllm
serve /mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-70B-Instruct/
--tensor-parallel-size 1 --pipeline-parallel-size 4 --max-num-seqs 384
--disable-log-requests --dtype bfloat16 --gpu-memory-util 0.9
--disable-log-stats --num_scheduler_steps 1 --max-num-batched-tokens
2048 --max-model-len 256 --block-size 128`

Known issues:
* since for Pipeline Parallelism max_num_seqs acts as a microbatch for a
single virtual_engine - for bigger batch_size we fall into a very
specific corner case and get flat_pa error -> set batch_size to
approximately batch size that you would use in TP but divided by pp_size
* delayed sampling is not yet compatible with pipeline parallelism
* virtaul_engine ID is passed to HPUGraph which results in pp_size *
amount of graphs

Signed-off-by: jmaksymczuk <[email protected]>
Co-authored-by: Rafal Litka <[email protected]>
Co-authored-by: Michał Kuligowski <[email protected]>
Cherry-pick of #921

Co-authored-by: Konrad Zawora <[email protected]>
Co-authored-by: Michał Kuligowski <[email protected]>
Signed-off-by: kwisniewski98 <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
Co-authored-by: Iryna Boiko <[email protected]>
Co-authored-by: Michał Kuligowski <[email protected]>
@jikunshang
Copy link

I try to run this script. got such error. vllm-hpu-extension version not match or anything else I need to change?

ImportError: cannot import name 'VllmMixtureOfExpertsOpFP8' from 'vllm_hpu_extension.ops' (/usr/local/lib/python3.10/dist-packages/vllm_hpu_extension/ops.py)

@yiliu30 , please help to review this PR for INC path

  1. I removed both "VLLM_REQUANT_FP8_INC=1" and "VLLM_ENABLE_RUNTIME_DEQUANT=1" and treat as True
    Test script as below
rom vllm import LLM, SamplingParams
import os

os.environ['VLLM_SKIP_WARMUP'] = 'true'
os.environ['PT_HPU_LAZY_MODE'] = '1'
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES']='true'
os.environ['PT_HPU_WEIGHT_SHARING']='0'
#os.environ['HABANA_LOGS']="vllm_inc_debug"
#os.environ["LOG_LEVEL_ALL"]="3"
os.environ['VLLM_MLA_DISABLE_REQUANTIZATION']='1'
#os.environ["QUANT_CONFIG"] = "inc_quant_with_fp8kv_config.json"
#os.environ["LOGLEVEL"] = "DEBUG"

 
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
 
if __name__ == "__main__":
    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.0, max_tokens=16, ignore_eos=True)
 
    # Create an LLM.
    model_path = "/data/models/DeepSeek-R1"
 
    llm = LLM(model=model_path,
            trust_remote_code=True,
            enforce_eager=True,
            dtype="bfloat16",
            use_v2_block_manager=True,
            max_model_len=16384,
            tensor_parallel_size=8,
            distributed_executor_backend='mp',
            gpu_memory_utilization=0.8,
            #kv_cache_dtype="fp8_inc",
            seed=2024)
 
    # Generate texts from the prompts. The output is a list of RequestOutput objects
    # that contain the prompt, generated text, and other information.
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

michalkuligowski pushed a commit that referenced this pull request Apr 15, 2025
migrated from a PR to habana_main:
#1014

For Best performance, this PR is recommended to run with INC:
[[SW-223553] [VLLM] Merge deepseek changes into habana_main - Habana
Labs](https://jira.habana-labs.com/browse/SW-223553)

**test acc of G3**:
```bash
huggingface-cli download Yi30/inc-woq-default-pile-one-cache-408  --local-dir ./scripts/nc_workspace_measure_kvache

cat inc_quant_with_fp8kv_config.json
{
    "mode": "QUANTIZE",
    "observer": "maxabs",
    "scale_method": "maxabs_hw",
    "scale_format": "const",
    "allowlist": {
        "types": [],
        "names": []
    },
    "blocklist": {
        "types": [],
        "names": [
            "lm_head",
            "mlp\\.gate\\b",
            "block2batch_matmul"
        ]
    },
    "dump_stats_path": "./inc-woq-default-pile-one-cache-408-for-fp8-mla/inc_measure_output"
}


QUANT_CONFIG=inc_quant_with_fp8kv_config.json \
PT_HPU_LAZY_MODE=1 \
VLLM_SKIP_WARMUP=true \
PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
PT_HPU_WEIGHT_SHARING=0 \
VLLM_MLA_DISABLE_REQUANTIZATION=1 \
lm_eval --model vllm \
  --model_args "pretrained=/mnt/weka/data/pytorch/DeepSeek-R1/,tensor_parallel_size=8,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=4096,use_v2_block_manager=True,dtype=bfloat16,kv_cache_dtype=fp8_inc" \
  --tasks gsm8k --num_fewshot "5" --limit "256" \
  --batch_size "8"
```

**test acc of G2**:
**convert original DeepSeek-R1** using
[convert_for_g2.py](https://github.com/yangulei/vllm-fork/blob/deepseek_r1_g2/scripts/convert_for_g2.py)
(this step will be removed as INC updates.)

```bash

huggingface-cli download Yi30/inc-woq-default-pile-one-cache-412-g2  --local-dir ./scripts/nc_workspace_measure_kvache

cat inc_quant_with_fp8kv_config.json
{
    "mode": "QUANTIZE",
    "observer": "maxabs",
    "scale_method": "maxabs_hw",
    "scale_format": "const",
    "allowlist": {
        "types": [],
        "names": []
    },
    "blocklist": {
        "types": [],
        "names": [
            "lm_head",
            "mlp\\.gate\\b",
            "block2batch_matmul"
        ]
    },
    "dump_stats_path": "./nc_workspace_measure_kvache/inc_measure_output"
}
```


vllm
(pretrained=/mnt/weka/data/pytorch/DeepSeek-R1/,tensor_parallel_size=8,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=4096,use_v2_block_manager=True,dtype=bfloat16,kv_cache_dtype=fp8_inc),
gen_kwargs: (None), limit: 256.0, num_fewshot: 5, batch_size: 128
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|

|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9492|± |0.0137|
| | |strict-match | 5|exact_match|↑ |0.9453|± |0.0142|


----------
Need to use vllm-hpu-extension:
https://github.com/HabanaAI/vllm-hpu-extension/tree/dev/chendi/deepseek_r1

Status:

runnable with Deepseek-R1.
Accuracy check: for block fp8 weight => garbage output
accuracy check for BF16 weight => looks good.

test scripts:
```
from vllm import LLM, SamplingParams
import os

os.environ['VLLM_SKIP_WARMUP'] = 'true'
os.environ['PT_HPU_LAZY_MODE'] = '1'
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES']='true'
os.environ['PT_HPU_WEIGHT_SHARING']='0'
#os.environ['HABANA_LOGS']="vllm_inc_debug"
#os.environ["LOG_LEVEL_ALL"]="3"
os.environ['VLLM_MLA_DISABLE_REQUANTIZATION']='1'
#os.environ["QUANT_CONFIG"] = "inc_quant_with_fp8kv_config.json"
#os.environ["LOGLEVEL"] = "DEBUG"

 
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
 
if __name__ == "__main__":
    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.0, max_tokens=16, ignore_eos=True)
 
    # Create an LLM.
    model_path = "/data/models/DeepSeek-R1"
 
    llm = LLM(model=model_path,
            trust_remote_code=True,
            enforce_eager=True,
            dtype="bfloat16",
            use_v2_block_manager=True,
            max_model_len=1024,
            max_num_seqs=1,
            tensor_parallel_size=8,
            distributed_executor_backend='mp',
            gpu_memory_utilization=0.8,
            #kv_cache_dtype="fp8_inc",
            seed=2024)
 
    # Generate texts from the prompts. The output is a list of RequestOutput objects
    # that contain the prompt, generated text, and other information.
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    if os.environ.get("QUANT_CONFIG", None) is not None:
        llm.llm_engine.model_executor.shutdown()
```

---------

Signed-off-by: Chendi.Xue <[email protected]>
Signed-off-by: kwisniewski98 <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
Co-authored-by: kwisniewski98 <[email protected]>
@zhenwei-intel zhenwei-intel mentioned this pull request Apr 21, 2025
2 tasks
xuechendi added a commit that referenced this pull request Apr 24, 2025
migrated from a PR to habana_main:
#1014

For Best performance, this PR is recommended to run with INC:
[[SW-223553] [VLLM] Merge deepseek changes into habana_main - Habana
Labs](https://jira.habana-labs.com/browse/SW-223553)

**test acc of G3**:
```bash
huggingface-cli download Yi30/inc-woq-default-pile-one-cache-408  --local-dir ./scripts/nc_workspace_measure_kvache

cat inc_quant_with_fp8kv_config.json
{
    "mode": "QUANTIZE",
    "observer": "maxabs",
    "scale_method": "maxabs_hw",
    "scale_format": "const",
    "allowlist": {
        "types": [],
        "names": []
    },
    "blocklist": {
        "types": [],
        "names": [
            "lm_head",
            "mlp\\.gate\\b",
            "block2batch_matmul"
        ]
    },
    "dump_stats_path": "./inc-woq-default-pile-one-cache-408-for-fp8-mla/inc_measure_output"
}

QUANT_CONFIG=inc_quant_with_fp8kv_config.json \
PT_HPU_LAZY_MODE=1 \
VLLM_SKIP_WARMUP=true \
PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
PT_HPU_WEIGHT_SHARING=0 \
VLLM_MLA_DISABLE_REQUANTIZATION=1 \
lm_eval --model vllm \
  --model_args "pretrained=/mnt/weka/data/pytorch/DeepSeek-R1/,tensor_parallel_size=8,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=4096,use_v2_block_manager=True,dtype=bfloat16,kv_cache_dtype=fp8_inc" \
  --tasks gsm8k --num_fewshot "5" --limit "256" \
  --batch_size "8"
```

**test acc of G2**:
**convert original DeepSeek-R1** using
[convert_for_g2.py](https://github.com/yangulei/vllm-fork/blob/deepseek_r1_g2/scripts/convert_for_g2.py)
(this step will be removed as INC updates.)

```bash

huggingface-cli download Yi30/inc-woq-default-pile-one-cache-412-g2  --local-dir ./scripts/nc_workspace_measure_kvache

cat inc_quant_with_fp8kv_config.json
{
    "mode": "QUANTIZE",
    "observer": "maxabs",
    "scale_method": "maxabs_hw",
    "scale_format": "const",
    "allowlist": {
        "types": [],
        "names": []
    },
    "blocklist": {
        "types": [],
        "names": [
            "lm_head",
            "mlp\\.gate\\b",
            "block2batch_matmul"
        ]
    },
    "dump_stats_path": "./nc_workspace_measure_kvache/inc_measure_output"
}
```

vllm
(pretrained=/mnt/weka/data/pytorch/DeepSeek-R1/,tensor_parallel_size=8,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=4096,use_v2_block_manager=True,dtype=bfloat16,kv_cache_dtype=fp8_inc),
gen_kwargs: (None), limit: 256.0, num_fewshot: 5, batch_size: 128
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|

|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9492|± |0.0137|
| | |strict-match | 5|exact_match|↑ |0.9453|± |0.0142|

----------
Need to use vllm-hpu-extension:
https://github.com/HabanaAI/vllm-hpu-extension/tree/dev/chendi/deepseek_r1

Status:

runnable with Deepseek-R1.
Accuracy check: for block fp8 weight => garbage output
accuracy check for BF16 weight => looks good.

test scripts:
```
from vllm import LLM, SamplingParams
import os

os.environ['VLLM_SKIP_WARMUP'] = 'true'
os.environ['PT_HPU_LAZY_MODE'] = '1'
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES']='true'
os.environ['PT_HPU_WEIGHT_SHARING']='0'
os.environ['VLLM_MLA_DISABLE_REQUANTIZATION']='1'

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

if __name__ == "__main__":
    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.0, max_tokens=16, ignore_eos=True)

    # Create an LLM.
    model_path = "/data/models/DeepSeek-R1"

    llm = LLM(model=model_path,
            trust_remote_code=True,
            enforce_eager=True,
            dtype="bfloat16",
            use_v2_block_manager=True,
            max_model_len=1024,
            max_num_seqs=1,
            tensor_parallel_size=8,
            distributed_executor_backend='mp',
            gpu_memory_utilization=0.8,
            #kv_cache_dtype="fp8_inc",
            seed=2024)

    # Generate texts from the prompts. The output is a list of RequestOutput objects
    # that contain the prompt, generated text, and other information.
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    if os.environ.get("QUANT_CONFIG", None) is not None:
        llm.llm_engine.model_executor.shutdown()
```

---------

Signed-off-by: Chendi.Xue <[email protected]>
Signed-off-by: kwisniewski98 <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
Co-authored-by: kwisniewski98 <[email protected]>
xuechendi added a commit that referenced this pull request May 5, 2025
migrated from a PR to habana_main:
#1014

For Best performance, this PR is recommended to run with INC:
[[SW-223553] [VLLM] Merge deepseek changes into habana_main - Habana
Labs](https://jira.habana-labs.com/browse/SW-223553)

**test acc of G3**:
```bash
huggingface-cli download Yi30/inc-woq-default-pile-one-cache-408  --local-dir ./scripts/nc_workspace_measure_kvache

cat inc_quant_with_fp8kv_config.json
{
    "mode": "QUANTIZE",
    "observer": "maxabs",
    "scale_method": "maxabs_hw",
    "scale_format": "const",
    "allowlist": {
        "types": [],
        "names": []
    },
    "blocklist": {
        "types": [],
        "names": [
            "lm_head",
            "mlp\\.gate\\b",
            "block2batch_matmul"
        ]
    },
    "dump_stats_path": "./inc-woq-default-pile-one-cache-408-for-fp8-mla/inc_measure_output"
}

QUANT_CONFIG=inc_quant_with_fp8kv_config.json \
PT_HPU_LAZY_MODE=1 \
VLLM_SKIP_WARMUP=true \
PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
PT_HPU_WEIGHT_SHARING=0 \
VLLM_MLA_DISABLE_REQUANTIZATION=1 \
lm_eval --model vllm \
  --model_args "pretrained=/mnt/weka/data/pytorch/DeepSeek-R1/,tensor_parallel_size=8,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=4096,use_v2_block_manager=True,dtype=bfloat16,kv_cache_dtype=fp8_inc" \
  --tasks gsm8k --num_fewshot "5" --limit "256" \
  --batch_size "8"
```

**test acc of G2**:
**convert original DeepSeek-R1** using
[convert_for_g2.py](https://github.com/yangulei/vllm-fork/blob/deepseek_r1_g2/scripts/convert_for_g2.py)
(this step will be removed as INC updates.)

```bash

huggingface-cli download Yi30/inc-woq-default-pile-one-cache-412-g2  --local-dir ./scripts/nc_workspace_measure_kvache

cat inc_quant_with_fp8kv_config.json
{
    "mode": "QUANTIZE",
    "observer": "maxabs",
    "scale_method": "maxabs_hw",
    "scale_format": "const",
    "allowlist": {
        "types": [],
        "names": []
    },
    "blocklist": {
        "types": [],
        "names": [
            "lm_head",
            "mlp\\.gate\\b",
            "block2batch_matmul"
        ]
    },
    "dump_stats_path": "./nc_workspace_measure_kvache/inc_measure_output"
}
```

vllm
(pretrained=/mnt/weka/data/pytorch/DeepSeek-R1/,tensor_parallel_size=8,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=4096,use_v2_block_manager=True,dtype=bfloat16,kv_cache_dtype=fp8_inc),
gen_kwargs: (None), limit: 256.0, num_fewshot: 5, batch_size: 128
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|

|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9492|± |0.0137|
| | |strict-match | 5|exact_match|↑ |0.9453|± |0.0142|

----------
Need to use vllm-hpu-extension:
https://github.com/HabanaAI/vllm-hpu-extension/tree/dev/chendi/deepseek_r1

Status:

runnable with Deepseek-R1.
Accuracy check: for block fp8 weight => garbage output
accuracy check for BF16 weight => looks good.

test scripts:
```
from vllm import LLM, SamplingParams
import os

os.environ['VLLM_SKIP_WARMUP'] = 'true'
os.environ['PT_HPU_LAZY_MODE'] = '1'
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES']='true'
os.environ['PT_HPU_WEIGHT_SHARING']='0'
os.environ['VLLM_MLA_DISABLE_REQUANTIZATION']='1'

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

if __name__ == "__main__":
    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.0, max_tokens=16, ignore_eos=True)

    # Create an LLM.
    model_path = "/data/models/DeepSeek-R1"

    llm = LLM(model=model_path,
            trust_remote_code=True,
            enforce_eager=True,
            dtype="bfloat16",
            use_v2_block_manager=True,
            max_model_len=1024,
            max_num_seqs=1,
            tensor_parallel_size=8,
            distributed_executor_backend='mp',
            gpu_memory_utilization=0.8,
            #kv_cache_dtype="fp8_inc",
            seed=2024)

    # Generate texts from the prompts. The output is a list of RequestOutput objects
    # that contain the prompt, generated text, and other information.
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    if os.environ.get("QUANT_CONFIG", None) is not None:
        llm.llm_engine.model_executor.shutdown()
```

---------

Signed-off-by: Chendi.Xue <[email protected]>
Signed-off-by: kwisniewski98 <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
Co-authored-by: kwisniewski98 <[email protected]>
xuechendi added a commit that referenced this pull request May 7, 2025
migrated from a PR to habana_main:
#1014

For Best performance, this PR is recommended to run with INC:
[[SW-223553] [VLLM] Merge deepseek changes into habana_main - Habana
Labs](https://jira.habana-labs.com/browse/SW-223553)

**test acc of G3**:
```bash
huggingface-cli download Yi30/inc-woq-default-pile-one-cache-408  --local-dir ./scripts/nc_workspace_measure_kvache

cat inc_quant_with_fp8kv_config.json
{
    "mode": "QUANTIZE",
    "observer": "maxabs",
    "scale_method": "maxabs_hw",
    "scale_format": "const",
    "allowlist": {
        "types": [],
        "names": []
    },
    "blocklist": {
        "types": [],
        "names": [
            "lm_head",
            "mlp\\.gate\\b",
            "block2batch_matmul"
        ]
    },
    "dump_stats_path": "./inc-woq-default-pile-one-cache-408-for-fp8-mla/inc_measure_output"
}

QUANT_CONFIG=inc_quant_with_fp8kv_config.json \
PT_HPU_LAZY_MODE=1 \
VLLM_SKIP_WARMUP=true \
PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
PT_HPU_WEIGHT_SHARING=0 \
VLLM_MLA_DISABLE_REQUANTIZATION=1 \
lm_eval --model vllm \
  --model_args "pretrained=/mnt/weka/data/pytorch/DeepSeek-R1/,tensor_parallel_size=8,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=4096,use_v2_block_manager=True,dtype=bfloat16,kv_cache_dtype=fp8_inc" \
  --tasks gsm8k --num_fewshot "5" --limit "256" \
  --batch_size "8"
```

**test acc of G2**:
**convert original DeepSeek-R1** using
[convert_for_g2.py](https://github.com/yangulei/vllm-fork/blob/deepseek_r1_g2/scripts/convert_for_g2.py)
(this step will be removed as INC updates.)

```bash

huggingface-cli download Yi30/inc-woq-default-pile-one-cache-412-g2  --local-dir ./scripts/nc_workspace_measure_kvache

cat inc_quant_with_fp8kv_config.json
{
    "mode": "QUANTIZE",
    "observer": "maxabs",
    "scale_method": "maxabs_hw",
    "scale_format": "const",
    "allowlist": {
        "types": [],
        "names": []
    },
    "blocklist": {
        "types": [],
        "names": [
            "lm_head",
            "mlp\\.gate\\b",
            "block2batch_matmul"
        ]
    },
    "dump_stats_path": "./nc_workspace_measure_kvache/inc_measure_output"
}
```

vllm
(pretrained=/mnt/weka/data/pytorch/DeepSeek-R1/,tensor_parallel_size=8,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=4096,use_v2_block_manager=True,dtype=bfloat16,kv_cache_dtype=fp8_inc),
gen_kwargs: (None), limit: 256.0, num_fewshot: 5, batch_size: 128
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|

|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9492|± |0.0137|
| | |strict-match | 5|exact_match|↑ |0.9453|± |0.0142|

----------
Need to use vllm-hpu-extension:
https://github.com/HabanaAI/vllm-hpu-extension/tree/dev/chendi/deepseek_r1

Status:

runnable with Deepseek-R1.
Accuracy check: for block fp8 weight => garbage output
accuracy check for BF16 weight => looks good.

test scripts:
```
from vllm import LLM, SamplingParams
import os

os.environ['VLLM_SKIP_WARMUP'] = 'true'
os.environ['PT_HPU_LAZY_MODE'] = '1'
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES']='true'
os.environ['PT_HPU_WEIGHT_SHARING']='0'
os.environ['VLLM_MLA_DISABLE_REQUANTIZATION']='1'

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

if __name__ == "__main__":
    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.0, max_tokens=16, ignore_eos=True)

    # Create an LLM.
    model_path = "/data/models/DeepSeek-R1"

    llm = LLM(model=model_path,
            trust_remote_code=True,
            enforce_eager=True,
            dtype="bfloat16",
            use_v2_block_manager=True,
            max_model_len=1024,
            max_num_seqs=1,
            tensor_parallel_size=8,
            distributed_executor_backend='mp',
            gpu_memory_utilization=0.8,
            #kv_cache_dtype="fp8_inc",
            seed=2024)

    # Generate texts from the prompts. The output is a list of RequestOutput objects
    # that contain the prompt, generated text, and other information.
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    if os.environ.get("QUANT_CONFIG", None) is not None:
        llm.llm_engine.model_executor.shutdown()
```

---------

Signed-off-by: Chendi.Xue <[email protected]>
Signed-off-by: kwisniewski98 <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
Co-authored-by: kwisniewski98 <[email protected]>
@xuechendi xuechendi closed this Jun 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.