11/* 
2-  * SPDX-FileCopyrightText: Copyright (c) 2022-2024  NVIDIA CORPORATION & AFFILIATES. All rights reserved. 
2+  * SPDX-FileCopyrightText: Copyright (c) 2022-2025  NVIDIA CORPORATION & AFFILIATES. All rights reserved. 
33 * SPDX-License-Identifier: Apache-2.0 
44 * 
55 * Licensed under the Apache License, Version 2.0 (the "License"); 
2020#include  " tensorrt_llm/batch_manager/assignReqSeqSlots.h" 
2121#include  " tensorrt_llm/batch_manager/capacityScheduler.h" 
2222#include  " tensorrt_llm/batch_manager/createNewDecoderRequests.h" 
23- #include  " tensorrt_llm/batch_manager/handleContextLogits.h" 
24- #include  " tensorrt_llm/batch_manager/handleGenerationLogits.h" 
2523#include  " tensorrt_llm/batch_manager/kvCacheManager.h" 
2624#include  " tensorrt_llm/batch_manager/llmRequest.h" 
2725#include  " tensorrt_llm/batch_manager/logitsPostProcessor.h" 
28- #include  " tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h" 
2926#include  " tensorrt_llm/batch_manager/medusaBuffers.h" 
3027#include  " tensorrt_llm/batch_manager/microBatchScheduler.h" 
3128#include  " tensorrt_llm/batch_manager/pauseRequests.h" 
3229#include  " tensorrt_llm/batch_manager/peftCacheManager.h" 
33- #include  " tensorrt_llm/batch_manager/runtimeBuffers.h" 
34- #include  " tensorrt_llm/batch_manager/updateDecoderBuffers.h" 
3530#include  " tensorrt_llm/runtime/decoderState.h" 
3631#include  " tensorrt_llm/runtime/torch.h" 
3732#include  " tensorrt_llm/runtime/torchView.h" 
@@ -96,48 +91,6 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
9691            py::arg (" generation_requests" py::arg (" model_config" py::arg (" cross_kv_cache_manager" nullopt )
9792        .def (" name" const &) { return  AllocateKvCache::name; });
9893
99-     py::class_<HandleContextLogits>(m, HandleContextLogits::name)
100-         .def (py::init ())
101-         .def (
102-             " __call__" 
103-             [](HandleContextLogits const & self, DecoderInputBuffers& inputBuffers, RequestVector const & contextRequests,
104-                 at::Tensor const & logits, std::vector<tr::SizeType32> const & numContextLogitsVec,
105-                 tr::ModelConfig const & modelConfig, tr::BufferManager const & manager,
106-                 OptionalRef<MedusaBuffers> medusaBuffers = std::nullopt )
107-             {
108-                 return  self (inputBuffers, contextRequests, tr::TorchView::of (logits), numContextLogitsVec, modelConfig,
109-                     manager, medusaBuffers);
110-             },
111-             py::arg (" decoder_input_buffers" py::arg (" context_requests" py::arg (" logits" 
112-             py::arg (" num_context_logits" py::arg (" model_config" py::arg (" buffer_manager" 
113-             py::arg (" medusa_buffers" nullopt )
114-         .def (" name" const &) { return  HandleContextLogits::name; });
115- 
116-     py::class_<HandleGenerationLogits>(m, HandleGenerationLogits::name)
117-         .def (py::init ())
118-         .def (
119-             " __call__" 
120-             [](HandleGenerationLogits const & self, DecoderInputBuffers& inputBuffers,
121-                 RequestVector const & generationRequests, at::Tensor const & logits, tr::SizeType32 logitsIndex,
122-                 tr::ModelConfig const & modelConfig, tr::BufferManager const & manager,
123-                 OptionalRef<RuntimeBuffers> genRuntimeBuffers = std::nullopt ,
124-                 OptionalRef<MedusaBuffers> medusaBuffers = std::nullopt )
125-             {
126-                 self (inputBuffers, generationRequests, tr::TorchView::of (logits), logitsIndex, modelConfig, manager,
127-                     genRuntimeBuffers, medusaBuffers);
128-             },
129-             py::arg (" decoder_input_buffers" py::arg (" generation_requests" py::arg (" logits" 
130-             py::arg (" logits_index" py::arg (" model_config" py::arg (" buffer_manager" 
131-             py::arg (" gen_runtime_buffers" nullopt , py::arg (" medusa_buffers" nullopt )
132-         .def (" name" const &) { return  HandleGenerationLogits::name; });
133- 
134-     py::class_<MakeDecodingBatchInputOutput>(m, MakeDecodingBatchInputOutput::name)
135-         .def (py::init ())
136-         .def (" __call__" MakeDecodingBatchInputOutput::operator (), py::arg (" decoder_input_buffers" 
137-             py::arg (" decoder_state" py::arg (" model_config" py::arg (" max_num_sequences" 
138-             py::arg (" fused_runtime_buffers" nullopt )
139-         .def (" name" const &) { return  MakeDecodingBatchInputOutput::name; });
140- 
14194    py::class_<LogitsPostProcessor>(m, LogitsPostProcessor::name)
14295        .def (py::init ())
14396        .def (" __call__" LogitsPostProcessor::operator (), py::arg (" decoder_input_buffers" 
@@ -156,8 +109,9 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
156109                DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
157110                tensorrt_llm::runtime::CudaStream const & runtimeStream,
158111                tensorrt_llm::runtime::CudaStream const & decoderStream, SizeType32 maxSequenceLength,
159-                 SizeType32 beamWidth, OptionalRef<MedusaBuffers  const > medusaBuffers = std:: nullopt )
112+                 SizeType32 beamWidth)
160113            {
114+                 OptionalRef<MedusaBuffers const > medusaBuffers = std::nullopt ;
161115                auto  [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self (modelConfig,
162116                    worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState,
163117                    runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
@@ -168,13 +122,6 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
168122            py::arg (" model_config" py::arg (" world_config" py::arg (" decoding_config" py::arg (" context_requests" 
169123            py::arg (" buffer_manager" py::arg (" logits_type" py::arg (" decoder_input_buffers" 
170124            py::arg (" decoder_state" py::arg (" runtime_stream" py::arg (" decoder_stream" 
171-             py::arg (" max_sequence_length" py::arg (" beam_width" ,  py::arg ( " medusa_buffers " ) = std:: nullopt )
125+             py::arg (" max_sequence_length" py::arg (" beam_width" 
172126        .def (" name" const &) { return  CreateNewDecoderRequests::name; });
173- 
174-     py::class_<UpdateDecoderBuffers>(m, UpdateDecoderBuffers::name)
175-         .def (py::init ())
176-         .def (" __call__" UpdateDecoderBuffers::operator (), py::arg (" model_config" py::arg (" decoder_output_buffers" 
177-             py::arg (" copy_buffer_manager" py::arg (" decoder_state" py::arg (" return_log_probs" 
178-             py::arg (" decoder_finish_event" 
179-         .def (" name" const &) { return  UpdateDecoderBuffers::name; });
180127}
0 commit comments