Skip to content

Commit 5aa958a

Browse files
authored
[TRTLLM-5838][fix] fix max batch size and max tokens in kv cache estimations for Nemotron-H (#5371)
Signed-off-by: Tomer Asida <[email protected]>
1 parent 10e6864 commit 5aa958a

File tree

7 files changed

+145
-37
lines changed

7 files changed

+145
-37
lines changed

tensorrt_llm/_torch/model_config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import transformers
88

99
from tensorrt_llm import logger
10+
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
1011
from tensorrt_llm._utils import torch_dtype_to_binding
1112
from tensorrt_llm.bindings import LayerType as LayerTypeCpp
1213
from tensorrt_llm.functional import AllReduceStrategy
@@ -298,7 +299,7 @@ def get_bindings_model_config(self,
298299
model_config_cpp = ModelConfigCpp(
299300
vocab_size=self.pretrained_config.vocab_size,
300301
num_layers=self.pretrained_config.num_hidden_layers,
301-
num_attention_layers=self.pretrained_config.num_hidden_layers,
302+
num_attention_layers=self.get_num_attention_layers(),
302303
num_rnn_layers=0,
303304
num_heads=num_heads,
304305
hidden_size=hidden_size,
@@ -376,3 +377,9 @@ def get_layer_types(self) -> Optional[List[LayerTypeCpp]]:
376377
] * self.pretrained_config.num_hidden_layers
377378
else:
378379
return None
380+
381+
def get_num_attention_layers(self):
382+
if is_nemotron_hybrid(self.pretrained_config):
383+
return self.pretrained_config.hybrid_override_pattern.count("*")
384+
else:
385+
return self.pretrained_config.num_hidden_layers

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ def _get_cache_size_per_token(model_config: ModelConfig,
8282
) * num_key_value_heads // tp_size
8383

8484
# provide at least 1 layer to prevent division by zero cache size
85-
num_hidden_layers = max(
86-
len(mapping.pp_layers(config.num_hidden_layers)), 1)
87-
mem_per_token *= num_hidden_layers * head_dim
85+
num_attention_layers = max(
86+
len(mapping.pp_layers(model_config.get_num_attention_layers())), 1)
87+
mem_per_token *= num_attention_layers * head_dim
8888
# K and V
8989
mem_per_token *= kv_factor
9090
return mem_per_token

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ def __init__(
818818
device=device,
819819
dtype=torch.int32)
820820

821-
def prepare_mamba_cache_blocks(self, request_ids: List[int]):
821+
def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
822822
state_indices = []
823823
for r in request_ids:
824824
# cache hit
@@ -834,23 +834,21 @@ def prepare_mamba_cache_blocks(self, request_ids: List[int]):
834834
self.state_indices[:len(state_indices)] = torch.as_tensor(
835835
state_indices, dtype=torch.int32, device=self.ssm_states.device)
836836

837-
def free_mamba_cache_blocks(self, request_id: int):
838-
if request_id in self.mamba_cache_index:
839-
block = self.mamba_cache_index.pop(request_id)
840-
self.mamba_cache_free_blocks.append(block)
841-
842-
def prepare_mamba_resources(self, scheduled_batch: ScheduledRequests):
837+
def prepare_resources(self, scheduled_batch: ScheduledRequests):
843838
context_ids = [
844839
i.py_request_id for i in scheduled_batch.context_requests
845840
]
846841
generation_ids = [
847842
i.py_request_id for i in scheduled_batch.generation_requests
848843
]
849844
request_ids = context_ids + generation_ids
850-
self.prepare_mamba_cache_blocks(request_ids)
845+
self._prepare_mamba_cache_blocks(request_ids)
851846

852-
def free_mamba_resources(self, request: LlmRequest):
853-
self.free_mamba_cache_blocks(request.py_request_id)
847+
def free_resources(self, request: LlmRequest):
848+
request_id = request.py_request_id
849+
if request_id in self.mamba_cache_index:
850+
block = self.mamba_cache_index.pop(request_id)
851+
self.mamba_cache_free_blocks.append(block)
854852

855853
def get_state_indices(self) -> torch.Tensor:
856854
return self.state_indices
@@ -863,6 +861,13 @@ def get_ssm_states(self, layer_idx: int) -> torch.Tensor:
863861
layer_offset = self.mamba_layer_offsets[layer_idx]
864862
return self.ssm_states[layer_offset]
865863

864+
def shutdown(self):
865+
# release tensor memory, keeping python references as tensors
866+
self.conv_states = torch.tensor([])
867+
self.ssm_states = torch.tensor([])
868+
self.state_indices = torch.tensor([])
869+
torch.cuda.empty_cache()
870+
866871

867872
class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
868873

@@ -933,12 +938,16 @@ def __init__(
933938
)
934939

935940
def prepare_resources(self, scheduled_batch: ScheduledRequests):
936-
self.prepare_mamba_resources(scheduled_batch)
937-
super().prepare_resources(scheduled_batch)
941+
MambaCacheManager.prepare_resources(self, scheduled_batch)
942+
KVCacheManager.prepare_resources(self, scheduled_batch)
938943

939944
def free_resources(self, request: LlmRequest):
940-
self.free_mamba_resources(request)
941-
super().free_resources(request)
945+
MambaCacheManager.free_resources(self, request)
946+
KVCacheManager.free_resources(self, request)
947+
948+
def shutdown(self):
949+
MambaCacheManager.shutdown(self)
950+
KVCacheManager.shutdown(self)
942951

943952

944953
class SlotManager:

tensorrt_llm/bench/benchmark/utils/general.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
130130
params.get("pp"),
131131
dataset_metadata.avg_isl,
132132
dataset_metadata.avg_osl,
133+
params.get("kv_cache_free_gpu_mem_fraction"),
133134
)
134135

135136
logger.info(

tensorrt_llm/bench/build/build.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
2+
from transformers import AutoConfig
23

34
from pathlib import Path
45
from typing import Tuple, get_args
56
import click
67
from click_option_group import AllOptionGroup, optgroup
78

9+
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
810
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
911
from tensorrt_llm.bench.utils.data import create_dataset_from_stream, initialize_tokenizer
1012
from tensorrt_llm.bench.utils import VALID_QUANT_ALGOS
@@ -13,7 +15,7 @@
1315
from tensorrt_llm.llmapi.llm_utils import QuantConfig
1416
from tensorrt_llm.logger import logger
1517
from tensorrt_llm.quantization.mode import QuantAlgo
16-
from tensorrt_llm.bench.build.dataclasses import ModelConfig
18+
from tensorrt_llm.bench.build.dataclasses import ModelConfig, NemotronHybridConfig
1719
from tensorrt_llm.bench.build.tuning import calc_engine_setting
1820

1921
TUNED_QUANTS = {
@@ -31,6 +33,7 @@ def get_benchmark_engine_settings(
3133
pp_size: int,
3234
target_input_len: int,
3335
target_output_len: int,
36+
kv_cache_gpu_mem_fraction: float = 0.95,
3437
) -> Tuple[int, int]:
3538
""" Retrieve benchmark settings for a specific model + configuration.
3639
@@ -58,6 +61,7 @@ def get_benchmark_engine_settings(
5861
pp_size,
5962
target_input_len,
6063
target_output_len,
64+
kv_cache_gpu_mem_fraction,
6165
)
6266
else:
6367
max_batch_size = DEFAULT_MAX_BATCH_SIZE
@@ -82,6 +86,10 @@ def get_model_config(model_name: str, model_path: Path = None) -> ModelConfig:
8286
Raises:
8387
ValueError: When model is not supported.
8488
"""
89+
if is_nemotron_hybrid(
90+
AutoConfig.from_pretrained(model_path or model_name,
91+
trust_remote_code=True)):
92+
return NemotronHybridConfig.from_hf(model_name, model_path)
8593
return ModelConfig.from_hf(model_name, model_path)
8694

8795

tensorrt_llm/bench/build/dataclasses.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ class ModelConfig(BaseModel):
124124
AliasPath("text_config", "num_hidden_layers"),
125125
AliasPath("language_config", "num_hidden_layers"),
126126
))
127+
num_attention_layers: Optional[int] = Field(default=None)
127128
num_attention_heads: int = Field(validation_alias=AliasChoices(
128129
"num_attention_heads",
129130
"n_head",
@@ -148,6 +149,7 @@ class ModelConfig(BaseModel):
148149
validation_alias=AliasChoices(
149150
"head_size",
150151
"head_dim",
152+
"attention_head_dim",
151153
AliasPath("text_config", "head_dim"),
152154
))
153155
max_position_embeddings: Optional[int] = Field(
@@ -171,6 +173,8 @@ def set_values_if_none(self):
171173
self.num_key_value_heads = self.num_attention_heads
172174
if self.head_size is None:
173175
self.head_size = self.hidden_size // self.num_attention_heads
176+
if self.num_attention_layers is None:
177+
self.num_attention_layers = self.num_hidden_layers
174178
return self
175179

176180
@classmethod
@@ -194,3 +198,59 @@ def from_hf(cls, model_hf_name, hf_model_path):
194198
param_count = cls.get_param_count(model_hf_name, hf_model_path)
195199

196200
return cls(name=model_hf_name, param_count=param_count, **hf_config)
201+
202+
def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None):
203+
return 0
204+
205+
def cache_memory_fraction(self, cache_memory_fraction):
206+
return cache_memory_fraction
207+
208+
209+
class NemotronHybridConfig(ModelConfig):
210+
hybrid_override_pattern: str
211+
d_state: int = Field(validation_alias=AliasChoices(
212+
"d_state",
213+
"mamba_d_state",
214+
"ssm_state_size",
215+
))
216+
d_conv: int = Field(validation_alias=AliasChoices(
217+
"d_conv",
218+
"mamba_d_conv",
219+
"conv_kernel",
220+
))
221+
expand: int = Field(validation_alias=AliasChoices(
222+
"expand",
223+
"mamba_expand",
224+
))
225+
n_groups: int
226+
mamba_head_dim: int
227+
d_inner: Optional[int] = Field(default=None)
228+
mamba_num_heads: Optional[int] = Field(default=None)
229+
num_mamba_layers: Optional[int] = Field(default=None)
230+
231+
@model_validator(mode="after")
232+
def set_values_if_none(self):
233+
""" Set the values if cannot get values from HF config.json. """
234+
if not self.d_inner:
235+
self.d_inner = self.hidden_size * self.expand
236+
if not self.mamba_num_heads:
237+
self.mamba_num_heads = self.d_inner // self.mamba_head_dim
238+
if self.num_mamba_layers is None:
239+
self.num_mamba_layers = self.hybrid_override_pattern.count("M")
240+
if self.num_attention_layers is None:
241+
self.num_attention_layers = self.hybrid_override_pattern.count("*")
242+
243+
super().set_values_if_none()
244+
return self
245+
246+
def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None):
247+
conv_dim = self.d_inner + 2 * self.n_groups * self.d_state
248+
conv_state_elems = conv_dim * (self.d_conv - 1)
249+
ssm_state_elems = self.mamba_num_heads * self.mamba_head_dim * self.d_state
250+
gb_per_mamba_cache = bytes_per_elem * self.num_mamba_layers * (
251+
conv_state_elems + ssm_state_elems) / (1024**3)
252+
return gb_per_mamba_cache
253+
254+
def cache_memory_fraction(self, cache_memory_fraction):
255+
# Each mamba cache entry is pretty large (~50MB for 8B model), so we are more conservative when estimating the max batch size
256+
return cache_memory_fraction**2

tensorrt_llm/bench/build/tuning.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from tensorrt_llm.llmapi.llm_utils import QuantConfig
44
from tensorrt_llm.logger import logger
55
from tensorrt_llm.quantization.mode import QuantAlgo
6-
from tensorrt_llm.bench.build.dataclasses import ModelConfig
6+
from tensorrt_llm.bench.build.dataclasses import ModelConfig, NemotronHybridConfig
77
from .utils import get_device_memory
88
import math
99

@@ -55,7 +55,11 @@ def calc_engine_setting(
5555

5656
# Each GPU in TP group has at least 1 kv head
5757
adjusted_num_kv_heads = max(tp_size, model_config.num_key_value_heads)
58-
byte_per_token = 2 * model_config.num_hidden_layers * adjusted_num_kv_heads \
58+
59+
logger.info(
60+
f"Number of attention layers: {model_config.num_attention_layers}")
61+
62+
gb_per_token = 2 * model_config.num_attention_layers * adjusted_num_kv_heads \
5963
* model_config.head_size * byte_per_kv_elem / (1024 ** 3)
6064

6165
# Number of GPU used for this run.
@@ -70,19 +74,33 @@ def calc_engine_setting(
7074
f"{available_memory:.2f} GB")
7175

7276
# Calculate max requests in KV cache based on target ISL and OSL.
73-
kv_cache_memory = available_memory * kv_cache_gpu_mem_fraction
74-
kv_cache_max_tokens = kv_cache_memory / byte_per_token
75-
kv_cache_max_requests = kv_cache_max_tokens / (target_input_len +
76-
target_output_len)
77-
logger.info(f"Estimated total KV cache memory: {kv_cache_memory:.2f} GB")
77+
target_seq_len = target_input_len + target_output_len
78+
cache_memory = available_memory * model_config.cache_memory_fraction(
79+
kv_cache_gpu_mem_fraction)
80+
gb_per_extra_cache = model_config.extra_model_cache_in_gb(
81+
BYTES_PER_ELEM.get(QuantAlgo.NO_QUANT), target_seq_len)
82+
kv_cache_max_requests = cache_memory / (gb_per_token * target_seq_len +
83+
gb_per_extra_cache)
84+
extra_cache_memory = gb_per_extra_cache * kv_cache_max_requests
85+
kv_cache_memory = cache_memory - extra_cache_memory
86+
kv_cache_max_tokens = kv_cache_memory / gb_per_token
87+
88+
logger.info(
89+
f"Estimated total cache memory: {cache_memory:.2f} GB. KV cache: {kv_cache_memory:.2f} GB, Extra cache: {extra_cache_memory:.2f} GB"
90+
)
91+
logger.info(f"Estimated kv cache max tokens: {kv_cache_max_tokens:.2f}")
7892
logger.info("Estimated max number of requests in KV cache memory: "
7993
f"{kv_cache_max_requests:.2f}")
8094

8195
# Fine-tune the max batch size and num token setting for performance.
82-
max_batch_size, max_num_tokens = finetune_setting(kv_cache_max_requests,
83-
target_input_len,
84-
target_output_len,
85-
pp_size)
96+
# For mamba-attn hybrid models, we disable optimistic tuning because the mamba cache leaves less memory for the KV cache
97+
max_batch_size, max_num_tokens = finetune_setting(
98+
kv_cache_max_requests,
99+
target_input_len,
100+
target_output_len,
101+
pp_size,
102+
disable_optimistic_tuning=isinstance(model_config,
103+
NemotronHybridConfig))
86104

87105
# Functional and performance
88106
if total_gpu_memory < engine_size:
@@ -107,7 +125,7 @@ def calc_engine_setting(
107125
if kv_cache_max_requests < 1:
108126
raise RuntimeError("The amount of KV cache memory is insufficient to "
109127
"run this model. Please try with more GPUs.")
110-
if kv_cache_memory / n_gpus < 10.0:
128+
if cache_memory / n_gpus < 10.0:
111129
logger.warning(
112130
f"The KV cache memory per GPU is less than 10 GB. "
113131
"Performance may be undesirable. Please consider using a different "
@@ -126,6 +144,7 @@ def finetune_setting(
126144
input_len: int,
127145
output_len: int,
128146
pp_size: int,
147+
disable_optimistic_tuning: bool = False,
129148
) -> Tuple[int, int]:
130149
""" Calculate and fine-tune the engine build settings (max batch size and
131150
max num tokens). Both max batch size and max num tokens are fine-tuned
@@ -137,6 +156,7 @@ def finetune_setting(
137156
input_len (int): Input sequence length to compile the engine.
138157
output_len (int): Output sequence length to compile the engine.
139158
pp_size (int): Number of pipeline parallel stages.
159+
disable_optimistic_tuning (bool): Whether to disable optimistic tuning.
140160
141161
Returns:
142162
Tuple[int, int]: Tuple containing fine-tuned values for engine
@@ -148,13 +168,16 @@ def finetune_setting(
148168
raw_token = min(raw_bs * (1 + input_len / output_len), 32768)
149169

150170
# Fine-tune the max batch size.
151-
# Set min BS to be 64.
152-
if raw_bs < 256:
153-
max_bs = max(64, 32 * math.ceil(raw_bs / 32))
154-
elif raw_bs < 1024:
155-
max_bs = 128 * math.ceil(raw_bs / 128)
171+
if disable_optimistic_tuning:
172+
max_bs = 2 * math.floor(raw_bs / 2)
156173
else:
157-
max_bs = 256 * math.ceil(raw_bs / 256)
174+
# Set min BS to be 64.
175+
if raw_bs < 256:
176+
max_bs = max(64, 32 * math.ceil(raw_bs / 32))
177+
elif raw_bs < 1024:
178+
max_bs = 128 * math.ceil(raw_bs / 128)
179+
else:
180+
max_bs = 256 * math.ceil(raw_bs / 256)
158181

159182
# Fine-tune the max num tokens.
160183
# Set min to 2048 to ensure Ctx/Gen overlap efficiency

0 commit comments

Comments
 (0)