Optimization for long prompt #15
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Implement optimization for processing long prompt.
Original pytorch generation code processes the whole prompt sequence in one
forward()pass. This forward pass computes the cached keys and vals for all the prompt tokens, which are needed for the subsequent decoding. Then it starts to generate new tokens 1 at a time.In XLA optimized algorithm, we process 1 prompt token at a time to avoid dynamism and graph recompilation caused by varying prompt lengths. This works well when the prompt length is short, but would be inefficient when prompt is long (1000+ tokens).
To optimize for long prompt case and strike a balance between graph compilation and efficiency, exponentially scaled prompt processing is implemented in this PR.
Example:
65B model on V4-32 with arguments:
--max_seq_len 2048 --max_batch_size 1 --prompt_len 1500 --max_gen_len 256The warmed-up total time reduced from 30.13373 seconds to 4.40351 seconds.
Before PR:
"Decoded 1756 tokens in 174.13809 seconds""Decoded 1756 tokens in 30.13373 seconds"After PR:
"Processed prompts with 1501 to 1501 tokens, and generated 256 tokens""Totally decoded 1756 tokens in 643.19650 seconds""Processed prompts with 1501 to 1501 tokens, and generated 256 tokens""Totally decoded 1756 tokens in 4.40351 seconds""Processing prompt pos [0, 512), section length 512
Processing prompt pos [512, 1024), section length 512
Processing prompt pos [1024, 1088), section length 64
Processing prompt pos [1088, 1152), section length 64
Processing prompt pos [1152, 1216), section length 64
Processing prompt pos [1216, 1280), section length 64
Processing prompt pos [1280, 1344), section length 64
Processing prompt pos [1344, 1408), section length 64
Processing prompt pos [1408, 1472), section length 64
Processing prompt pos [1472, 1480), section length 8
Processing prompt pos [1480, 1488), section length 8
Processing prompt pos [1488, 1496), section length 8
Processing prompt pos [1496, 1497), section length 1
Processing prompt pos [1497, 1498), section length 1
Processing prompt pos [1498, 1499), section length 1
Processing prompt pos [1499, 1500), section length 1
Processing prompt pos [1500, 1501), section length 1"