[core][optimization] use a pool of numpy ndarray to hold seq data#5942
[core][optimization] use a pool of numpy ndarray to hold seq data#5942youkaichao wants to merge 14 commits into
Conversation
| output_token_ids: Optional[List[int]] = None, | ||
| ) -> None: | ||
| self.tokens = _SEQUENCE_DATA_POOL.alloc_array() | ||
| self.prompt_token_ids_list = prompt_token_ids |
There was a problem hiding this comment.
Is there any opportunity to get rid of this list (and output token ids list)? This is completely duplicated to the numpy arrays and we should avoid that as possible.
There was a problem hiding this comment.
I want to delete it, too. However, sometimes we need to get the list of int of prompt token ids because users want list of int. If we don't store it here, we need to create a copy from numpy array, which is expensive.
Fortunately, this is just a reference, performance-wise it is fine.
There was a problem hiding this comment.
I searched the code base and seems like only batch expansion uses get_prompt_token_ids() and get_output_token_ids(), so it should be possible, as batch expansion is going to be removed by @LiuXiaoxuanPKU
|
|
||
| def append_token_id(self, token_id: int, logprob: float) -> None: | ||
| self.output_token_ids.append(token_id) | ||
| self.tokens[self.num_prompt_tokens + self.num_output_tokens] = token_id |
There was a problem hiding this comment.
Ideally we should have an assertion to check the boundary, even 128k should always be sufficient atm. Let's add an assert if it doesn't hurt performance; otherwise we could just comment that we assume the context length won't go beyond 128k.
There was a problem hiding this comment.
I think numpy array indexing already has boundary check.
There was a problem hiding this comment.
I want to somehow know the max seq length in seqdata, but don't know how to pass that information across so many levels.
There was a problem hiding this comment.
Setting a fixed length makes sense to me considering the simplicity. Hmm maybe it's ok to keep the current implementation then. If someone really hits the boundary and see the numpy error, we could know what's going on...
|
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
|
This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it. Thank you! |
The remaining part of #5877 after separating #5882 out.
the same benchmark command:
python benchmarks/benchmark_throughput.py --output-len 256 --input 256 --model meta-llama/Llama-2-7b-hf -tp 8the same machine: 8*H100
before (current main): Throughput: 38.89 requests/s, 19909.29 tokens/s
after (this PR): Throughput: 40.12 requests/s, 20541.11 tokens/s