-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add document for speculative decoding #9492
Merged
Merged
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -108,8 +108,19 @@ PaddleNLP 提供了多种量化策略,支持Weight Only INT8及INT4推理, | |
|
||
- `cachekv_int8_type`: 是否使用cachekv int8量化,默认值为None。可选`dynamic`(已不再维护,不建议使用)和`static`两种,`static`需要额外的cache kv的scale校准表,传入的 `model_name_or_path` 为PTQ校准产出的量化模型。量化模型导出参考[大模型量化教程](../quantization.md)。 | ||
|
||
### 3.4 投机解码参数 | ||
|
||
### 3.4 解码策略参数 | ||
- `speculate_method`: 推理解码算法,默认值为`None`,可选的数值有`None`、`inference_with_reference`。为`None`时为正常自回归解码,为`inference_with_reference`时为基于上下文的投机解码[论文地址](https://arxiv.org/pdf/2304.04487)。 | ||
|
||
- `speculate_max_draft_token_num`: 投机解码算法中每轮产生的最大 draft tokens 数目,默认值为 1。 | ||
|
||
- `speculate_max_ngram_size`: ngram 匹配 draft tokens 时的最大窗口大小,默认值为`1`。inference_with_reference 算法中会先从 prompt 中使用 ngram 窗口滑动匹配 draft tokens,窗口大小和输入输出重叠程度共同决定了产生 draft tokens 的开销从而影响 inference_with_reference 算法的加速效果。 | ||
|
||
- `speculate_verify_window`: 投机解码 verify 策略默认采用 TopP + TopK 验证中的 K,默认值为`2`。 | ||
|
||
- `speculate_max_candidate_len`: 产生的最大候选 tokens 数目,根据候选 tokens 与 draft tokens 比较来进行 verify,默认值为`5`。 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个需要讲清楚,仅在topp + window verify策略下生效。我觉得可能有必要在这个文档里面单开一个小节讲述一下我们现在支持的top-1验证和top-p + window verify There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的 |
||
|
||
### 3.5 解码策略参数 | ||
|
||
- `decode_strategy`: 推理解码策略,默认值为`sampling`,可选的数值有`greedy_search`、`beam_search`和`sampling`。 | ||
|
||
|
@@ -119,7 +130,7 @@ PaddleNLP 提供了多种量化策略,支持Weight Only INT8及INT4推理, | |
|
||
- `temperature`:“采样”策略中会对输出logit除以temperature。默认值为1.0,表示不起作用。 | ||
|
||
### 3.4 性能分析参数 | ||
### 3.6 性能分析参数 | ||
|
||
- `benchmark`: 是否开启性能分析,默认值为False。如果设为true,会将模型输入填充为src_length并强制解码到max_length,并计算模型推理吞吐量、记录推理时间。 | ||
|
||
|
@@ -165,6 +176,7 @@ python ./predict/predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat -- | |
- [llama](./llama.md) | ||
- [qwen](./qwen.md) | ||
- [mixtral](./mixtral.md) | ||
- [投机解码](./speculative_decoding.md) | ||
|
||
环境准备,参考: | ||
|
||
|
@@ -190,4 +202,3 @@ python ./predict/predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat -- | |
## 致谢 | ||
|
||
我们参考[FlashInfer框架](https://github.com/flashinfer-ai/flashinfer),在FlashInfer的基础上,实现了append attention。参考[PageAttention](https://github.com/vllm-project/vllm)的page分块的思想实现了generation阶段的block attention。基于[Flash Decoding](https://github.com/Dao-AILab/flash-attention)的KV分块思想实现了长sequence场景下的推理加速。基于[Flash Attention2](https://github.com/Dao-AILab/flash-attention)实现了prefill阶段的attention加速。FP8 GEMM基于[CUTLASS](https://github.com/NVIDIA/cutlass)的高性能模板库实现。有部分算子如gemm_dequant参考了[TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM)和[FasterTransformer](https://github.com/NVIDIA/FasterTransformer.git)的实现和优化思路。 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# 投机解码教程 | ||
|
||
投机解码是一个通过投机性地一次性猜测多个 token 然后进行验证和接收的算法,通过投机解码可以极大地减小推理时延。PaddleNLP 提供了简单、高效的投机解码推理流程。下面提供 PaddleNLP 中各种投机解码算法的使用说明。 | ||
|
||
## Inference with reference | ||
|
||
该算法通过 n-gram 窗口从 prompt 中匹配 draft tokens,适合输入和输出有很大 overlap 的场景如代码编辑、文档查询等,更多信息查看查看[论文地址](https://arxiv.org/pdf/2304.04487)。 | ||
|
||
### 使用命令 | ||
|
||
```shell | ||
# 动态图模型推理命令参考 | ||
python ./predict/predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --speculate_method inference_with_reference --speculate_max_draft_token_num 5 --speculate_max_ngram_size 2 | ||
``` | ||
|
||
**Note:** | ||
|
||
1. 该算法目前只支持 llama 系列模型 | ||
2. 投机解码同时支持量化推理,具体命令参考[推理示例](./inference.md),将 speculate_method 等投机解码参数加上即可。 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的window含义不是K,是指在这个window中的所有draft tokens,需要被topk策略同时接收,否则被同时拒绝
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的