33from  tensorrt_llm .llmapi .llm_utils  import  QuantConfig 
44from  tensorrt_llm .logger  import  logger 
55from  tensorrt_llm .quantization .mode  import  QuantAlgo 
6- from  tensorrt_llm .bench .build .dataclasses  import  ModelConfig 
6+ from  tensorrt_llm .bench .build .dataclasses  import  ModelConfig ,  NemotronHybridConfig 
77from  .utils  import  get_device_memory 
88import  math 
99
@@ -55,7 +55,11 @@ def calc_engine_setting(
5555
5656    # Each GPU in TP group has at least 1 kv head 
5757    adjusted_num_kv_heads  =  max (tp_size , model_config .num_key_value_heads )
58-     byte_per_token  =  2  *  model_config .num_hidden_layers  *  adjusted_num_kv_heads  \
58+ 
59+     logger .info (
60+         f"Number of attention layers: { model_config .num_attention_layers }  )
61+ 
62+     gb_per_token  =  2  *  model_config .num_attention_layers  *  adjusted_num_kv_heads  \
5963        *  model_config .head_size  *  byte_per_kv_elem  /  (1024  **  3 )
6064
6165    # Number of GPU used for this run. 
@@ -70,19 +74,33 @@ def calc_engine_setting(
7074                f"{ available_memory :.2f}  )
7175
7276    # Calculate max requests in KV cache based on target ISL and OSL. 
73-     kv_cache_memory  =  available_memory  *  kv_cache_gpu_mem_fraction 
74-     kv_cache_max_tokens  =  kv_cache_memory  /  byte_per_token 
75-     kv_cache_max_requests  =  kv_cache_max_tokens  /  (target_input_len  + 
76-                                                    target_output_len )
77-     logger .info (f"Estimated total KV cache memory: { kv_cache_memory :.2f}  )
77+     target_seq_len  =  target_input_len  +  target_output_len 
78+     cache_memory  =  available_memory  *  model_config .cache_memory_fraction (
79+         kv_cache_gpu_mem_fraction )
80+     gb_per_extra_cache  =  model_config .extra_model_cache_in_gb (
81+         BYTES_PER_ELEM .get (QuantAlgo .NO_QUANT ), target_seq_len )
82+     kv_cache_max_requests  =  cache_memory  /  (gb_per_token  *  target_seq_len  + 
83+                                             gb_per_extra_cache )
84+     extra_cache_memory  =  gb_per_extra_cache  *  kv_cache_max_requests 
85+     kv_cache_memory  =  cache_memory  -  extra_cache_memory 
86+     kv_cache_max_tokens  =  kv_cache_memory  /  gb_per_token 
87+ 
88+     logger .info (
89+         f"Estimated total cache memory: { cache_memory :.2f} { kv_cache_memory :.2f} { extra_cache_memory :.2f}  
90+     )
91+     logger .info (f"Estimated kv cache max tokens: { kv_cache_max_tokens :.2f}  )
7892    logger .info ("Estimated max number of requests in KV cache memory: " 
7993                f"{ kv_cache_max_requests :.2f}  )
8094
8195    # Fine-tune the max batch size and num token setting for performance. 
82-     max_batch_size , max_num_tokens  =  finetune_setting (kv_cache_max_requests ,
83-                                                       target_input_len ,
84-                                                       target_output_len ,
85-                                                       pp_size )
96+     # For mamba-attn hybrid models, we disable optimistic tuning because the mamba cache leaves less memory for the KV cache 
97+     max_batch_size , max_num_tokens  =  finetune_setting (
98+         kv_cache_max_requests ,
99+         target_input_len ,
100+         target_output_len ,
101+         pp_size ,
102+         disable_optimistic_tuning = isinstance (model_config ,
103+                                              NemotronHybridConfig ))
86104
87105    # Functional and performance 
88106    if  total_gpu_memory  <  engine_size :
@@ -107,7 +125,7 @@ def calc_engine_setting(
107125    if  kv_cache_max_requests  <  1 :
108126        raise  RuntimeError ("The amount of KV cache memory is insufficient to " 
109127                           "run this model. Please try with more GPUs." )
110-     if  kv_cache_memory  /  n_gpus  <  10.0 :
128+     if  cache_memory  /  n_gpus  <  10.0 :
111129        logger .warning (
112130            f"The KV cache memory per GPU is less than 10 GB. " 
113131            "Performance may be undesirable. Please consider using a different " 
@@ -126,6 +144,7 @@ def finetune_setting(
126144    input_len : int ,
127145    output_len : int ,
128146    pp_size : int ,
147+     disable_optimistic_tuning : bool  =  False ,
129148) ->  Tuple [int , int ]:
130149    """ Calculate and fine-tune the engine build settings (max batch size and 
131150        max num tokens). Both max batch size and max num tokens are fine-tuned 
@@ -137,6 +156,7 @@ def finetune_setting(
137156        input_len (int): Input sequence length to compile the engine. 
138157        output_len (int): Output sequence length to compile the engine. 
139158        pp_size (int): Number of pipeline parallel stages. 
159+         disable_optimistic_tuning (bool): Whether to disable optimistic tuning. 
140160
141161    Returns: 
142162        Tuple[int, int]: Tuple containing fine-tuned values for engine 
@@ -148,13 +168,16 @@ def finetune_setting(
148168    raw_token  =  min (raw_bs  *  (1  +  input_len  /  output_len ), 32768 )
149169
150170    # Fine-tune the max batch size. 
151-     # Set min BS to be 64. 
152-     if  raw_bs  <  256 :
153-         max_bs  =  max (64 , 32  *  math .ceil (raw_bs  /  32 ))
154-     elif  raw_bs  <  1024 :
155-         max_bs  =  128  *  math .ceil (raw_bs  /  128 )
171+     if  disable_optimistic_tuning :
172+         max_bs  =  2  *  math .floor (raw_bs  /  2 )
156173    else :
157-         max_bs  =  256  *  math .ceil (raw_bs  /  256 )
174+         # Set min BS to be 64. 
175+         if  raw_bs  <  256 :
176+             max_bs  =  max (64 , 32  *  math .ceil (raw_bs  /  32 ))
177+         elif  raw_bs  <  1024 :
178+             max_bs  =  128  *  math .ceil (raw_bs  /  128 )
179+         else :
180+             max_bs  =  256  *  math .ceil (raw_bs  /  256 )
158181
159182    # Fine-tune the max num tokens. 
160183    # Set min to 2048 to ensure Ctx/Gen overlap efficiency 
0 commit comments