Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] support qqq(w4a8) for lmdeploy #2274

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

HandH1998
Copy link

@HandH1998 HandH1998 commented Aug 9, 2024

Motivation

We have implemented W4A8 quantization for the lmdeploy turbomind backend using our quantization algorithm QQQ to enhance inference throughput. We hope that lmdeploy users will find this beneficial. Additionally, we have submitted a PR to vLLM, which has been incorporated into vLLM v0.5.4.

Modification

We have completed the following tasks to enable the w4a8 pipeline:

  • Converted our QQQ quantized model weights to lmdeploy format.
  • Enabled the turbomind backend to load quantized model weights.
  • Added the Marlin QQQ w4a8 GEMM kernel.
  • Fused online quantization with element-wise operations such as RMSnorm and Silu.
  • Modified the inference pipeline to accommodate online activation quantization.
  • Fused gate and up weights into one weight.

Use cases

First you need to export the quantized model weights using our repo. Then, you can enable QQQ in the same manner as you would enable AWQ. Here, we provide two examples for inference and service.

Inference

from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig, ChatTemplateConfig

# Here we use completion template. You can modify capability following the official guidance.
chat_config = ChatTemplateConfig(model_name='llama2', capability='completion')

backend_config = TurbomindEngineConfig(model_format='qqq')
model_path = your_quantized_model_path

pipe = pipeline(model_path=model_path,
                chat_template_config=chat_config,
                backend_config=backend_config,
                log_level='INFO')

gen_config = GenerationConfig(top_p=0.95,
                              temperature=0.8,
                              repetition_penalty=1.0,
                              random_seed=0,
                              max_new_tokens=512)
prompts = ["Hi, pls intro yourself", "Shanghai is"]
response = pipe(prompts, gen_config=gen_config)
print(response)

Service

lmdeploy serve api_server your_quantized_model_path --backend turbomind --model-format qqq

Benchmark

Accuracy

We employ OpenCompass to evaluate the quantized model. Here we provide the evaluation results for llama2-13b-base.

ceval mmlu triviaqa gsm8k
FP16 38.46 41.35 67.36 29.72
AWQ-g128(W4A16) 36.00 41.48 66.53 29.87
QQQ(W4A8) 32.93 41.09 64.35 25.70
QQQ-g128(W4A8) 36.16 40.94 65.85 28.51

You can add the following script to configs to reproduce our results.

from mmengine.config import read_base
from opencompass.models.turbomind import TurboMindModel

with read_base():
    # choose a list of datasets
    from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
    from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
    from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
    from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
    from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
    # and output the results in a choosen format
    from .summarizers.medium import summarizer

datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])

llama2_13b_base = dict(
        type=TurboMindModel,
        abbr='llama2-13b-base-qqq-g128',
        path=your_quantized_model_path,
        engine_config=dict(session_len=2048,
                           max_batch_size=8,
                           rope_scaling_factor=1.0,
                           model_format="qqq"),
        gen_config=dict(top_k=1, top_p=0.8,
                        temperature=1.0,
                        max_new_tokens=100),
        max_out_len=100,
        max_seq_len=2048,
        batch_size=8,
        concurrency=8,
        run_cfg=dict(num_gpus=1, num_procs=1),
    )

models = [llama2_13b_base]

Throughput

We use the script profile_restful_api.py and ShareGPT dataset to benchmark throughput. Here we provide the results for llama2-13b-base on one A100-80G.
Settings:

concurrency: 128
num_prompts: 1000
number of prompt tokens: 248339
number of completion tokens: 240582
RPS (request per second) token throughput (completion token) token throughput (prompt + completion token)
FP16 7.300 1756.165 3568.954
AWQ-g128(W4A16) 8.272 1990.156 4044.479
QQQ(W4A8) 9.454 2296.056 4666.144
QQQ-g128(W4A8) 8.484 2041.167 4148.146

@zhyncs
Copy link
Collaborator

zhyncs commented Aug 9, 2024

Hi @HandH1998 Nice work! May you merge the latest main branch and fix the conflicts?

@zhyncs
Copy link
Collaborator

zhyncs commented Aug 9, 2024

We might wait for the merge of #2090.

@lvhan028
Copy link
Collaborator

lvhan028 commented Aug 9, 2024

Brilliant!

@zhyncs
Copy link
Collaborator

zhyncs commented Aug 9, 2024

When implementing W8A8 in the future, some components may be reused.

@HandH1998
Copy link
Author

Hi @HandH1998 Nice work! May you merge the latest main branch and fix the conflicts?

Done

Copy link
Collaborator

@zhyncs zhyncs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HandH1998 May you fix the CI issue? Thanks.

@HandH1998 HandH1998 force-pushed the w4a8 branch 2 times, most recently from ea251f3 to ecee3aa Compare August 9, 2024 08:11
@zhyncs
Copy link
Collaborator

zhyncs commented Aug 9, 2024

@HandH1998 The Windows build is still failing.

@brisker
Copy link

brisker commented Aug 13, 2024

@HandH1998
It seems that, if w4a8 is per-channel quantized without group( group_size=-1), the w8a8 triton kernel in lmdeploy repo can be easily modified into a w4a8 one. Will the speedup of QQQ cuda implementation be similar to that triton-version one, since things get a lot simpler without group-quantization

@zhyncs
Copy link
Collaborator

zhyncs commented Aug 17, 2024

@HandH1998 Marlin W4A16 is mainly optimized for A100, but compared to TurboMind AWQ, its performance is still worse. Marlin's performance on H100 is average, especially compared to #2090, the gap is very large. After #2090 merges next week, this PR will be reviewed. There are probably 2 strategies: one is to review based on the current implementation first (of course, assuming you still need to merge the latest main and resolve some conflicts), and then reimplement it later according to the optimized implementation in TurboMind. Another strategy is to reimplement it directly (which can be based on some existing components), we'll discuss this at that time. @lzhangzz cc @irexyc @lvhan028

@zhyncs
Copy link
Collaborator

zhyncs commented Aug 17, 2024

And the difference should not be significant on A100. I have roughly verified it using SGLang's Marlin AWQ and LMDeploy TurboMind's AWQ on Llama 3.1 8B Instruct, and their performance is basically close (though I don't remember if this was based on whether LMDeploy had already fixed that chunked prefill bug).

Copy link
Collaborator

@zhyncs zhyncs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2090 has been merged. Please merge the latest main and resolve the conflicts. Thanks.

@lzhangzz
Copy link
Collaborator

And the difference should not be significant on A100. I have roughly verified it using SGLang's Marlin AWQ and LMDeploy TurboMind's AWQ on Llama 3.1 8B Instruct, and their performance is basically close (though I don't remember if this was based on whether LMDeploy had already fixed that chunked prefill bug).

That's the old AWQ kernels. The new kernels achieve 26+ RPS with A100 and Llama 3.1 8B.

@zhyncs
Copy link
Collaborator

zhyncs commented Aug 26, 2024

@HandH1998 May you resolve the conflicts in these days? After that, @lzhangzz will help rewrite with the TurboMind’s style. We should move forward together.

@HandH1998
Copy link
Author

@HandH1998 May you resolve the conflicts in these days? After that, @lzhangzz will help rewrite with the TurboMind’s style. We should move forward together.

I am working on it. Since the main branch changed a lot, I still need time to resolve the conflicts and fix new bugs. Probably I can finish it in two days.

@HandH1998
Copy link
Author

@zhyncs @lzhangzz I have resolved the conflicts, and you can continue to do the optimization work. Two checks failed, but I think they are irrelevant with my code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants