-
Notifications
You must be signed in to change notification settings - Fork 5.3k
style refinement for hisparse #21198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
719c9b7
7937b11
b80b055
8ed29bf
845d73a
9b1552a
21f1d9b
11050b1
19ab214
ff7d02a
9b9dd3d
3eee1e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,11 +53,24 @@ __device__ __forceinline__ int warp_inclusive_scan(int* s_data, int lane_id, int | |
| return accumulator; | ||
| } | ||
|
|
||
| // Shared memory size calculation for dynamic allocation. | ||
| // Layout: int32_t region (4-byte aligned) followed by int16_t region (2-byte aligned). | ||
| template <int NUM_TOP_K, int HOT_BUFFER_SIZE> | ||
| struct SmemLayout { | ||
| static constexpr int HASH_SIZE = NUM_TOP_K * 2; | ||
| static constexpr int NUM_BUFFER_CHUNKS = (HOT_BUFFER_SIZE + WARP_SIZE - 1) / WARP_SIZE; | ||
| // int32_t region: top_k_tokens + chunk_offset + evict_chunk_offset + hash_keys + total_hits + newest_hit | ||
| static constexpr int TOTAL_INT32 = NUM_TOP_K + (NUM_BUFFER_CHUNKS + 1) + (NUM_BUFFER_CHUNKS + 1) + HASH_SIZE + 2; | ||
| // int16_t region: lru_slots_out + hash_vals | ||
| static constexpr int TOTAL_INT16 = HOT_BUFFER_SIZE + HASH_SIZE; | ||
| static constexpr size_t BYTES = TOTAL_INT32 * sizeof(int32_t) + TOTAL_INT16 * sizeof(int16_t); | ||
| }; | ||
|
|
||
| // Each block processes one request | ||
| // req_pool_indices are int64_t (pool indices can be large), seq_lens are int32_t | ||
| // req_pool_indices are int64_t (pool indices can be large), seq_lens can be int32_t or int64_t | ||
| // Layout: [HOT_BUFFER_SIZE slots for LRU] + [page_size slots for newest token] | ||
| // newest_slot is at HOT_BUFFER_SIZE (first position of extra page) | ||
| template <int BLOCK_SIZE, int NUM_TOP_K, int HOT_BUFFER_SIZE, bool IsMLA> | ||
| template <int BLOCK_SIZE, int NUM_TOP_K, int HOT_BUFFER_SIZE, bool IsMLA, typename SeqLensT> | ||
| __global__ void load_cache_to_device_buffer_kernel( | ||
| const int32_t* __restrict__ top_k_tokens, | ||
| int32_t* __restrict__ device_buffer_tokens, | ||
|
|
@@ -69,7 +82,7 @@ __global__ void load_cache_to_device_buffer_kernel( | |
| void* __restrict__ device_buffer_v, | ||
| int32_t* __restrict__ top_k_device_locs, | ||
| const int64_t* __restrict__ req_pool_indices, | ||
| const int32_t* __restrict__ seq_lens, | ||
| const SeqLensT* __restrict__ seq_lens, | ||
| int16_t* __restrict__ lru_slots, | ||
| const int32_t* __restrict__ num_real_reqs, | ||
| int64_t buffer_stride_0, | ||
|
|
@@ -118,21 +131,29 @@ __global__ void load_cache_to_device_buffer_kernel( | |
| return; | ||
| } | ||
|
|
||
| // Dynamic shared memory layout: int32_t arrays first, then int16_t arrays. | ||
| extern __shared__ char smem_raw[]; | ||
| using Layout = SmemLayout<NUM_TOP_K, HOT_BUFFER_SIZE>; | ||
| constexpr int HASH_SIZE = Layout::HASH_SIZE; | ||
|
|
||
| int32_t* smem_i32 = reinterpret_cast<int32_t*>(smem_raw); | ||
| // Top-k token positions; reused as miss-token scratch in the copy phase | ||
| __shared__ int32_t s_top_k_tokens[NUM_TOP_K]; | ||
| int32_t* s_top_k_tokens = smem_i32; | ||
| // Prefix-sum offsets for hit counting and miss counting | ||
| __shared__ int32_t s_chunk_offset[NUM_BUFFER_CHUNKS + 1]; | ||
| int32_t* s_chunk_offset = s_top_k_tokens + NUM_TOP_K; | ||
| // Prefix-sum offsets for evictable counting | ||
| __shared__ int32_t s_evict_chunk_offset[NUM_BUFFER_CHUNKS + 1]; | ||
| int32_t* s_evict_chunk_offset = s_chunk_offset + (NUM_BUFFER_CHUNKS + 1); | ||
| // Open-addressing hash table: top-k token_id → top-k index (keys) | ||
| int32_t* s_hash_keys = s_evict_chunk_offset + (NUM_BUFFER_CHUNKS + 1); | ||
| // Scalar counters | ||
| int32_t& s_total_hits = s_hash_keys[HASH_SIZE]; | ||
| int32_t& s_newest_hit = s_hash_keys[HASH_SIZE + 1]; | ||
|
|
||
| int16_t* smem_i16 = reinterpret_cast<int16_t*>(smem_i32 + Layout::TOTAL_INT32); | ||
| // Compacted slot ordering: [hits fwd→ ... ←evictables bwd] | ||
| __shared__ int16_t s_lru_slots_out[HOT_BUFFER_SIZE]; | ||
| // Open-addressing hash table: top-k token_id → top-k index | ||
| constexpr int HASH_SIZE = NUM_TOP_K * 2; | ||
| __shared__ int32_t s_hash_keys[HASH_SIZE]; | ||
| __shared__ int16_t s_hash_vals[HASH_SIZE]; | ||
|
|
||
| __shared__ int32_t s_total_hits; | ||
| __shared__ int32_t s_newest_hit; | ||
| int16_t* s_lru_slots_out = smem_i16; | ||
| // Open-addressing hash table: top-k token_id → top-k index (values) | ||
| int16_t* s_hash_vals = s_lru_slots_out + HOT_BUFFER_SIZE; | ||
|
Comment on lines
+134
to
+156
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Switching from static shared memory arrays to dynamic allocation using |
||
|
|
||
| // Initialize shared memory: counters, hash table, prefix-sum offsets. | ||
| if (tid == 0) { | ||
|
|
@@ -363,28 +384,47 @@ void load_cache_to_device_buffer( | |
| const int64_t top_k_device_locs_stride = top_k_device_locs.strides()[0]; | ||
| const auto device = LaunchKernel::resolve_device(top_k_tokens.device()); | ||
|
|
||
| LaunchKernel(bs, BLOCK_SIZE, device)( | ||
| load_cache_to_device_buffer_kernel<BLOCK_SIZE, NUM_TOP_K, HOT_BUFFER_SIZE, IsMLA>, | ||
| static_cast<const int32_t*>(top_k_tokens.data_ptr()), | ||
| static_cast<int32_t*>(device_buffer_tokens.data_ptr()), | ||
| static_cast<const int64_t*>(host_cache_locs.data_ptr()), | ||
| static_cast<const int32_t*>(device_buffer_locs.data_ptr()), | ||
| host_cache_k.data_ptr(), | ||
| (IsMLA || host_cache_v.ndim() == 0) ? (const void*)nullptr : host_cache_v.data_ptr(), | ||
| device_buffer_k.data_ptr(), | ||
| (IsMLA || device_buffer_v.ndim() == 0) ? (void*)nullptr : device_buffer_v.data_ptr(), | ||
| static_cast<int32_t*>(top_k_device_locs.data_ptr()), | ||
| static_cast<const int64_t*>(req_pool_indices.data_ptr()), | ||
| static_cast<const int32_t*>(seq_lens.data_ptr()), | ||
| static_cast<int16_t*>(lru_slots.data_ptr()), | ||
| static_cast<const int32_t*>(num_real_reqs.data_ptr()), | ||
| buffer_stride_0, | ||
| host_stride, | ||
| lru_slot_stride_0, | ||
| top_k_tokens_stride, | ||
| top_k_device_locs_stride, | ||
| page_size, | ||
| item_size_bytes); | ||
| // Generic lambda: both int32 and int64 kernel variants are compiled; | ||
| // the correct one is selected at runtime based on seq_lens dtype. | ||
| auto launch = [&](auto kernel_fn, const auto* seq_lens_ptr) { | ||
| constexpr size_t smem_bytes = SmemLayout<NUM_TOP_K, HOT_BUFFER_SIZE>::BYTES; | ||
| if constexpr (smem_bytes > 48u * 1024u) { | ||
| cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); | ||
| } | ||
| LaunchKernel(bs, BLOCK_SIZE, device, smem_bytes)( | ||
| kernel_fn, | ||
| static_cast<const int32_t*>(top_k_tokens.data_ptr()), | ||
| static_cast<int32_t*>(device_buffer_tokens.data_ptr()), | ||
| static_cast<const int64_t*>(host_cache_locs.data_ptr()), | ||
| static_cast<const int32_t*>(device_buffer_locs.data_ptr()), | ||
| host_cache_k.data_ptr(), | ||
| (IsMLA || host_cache_v.ndim() == 0) ? (const void*)nullptr : host_cache_v.data_ptr(), | ||
| device_buffer_k.data_ptr(), | ||
| (IsMLA || device_buffer_v.ndim() == 0) ? (void*)nullptr : device_buffer_v.data_ptr(), | ||
| static_cast<int32_t*>(top_k_device_locs.data_ptr()), | ||
| static_cast<const int64_t*>(req_pool_indices.data_ptr()), | ||
| seq_lens_ptr, | ||
| static_cast<int16_t*>(lru_slots.data_ptr()), | ||
| static_cast<const int32_t*>(num_real_reqs.data_ptr()), | ||
| buffer_stride_0, | ||
| host_stride, | ||
| lru_slot_stride_0, | ||
| top_k_tokens_stride, | ||
| top_k_device_locs_stride, | ||
| page_size, | ||
| item_size_bytes); | ||
| }; | ||
|
|
||
| const auto dtype = seq_lens.dtype(); | ||
| if (dtype.code == kDLInt && dtype.bits == 64) { | ||
| launch( | ||
| load_cache_to_device_buffer_kernel<BLOCK_SIZE, NUM_TOP_K, HOT_BUFFER_SIZE, IsMLA, int64_t>, | ||
| static_cast<const int64_t*>(seq_lens.data_ptr())); | ||
| } else { | ||
| launch( | ||
| load_cache_to_device_buffer_kernel<BLOCK_SIZE, NUM_TOP_K, HOT_BUFFER_SIZE, IsMLA, int32_t>, | ||
| static_cast<const int32_t*>(seq_lens.data_ptr())); | ||
| } | ||
| } | ||
|
|
||
| } // namespace | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2146,6 +2146,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: | |
| chunked_req_to_exclude.add(self.chunked_req) | ||
| self.stash_chunked_request(self.chunked_req) | ||
|
|
||
| # HiSparse has its own prefill-to-decode transition; skip last_batch merge. | ||
| if self.enable_hisparse: | ||
| ready_reqs = self.hisparse_coordinator.collect_ready_reqs() | ||
| if len(ready_reqs) > 0: | ||
|
|
@@ -2155,31 +2156,35 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: | |
| else: | ||
| self.running_batch.merge_batch(new_batch) | ||
| self.running_batch.hisparse_coordinator = self.hisparse_coordinator | ||
| else: | ||
| if self.last_batch and self.last_batch.forward_mode.is_extend(): | ||
| if self.last_batch.chunked_req is not None: | ||
| # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req. | ||
| # We need to discard it. | ||
| chunked_req_to_exclude.add(self.last_batch.chunked_req) | ||
|
|
||
| if self.dllm_config is not None and self.last_batch.reqs: | ||
| chunked_req_to_exclude.update(self.last_batch.reqs) | ||
|
|
||
| # Filter batch | ||
| last_bs = self.last_batch.batch_size() | ||
| self.last_batch.filter_batch( | ||
| chunked_req_to_exclude=list(chunked_req_to_exclude) | ||
| ) | ||
| if self.last_batch.batch_size() < last_bs: | ||
| self.running_batch.batch_is_full = False | ||
|
|
||
| # Merge the new batch into the running batch. | ||
| if not self.last_batch.is_empty(): | ||
| if self.running_batch.is_empty(): | ||
| self.running_batch = self.last_batch | ||
| else: | ||
| # Merge running_batch with prefill batch | ||
| self.running_batch.merge_batch(self.last_batch) | ||
| if ( | ||
| not self.enable_hisparse | ||
| and self.last_batch | ||
| and self.last_batch.forward_mode.is_extend() | ||
|
Comment on lines
+2160
to
+2163
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| ): | ||
| if self.last_batch.chunked_req is not None: | ||
| # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req. | ||
| # We need to discard it. | ||
| chunked_req_to_exclude.add(self.last_batch.chunked_req) | ||
|
|
||
| if self.dllm_config is not None and self.last_batch.reqs: | ||
| chunked_req_to_exclude.update(self.last_batch.reqs) | ||
|
|
||
| # Filter batch | ||
| last_bs = self.last_batch.batch_size() | ||
| self.last_batch.filter_batch( | ||
| chunked_req_to_exclude=list(chunked_req_to_exclude) | ||
| ) | ||
| if self.last_batch.batch_size() < last_bs: | ||
| self.running_batch.batch_is_full = False | ||
|
|
||
| # Merge the new batch into the running batch. | ||
| if not self.last_batch.is_empty(): | ||
| if self.running_batch.is_empty(): | ||
| self.running_batch = self.last_batch | ||
| else: | ||
| # Merge running_batch with prefill batch | ||
| self.running_batch.merge_batch(self.last_batch) | ||
|
|
||
| # For prefill-only batch, filter out finished requests since they | ||
| # won't go through the decode step. This keeps running_batch accurate | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The introduction of the
SmemLayoutstruct for dynamic shared memory calculation is a good practice. It centralizes the shared memory layout definition, making it easier to manage and reason about the memory usage, especially with different data types and alignment requirements. This approach enhances flexibility and maintainability for CUDA kernels.