Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adjust schedule to improve TTFT in pytorch engine #2477

Merged
merged 1 commit into from
Sep 20, 2024

Conversation

grimoire
Copy link
Collaborator

prefill enough requests before decoding.

@@ -636,7 +631,11 @@ async def __long_context_single_forward(inputs):
ret['logits'] = ret['logits'][:, last_token_loc]
return ret
else:
return await __long_context_single_forward(inputs)
ret = await __long_context_single_forward(inputs)
if not return_logits and not inputs.is_decoding:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what benefits can we get from this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def __get_last_logits():
"""get last logits."""
seq_length = inputs.seq_length
if len(seq_length) == logits.size(0):
return logits
last_idx = seq_length.cumsum(-1) - 1
return logits[last_idx, :]

return await __long_context_single_forward(inputs)
ret = await __long_context_single_forward(inputs)
if not return_logits and not inputs.is_decoding:
last_token_loc = [-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it influence the following computation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__long_context_single_forward is used for very long prefill. Only last output is required.

loop_count = 1 if is_prefill else (prefill_interval - 1)
if len(running) == 0:
raise NoRunningSeqs()
assert len(running) > 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we make sure it is True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not self.scheduler.has_unfinished():
await asyncio.sleep(0.01)
continue

Empty requests would be skipped.

@lvhan028 lvhan028 merged commit 82d0c00 into InternLM:main Sep 20, 2024
5 checks passed
@lvhan028 lvhan028 changed the title Pytorch Engine reduce TTFT adjust schedule to improve TTFT in pytorch engine Sep 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants