[Draft][Core] Refactor _prepare_model_input_tensors#5972
Conversation
rkooo567
left a comment
There was a problem hiding this comment.
I remember the goal we want to is to write logics agonistic to prefill/decode (mainly because prefill is a special case of decode). At least that was the direction we wanted last time (and this PR seems to revert that direction). That's also why existing prepare_inputs doesn't distinguish prefill/decode as much as possible. That will enable features such as https://github.com/vllm-project/vllm/pull/6052/files#diff-d3df23c3e3bcfe97ee8507061c6de54f0eff23a8c75d7f5999062c42245290f8
How difficult is it to not distinguish prefill/decode at least in metadata level? Also, cc @zhuohan123
|
The reason I separated prefill/decode is I observed the following things:
Meanwhile, this separation shouldn't affect #6052, which focuses on the forward logic that is orthogonal to prepare_input. And some attention backends (e.g. xformers) cannot be unified in this way anyways. However, if you feel it's still better to not separate them, I can revert that in this PR. Happy to discuss :) |
|
Let me cc @zhuohan123 and @simon-mo for this one. We discussed this before, and I combined prepare_prefill/decode into a single API, and that was the direction they wanted before. It is the second item in this proposal. https://docs.google.com/document/d/1rg8CoOnrtz1LT-hCK86ZsHuhoTDtqSEGs8KrN4wbITo/edit I agreed with complex logics. But I think this is actually not fundamental but more of due to tech debt. |
|
Moved to #6164 |
NOTE: This PR will be rebased after the following PRs are merged: #4628 #5942.
Meanwhile, reviews and comments are welcome.
This PR refactors
_prepare_model_input_tensors. Specifically, we introduceModelRunnerInputBuildermainly for logic isolation and modularization. Specifically,ModelRunnerInputBuildermanages all processed input data, including token IDs, positions, sequence length, etc, in one place, and isolates the following logic:Note that the purpose of this PR is to enable follow-up refactoring and optimizations, so we don't expect an obvious performance improvement at this moment, although the following optimizations may be slightly helpful:
.extend().With this isolation, we could further have follow-up optimizations:
AttentionMetadatato only include on-device tensors, and move all related logic fromModelRunnerInputBuilder.for seq_id in seq_idsinModelRunnerInputBuilder._add_decode_seq_group()by leveraging tensor processing.for seq_group_metadata in seq_group_metadata_list.