Skip to content

Commit 7f56208

Browse files
committed
refactor: Remove MedusaBuffers from bindings and update function signatures
- Removed MedusaBuffers class bindings from both nanobind and pybind implementations. - Updated function signatures in initBindings to remove medusaBuffers parameter, defaulting it to std::nullopt within the function body. Signed-off-by: Robin Kobus <[email protected]>
1 parent 765c942 commit 7f56208

File tree

4 files changed

+6
-20
lines changed

4 files changed

+6
-20
lines changed

cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,9 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
107107
DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
108108
tensorrt_llm::runtime::CudaStream const& runtimeStream,
109109
tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength,
110-
SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt)
110+
SizeType32 beamWidth)
111111
{
112+
OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt;
112113
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self(modelConfig,
113114
worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState,
114115
runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
@@ -119,6 +120,6 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
119120
nb::arg("model_config"), nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("context_requests"),
120121
nb::arg("buffer_manager"), nb::arg("logits_type"), nb::arg("decoder_input_buffers"),
121122
nb::arg("decoder_state"), nb::arg("runtime_stream"), nb::arg("decoder_stream"),
122-
nb::arg("max_sequence_length"), nb::arg("beam_width"), nb::arg("medusa_buffers") = std::nullopt)
123+
nb::arg("max_sequence_length"), nb::arg("beam_width"))
123124
.def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; });
124125
}

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
#include "tensorrt_llm/batch_manager/common.h"
2222
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
23-
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
2423
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
2524
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
2625
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
@@ -419,13 +418,6 @@ void initBindings(nb::module_& m)
419418
.def_rw("log_probs_host", &tb::SlotDecoderBuffers::logProbsHost)
420419
.def_rw("finish_reasons_host", &tb::SlotDecoderBuffers::finishReasonsHost);
421420

422-
nb::class_<tb::MedusaBuffers>(m, "MedusaBuffers")
423-
.def(nb::init<runtime::SizeType32, runtime::SizeType32, runtime::BufferManager const&,
424-
runtime::ModelConfig const&, runtime::WorldConfig const&, executor::DecodingConfig const&,
425-
runtime::TllmRuntime const&>(),
426-
nb::arg("max_beam_width"), nb::arg("max_seq_len"), nb::arg("buffer_manager"), nb::arg("model_config"),
427-
nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("runtime"));
428-
429421
m.def(
430422
"add_new_tokens_to_requests",
431423
[](std::vector<std::shared_ptr<tb::LlmRequest>>& requests,

cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,9 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
109109
DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
110110
tensorrt_llm::runtime::CudaStream const& runtimeStream,
111111
tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength,
112-
SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt)
112+
SizeType32 beamWidth)
113113
{
114+
OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt;
114115
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self(modelConfig,
115116
worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState,
116117
runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
@@ -121,6 +122,6 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
121122
py::arg("model_config"), py::arg("world_config"), py::arg("decoding_config"), py::arg("context_requests"),
122123
py::arg("buffer_manager"), py::arg("logits_type"), py::arg("decoder_input_buffers"),
123124
py::arg("decoder_state"), py::arg("runtime_stream"), py::arg("decoder_stream"),
124-
py::arg("max_sequence_length"), py::arg("beam_width"), py::arg("medusa_buffers") = std::nullopt)
125+
py::arg("max_sequence_length"), py::arg("beam_width"))
125126
.def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; });
126127
}

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
#include "tensorrt_llm/batch_manager/common.h"
2121
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
22-
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
2322
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
2423
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
2524
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
@@ -425,13 +424,6 @@ void initBindings(pybind11::module_& m)
425424
.def_readwrite("log_probs_host", &tb::SlotDecoderBuffers::logProbsHost)
426425
.def_readwrite("finish_reasons_host", &tb::SlotDecoderBuffers::finishReasonsHost);
427426

428-
py::class_<tb::MedusaBuffers>(m, "MedusaBuffers")
429-
.def(py::init<runtime::SizeType32, runtime::SizeType32, runtime::BufferManager const&,
430-
runtime::ModelConfig const&, runtime::WorldConfig const&, executor::DecodingConfig const&,
431-
runtime::TllmRuntime const&>(),
432-
py::arg("max_beam_width"), py::arg("max_seq_len"), py::arg("buffer_manager"), py::arg("model_config"),
433-
py::arg("world_config"), py::arg("decoding_config"), py::arg("runtime"));
434-
435427
m.def(
436428
"add_new_tokens_to_requests",
437429
[](std::vector<std::shared_ptr<tb::LlmRequest>>& requests,

0 commit comments

Comments
 (0)