@@ -137,10 +137,11 @@ __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdT
137137 const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
138138 const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
139139
140- extern __shared__ __align__ (alignof (
141- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
140+ extern __shared__ __align__ (
141+ alignof (SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
142+ uint8_t smem_sampling[];
142143 auto & temp_storage = reinterpret_cast <
143- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem );
144+ SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling );
144145 temp_storage.data .sampled_id = d - 1 ;
145146 __syncthreads ();
146147
@@ -171,10 +172,11 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
171172 const uint32_t batch_size = gridDim .x ;
172173 const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
173174
174- extern __shared__ __align__ (alignof (
175- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
175+ extern __shared__ __align__ (
176+ alignof (SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
177+ uint8_t smem_sampling[];
176178 auto & temp_storage = reinterpret_cast <
177- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem );
179+ SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling );
178180
179181 vec_t <DType, VEC_SIZE> probs_vec;
180182 DType aggregate;
@@ -264,10 +266,11 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
264266 }
265267 const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
266268
267- extern __shared__ __align__ (alignof (
268- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
269+ extern __shared__ __align__ (
270+ alignof (SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
271+ uint8_t smem_sampling[];
269272 auto & temp_storage = reinterpret_cast <
270- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem );
273+ SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling );
271274
272275 vec_t <DType, VEC_SIZE> probs_vec;
273276 DType aggregate;
@@ -454,9 +457,9 @@ __global__ void TopPRenormProbKernel(DType* probs, IdType* renormed_prob, float
454457 const uint32_t row_idx = bx;
455458
456459 extern __shared__ __align__ (alignof (RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
457- uint8_t smem [];
460+ uint8_t smem_renorm [];
458461 auto & temp_storage =
459- reinterpret_cast <RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem );
462+ reinterpret_cast <RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm );
460463 temp_storage.data .max_val = DType (0 );
461464 vec_t <DType, VEC_SIZE> probs_vec;
462465 DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
@@ -543,9 +546,9 @@ __global__ void TopKRenormProbKernel(DType* probs, IdType* renormed_prob, uint32
543546 const uint32_t row_idx = bx;
544547
545548 extern __shared__ __align__ (alignof (RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
546- uint8_t smem [];
549+ uint8_t smem_renorm [];
547550 auto & temp_storage =
548- reinterpret_cast <RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem );
551+ reinterpret_cast <RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm );
549552 temp_storage.data .max_val = DType (0 );
550553 vec_t <DType, VEC_SIZE> probs_vec;
551554 DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
@@ -674,10 +677,11 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
674677 const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
675678 const uint32_t row_idx = bx;
676679
677- extern __shared__ __align__ (alignof (
678- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
680+ extern __shared__ __align__ (
681+ alignof (SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
682+ uint8_t smem_sampling[];
679683 auto & temp_storage = reinterpret_cast <
680- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem );
684+ SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling );
681685
682686 uint32_t pos = 0 ;
683687 for (pos = 0 ; pos < num_speculative_tokens; ++pos) {
0 commit comments