Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,11 @@ __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdT
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];

extern __shared__ __align__(alignof(
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
extern __shared__ __align__(
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
temp_storage.data.sampled_id = d - 1;
__syncthreads();

Expand Down Expand Up @@ -171,10 +172,11 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
const uint32_t batch_size = gridDim.x;
const uint32_t bx = blockIdx.x, tx = threadIdx.x;

extern __shared__ __align__(alignof(
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
extern __shared__ __align__(
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);

vec_t<DType, VEC_SIZE> probs_vec;
DType aggregate;
Expand Down Expand Up @@ -264,10 +266,11 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
}
const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];

extern __shared__ __align__(alignof(
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
extern __shared__ __align__(
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);

vec_t<DType, VEC_SIZE> probs_vec;
DType aggregate;
Expand Down Expand Up @@ -454,9 +457,9 @@ __global__ void TopPRenormProbKernel(DType* probs, IdType* renormed_prob, float
const uint32_t row_idx = bx;

extern __shared__ __align__(alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
uint8_t smem[];
uint8_t smem_renorm[];
auto& temp_storage =
reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem);
reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
temp_storage.data.max_val = DType(0);
vec_t<DType, VEC_SIZE> probs_vec;
DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
Expand Down Expand Up @@ -543,9 +546,9 @@ __global__ void TopKRenormProbKernel(DType* probs, IdType* renormed_prob, uint32
const uint32_t row_idx = bx;

extern __shared__ __align__(alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
uint8_t smem[];
uint8_t smem_renorm[];
auto& temp_storage =
reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem);
reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
temp_storage.data.max_val = DType(0);
vec_t<DType, VEC_SIZE> probs_vec;
DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
Expand Down Expand Up @@ -674,10 +677,11 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
const uint32_t row_idx = bx;

extern __shared__ __align__(alignof(
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
extern __shared__ __align__(
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);

uint32_t pos = 0;
for (pos = 0; pos < num_speculative_tokens; ++pos) {
Expand Down