diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 1b3f0d5ad833..0389ce1299a8 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -349,6 +349,7 @@ endif() set(VLLM_EXT_SRC "csrc/cpu/activation.cpp" "csrc/cpu/utils.cpp" + "csrc/cpu/spec_decode_utils.cpp" "csrc/cpu/layernorm.cpp" "csrc/cpu/mla_decode.cpp" "csrc/cpu/pos_encoding.cpp" @@ -383,6 +384,7 @@ if (ENABLE_X86_ISA) "csrc/cpu/cpu_wna16.cpp" "csrc/cpu/cpu_fused_moe.cpp" "csrc/cpu/utils.cpp" + "csrc/cpu/spec_decode_utils.cpp" "csrc/cpu/cpu_attn.cpp" "csrc/cpu/dnnl_kernels.cpp" "csrc/cpu/torch_bindings.cpp" @@ -395,6 +397,7 @@ if (ENABLE_X86_ISA) set(VLLM_EXT_SRC_AVX2 "csrc/cpu/utils.cpp" + "csrc/cpu/spec_decode_utils.cpp" "csrc/cpu/cpu_attn.cpp" "csrc/cpu/torch_bindings.cpp" # TODO: Remove these files diff --git a/csrc/cpu/spec_decode_utils.cpp b/csrc/cpu/spec_decode_utils.cpp new file mode 100644 index 000000000000..a76b8bc69376 --- /dev/null +++ b/csrc/cpu/spec_decode_utils.cpp @@ -0,0 +1,409 @@ +#include "cpu_types.hpp" + +#include + +namespace cpu_utils { + +void eagle_prepare_inputs_padded_kernel_impl( + const torch::Tensor& cu_num_draft_tokens, + const torch::Tensor& valid_sampled_tokens_count, + const torch::Tensor& query_start_loc_gpu, + torch::Tensor& token_indices_to_sample, + torch::Tensor& num_rejected_tokens_gpu, const int64_t num_reqs) { + const int64_t* cu_draft_ptr = cu_num_draft_tokens.data_ptr(); + const int64_t* valid_count_ptr = + valid_sampled_tokens_count.data_ptr(); + const int32_t* query_loc_ptr = query_start_loc_gpu.data_ptr(); + int32_t* indices_out_ptr = token_indices_to_sample.data_ptr(); + int64_t* rejected_out_ptr = num_rejected_tokens_gpu.data_ptr(); + +#pragma omp parallel for + for (int64_t req_idx = 0; req_idx < num_reqs; ++req_idx) { + int64_t start_idx = req_idx == 0 ? 0 : cu_draft_ptr[req_idx - 1]; + int64_t num_draft_tokens = cu_draft_ptr[req_idx] - start_idx; + int64_t num_valid_tokens = valid_count_ptr[req_idx]; + + int64_t num_rejected = 0; + if (num_draft_tokens > 0) { + num_rejected = num_draft_tokens + 1 - num_valid_tokens; + } + + int32_t q_last_tok_idx = query_loc_ptr[req_idx + 1] - 1; + int32_t index_to_sample = q_last_tok_idx - num_rejected; + + indices_out_ptr[req_idx] = index_to_sample; + rejected_out_ptr[req_idx] = num_rejected; + } +} + +void eagle_prepare_next_token_padded_kernel_impl( + const torch::Tensor& sampled_token_ids, + const torch::Tensor& discard_request_mask, + const torch::Tensor& backup_next_token_ids, torch::Tensor& next_token_ids, + torch::Tensor& valid_sampled_tokens_count, const int64_t vocab_size, + const int64_t num_sampled_tokens_per_req, const int64_t num_reqs) { + const int64_t* sampled_ids_ptr = sampled_token_ids.data_ptr(); + const bool* discard_mask_ptr = discard_request_mask.data_ptr(); + const int64_t* backup_ids_ptr = backup_next_token_ids.data_ptr(); + int64_t* next_ids_out_ptr = next_token_ids.data_ptr(); + int64_t* valid_count_out_ptr = valid_sampled_tokens_count.data_ptr(); + + const int64_t stride = sampled_token_ids.stride(0); + +#pragma omp parallel for + for (int64_t req_idx = 0; req_idx < num_reqs; ++req_idx) { + const int64_t* row_ptr = sampled_ids_ptr + req_idx * stride; + int64_t valid_count = 0; + int64_t last_valid_token = -1; + + for (int64_t pos = 0; pos < num_sampled_tokens_per_req; ++pos) { + int64_t token = row_ptr[pos]; + if (token != -1 && token < vocab_size) { + valid_count++; + last_valid_token = token; + } + } + + bool discard = discard_mask_ptr[req_idx]; + if (discard) { + next_ids_out_ptr[req_idx] = backup_ids_ptr[req_idx]; + valid_count_out_ptr[req_idx] = 0; + } else { + next_ids_out_ptr[req_idx] = + (valid_count > 0) ? last_valid_token : backup_ids_ptr[req_idx]; + valid_count_out_ptr[req_idx] = valid_count; + } + } +} + +void eagle_step_slot_mapping_metadata_kernel_impl( + const torch::Tensor& positions, const torch::Tensor& block_table, + torch::Tensor& seq_lens, torch::Tensor& out_clamped_positions, + torch::Tensor& out_slot_mapping, const int64_t block_size, + const int64_t max_model_len, const int64_t PAD_ID) { + const int64_t batch_size = positions.size(0); + const int64_t input_batch_size = out_slot_mapping.size(0); + + const int64_t* pos_ptr = positions.data_ptr(); + const int32_t* bt_ptr = block_table.data_ptr(); + int32_t* seq_lens_ptr = seq_lens.data_ptr(); + int64_t* out_clamped_ptr = out_clamped_positions.data_ptr(); + int64_t* out_slot_ptr = out_slot_mapping.data_ptr(); + + const int64_t bt_stride = block_table.stride(0); + const int64_t n_blocks_per_req = block_table.size(1); + +#pragma omp parallel for + for (int64_t req_idx = 0; req_idx < input_batch_size; ++req_idx) { + if (req_idx >= batch_size) { + out_slot_ptr[req_idx] = PAD_ID; + continue; + } + + int64_t position = pos_ptr[req_idx]; + int64_t new_position = position + 1; + bool exceeds_max = new_position >= max_model_len; + int64_t clamped_position = exceeds_max ? 0 : new_position; + + out_clamped_ptr[req_idx] = clamped_position; + + int64_t block_number = clamped_position / block_size; + block_number = std::min(block_number, n_blocks_per_req - 1); + int32_t block_id = bt_ptr[req_idx * bt_stride + block_number]; + int64_t slot_id = block_id * block_size + (clamped_position % block_size); + out_slot_ptr[req_idx] = exceeds_max ? PAD_ID : slot_id; + + int32_t seq_len = seq_lens_ptr[req_idx]; + int32_t new_seq_len = exceeds_max ? 1 : (seq_len + 1); + new_seq_len = std::min(new_seq_len, static_cast(max_model_len)); + seq_lens_ptr[req_idx] = new_seq_len; + } +} + +void copy_and_expand_eagle_inputs_kernel_impl( + const torch::Tensor& target_token_ids, + const torch::Tensor& target_positions, const torch::Tensor& next_token_ids, + torch::Tensor& out_input_ids, torch::Tensor& out_positions, + torch::Tensor& out_is_rejected_token_mask, + torch::Tensor& out_is_masked_token_mask, + torch::Tensor& out_new_token_indices, + torch::Tensor& out_hidden_state_mapping, + const torch::Tensor& query_start_loc, const torch::Tensor& query_end_loc, + const int64_t padding_token_id, const int64_t parallel_drafting_token_id, + const int64_t total_input_tokens, + const int64_t num_padding_slots_per_request, const bool shift_input_ids) { + const int64_t num_reqs = query_end_loc.size(0); + + const int64_t* target_ids_ptr = target_token_ids.data_ptr(); + const int64_t* target_pos_ptr = target_positions.data_ptr(); + const int64_t* next_ids_ptr = next_token_ids.data_ptr(); + const int32_t* query_start_ptr = query_start_loc.data_ptr(); + const int32_t* query_end_ptr = query_end_loc.data_ptr(); + + int64_t* out_ids_ptr = out_input_ids.data_ptr(); + int64_t* out_pos_ptr = out_positions.data_ptr(); + bool* out_rej_mask_ptr = out_is_rejected_token_mask.data_ptr(); + bool* out_mask_ptr = out_is_masked_token_mask.data_ptr(); + int32_t* out_new_idx_ptr = out_new_token_indices.data_ptr(); + int32_t* out_hidden_map_ptr = out_hidden_state_mapping.data_ptr(); + +#pragma omp parallel for + for (int64_t req_idx = 0; req_idx < num_reqs; ++req_idx) { + int32_t q_start = query_start_ptr[req_idx]; + int32_t next_q_start = query_start_ptr[req_idx + 1]; + int32_t q_end = query_end_ptr[req_idx]; + + int64_t num_valid_tokens = + shift_input_ids ? (q_end - q_start) : (q_end - q_start + 1); + int64_t input_offset = shift_input_ids ? 1 : 0; + + int64_t out_start = q_start + req_idx * (num_padding_slots_per_request - + (shift_input_ids ? 1 : 0)); + int64_t num_rejected = next_q_start - q_end - 1; + int64_t total_output_tokens = + num_valid_tokens + num_padding_slots_per_request + num_rejected; + + int64_t start_pos = target_pos_ptr[q_start]; + int64_t bonus_token = next_ids_ptr[req_idx]; + + for (int64_t j = 0; j < total_output_tokens; ++j) { + int64_t out_idx = out_start + j; + bool is_valid = j < num_valid_tokens; + bool is_bonus = j == num_valid_tokens; + bool is_parallel = (j > num_valid_tokens) && + (j < num_valid_tokens + num_padding_slots_per_request); + bool is_rejected = j >= num_valid_tokens + num_padding_slots_per_request; + + int64_t in_idx = + std::min(static_cast(q_start + input_offset + j), + total_input_tokens - 1); + + int64_t token_id = padding_token_id; + if (is_valid) + token_id = target_ids_ptr[in_idx]; + else if (is_bonus) + token_id = bonus_token; + else if (is_parallel) + token_id = parallel_drafting_token_id; + + out_ids_ptr[out_idx] = token_id; + out_pos_ptr[out_idx] = is_rejected ? 0 : (start_pos + j); + out_rej_mask_ptr[out_idx] = is_rejected; + out_mask_ptr[out_idx] = is_parallel; + + if (is_bonus || is_parallel) { + int64_t new_token_local_idx = j - num_valid_tokens; + int64_t new_token_out_idx = + req_idx * num_padding_slots_per_request + new_token_local_idx; + out_new_idx_ptr[new_token_out_idx] = out_idx; + } + } + + if (shift_input_ids) { + int64_t n_input = next_q_start - q_start; + for (int64_t j = 0; j < n_input; ++j) { + out_hidden_map_ptr[q_start + j] = out_start + j; + } + } + } +} + +void rejection_greedy_sample_kernel_impl( + torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens, + const torch::Tensor& draft_token_ids, const torch::Tensor& target_argmax, + const torch::Tensor& bonus_token_ids, + const std::optional& is_greedy, const int64_t max_spec_len) { + const int64_t batch_size = cu_num_draft_tokens.size(0); + + int64_t* out_ptr = output_token_ids.data_ptr(); + const int64_t* cu_draft_ptr = cu_num_draft_tokens.data_ptr(); + const int64_t* draft_ids_ptr = draft_token_ids.data_ptr(); + const int64_t* target_argmax_ptr = target_argmax.data_ptr(); + const int64_t* bonus_ids_ptr = bonus_token_ids.data_ptr(); + const bool* greedy_ptr = + is_greedy.has_value() ? is_greedy.value().data_ptr() : nullptr; + + const int64_t out_stride = output_token_ids.stride(0); + const int64_t bonus_stride = bonus_token_ids.stride(0); + +#pragma omp parallel for + for (int64_t req_idx = 0; req_idx < batch_size; ++req_idx) { + if (greedy_ptr && !greedy_ptr[req_idx]) continue; + + int64_t start_idx = req_idx == 0 ? 0 : cu_draft_ptr[req_idx - 1]; + int64_t end_idx = cu_draft_ptr[req_idx]; + int64_t num_draft_tokens = end_idx - start_idx; + + bool rejected = false; + for (int64_t pos = 0; pos < num_draft_tokens; ++pos) { + int64_t target_id = target_argmax_ptr[start_idx + pos]; + out_ptr[req_idx * out_stride + pos] = target_id; + + if (draft_ids_ptr[start_idx + pos] != target_id) { + rejected = true; + break; + } + } + + if (!rejected) { + out_ptr[req_idx * out_stride + num_draft_tokens] = + bonus_ids_ptr[req_idx * bonus_stride]; + } + } +} + +void rejection_random_sample_kernel_impl( + torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens, + const torch::Tensor& draft_token_ids, + const std::optional& draft_probs, + const torch::Tensor& target_probs, const torch::Tensor& bonus_token_ids, + const torch::Tensor& recovered_token_ids, + const torch::Tensor& uniform_probs, + const std::optional& is_greedy, const int64_t max_spec_len, + const int64_t vocab_size, const bool no_draft_probs) { + const int64_t batch_size = cu_num_draft_tokens.size(0); + + int64_t* out_ptr = output_token_ids.data_ptr(); + const int64_t* cu_draft_ptr = cu_num_draft_tokens.data_ptr(); + const int64_t* draft_ids_ptr = draft_token_ids.data_ptr(); + const float* draft_probs_ptr = + no_draft_probs ? nullptr : draft_probs.value().data_ptr(); + const float* target_probs_ptr = target_probs.data_ptr(); + const int64_t* bonus_ids_ptr = bonus_token_ids.data_ptr(); + const int64_t* recovered_ids_ptr = recovered_token_ids.data_ptr(); + const float* uniform_probs_ptr = uniform_probs.data_ptr(); + const bool* greedy_ptr = + is_greedy.has_value() ? is_greedy.value().data_ptr() : nullptr; + + const int64_t out_stride = output_token_ids.stride(0); + const int64_t bonus_stride = bonus_token_ids.stride(0); + const int64_t target_stride = target_probs.stride(0); + const int64_t draft_probs_stride = + no_draft_probs ? 0 : draft_probs.value().stride(0); + +#pragma omp parallel for + for (int64_t req_idx = 0; req_idx < batch_size; ++req_idx) { + if (greedy_ptr && greedy_ptr[req_idx]) continue; + + int64_t start_idx = req_idx == 0 ? 0 : cu_draft_ptr[req_idx - 1]; + int64_t end_idx = cu_draft_ptr[req_idx]; + int64_t num_draft_tokens = end_idx - start_idx; + + bool rejected = false; + for (int64_t pos = 0; pos < num_draft_tokens; ++pos) { + int64_t token_idx = start_idx + pos; + int64_t draft_id = draft_ids_ptr[token_idx]; + + float p = target_probs_ptr[token_idx * target_stride + draft_id]; + float q = + no_draft_probs + ? 1.0f + : draft_probs_ptr[token_idx * draft_probs_stride + draft_id]; + float uniform_p = uniform_probs_ptr[token_idx]; + + float ratio = (q > 0.0f) ? (p / q) : 0.0f; + + if (ratio >= uniform_p) { + out_ptr[req_idx * out_stride + pos] = draft_id; + } else { + out_ptr[req_idx * out_stride + pos] = recovered_ids_ptr[token_idx]; + rejected = true; + break; + } + } + + if (!rejected) { + out_ptr[req_idx * out_stride + num_draft_tokens] = + bonus_ids_ptr[req_idx * bonus_stride]; + } + } +} + +void expand_kernel_impl(torch::Tensor& output, const torch::Tensor& input, + const torch::Tensor& cu_num_tokens, + const int64_t replace_from, const int64_t replace_to) { + const int64_t batch_size = cu_num_tokens.size(0); + const int64_t* cu_tokens_ptr = cu_num_tokens.data_ptr(); + + int64_t* out_ptr = output.data_ptr(); + const int64_t* in_ptr = input.data_ptr(); + +#pragma omp parallel for + for (int64_t req_idx = 0; req_idx < batch_size; ++req_idx) { + int64_t start_idx = req_idx == 0 ? 0 : cu_tokens_ptr[req_idx - 1]; + int64_t end_idx = cu_tokens_ptr[req_idx]; + int64_t val = in_ptr[req_idx]; + + if (val == replace_from) { + val = replace_to; + } + + for (int64_t i = start_idx; i < end_idx; ++i) { + out_ptr[i] = val; + } + } +} + +void sample_recovered_tokens_kernel_impl( + torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens, + const torch::Tensor& draft_token_ids, + const std::optional& draft_probs, + const torch::Tensor& target_probs, const torch::Tensor& inv_q, + const int64_t vocab_size, const bool no_draft_probs) { + const int64_t batch_size = cu_num_draft_tokens.size(0); + + int64_t* out_ptr = output_token_ids.data_ptr(); + const int64_t* cu_draft_ptr = cu_num_draft_tokens.data_ptr(); + const int64_t* draft_ids_ptr = draft_token_ids.data_ptr(); + const float* draft_probs_ptr = + no_draft_probs ? nullptr : draft_probs.value().data_ptr(); + const float* target_probs_ptr = target_probs.data_ptr(); + const float* inv_q_ptr = inv_q.data_ptr(); + + const int64_t target_stride = target_probs.stride(0); + const int64_t draft_probs_stride = + no_draft_probs ? 0 : draft_probs.value().stride(0); + const int64_t inv_q_stride = inv_q.stride(0); + +#pragma omp parallel for + for (int64_t req_idx = 0; req_idx < batch_size; ++req_idx) { + int64_t start_idx = req_idx == 0 ? 0 : cu_draft_ptr[req_idx - 1]; + int64_t end_idx = cu_draft_ptr[req_idx]; + int64_t num_draft_tokens = end_idx - start_idx; + + const float* req_inv_q = inv_q_ptr + req_idx * inv_q_stride; + + for (int64_t pos = 0; pos < num_draft_tokens; ++pos) { + int64_t token_idx = start_idx + pos; + int64_t draft_id = draft_ids_ptr[token_idx]; + + const float* token_target_probs = + target_probs_ptr + token_idx * target_stride; + const float* token_draft_probs = + no_draft_probs ? nullptr + : (draft_probs_ptr + token_idx * draft_probs_stride); + + int64_t best_id = 0; + float best_val = -1.0f; + + for (int64_t v = 0; v < vocab_size; ++v) { + float prob = token_target_probs[v]; + if (no_draft_probs) { + if (v == draft_id) prob = 0.0f; + } else { + float diff = prob - token_draft_probs[v]; + prob = diff > 0.0f ? diff : 0.0f; + } + + float val = prob * req_inv_q[v]; + if (val > best_val) { + best_val = val; + best_id = v; + } + } + out_ptr[token_idx] = best_id; + } + } +} + +} // namespace cpu_utils diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index fbc9c65241b6..fcf7064f606a 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -138,6 +138,61 @@ void compute_slot_mapping_kernel_impl(const torch::Tensor query_start_loc, torch::Tensor slot_mapping, const int64_t block_size); +namespace cpu_utils { +void eagle_prepare_inputs_padded_kernel_impl( + const torch::Tensor& cu_num_draft_tokens, + const torch::Tensor& valid_sampled_tokens_count, + const torch::Tensor& query_start_loc_gpu, + torch::Tensor& token_indices_to_sample, + torch::Tensor& num_rejected_tokens_gpu, const int64_t num_reqs); +void eagle_prepare_next_token_padded_kernel_impl( + const torch::Tensor& sampled_token_ids, + const torch::Tensor& discard_request_mask, + const torch::Tensor& backup_next_token_ids, torch::Tensor& next_token_ids, + torch::Tensor& valid_sampled_tokens_count, const int64_t vocab_size, + const int64_t num_sampled_tokens_per_req, const int64_t num_reqs); +void eagle_step_slot_mapping_metadata_kernel_impl( + const torch::Tensor& positions, const torch::Tensor& block_table, + torch::Tensor& seq_lens, torch::Tensor& out_clamped_positions, + torch::Tensor& out_slot_mapping, const int64_t block_size, + const int64_t max_model_len, const int64_t PAD_ID); +void copy_and_expand_eagle_inputs_kernel_impl( + const torch::Tensor& target_token_ids, + const torch::Tensor& target_positions, const torch::Tensor& next_token_ids, + torch::Tensor& out_input_ids, torch::Tensor& out_positions, + torch::Tensor& out_is_rejected_token_mask, + torch::Tensor& out_is_masked_token_mask, + torch::Tensor& out_new_token_indices, + torch::Tensor& out_hidden_state_mapping, + const torch::Tensor& query_start_loc, const torch::Tensor& query_end_loc, + const int64_t padding_token_id, const int64_t parallel_drafting_token_id, + const int64_t total_input_tokens, + const int64_t num_padding_slots_per_request, const bool shift_input_ids); +void rejection_greedy_sample_kernel_impl( + torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens, + const torch::Tensor& draft_token_ids, const torch::Tensor& target_argmax, + const torch::Tensor& bonus_token_ids, + const std::optional& is_greedy, const int64_t max_spec_len); +void rejection_random_sample_kernel_impl( + torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens, + const torch::Tensor& draft_token_ids, + const std::optional& draft_probs, + const torch::Tensor& target_probs, const torch::Tensor& bonus_token_ids, + const torch::Tensor& recovered_token_ids, + const torch::Tensor& uniform_probs, + const std::optional& is_greedy, const int64_t max_spec_len, + const int64_t vocab_size, const bool no_draft_probs); +void expand_kernel_impl(torch::Tensor& output, const torch::Tensor& input, + const torch::Tensor& cu_num_tokens, + const int64_t replace_from, const int64_t replace_to); +void sample_recovered_tokens_kernel_impl( + torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens, + const torch::Tensor& draft_token_ids, + const std::optional& draft_probs, + const torch::Tensor& target_probs, const torch::Tensor& inv_q, + const int64_t vocab_size, const bool no_draft_probs); +} // namespace cpu_utils + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -363,6 +418,70 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "positions, Tensor block_table, Tensor(a3!) slot_mapping, SymInt " "block_size) -> ()", &compute_slot_mapping_kernel_impl); + + // Speculative decoding kernels + ops.def( + "eagle_prepare_inputs_padded_kernel_impl(Tensor cu_num_draft_tokens, " + "Tensor valid_sampled_tokens_count, Tensor query_start_loc_gpu, " + "Tensor(a3!) token_indices_to_sample, " + "Tensor(a4!) num_rejected_tokens_gpu, " + "SymInt num_reqs) -> ()", + &cpu_utils::eagle_prepare_inputs_padded_kernel_impl); + ops.def( + "eagle_prepare_next_token_padded_kernel_impl(" + "Tensor sampled_token_ids, Tensor discard_request_mask, " + "Tensor backup_next_token_ids, Tensor(a3!) next_token_ids, " + "Tensor(a4!) valid_sampled_tokens_count, SymInt vocab_size, " + "SymInt num_sampled_tokens_per_req, SymInt num_reqs) -> ()", + &cpu_utils::eagle_prepare_next_token_padded_kernel_impl); + ops.def( + "eagle_step_slot_mapping_metadata_kernel_impl(" + "Tensor positions, Tensor block_table, Tensor(a2!) seq_lens, " + "Tensor(a3!) out_clamped_positions, Tensor(a4!) out_slot_mapping, " + "SymInt block_size, SymInt max_model_len, SymInt PAD_ID) -> ()", + &cpu_utils::eagle_step_slot_mapping_metadata_kernel_impl); + ops.def( + "copy_and_expand_eagle_inputs_kernel_impl(" + "Tensor target_token_ids, Tensor target_positions, " + "Tensor next_token_ids, Tensor(a3!) out_input_ids, " + "Tensor(a4!) out_positions, " + "Tensor(a5!) out_is_rejected_token_mask, " + "Tensor(a6!) out_is_masked_token_mask, " + "Tensor(a7!) out_new_token_indices, " + "Tensor(a8!) out_hidden_state_mapping, " + "Tensor query_start_loc, Tensor query_end_loc, " + "SymInt padding_token_id, SymInt parallel_drafting_token_id, " + "SymInt total_input_tokens, SymInt num_padding_slots_per_request, " + "bool shift_input_ids) -> ()", + &cpu_utils::copy_and_expand_eagle_inputs_kernel_impl); + ops.def( + "rejection_greedy_sample_kernel_impl(" + "Tensor(a0!) output_token_ids, Tensor cu_num_draft_tokens, " + "Tensor draft_token_ids, Tensor target_argmax, " + "Tensor bonus_token_ids, Tensor? is_greedy, " + "SymInt max_spec_len) -> ()", + &cpu_utils::rejection_greedy_sample_kernel_impl); + ops.def( + "rejection_random_sample_kernel_impl(" + "Tensor(a0!) output_token_ids, Tensor cu_num_draft_tokens, " + "Tensor draft_token_ids, Tensor? draft_probs, " + "Tensor target_probs, Tensor bonus_token_ids, " + "Tensor recovered_token_ids, Tensor uniform_probs, " + "Tensor? is_greedy, SymInt max_spec_len, SymInt vocab_size, " + "bool no_draft_probs) -> ()", + &cpu_utils::rejection_random_sample_kernel_impl); + ops.def( + "expand_kernel_impl(Tensor(a0!) output, Tensor input, " + "Tensor cu_num_tokens, SymInt replace_from, " + "SymInt replace_to) -> ()", + &cpu_utils::expand_kernel_impl); + ops.def( + "sample_recovered_tokens_kernel_impl(" + "Tensor(a0!) output_token_ids, Tensor cu_num_draft_tokens, " + "Tensor draft_token_ids, Tensor? draft_probs, " + "Tensor target_probs, Tensor inv_q, SymInt vocab_size, " + "bool no_draft_probs) -> ()", + &cpu_utils::sample_recovered_tokens_kernel_impl); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/vllm/utils/cpu_triton_utils.py b/vllm/utils/cpu_triton_utils.py index d956dde8b071..d823e8b14aba 100644 --- a/vllm/utils/cpu_triton_utils.py +++ b/vllm/utils/cpu_triton_utils.py @@ -45,3 +45,277 @@ def _compute_slot_mapping_kernel_impl( compute_slot_mapping_kernel = _FuncWrapper(_compute_slot_mapping_kernel_impl) + + +def _ensure_int64(t: torch.Tensor) -> torch.Tensor: + return t if t.dtype == torch.int64 else t.to(torch.int64) + + +def _eagle_prepare_inputs_padded_kernel_impl( + cu_num_draft_tokens, + valid_sampled_tokens_count, + query_start_loc_gpu, + token_indices_to_sample, + num_rejected_tokens_gpu, + num_reqs, +): + # C++ expects int64 for cu_num_draft_tokens, valid_sampled_tokens_count, + # and num_rejected_tokens_gpu, but Python allocates them as int32. + orig_rejected_dtype = num_rejected_tokens_gpu.dtype + rejected_i64 = ( + num_rejected_tokens_gpu + if orig_rejected_dtype == torch.int64 + else num_rejected_tokens_gpu.to(torch.int64) + ) + torch.ops._C.eagle_prepare_inputs_padded_kernel_impl( + _ensure_int64(cu_num_draft_tokens), + _ensure_int64(valid_sampled_tokens_count), + query_start_loc_gpu, + token_indices_to_sample, + rejected_i64, + num_reqs, + ) + if orig_rejected_dtype != torch.int64: + num_rejected_tokens_gpu.copy_(rejected_i64.to(orig_rejected_dtype)) + + +def _eagle_prepare_next_token_padded_kernel_impl( + sampled_token_ids, + discard_request_mask, + backup_next_token_ids, + next_token_ids, + valid_sampled_tokens_count, + vocab_size, + num_sampled_tokens_per_req, + num_reqs, + stride=None, + BLOCK_SIZE_TOKENS=None, +): + # C++ reads all integer tensors as int64_t*. Output tensors are written + # in-place so we create int64 copies, call C++, and copy back. + orig_next_dtype = next_token_ids.dtype + orig_valid_dtype = valid_sampled_tokens_count.dtype + next_i64 = _ensure_int64(next_token_ids) + valid_i64 = _ensure_int64(valid_sampled_tokens_count) + torch.ops._C.eagle_prepare_next_token_padded_kernel_impl( + _ensure_int64(sampled_token_ids), + discard_request_mask, + _ensure_int64(backup_next_token_ids), + next_i64, + valid_i64, + vocab_size, + num_sampled_tokens_per_req, + num_reqs, + ) + if orig_next_dtype != torch.int64: + next_token_ids.copy_(next_i64.to(orig_next_dtype)) + if orig_valid_dtype != torch.int64: + valid_sampled_tokens_count.copy_(valid_i64.to(orig_valid_dtype)) + + +def _eagle_step_slot_mapping_metadata_kernel_impl( + positions, + block_table, + stride, + seq_lens, + out_clamped_positions, + out_slot_mapping, + block_size, + max_model_len, + n_blocks_per_req, + PAD_ID, + batch_size=None, +): + assert batch_size is None or batch_size == positions.shape[0], ( + f"batch_size mismatch: {batch_size} vs positions.shape[0]={positions.shape[0]}" + ) + torch.ops._C.eagle_step_slot_mapping_metadata_kernel_impl( + positions, + block_table, + seq_lens, + out_clamped_positions, + out_slot_mapping, + block_size, + max_model_len, + PAD_ID, + ) + + +def _copy_and_expand_eagle_inputs_kernel_impl( + target_token_ids_ptr, + target_positions_ptr, + next_token_ids_ptr, + out_input_ids_ptr, + out_positions_ptr, + out_is_rejected_token_mask_ptr, + out_is_masked_token_mask_ptr, + out_new_token_indices_ptr, + out_hidden_state_mapping_ptr, + query_start_loc_ptr, + query_end_loc_ptr, + padding_token_id, + parallel_drafting_token_id, + total_input_tokens, + num_padding_slots_per_request, + shift_input_ids, + BLOCK_SIZE_TOKENS=None, + BLOCK_SIZE_REQS=None, +): + """Adapter between Triton kernel call convention and C++ implementation. + + The Triton kernel uses '_ptr' suffixed parameter names and compile-time + constants (BLOCK_SIZE_TOKENS, BLOCK_SIZE_REQS) which are not needed by + the C++ implementation. C++ reads token id tensors as int64_t*. + Output tensors that are int32 need copy-back after C++ writes int64. + """ + orig_ids_dtype = out_input_ids_ptr.dtype + orig_pos_dtype = out_positions_ptr.dtype + out_ids_i64 = _ensure_int64(out_input_ids_ptr) + out_pos_i64 = _ensure_int64(out_positions_ptr) + torch.ops._C.copy_and_expand_eagle_inputs_kernel_impl( + _ensure_int64(target_token_ids_ptr), + _ensure_int64(target_positions_ptr), + _ensure_int64(next_token_ids_ptr), + out_ids_i64, + out_pos_i64, + out_is_rejected_token_mask_ptr, + out_is_masked_token_mask_ptr, + out_new_token_indices_ptr, + out_hidden_state_mapping_ptr, + query_start_loc_ptr, + query_end_loc_ptr, + padding_token_id, + parallel_drafting_token_id, + total_input_tokens, + num_padding_slots_per_request, + shift_input_ids, + ) + if orig_ids_dtype != torch.int64: + out_input_ids_ptr.copy_(out_ids_i64.to(orig_ids_dtype)) + if orig_pos_dtype != torch.int64: + out_positions_ptr.copy_(out_pos_i64.to(orig_pos_dtype)) + + +def _rejection_greedy_sample_kernel_impl( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + is_greedy, + max_spec_len, +): + # C++ kernel expects int64 for all integer tensors. + orig_dtype = output_token_ids.dtype + output_token_ids_i64 = _ensure_int64(output_token_ids) + torch.ops._C.rejection_greedy_sample_kernel_impl( + output_token_ids_i64, + _ensure_int64(cu_num_draft_tokens), + _ensure_int64(draft_token_ids), + _ensure_int64(target_argmax), + _ensure_int64(bonus_token_ids), + is_greedy, + max_spec_len, + ) + if orig_dtype != torch.int64: + output_token_ids.copy_(output_token_ids_i64.to(orig_dtype)) + + +def _rejection_random_sample_kernel_impl( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + NO_DRAFT_PROBS=False, +): + # C++ kernel expects int64 for all integer tensors and float32 for probs. + # uniform_probs is intentionally float64 in Python to avoid exact-zero + # samples; cast to float32 here for C++ compatibility. + orig_dtype = output_token_ids.dtype + output_token_ids_i64 = _ensure_int64(output_token_ids) + torch.ops._C.rejection_random_sample_kernel_impl( + output_token_ids_i64, + _ensure_int64(cu_num_draft_tokens), + _ensure_int64(draft_token_ids), + draft_probs, + target_probs, + _ensure_int64(bonus_token_ids), + _ensure_int64(recovered_token_ids), + uniform_probs.to(torch.float32), + is_greedy, + max_spec_len, + vocab_size, + NO_DRAFT_PROBS, + ) + if orig_dtype != torch.int64: + output_token_ids.copy_(output_token_ids_i64.to(orig_dtype)) + + +def _expand_kernel_impl( + output, + input_val, + cu_num_tokens, + replace_from, + replace_to, + MAX_NUM_TOKENS=None, +): + torch.ops._C.expand_kernel_impl( + _ensure_int64(output), + _ensure_int64(input_val), + _ensure_int64(cu_num_tokens), + replace_from, + replace_to, + ) + + +def _sample_recovered_tokens_kernel_impl( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + inv_q, + vocab_size, + BLOCK_SIZE=None, + NO_DRAFT_PROBS=False, +): + # C++ reads integer tensors as int64_t*; ensure correct dtype. + orig_dtype = output_token_ids.dtype + output_i64 = _ensure_int64(output_token_ids) + torch.ops._C.sample_recovered_tokens_kernel_impl( + output_i64, + _ensure_int64(cu_num_draft_tokens), + _ensure_int64(draft_token_ids), + draft_probs, + target_probs, + inv_q, + vocab_size, + NO_DRAFT_PROBS, + ) + if orig_dtype != torch.int64: + output_token_ids.copy_(output_i64.to(orig_dtype)) + + +eagle_prepare_inputs_padded_kernel = _FuncWrapper( + _eagle_prepare_inputs_padded_kernel_impl +) +eagle_prepare_next_token_padded_kernel = _FuncWrapper( + _eagle_prepare_next_token_padded_kernel_impl +) +copy_and_expand_eagle_inputs_kernel = _FuncWrapper( + _copy_and_expand_eagle_inputs_kernel_impl +) +eagle_step_slot_mapping_metadata_kernel = _FuncWrapper( + _eagle_step_slot_mapping_metadata_kernel_impl +) +rejection_greedy_sample_kernel = _FuncWrapper(_rejection_greedy_sample_kernel_impl) +rejection_random_sample_kernel = _FuncWrapper(_rejection_random_sample_kernel_impl) +expand_kernel = _FuncWrapper(_expand_kernel_impl) +sample_recovered_tokens_kernel = _FuncWrapper(_sample_recovered_tokens_kernel_impl) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 3bacda18c2a7..77e6f64bab86 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -26,7 +26,6 @@ from vllm.model_executor.models.qwen3_dflash import DFlashQwen3ForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform -from vllm.triton_utils import triton from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -48,6 +47,7 @@ eagle_prepare_next_token_padded_kernel, eagle_step_update_slot_mapping_and_metadata, extend_all_queries_by_N, + next_power_of_2, ) from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.dp_utils import coordinate_batch_across_dp @@ -689,9 +689,7 @@ def set_inputs_first_pass( max_num_tokens_per_request = ( cad.max_query_len + self.net_num_new_slots_per_request ) - BLOCK_SIZE_TOKENS = min( - 256, triton.next_power_of_2(max_num_tokens_per_request) - ) + BLOCK_SIZE_TOKENS = min(256, next_power_of_2(max_num_tokens_per_request)) num_blocks = ( max_num_tokens_per_request + BLOCK_SIZE_TOKENS - 1 ) // BLOCK_SIZE_TOKENS @@ -717,6 +715,7 @@ def set_inputs_first_pass( query_end_loc = cad.query_start_loc[1:] - 1 if num_rejected_tokens_gpu is not None: query_end_loc = query_end_loc - num_rejected_tokens_gpu + copy_and_expand_eagle_inputs_kernel[grid]( # (Padded) Inputs from the target model target_token_ids_ptr=target_token_ids, @@ -899,7 +898,7 @@ def prepare_next_token_ids_padded( grid = (batch_size,) # Find the next power of 2 for block sizes - BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens) + BLOCK_SIZE_TOKENS = next_power_of_2(num_tokens) eagle_prepare_next_token_padded_kernel[grid]( sampled_token_ids, discard_request_mask, diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index fdeab36d91d1..cdcb3e05bfad 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -11,6 +11,20 @@ PADDING_SLOT_ID = -1 +def next_power_of_2(n: int) -> int: + """Return the smallest power of 2 >= n.""" + if n <= 0: + return 1 + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + return n + 1 + + @triton.jit def eagle_step_slot_mapping_metadata_kernel( positions_ptr, # [batch_size] - current positions (1D view for M-RoPE) @@ -102,8 +116,8 @@ def eagle_step_update_slot_mapping_and_metadata( batch_size = positions_1d.shape[0] if input_batch_size is None: input_batch_size = batch_size - n_blocks_per_req = block_table_tensor.shape[1] + n_blocks_per_req = block_table_tensor.shape[1] eagle_step_slot_mapping_metadata_kernel[(input_batch_size,)]( positions_1d, block_table_tensor, diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 2d0bf28cf1ee..7067ef62a75a 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -11,6 +11,8 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.tracing import instrument +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -23,7 +25,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): super().__init__(vllm_config, device) assert device == torch.device("cpu") - assert self.speculative_config is None, "spec decode is not supported." + # Note: speculative decoding is now supported on CPU with C++ native impls self.use_cuda_graph = False self.cascade_attn_enabled = False @@ -61,6 +63,34 @@ def _postprocess_triton(self) -> None: cpu_tl.compute_slot_mapping_kernel ) + # Speculative decoding fallbacks + import vllm.v1.sample.rejection_sampler + import vllm.v1.spec_decode.eagle + import vllm.v1.spec_decode.utils + + vllm.v1.spec_decode.eagle.eagle_prepare_inputs_padded_kernel = ( + cpu_tl.eagle_prepare_inputs_padded_kernel + ) + vllm.v1.spec_decode.eagle.eagle_prepare_next_token_padded_kernel = ( + cpu_tl.eagle_prepare_next_token_padded_kernel + ) + vllm.v1.spec_decode.eagle.copy_and_expand_eagle_inputs_kernel = ( + cpu_tl.copy_and_expand_eagle_inputs_kernel + ) + vllm.v1.spec_decode.utils.eagle_step_slot_mapping_metadata_kernel = ( + cpu_tl.eagle_step_slot_mapping_metadata_kernel + ) + vllm.v1.sample.rejection_sampler.rejection_greedy_sample_kernel = ( + cpu_tl.rejection_greedy_sample_kernel + ) + vllm.v1.sample.rejection_sampler.rejection_random_sample_kernel = ( + cpu_tl.rejection_random_sample_kernel + ) + vllm.v1.sample.rejection_sampler.expand_kernel = cpu_tl.expand_kernel + vllm.v1.sample.rejection_sampler.sample_recovered_tokens_kernel = ( + cpu_tl.sample_recovered_tokens_kernel + ) + @instrument(span_name="Loading (CPU)") def load_model(self, load_dummy_weights: bool = False) -> None: if load_dummy_weights: @@ -74,6 +104,10 @@ def load_model(self, load_dummy_weights: bool = False) -> None: if self.lora_config: self.model = self.load_lora_model(self.model, self.vllm_config, self.device) + if hasattr(self, "drafter"): + logger.info_once("Loading drafter model...") + self.drafter.load_model(self.model) + def get_model(self) -> nn.Module: return self.model @@ -89,8 +123,29 @@ def warming_up_model(self) -> None: ) ) + # Warm up drafter for speculative decoding + if self.speculative_config and (self.speculative_config.uses_draft_model()): + from vllm.v1.spec_decode.draft_model import DraftModelProposer + + if isinstance(self.drafter, (DraftModelProposer)): + logger.info("Warming up drafter model...") + self.drafter.dummy_run(max(16, self.max_num_reqs)) + logger.info("Warming up done.") + def initialize_kv_cache( + self, + kv_cache_config: KVCacheConfig, + is_profiling: bool = False, + ) -> None: + super().initialize_kv_cache(kv_cache_config, is_profiling) + + if self.speculative_config: + if self.speculative_config.use_eagle(): + logger.info("EAGLE drafter KV cache initialized for CPU backend") + elif self.speculative_config.uses_draft_model(): + logger.info("Draft model KV cache initialized for CPU backend") + def _init_device_properties(self) -> None: pass @@ -102,6 +157,71 @@ def _zero_block_ids(self, block_ids: list[int]) -> None: # so stale KV cache data never affects computation. pass + # ========================================================================= + # CPU-safe overrides for speculative decoding methods + # These methods override GPU-specific implementations that use CUDA streams + # ========================================================================= + + def _copy_draft_token_ids_to_cpu( + self, scheduler_output: "SchedulerOutput", zeros_only: bool = False + ) -> None: + """CPU-safe version: no async copy needed, tensors already on CPU.""" + if self.use_async_scheduling and not ( + scheduler_output.has_structured_output_requests + or self.input_batch.sampling_metadata.output_token_ids + ): + return + self._draft_token_req_ids = self.input_batch.req_ids.copy() + + draft_token_ids: torch.Tensor = self._draft_token_ids + if not torch.is_tensor(draft_token_ids): + return + + num_reqs = draft_token_ids.shape[0] + if self.draft_token_ids_cpu is not None: + if not zeros_only: + self.draft_token_ids_cpu[:num_reqs].copy_(draft_token_ids) + else: + self.draft_token_ids_cpu[:num_reqs] = 0 + + def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]: + """CPU-safe version: no event synchronization needed.""" + if isinstance(self._draft_token_ids, list): + return self._draft_token_ids, self.input_batch.req_ids + req_ids = self._draft_token_req_ids + if req_ids is None: + return [], [] + if self.draft_token_ids_cpu is not None: + return self.draft_token_ids_cpu[: len(req_ids)].tolist(), req_ids + return [], [] + + def _copy_valid_sampled_token_count( + self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor + ) -> None: + """CPU-safe version: direct copy without CUDA streams.""" + if self.valid_sampled_token_count_cpu is None: + return + + counts = valid_sampled_tokens_count + counts_cpu = self.valid_sampled_token_count_cpu + counts_cpu[: counts.shape[0]].copy_(counts) + self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1) + + def _get_valid_sampled_token_count(self) -> list[int]: + """CPU-safe version: no event synchronization needed.""" + prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids + if prev_sampled_token_ids is None: + return [] + + counts_cpu = self.valid_sampled_token_count_cpu + if counts_cpu is None: + return [] + return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist() + + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + """CPU-safe version: direct tolist() without CUDA events.""" + return sampled_token_ids.tolist() + @contextmanager def _torch_cuda_wrapper():