Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/cuda/beam_search_scorer_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "generators.h"
#include "search.h"
#include "search_cuda.h"
#include "cuda_common.h"
#include "beam_search_scorer_cuda.cuh"
#include "beam_search_scorer_cuda.h"
#include "interface.h"
Expand All @@ -22,7 +23,7 @@ BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const GeneratorParams& parameters,
state_cpu_->not_done_count_ = parameters.search.batch_size;
state_cpu_->hypothesis_buffer_used_ = 0;
state_gpu_ = CudaMallocArray<cuda::BeamScorerState>(1);
cudaMemcpyAsync(state_gpu_.get(), state_cpu_.get(), sizeof(cuda::BeamScorerState), ::cudaMemcpyHostToDevice, stream_);
CUDA_CHECK(cudaMemcpyAsync(state_gpu_.get(), state_cpu_.get(), sizeof(cuda::BeamScorerState), ::cudaMemcpyHostToDevice, stream_));

size_t batch_beam_size = state_cpu_->batch_size_ * state_cpu_->num_beams_;

Expand Down Expand Up @@ -63,7 +64,7 @@ void BeamSearchScorer_Cuda::Process(Sequences& sequences,
next_tokens,
next_indices,
stream_);
cudaEventRecord(event_process_complete_, stream_);
CUDA_CHECK(cudaEventRecord(event_process_complete_, stream_));

cuda::LaunchBeamSearchScorer_AppendNextTokenToSequences(*state_cpu_,
*state_gpu_,
Expand All @@ -76,7 +77,7 @@ void BeamSearchScorer_Cuda::Process(Sequences& sequences,
}

bool BeamSearchScorer_Cuda::IsDoneLater() const {
cudaEventSynchronize(event_process_complete_);
CUDA_CHECK(cudaEventSynchronize(event_process_complete_));
return state_cpu_->not_done_count_ == 0;
}

Expand All @@ -90,7 +91,7 @@ DeviceSpan<int32_t> BeamSearchScorer_Cuda::GetBeamHypothesis(size_t batch_id, si
cuda_host_unique_ptr<int> hypothesis_length = CudaMallocHostArray<int>(1);
cuda_host_unique_ptr<float> hypothesis_score = CudaMallocHostArray<float>(1);
cuda::LaunchBeamSearchScorer_GetHypothesisPtr(batch_id, beam_id, beam_hyps_, hypothesis_ptr.get(), hypothesis_length.get(), hypothesis_score.get(), stream_);
CudaCheck() == cudaStreamSynchronize(stream_);
CUDA_CHECK(cudaStreamSynchronize(stream_));
std::span<int32_t> hypothesis(*hypothesis_ptr.get(), *hypothesis_length.get());
// Translate the hypothesis span back to the original device buffer span
return hypothesis_buffer_.subspan(hypothesis.data() - hypothesis_buffer_.Span().data(), hypothesis.size());
Expand Down
7 changes: 7 additions & 0 deletions src/cuda/beam_search_scorer_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <assert.h>
#include <algorithm>
#include "span.h"
#include "cuda_common.h"
#include "beam_search_scorer_cuda.cuh"

namespace Generators {
Expand Down Expand Up @@ -49,6 +50,7 @@ void LaunchInitializeBeamHypotheses(std::span<BeamHypotheses> beam_hyps,
length_penalty,
beams.data(),
num_beams);
CUDA_CHECK_LAUNCH();
}

__device__ void BeamHypotheses::Add(const int32_t* hypothesis, int hypothesis_length, float sum_logprobs) {
Expand Down Expand Up @@ -187,6 +189,7 @@ void LaunchBeamSearchScorer_Process(BeamScorerState& state_cpu,
next_scores.data(),
next_tokens.data(),
next_indices.data());
CUDA_CHECK_LAUNCH();
}

__global__ void BeamSearchScorer_AppendNextTokenToSequences1(BeamScorerState& state,
Expand Down Expand Up @@ -254,6 +257,7 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp
next_sequences.data(),
sequence_length,
next_beam_tokens.data());
CUDA_CHECK_LAUNCH();
}

__global__ void BeamSearchScorer_Finalize(BeamScorerState& state,
Expand Down Expand Up @@ -298,6 +302,7 @@ void LaunchBeamSearchScorer_Finalize(int batch_size,
beam_hyps.data(),
hypothesis_buffer.data(),
final_beam_scores.data());
CUDA_CHECK_LAUNCH();
}

__global__ void BeamSearchScorer_GetHypothesisPtr(size_t batch_id,
Expand Down Expand Up @@ -326,6 +331,7 @@ void LaunchBeamSearchScorer_GetHypothesisPtr(size_t batch_id,
hypothesis_ptr,
hypothesis_length,
hypothesis_score);
CUDA_CHECK_LAUNCH();
}

__global__ void InitScoresKernel(float* beam_scores,
Expand All @@ -347,6 +353,7 @@ void LaunchInitScoresKernel(
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
InitScoresKernel<<<gridSize, blockSize, 0, stream>>>(beam_scores, num_beams, total_elements);
CUDA_CHECK_LAUNCH();
}

} // namespace cuda
Expand Down
12 changes: 9 additions & 3 deletions src/cuda/beam_search_topk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cub/cub.cuh>
#include <limits>
#include "beam_search_topk.h"
#include "cuda_common.h"

namespace Generators {
namespace cuda {
Expand Down Expand Up @@ -124,17 +125,20 @@ void LaunchBeamSearchOnlineTopKStage2Kernel(
if (parts_per_beam <= 32) {
BeamSearchOnlineTopKStage2Kernel<T, max_k, 32><<<batch_beam_size, 32, smem_stage2_size, stream>>>(
topk_values_tmp, topk_indices_tmp, K, vocab_size, parts_per_beam, output_values, output_indices);
CUDA_CHECK_LAUNCH();
return;
}

if (parts_per_beam <= 64) {
BeamSearchOnlineTopKStage2Kernel<T, max_k, 64><<<batch_beam_size, 64, smem_stage2_size, stream>>>(
topk_values_tmp, topk_indices_tmp, K, vocab_size, parts_per_beam, output_values, output_indices);
CUDA_CHECK_LAUNCH();
return;
}

BeamSearchOnlineTopKStage2Kernel<T, max_k, 128><<<batch_beam_size, 128, smem_stage2_size, stream>>>(
topk_values_tmp, topk_indices_tmp, K, vocab_size, parts_per_beam, output_values, output_indices);
CUDA_CHECK_LAUNCH();
return;
}

Expand All @@ -160,12 +164,13 @@ void TopKLauncherMaxK(

dim3 grid(batch_beam_size, voc_parts);

cudaFuncSetAttribute(BeamSearchOnlineTopKStage1Kernel<T, max_k, kThreadBlockSize>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxL1);
CUDA_CHECK(cudaFuncSetAttribute(BeamSearchOnlineTopKStage1Kernel<T, max_k, kThreadBlockSize>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxL1));

BeamSearchOnlineTopKStage1Kernel<T, max_k, kThreadBlockSize>
<<<grid, kThreadBlockSize, 0, stream>>>(input, K, vocab_size, (vocab_size + voc_parts - 1) / voc_parts, output_values_tmp, output_indices_tmp);
CUDA_CHECK_LAUNCH();

LaunchBeamSearchOnlineTopKStage2Kernel<T, max_k>(
output_values_tmp,
Expand Down Expand Up @@ -242,6 +247,7 @@ void LaunchBatchTopKKernel(const T* topk_scores,
} else {
BatchTopKKernelLauncher(64);
}
CUDA_CHECK_LAUNCH();
}

template void LaunchBatchTopKKernel(const float* topk_scores,
Expand Down
101 changes: 40 additions & 61 deletions src/cuda/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,51 @@ namespace Generators {

cudaStream_t GetStream();

void OnCudaError(cudaError_t error);
#define CeilDiv(a, b) ((a + (b - 1)) / b)

struct CudaCheck {
void operator==(cudaError_t error) {
if (error != cudaSuccess)
OnCudaError(error);
}
class CudaError : public std::runtime_error {
public:
explicit CudaError(const std::string& msg, cudaError_t code)
: std::runtime_error(msg), code_(code) {}

cudaError_t code() const noexcept { return code_; }

private:
cudaError_t code_;
};

#define CUDA_CHECK(call) \
do { \
cudaError_t err = (call); \
if (err != cudaSuccess) { \
std::stringstream ss; \
ss << "CUDA error in " << __func__ << " at " << __FILE__ \
<< ":" << __LINE__ << " - " << cudaGetErrorString(err); \
(void)cudaGetLastError(); \
throw Generators::CudaError(ss.str(), err); \
} \
} while (0)

#define CUDA_CHECK_LAUNCH() \
do { \
cudaError_t err = cudaPeekAtLastError(); \
if (err != cudaSuccess) { \
std::stringstream ss; \
ss << "CUDA launch error in " << __func__ << " at " \
<< __FILE__ << ":" << __LINE__ << " - " \
<< cudaGetErrorString(err); \
(void)cudaGetLastError(); \
throw Generators::CudaError(ss.str(), err); \
} \
} while (0)

struct cuda_event_holder {
cuda_event_holder() {
cudaEventCreate(&v_);
CUDA_CHECK(cudaEventCreate(&v_));
}

cuda_event_holder(unsigned flags) {
cudaEventCreateWithFlags(&v_, flags);
CUDA_CHECK(cudaEventCreateWithFlags(&v_, flags));
}

~cuda_event_holder() {
Expand All @@ -48,7 +77,7 @@ struct cuda_event_holder {
struct cuda_stream_holder {
void Create() {
assert(!v_);
cudaStreamCreate(&v_);
CUDA_CHECK(cudaStreamCreate(&v_));
}

~cuda_stream_holder() {
Expand Down Expand Up @@ -81,7 +110,7 @@ using cuda_host_unique_ptr = std::unique_ptr<T, CudaHostDeleter>;
template <typename T>
cuda_host_unique_ptr<T> CudaMallocHostArray(size_t count, std::span<T>* p_span = nullptr) {
T* p;
::cudaMallocHost(&p, sizeof(T) * count);
CUDA_CHECK(::cudaMallocHost(&p, sizeof(T) * count));
if (p_span)
*p_span = std::span<T>(p, count);
return cuda_host_unique_ptr<T>{p};
Expand All @@ -99,60 +128,10 @@ using cuda_unique_ptr = std::unique_ptr<T, CudaDeleter>;
template <typename T>
cuda_unique_ptr<T> CudaMallocArray(size_t count, std::span<T>* p_span = nullptr) {
T* p;
::cudaMalloc(&p, sizeof(T) * count);
CUDA_CHECK(::cudaMalloc(&p, sizeof(T) * count));
if (p_span)
*p_span = std::span<T>(p, count);
return cuda_unique_ptr<T>{p};
}

#define CeilDiv(a, b) ((a + (b - 1)) / b)

class CudaError : public std::runtime_error {
public:
explicit CudaError(const std::string& msg, cudaError_t code)
: std::runtime_error(msg), code_(code) {}

cudaError_t code() const noexcept { return code_; }

private:
cudaError_t code_;
};

#define CUDA_CHECK(call) \
do { \
cudaError_t err = (call); \
if (err != cudaSuccess) { \
std::stringstream ss; \
ss << "CUDA error in " << __func__ << " at " << __FILE__ \
<< ":" << __LINE__ << " - " << cudaGetErrorString(err); \
throw Generators::CudaError(ss.str(), err); \
} \
} while (0)

#ifdef NDEBUG
#define CUDA_CHECK_LAUNCH() \
do { \
cudaError_t err = cudaPeekAtLastError(); \
if (err != cudaSuccess) { \
std::stringstream ss; \
ss << "CUDA launch error in " << __func__ << " at " \
<< __FILE__ << ":" << __LINE__ << " - " \
<< cudaGetErrorString(err); \
throw Generators::CudaError(ss.str(), err); \
} \
} while (0)
#else
#define CUDA_CHECK_LAUNCH() \
do { \
cudaError_t err = cudaGetLastError(); \
if (err != cudaSuccess) { \
std::stringstream ss; \
ss << "CUDA launch error in " << __func__ << " at " \
<< __FILE__ << ":" << __LINE__ << " - " \
<< cudaGetErrorString(err); \
throw Generators::CudaError(ss.str(), err); \
} \
} while (0)
#endif

} // namespace Generators
2 changes: 2 additions & 0 deletions src/cuda/cuda_topk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <cassert>

#include "cuda_topk.h"
#include "cuda_common.h"
#include "cuda_topk_benchmark_cache.h"
#include "cuda_topk_benchmark.cuh"
#include "cuda_topk_common.cuh"
Expand Down Expand Up @@ -153,6 +154,7 @@ void TopkDataCompact::CompactOutput(int batch_size, int k, cudaStream_t stream)
dim3 block(256);
CompactStridedData<float><<<grid, block, 0, stream>>>(topk_scores, topk_scores_compact.get(), k, batch_size, topk_stride);
Comment thread
tianleiwu marked this conversation as resolved.
CompactStridedData<int><<<grid, block, 0, stream>>>(topk_indices, topk_indices_compact.get(), k, batch_size, topk_stride);
CUDA_CHECK_LAUNCH();
}

// Main dispatcher for Top-K. It implements the caching and benchmarking logic to select and run the best algorithm.
Expand Down
14 changes: 11 additions & 3 deletions src/cuda/cuda_topk_benchmark.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ static TopkAlgo BenchmarkAndSelectBestAlgo(TopkData* topk_data,
int vocab_size,
int batch_size,
int k) {
// Clear any stale CUDA errors from previous operations to prevent false failures.
// Successful CUDA API calls do NOT clear the thread-local error state, so a stale
// error (e.g., from TopkData construction or prior inference) can persist and be
// falsely detected by CUDA_CHECK_LAUNCH() inside the benchmark kernels.
Comment thread
tianleiwu marked this conversation as resolved.
cudaGetLastError();

float min_latency = std::numeric_limits<float>::max();
TopkAlgo best_algo = TopkAlgo::UNKNOWN;

Expand Down Expand Up @@ -147,9 +153,11 @@ static TopkAlgo BenchmarkAndSelectBestAlgo(TopkData* topk_data,
});
}

// Candidate: Hybrid Sort. This is a robust fallback. We benchmark it if either the cooperative
// kernels are not supported, or if the vocab size is small, where hybrid can sometimes be faster.
if (!use_iterative_sort && !use_cascaded_sort && !use_flash_convergent || vocab_size <= 4096) {
// Candidate: Hybrid Sort. This is a robust fallback. We benchmark it if the cooperative
// kernels are not supported, if their benchmarks all failed at runtime (best_algo is still
// UNKNOWN despite IsSupported returning true), or if the vocab size is small, where hybrid
// can sometimes be faster.
if (best_algo == TopkAlgo::UNKNOWN || (!use_iterative_sort && !use_cascaded_sort && !use_flash_convergent) || vocab_size <= 4096) {
if (hybrid_sort::IsSupported(batch_size, vocab_size, k)) {
BENCHMARK_KERNEL(TopkAlgo::HYBRID, [&]() {
hybrid_sort::RunTopK(topk_data, stream, scores_in, vocab_size, batch_size, k);
Expand Down
3 changes: 2 additions & 1 deletion src/cuda/cuda_topk_per_batch_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ void RunTopK(TopkData* data, cudaStream_t stream, const float* scores_in, int vo

// Populate workspace buffers with the current batch item's scores and indices.
FillInput<<<blocks_per_batch, block_size, 0, stream>>>(current_scores_in, workspace_scores, workspace_indices, vocab_size);
Comment thread
tianleiwu marked this conversation as resolved.
CUDA_CHECK_LAUNCH();
// Launch the CUB radix sort. It sorts from the workspace directly into the final output buffers.
cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_storage_bytes, workspace_scores, final_scores_out, workspace_indices, final_indices_out, vocab_size, 0, sizeof(float) * 8, stream);
CUDA_CHECK(cub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_storage_bytes, workspace_scores, final_scores_out, workspace_indices, final_indices_out, vocab_size, 0, sizeof(float) * 8, stream));
}

data->topk_scores = final_scores_buffer;
Expand Down
Loading
Loading