Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add speculative decoding #1120

Merged
merged 25 commits into from
Jan 31, 2024
Merged

Add speculative decoding #1120

merged 25 commits into from
Jan 31, 2024

Conversation

abetlen
Copy link
Owner

@abetlen abetlen commented Jan 23, 2024

Uses prompt lookup decoding but the draft model class can be extended to support almost any existing method.

Server Usage

python3 -m llama_cpp.server --model models/7B/llama-model.gguf --draft_model=prompt-lookup-decoding --draft_model_num_pred_tokens=2

Python Usage

>>> from llama_cpp import Llama
>>> from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
>>> llm = Llama(
    model_path="./models/7B/llama-model.gguf", 
    draft_model=LlamaPromptLookupDecoding(
        num_pred_tokens=10, # Good default for gpu offloading, 2 is better for cpu-only machines
        max_ngram_size=2, # 2 is the huggingface implementation and found to work the best for me as well.
    )
)

Performance

This is a very dumb / easy example but it looks like it's working!

With prompt lookup decoding

image

Without prompt lookup decoding

image

Closes #675

@abetlen abetlen mentioned this pull request Jan 23, 2024
@abetlen abetlen changed the title Add speculative decoding support Add speculative decoding Jan 23, 2024
@abetlen
Copy link
Owner Author

abetlen commented Jan 24, 2024

Tried on a more realistic example and got worse performance, think I'll need to tune / implement a heuristic for draft models similar to https://huggingface.co/blog/assisted-generation

Adjust the number of candidate tokens to be produced in the next iteration — our original heuristic increases it by 2 if ALL tokens match and decreases it by 1 otherwise.

@abetlen
Copy link
Owner Author

abetlen commented Jan 24, 2024

Added the adaptive heuristic and it does do better but still occasionally slower even with termperature=0, will need to investigate.

@oobabooga
Copy link
Contributor

Highly appreciated PR. Is it possible to make prompt_lookup_num_tokens a generation parameter on the same footing as temperature as is done in the transformers library? That would make it possible to change that parameter without having to reload the model.

@abetlen
Copy link
Owner Author

abetlen commented Jan 25, 2024

@oobabooga I saw, I was looking at the hf implementation as a reference. I could add it as a general num_pred_tokens because I want to keep it open to other implementations of speculative decoding. I'll think on that one.

@abetlen
Copy link
Owner Author

abetlen commented Jan 31, 2024

@oobabooga going to merge this now. For updating the draft model or it's properties without re-creating the entire Llama model class just assume that you can access llm.draft_model, set it to None to disable.

@abetlen abetlen merged commit fb762a6 into main Jan 31, 2024
16 checks passed
@oobabooga
Copy link
Contributor

Awesome, thanks @abetlen!

@abetlen abetlen deleted the add-speculative-decoding branch January 31, 2024 20:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Speculative sampling
2 participants