11/*
2- * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+ * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33 * SPDX-License-Identifier: Apache-2.0
44 *
55 * Licensed under the Apache License, Version 2.0 (the "License");
2020#include " tensorrt_llm/batch_manager/assignReqSeqSlots.h"
2121#include " tensorrt_llm/batch_manager/capacityScheduler.h"
2222#include " tensorrt_llm/batch_manager/createNewDecoderRequests.h"
23- #include " tensorrt_llm/batch_manager/handleContextLogits.h"
24- #include " tensorrt_llm/batch_manager/handleGenerationLogits.h"
2523#include " tensorrt_llm/batch_manager/kvCacheManager.h"
2624#include " tensorrt_llm/batch_manager/llmRequest.h"
2725#include " tensorrt_llm/batch_manager/logitsPostProcessor.h"
28- #include " tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h"
2926#include " tensorrt_llm/batch_manager/medusaBuffers.h"
3027#include " tensorrt_llm/batch_manager/microBatchScheduler.h"
3128#include " tensorrt_llm/batch_manager/pauseRequests.h"
3229#include " tensorrt_llm/batch_manager/peftCacheManager.h"
33- #include " tensorrt_llm/batch_manager/runtimeBuffers.h"
34- #include " tensorrt_llm/batch_manager/updateDecoderBuffers.h"
3530#include " tensorrt_llm/runtime/decoderState.h"
3631#include " tensorrt_llm/runtime/torch.h"
3732#include " tensorrt_llm/runtime/torchView.h"
@@ -96,48 +91,6 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
9691 py::arg (" generation_requests" ), py::arg (" model_config" ), py::arg (" cross_kv_cache_manager" ) = std::nullopt )
9792 .def (" name" , [](AllocateKvCache const &) { return AllocateKvCache::name; });
9893
99- py::class_<HandleContextLogits>(m, HandleContextLogits::name)
100- .def (py::init ())
101- .def (
102- " __call__" ,
103- [](HandleContextLogits const & self, DecoderInputBuffers& inputBuffers, RequestVector const & contextRequests,
104- at::Tensor const & logits, std::vector<tr::SizeType32> const & numContextLogitsVec,
105- tr::ModelConfig const & modelConfig, tr::BufferManager const & manager,
106- OptionalRef<MedusaBuffers> medusaBuffers = std::nullopt )
107- {
108- return self (inputBuffers, contextRequests, tr::TorchView::of (logits), numContextLogitsVec, modelConfig,
109- manager, medusaBuffers);
110- },
111- py::arg (" decoder_input_buffers" ), py::arg (" context_requests" ), py::arg (" logits" ),
112- py::arg (" num_context_logits" ), py::arg (" model_config" ), py::arg (" buffer_manager" ),
113- py::arg (" medusa_buffers" ) = std::nullopt )
114- .def (" name" , [](HandleContextLogits const &) { return HandleContextLogits::name; });
115-
116- py::class_<HandleGenerationLogits>(m, HandleGenerationLogits::name)
117- .def (py::init ())
118- .def (
119- " __call__" ,
120- [](HandleGenerationLogits const & self, DecoderInputBuffers& inputBuffers,
121- RequestVector const & generationRequests, at::Tensor const & logits, tr::SizeType32 logitsIndex,
122- tr::ModelConfig const & modelConfig, tr::BufferManager const & manager,
123- OptionalRef<RuntimeBuffers> genRuntimeBuffers = std::nullopt ,
124- OptionalRef<MedusaBuffers> medusaBuffers = std::nullopt )
125- {
126- self (inputBuffers, generationRequests, tr::TorchView::of (logits), logitsIndex, modelConfig, manager,
127- genRuntimeBuffers, medusaBuffers);
128- },
129- py::arg (" decoder_input_buffers" ), py::arg (" generation_requests" ), py::arg (" logits" ),
130- py::arg (" logits_index" ), py::arg (" model_config" ), py::arg (" buffer_manager" ),
131- py::arg (" gen_runtime_buffers" ) = std::nullopt , py::arg (" medusa_buffers" ) = std::nullopt )
132- .def (" name" , [](HandleGenerationLogits const &) { return HandleGenerationLogits::name; });
133-
134- py::class_<MakeDecodingBatchInputOutput>(m, MakeDecodingBatchInputOutput::name)
135- .def (py::init ())
136- .def (" __call__" , &MakeDecodingBatchInputOutput::operator (), py::arg (" decoder_input_buffers" ),
137- py::arg (" decoder_state" ), py::arg (" model_config" ), py::arg (" max_num_sequences" ),
138- py::arg (" fused_runtime_buffers" ) = std::nullopt )
139- .def (" name" , [](MakeDecodingBatchInputOutput const &) { return MakeDecodingBatchInputOutput::name; });
140-
14194 py::class_<LogitsPostProcessor>(m, LogitsPostProcessor::name)
14295 .def (py::init ())
14396 .def (" __call__" , &LogitsPostProcessor::operator (), py::arg (" decoder_input_buffers" ),
@@ -156,8 +109,9 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
156109 DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
157110 tensorrt_llm::runtime::CudaStream const & runtimeStream,
158111 tensorrt_llm::runtime::CudaStream const & decoderStream, SizeType32 maxSequenceLength,
159- SizeType32 beamWidth, OptionalRef<MedusaBuffers const > medusaBuffers = std:: nullopt )
112+ SizeType32 beamWidth)
160113 {
114+ OptionalRef<MedusaBuffers const > medusaBuffers = std::nullopt ;
161115 auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self (modelConfig,
162116 worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState,
163117 runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
@@ -168,13 +122,6 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
168122 py::arg (" model_config" ), py::arg (" world_config" ), py::arg (" decoding_config" ), py::arg (" context_requests" ),
169123 py::arg (" buffer_manager" ), py::arg (" logits_type" ), py::arg (" decoder_input_buffers" ),
170124 py::arg (" decoder_state" ), py::arg (" runtime_stream" ), py::arg (" decoder_stream" ),
171- py::arg (" max_sequence_length" ), py::arg (" beam_width" ), py::arg ( " medusa_buffers " ) = std:: nullopt )
125+ py::arg (" max_sequence_length" ), py::arg (" beam_width" ))
172126 .def (" name" , [](CreateNewDecoderRequests const &) { return CreateNewDecoderRequests::name; });
173-
174- py::class_<UpdateDecoderBuffers>(m, UpdateDecoderBuffers::name)
175- .def (py::init ())
176- .def (" __call__" , &UpdateDecoderBuffers::operator (), py::arg (" model_config" ), py::arg (" decoder_output_buffers" ),
177- py::arg (" copy_buffer_manager" ), py::arg (" decoder_state" ), py::arg (" return_log_probs" ),
178- py::arg (" decoder_finish_event" ))
179- .def (" name" , [](UpdateDecoderBuffers const &) { return UpdateDecoderBuffers::name; });
180127}
0 commit comments