Skip to content

Conversation

@Liyang90
Copy link
Collaborator

@Liyang90 Liyang90 commented May 12, 2023

Implement optimization for processing long prompt.

Original pytorch generation code processes the whole prompt sequence in one forward() pass. This forward pass computes the cached keys and vals for all the prompt tokens, which are needed for the subsequent decoding. Then it starts to generate new tokens 1 at a time.
In XLA optimized algorithm, we process 1 prompt token at a time to avoid dynamism and graph recompilation caused by varying prompt lengths. This works well when the prompt length is short, but would be inefficient when prompt is long (1000+ tokens).

To optimize for long prompt case and strike a balance between graph compilation and efficiency, exponentially scaled prompt processing is implemented in this PR.

Example:
65B model on V4-32 with arguments: --max_seq_len 2048 --max_batch_size 1 --prompt_len 1500 --max_gen_len 256

The warmed-up total time reduced from 30.13373 seconds to 4.40351 seconds.

Before PR:

  • Warmup pass: compiles 1 major graph, executed 1756 major graphs
    "Decoded 1756 tokens in 174.13809 seconds"
  • Second pass: executed 1756 major graphs
    "Decoded 1756 tokens in 30.13373 seconds"

After PR:

  • Warmup pass: compiles 4 major graphs, executed 272 major graphs
    "Processed prompts with 1501 to 1501 tokens, and generated 256 tokens"
    "Totally decoded 1756 tokens in 643.19650 seconds"
  • Second pass: executed 272 major graphs
    "Processed prompts with 1501 to 1501 tokens, and generated 256 tokens"
    "Totally decoded 1756 tokens in 4.40351 seconds"
  • 4 graph compiled are for input lengths 512, 64, 8, and 1. They can cover any future prompt lengths between 1 and 2047. In this example, they are used to process the 1501 prompt length in the following manner automatically:
    "Processing prompt pos [0, 512), section length 512
    Processing prompt pos [512, 1024), section length 512
    Processing prompt pos [1024, 1088), section length 64
    Processing prompt pos [1088, 1152), section length 64
    Processing prompt pos [1152, 1216), section length 64
    Processing prompt pos [1216, 1280), section length 64
    Processing prompt pos [1280, 1344), section length 64
    Processing prompt pos [1344, 1408), section length 64
    Processing prompt pos [1408, 1472), section length 64
    Processing prompt pos [1472, 1480), section length 8
    Processing prompt pos [1480, 1488), section length 8
    Processing prompt pos [1488, 1496), section length 8
    Processing prompt pos [1496, 1497), section length 1
    Processing prompt pos [1497, 1498), section length 1
    Processing prompt pos [1498, 1499), section length 1
    Processing prompt pos [1499, 1500), section length 1
    Processing prompt pos [1500, 1501), section length 1"

Copy link

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this is smart

@AlexWertheim
Copy link
Collaborator

LGTM, really clever approach!

@Liyang90 Liyang90 merged commit f61383e into prompt-gen May 12, 2023
Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart!

@Liyang90 Liyang90 deleted the liyanglu/bucketized_prompt_len branch July 25, 2023 19:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants