Skip to content

Commit 775f022

Browse files
Funatiqlancelly
authored andcommitted
refactor: Remove unused buffers and bindings from sampler (NVIDIA#6484)
Signed-off-by: Robin Kobus <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent 285ae6e commit 775f022

File tree

15 files changed

+18
-335
lines changed

15 files changed

+18
-335
lines changed

cpp/include/tensorrt_llm/runtime/gptDecoder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include "tensorrt_llm/runtime/bufferManager.h"
2121
#include "tensorrt_llm/runtime/decodingInput.h"
2222
#include "tensorrt_llm/runtime/decodingOutput.h"
23-
#include "tensorrt_llm/runtime/request.h"
2423
#include "tensorrt_llm/runtime/samplingConfig.h"
2524

2625
#include <NvInferRuntime.h>

cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp

Lines changed: 3 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,13 @@
2020
#include "tensorrt_llm/batch_manager/assignReqSeqSlots.h"
2121
#include "tensorrt_llm/batch_manager/capacityScheduler.h"
2222
#include "tensorrt_llm/batch_manager/createNewDecoderRequests.h"
23-
#include "tensorrt_llm/batch_manager/handleContextLogits.h"
24-
#include "tensorrt_llm/batch_manager/handleGenerationLogits.h"
2523
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
2624
#include "tensorrt_llm/batch_manager/llmRequest.h"
2725
#include "tensorrt_llm/batch_manager/logitsPostProcessor.h"
28-
#include "tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h"
2926
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
3027
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
3128
#include "tensorrt_llm/batch_manager/pauseRequests.h"
3229
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
33-
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
34-
#include "tensorrt_llm/batch_manager/updateDecoderBuffers.h"
3530
#include "tensorrt_llm/nanobind/common/customCasters.h"
3631
#include "tensorrt_llm/runtime/decoderState.h"
3732
#include "tensorrt_llm/runtime/torch.h"
@@ -94,48 +89,6 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
9489
nb::arg("generation_requests"), nb::arg("model_config"), nb::arg("cross_kv_cache_manager") = std::nullopt)
9590
.def("name", [](AllocateKvCache const&) { return AllocateKvCache::name; });
9691

97-
nb::class_<HandleContextLogits>(m, HandleContextLogits::name)
98-
.def(nb::init<>())
99-
.def(
100-
"__call__",
101-
[](HandleContextLogits const& self, DecoderInputBuffers& inputBuffers, RequestVector const& contextRequests,
102-
at::Tensor const& logits, std::vector<tr::SizeType32> const& numContextLogitsVec,
103-
tr::ModelConfig const& modelConfig, tr::BufferManager const& manager,
104-
OptionalRef<MedusaBuffers> medusaBuffers = std::nullopt)
105-
{
106-
return self(inputBuffers, contextRequests, tr::TorchView::of(logits), numContextLogitsVec, modelConfig,
107-
manager, medusaBuffers);
108-
},
109-
nb::arg("decoder_input_buffers"), nb::arg("context_requests"), nb::arg("logits"),
110-
nb::arg("num_context_logits"), nb::arg("model_config"), nb::arg("buffer_manager"),
111-
nb::arg("medusa_buffers") = std::nullopt)
112-
.def("name", [](HandleContextLogits const&) { return HandleContextLogits::name; });
113-
114-
nb::class_<HandleGenerationLogits>(m, HandleGenerationLogits::name)
115-
.def(nb::init<>())
116-
.def(
117-
"__call__",
118-
[](HandleGenerationLogits const& self, DecoderInputBuffers& inputBuffers,
119-
RequestVector const& generationRequests, at::Tensor const& logits, tr::SizeType32 logitsIndex,
120-
tr::ModelConfig const& modelConfig, tr::BufferManager const& manager,
121-
OptionalRef<RuntimeBuffers> genRuntimeBuffers = std::nullopt,
122-
OptionalRef<MedusaBuffers> medusaBuffers = std::nullopt)
123-
{
124-
self(inputBuffers, generationRequests, tr::TorchView::of(logits), logitsIndex, modelConfig, manager,
125-
genRuntimeBuffers, medusaBuffers);
126-
},
127-
nb::arg("decoder_input_buffers"), nb::arg("generation_requests"), nb::arg("logits"),
128-
nb::arg("logits_index"), nb::arg("model_config"), nb::arg("buffer_manager"),
129-
nb::arg("gen_runtime_buffers") = std::nullopt, nb::arg("medusa_buffers") = std::nullopt)
130-
.def("name", [](HandleGenerationLogits const&) { return HandleGenerationLogits::name; });
131-
132-
nb::class_<MakeDecodingBatchInputOutput>(m, MakeDecodingBatchInputOutput::name)
133-
.def(nb::init<>())
134-
.def("__call__", &MakeDecodingBatchInputOutput::operator(), nb::arg("decoder_input_buffers"),
135-
nb::arg("decoder_state"), nb::arg("model_config"), nb::arg("max_num_sequences"),
136-
nb::arg("fused_runtime_buffers") = std::nullopt)
137-
.def("name", [](MakeDecodingBatchInputOutput const&) { return MakeDecodingBatchInputOutput::name; });
138-
13992
nb::class_<LogitsPostProcessor>(m, LogitsPostProcessor::name)
14093
.def(nb::init<>())
14194
.def("__call__", &LogitsPostProcessor::operator(), nb::arg("decoder_input_buffers"),
@@ -154,8 +107,9 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
154107
DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
155108
tensorrt_llm::runtime::CudaStream const& runtimeStream,
156109
tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength,
157-
SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt)
110+
SizeType32 beamWidth)
158111
{
112+
OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt;
159113
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self(modelConfig,
160114
worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState,
161115
runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
@@ -166,13 +120,6 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
166120
nb::arg("model_config"), nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("context_requests"),
167121
nb::arg("buffer_manager"), nb::arg("logits_type"), nb::arg("decoder_input_buffers"),
168122
nb::arg("decoder_state"), nb::arg("runtime_stream"), nb::arg("decoder_stream"),
169-
nb::arg("max_sequence_length"), nb::arg("beam_width"), nb::arg("medusa_buffers") = std::nullopt)
123+
nb::arg("max_sequence_length"), nb::arg("beam_width"))
170124
.def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; });
171-
172-
nb::class_<UpdateDecoderBuffers>(m, UpdateDecoderBuffers::name)
173-
.def(nb::init<>())
174-
.def("__call__", &UpdateDecoderBuffers::operator(), nb::arg("model_config"), nb::arg("decoder_output_buffers"),
175-
nb::arg("copy_buffer_manager"), nb::arg("decoder_state"), nb::arg("return_log_probs"),
176-
nb::arg("decoder_finish_event"))
177-
.def("name", [](UpdateDecoderBuffers const&) { return UpdateDecoderBuffers::name; });
178125
}

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@
2020

2121
#include "tensorrt_llm/batch_manager/common.h"
2222
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
23-
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
2423
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
2524
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
2625
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
27-
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
2826
#include "tensorrt_llm/batch_manager/sequenceSlotManager.h"
2927
#include "tensorrt_llm/nanobind/common/bindTypes.h"
3028
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
@@ -419,13 +417,6 @@ void initBindings(nb::module_& m)
419417
.def_rw("log_probs_host", &tb::SlotDecoderBuffers::logProbsHost)
420418
.def_rw("finish_reasons_host", &tb::SlotDecoderBuffers::finishReasonsHost);
421419

422-
nb::class_<tb::MedusaBuffers>(m, "MedusaBuffers")
423-
.def(nb::init<runtime::SizeType32, runtime::SizeType32, runtime::BufferManager const&,
424-
runtime::ModelConfig const&, runtime::WorldConfig const&, executor::DecodingConfig const&,
425-
runtime::TllmRuntime const&>(),
426-
nb::arg("max_beam_width"), nb::arg("max_seq_len"), nb::arg("buffer_manager"), nb::arg("model_config"),
427-
nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("runtime"));
428-
429420
m.def(
430421
"add_new_tokens_to_requests",
431422
[](std::vector<std::shared_ptr<tb::LlmRequest>>& requests,

cpp/tensorrt_llm/nanobind/bindings.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050

5151
namespace nb = nanobind;
5252
namespace tb = tensorrt_llm::batch_manager;
53-
namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager;
5453
namespace tpb = tensorrt_llm::nanobind::batch_manager;
5554
namespace tc = tensorrt_llm::common;
5655
namespace tr = tensorrt_llm::runtime;

cpp/tensorrt_llm/nanobind/common/customCasters.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,11 @@
2121
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
2222
#include "tensorrt_llm/common/optionalRef.h"
2323
#include "tensorrt_llm/runtime/cudaStream.h"
24-
#include "tensorrt_llm/runtime/request.h"
2524
#include "tensorrt_llm/runtime/samplingConfig.h"
2625
#include "tensorrt_llm/runtime/torch.h"
2726
#include "tensorrt_llm/runtime/torchView.h"
2827

2928
#include <ATen/DLConvertor.h>
30-
#include <deque>
31-
#include <filesystem>
3229
#include <nanobind/nanobind.h>
3330
#include <nanobind/stl/filesystem.h>
3431
#include <nanobind/stl/optional.h>
@@ -38,7 +35,8 @@
3835
#include <torch/csrc/autograd/variable.h>
3936
#include <torch/extension.h>
4037
#include <torch/torch.h>
41-
#include <vector>
38+
39+
#include <deque>
4240

4341
// Pybind requires to have a central include in order for type casters to work.
4442
// Opaque bindings add a type caster, so they have the same requirement.
@@ -47,7 +45,6 @@
4745
// Opaque bindings
4846
NB_MAKE_OPAQUE(tensorrt_llm::batch_manager::ReqIdsSet)
4947
NB_MAKE_OPAQUE(std::vector<tensorrt_llm::batch_manager::SlotDecoderBuffers>)
50-
NB_MAKE_OPAQUE(std::vector<tensorrt_llm::runtime::decoder_batch::Request>)
5148
NB_MAKE_OPAQUE(std::vector<tensorrt_llm::runtime::SamplingConfig>)
5249

5350
namespace nb = nanobind;

cpp/tensorrt_llm/nanobind/runtime/bindings.cpp

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
#include "tensorrt_llm/runtime/lookaheadBuffers.h"
3737
#include "tensorrt_llm/runtime/loraCache.h"
3838
#include "tensorrt_llm/runtime/mcastGPUBuffer.h"
39-
#include "tensorrt_llm/runtime/request.h"
4039
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
4140
#include "tensorrt_llm/runtime/tllmRuntime.h"
4241
#include "tensorrt_llm/runtime/torchView.h"
@@ -158,25 +157,6 @@ void initBindings(nb::module_& m)
158157
.def_prop_ro("logits_dtype_from_engine",
159158
[](tr::TllmRuntime& self) { return self.getEngine().getTensorDataType("logits"); });
160159

161-
nb::class_<tr::decoder_batch::Request>(m, "Request")
162-
.def(nb::init<tr::decoder_batch::Request::TensorConstPtr, tr::SizeType32, std::optional<tr::SizeType32>,
163-
std::optional<tr::SizeType32>>(),
164-
nb::arg("ids"), nb::arg("input_len"), nb::arg("max_new_tokens") = std::nullopt,
165-
nb::arg("end_id") = std::nullopt)
166-
.def_rw("ids", &tr::decoder_batch::Request::ids)
167-
.def_rw("input_len", &tr::decoder_batch::Request::inputLen)
168-
.def_rw("max_new_tokens", &tr::decoder_batch::Request::maxNewTokens)
169-
.def_rw("end_id", &tr::decoder_batch::Request::endId)
170-
.def_rw("draft_logits", &tr::decoder_batch::Request::draftLogits)
171-
.def_rw("embedding_bias", &tr::decoder_batch::Request::embeddingBias)
172-
.def_rw("bad_words_list", &tr::decoder_batch::Request::badWordsList)
173-
.def_rw("stop_words_list", &tr::decoder_batch::Request::stopWordsList)
174-
.def_rw("generated_tokens_per_engine_step", &tr::decoder_batch::Request::generatedTokensPerEngineStep)
175-
.def_rw("medusa_paths", &tr::decoder_batch::Request::medusaPaths)
176-
.def_rw("medusa_tree_ids", &tr::decoder_batch::Request::medusaTreeIds)
177-
.def_rw("lookahead_runtime_config", &tr::decoder_batch::Request::lookaheadRuntimeConfig);
178-
nb::bind_vector<std::vector<tr::decoder_batch::Request>>(m, "RequestVector");
179-
180160
nb::class_<tr::decoder_batch::Input>(m, "DecoderBatchInput")
181161
.def(nb::init<std::vector<std::vector<tr::ITensor::SharedConstPtr>>, tr::SizeType32>(), nb::arg("logits"),
182162
nb::arg("max_decoding_engine_tokens"))

cpp/tensorrt_llm/pybind/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ set(TRTLLM_PYBIND_MODULE
66
set(SRCS
77
batch_manager/algorithms.cpp
88
batch_manager/bindings.cpp
9-
batch_manager/buffers.cpp
109
batch_manager/cacheTransceiver.cpp
1110
batch_manager/kvCacheManager.cpp
1211
batch_manager/llmRequest.cpp

cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp

Lines changed: 4 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
* SPDX-License-Identifier: Apache-2.0
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,18 +20,13 @@
2020
#include "tensorrt_llm/batch_manager/assignReqSeqSlots.h"
2121
#include "tensorrt_llm/batch_manager/capacityScheduler.h"
2222
#include "tensorrt_llm/batch_manager/createNewDecoderRequests.h"
23-
#include "tensorrt_llm/batch_manager/handleContextLogits.h"
24-
#include "tensorrt_llm/batch_manager/handleGenerationLogits.h"
2523
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
2624
#include "tensorrt_llm/batch_manager/llmRequest.h"
2725
#include "tensorrt_llm/batch_manager/logitsPostProcessor.h"
28-
#include "tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h"
2926
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
3027
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
3128
#include "tensorrt_llm/batch_manager/pauseRequests.h"
3229
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
33-
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
34-
#include "tensorrt_llm/batch_manager/updateDecoderBuffers.h"
3530
#include "tensorrt_llm/runtime/decoderState.h"
3631
#include "tensorrt_llm/runtime/torch.h"
3732
#include "tensorrt_llm/runtime/torchView.h"
@@ -96,48 +91,6 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
9691
py::arg("generation_requests"), py::arg("model_config"), py::arg("cross_kv_cache_manager") = std::nullopt)
9792
.def("name", [](AllocateKvCache const&) { return AllocateKvCache::name; });
9893

99-
py::class_<HandleContextLogits>(m, HandleContextLogits::name)
100-
.def(py::init())
101-
.def(
102-
"__call__",
103-
[](HandleContextLogits const& self, DecoderInputBuffers& inputBuffers, RequestVector const& contextRequests,
104-
at::Tensor const& logits, std::vector<tr::SizeType32> const& numContextLogitsVec,
105-
tr::ModelConfig const& modelConfig, tr::BufferManager const& manager,
106-
OptionalRef<MedusaBuffers> medusaBuffers = std::nullopt)
107-
{
108-
return self(inputBuffers, contextRequests, tr::TorchView::of(logits), numContextLogitsVec, modelConfig,
109-
manager, medusaBuffers);
110-
},
111-
py::arg("decoder_input_buffers"), py::arg("context_requests"), py::arg("logits"),
112-
py::arg("num_context_logits"), py::arg("model_config"), py::arg("buffer_manager"),
113-
py::arg("medusa_buffers") = std::nullopt)
114-
.def("name", [](HandleContextLogits const&) { return HandleContextLogits::name; });
115-
116-
py::class_<HandleGenerationLogits>(m, HandleGenerationLogits::name)
117-
.def(py::init())
118-
.def(
119-
"__call__",
120-
[](HandleGenerationLogits const& self, DecoderInputBuffers& inputBuffers,
121-
RequestVector const& generationRequests, at::Tensor const& logits, tr::SizeType32 logitsIndex,
122-
tr::ModelConfig const& modelConfig, tr::BufferManager const& manager,
123-
OptionalRef<RuntimeBuffers> genRuntimeBuffers = std::nullopt,
124-
OptionalRef<MedusaBuffers> medusaBuffers = std::nullopt)
125-
{
126-
self(inputBuffers, generationRequests, tr::TorchView::of(logits), logitsIndex, modelConfig, manager,
127-
genRuntimeBuffers, medusaBuffers);
128-
},
129-
py::arg("decoder_input_buffers"), py::arg("generation_requests"), py::arg("logits"),
130-
py::arg("logits_index"), py::arg("model_config"), py::arg("buffer_manager"),
131-
py::arg("gen_runtime_buffers") = std::nullopt, py::arg("medusa_buffers") = std::nullopt)
132-
.def("name", [](HandleGenerationLogits const&) { return HandleGenerationLogits::name; });
133-
134-
py::class_<MakeDecodingBatchInputOutput>(m, MakeDecodingBatchInputOutput::name)
135-
.def(py::init())
136-
.def("__call__", &MakeDecodingBatchInputOutput::operator(), py::arg("decoder_input_buffers"),
137-
py::arg("decoder_state"), py::arg("model_config"), py::arg("max_num_sequences"),
138-
py::arg("fused_runtime_buffers") = std::nullopt)
139-
.def("name", [](MakeDecodingBatchInputOutput const&) { return MakeDecodingBatchInputOutput::name; });
140-
14194
py::class_<LogitsPostProcessor>(m, LogitsPostProcessor::name)
14295
.def(py::init())
14396
.def("__call__", &LogitsPostProcessor::operator(), py::arg("decoder_input_buffers"),
@@ -156,8 +109,9 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
156109
DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
157110
tensorrt_llm::runtime::CudaStream const& runtimeStream,
158111
tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength,
159-
SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt)
112+
SizeType32 beamWidth)
160113
{
114+
OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt;
161115
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self(modelConfig,
162116
worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState,
163117
runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
@@ -168,13 +122,6 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
168122
py::arg("model_config"), py::arg("world_config"), py::arg("decoding_config"), py::arg("context_requests"),
169123
py::arg("buffer_manager"), py::arg("logits_type"), py::arg("decoder_input_buffers"),
170124
py::arg("decoder_state"), py::arg("runtime_stream"), py::arg("decoder_stream"),
171-
py::arg("max_sequence_length"), py::arg("beam_width"), py::arg("medusa_buffers") = std::nullopt)
125+
py::arg("max_sequence_length"), py::arg("beam_width"))
172126
.def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; });
173-
174-
py::class_<UpdateDecoderBuffers>(m, UpdateDecoderBuffers::name)
175-
.def(py::init())
176-
.def("__call__", &UpdateDecoderBuffers::operator(), py::arg("model_config"), py::arg("decoder_output_buffers"),
177-
py::arg("copy_buffer_manager"), py::arg("decoder_state"), py::arg("return_log_probs"),
178-
py::arg("decoder_finish_event"))
179-
.def("name", [](UpdateDecoderBuffers const&) { return UpdateDecoderBuffers::name; });
180127
}

0 commit comments

Comments
 (0)