diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index f29d47f4bf..32e9f4dfc9 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -137,10 +137,11 @@ __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdT const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx]; - extern __shared__ __align__(alignof( - SamplingTempStorage)) uint8_t smem[]; + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; auto& temp_storage = reinterpret_cast< - SamplingTempStorage&>(smem); + SamplingTempStorage&>(smem_sampling); temp_storage.data.sampled_id = d - 1; __syncthreads(); @@ -171,10 +172,11 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; - extern __shared__ __align__(alignof( - SamplingTempStorage)) uint8_t smem[]; + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; auto& temp_storage = reinterpret_cast< - SamplingTempStorage&>(smem); + SamplingTempStorage&>(smem_sampling); vec_t probs_vec; DType aggregate; @@ -264,10 +266,11 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, } const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx]; - extern __shared__ __align__(alignof( - SamplingTempStorage)) uint8_t smem[]; + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; auto& temp_storage = reinterpret_cast< - SamplingTempStorage&>(smem); + SamplingTempStorage&>(smem_sampling); vec_t probs_vec; DType aggregate; @@ -454,9 +457,9 @@ __global__ void TopPRenormProbKernel(DType* probs, IdType* renormed_prob, float const uint32_t row_idx = bx; extern __shared__ __align__(alignof(RenormTempStorage)) - uint8_t smem[]; + uint8_t smem_renorm[]; auto& temp_storage = - reinterpret_cast&>(smem); + reinterpret_cast&>(smem_renorm); temp_storage.data.max_val = DType(0); vec_t probs_vec; DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0 @@ -543,9 +546,9 @@ __global__ void TopKRenormProbKernel(DType* probs, IdType* renormed_prob, uint32 const uint32_t row_idx = bx; extern __shared__ __align__(alignof(RenormTempStorage)) - uint8_t smem[]; + uint8_t smem_renorm[]; auto& temp_storage = - reinterpret_cast&>(smem); + reinterpret_cast&>(smem_renorm); temp_storage.data.max_val = DType(0); vec_t probs_vec; DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0 @@ -674,10 +677,11 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; - extern __shared__ __align__(alignof( - SamplingTempStorage)) uint8_t smem[]; + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; auto& temp_storage = reinterpret_cast< - SamplingTempStorage&>(smem); + SamplingTempStorage&>(smem_sampling); uint32_t pos = 0; for (pos = 0; pos < num_speculative_tokens; ++pos) {