Skip to content

Commit

Permalink
Update TensorRT-LLM (#1168)
Browse files Browse the repository at this point in the history
* Update TensorRT-LLM

---------

Co-authored-by: Bhuvanesh Sridharan <[email protected]>
Co-authored-by: Shixiaowei02 <[email protected]>
  • Loading branch information
3 people authored Feb 27, 2024
1 parent e4e09da commit 655524d
Show file tree
Hide file tree
Showing 229 changed files with 4,791 additions and 4,078 deletions.
22 changes: 12 additions & 10 deletions benchmarks/cpp/gptManagerBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ class GptServer
, mStaticEmulatedTimeoutMs(staticEmulatedTimeoutMs)
, mActiveCount(0)
{
ReturnBatchManagerStatsCallback iterationDataCallback = [this, &logIterationData](std::string const& log)
ReturnBatchManagerStatsCallback iterationDataCallback = [this, logIterationData](std::string const& log)
{
if (logIterationData)
{
Expand Down Expand Up @@ -563,16 +563,18 @@ class GptServer
{
auto numNewWorkItems = static_cast<int64_t>(rval.size());
comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);

std::vector<int64_t> packed;
for (auto const& ir : rval)
if (numNewWorkItems > 0)
{
auto vpacked = ir->serialize();
packed.push_back(static_cast<int64_t>(vpacked.size()));
packed.insert(
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
std::vector<int64_t> packed;
for (auto const& ir : rval)
{
auto vpacked = ir->serialize();
packed.push_back(static_cast<int64_t>(vpacked.size()));
packed.insert(
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
}
comm.bcast(packed, 0);
}
comm.bcast(packed, 0);
}
}
else
Expand Down Expand Up @@ -791,7 +793,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
recorder->report();
recorder->writeOpMetricsToCsv();
// Send terminateReqId to terminate servers on all ranks
// Sever on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
// Server on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
gptServer->enqueue(std::make_shared<InferenceRequest>(terminateReqId));
}
// Wait until benchmarking is done and batch manager is terminated
Expand Down
13 changes: 11 additions & 2 deletions benchmarks/cpp/gptSessionBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
"benchmark on %d tokens",
maxNumTokens.value(), maxBatchSize * maxInputLength);
}
std::atomic_bool done = false;
try
{
std::atomic_bool done = false;
auto peakMemFuture = std::async(&monitorMemory, std::ref(done));
TLLM_LOG_INFO(memoryCounter.toString());

Expand Down Expand Up @@ -266,11 +266,14 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
catch (std::runtime_error& e)
{
std::size_t found = std::string(e.what()).find("out of memory");
// We need to kill the memory monitor when OOM.
done = true;

// Unexpected error; rethrow
if (found == std::string::npos)
{
throw;
TLLM_LOG_ERROR(e.what());
throw e;
}

// We can ignore the OOM exception and continue the rest of the benchmark
Expand All @@ -283,6 +286,12 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
}
continue;
}
catch (...)
{
// We need to kill memory monitor when any other issue occurs
done = true;
throw;
}
}
TLLM_LOG_INFO(memoryCounter.toString());
}
Expand Down
22 changes: 22 additions & 0 deletions benchmarks/python/allowed_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class EncDecBuildConfig:
max_decoder_input_len: Optional[int] = None
max_output_len: Optional[int] = None
builder_opt: Optional[int] = None
n_mels: Optional[int] = None

def __post_init__(self) -> None:
assert self.head_size is not None
Expand Down Expand Up @@ -1179,6 +1180,27 @@ class ModelConfig:
mamba_d_conv=4,
mamba_expand=2,
)),
"whisper_large_v3":
ModelConfig(name="whisper_large_v3",
family="whisper",
benchmark_type="enc_dec",
build_config=EncDecBuildConfig(
num_layers=32,
num_decoder_layers=32,
num_heads=20,
head_size=64,
ffn_hidden_size=5120,
hidden_size=1280,
vocab_size=51866,
hidden_act="gelu",
n_positions=448,
n_mels=128,
max_batch_size=8,
max_encoder_input_len=1500,
max_decoder_input_len=1,
max_output_len=200,
builder_opt=None,
)),
}


Expand Down
136 changes: 95 additions & 41 deletions benchmarks/python/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def build_gpt(args):
max_beam_width=max_beam_width)
if family in [
'opt', 'bloom', 'falcon', 'llama', 'internlm', 'gptneox',
'gptj', 'mamba', 'baichuan'
'gptj', 'mamba', 'baichuan', 'chatglm', 'chatglm2', 'chatglm3'
]:
tensorrt_llm_model(**inputs)
else:
Expand Down Expand Up @@ -957,6 +957,8 @@ def enc_dec_build_helper(component, config, args):
torch.cuda.set_device(runtime_rank)

family = get_model_family(args.model)
logits_dtype = 'float32'
n_mels = 0
if family == 'bart':
q_scaling = 1.0
has_attention_qkvo_bias = True
Expand All @@ -969,6 +971,19 @@ def enc_dec_build_helper(component, config, args):
layernorm_position = LayerNormPositionType.pre_layernorm if config.get(
'normalize_before', True) else LayerNormPositionType.post_layernorm
rescale_before_lm_head = False
elif family == 'whisper':
q_scaling = 1.0
has_position_embedding = True
relative_attention = False
has_embedding_layernorm = False
has_attention_qkvo_bias = True
has_mlp_bias = True
has_model_final_layernorm = True
layernorm_position = LayerNormPositionType.pre_layernorm
layernorm_type = LayerNormType.LayerNorm
rescale_before_lm_head = False
logits_dtype = str_dtype_to_trt(args.dtype)
n_mels = config['n_mels']
else:
q_scaling = 1 / config['head_size']**.5
has_attention_qkvo_bias = False
Expand All @@ -984,6 +999,9 @@ def enc_dec_build_helper(component, config, args):
else:
rescale_before_lm_head = False

quant_mode, _, _ = get_quant_mode(args.quantization)
use_weight_only = quant_mode.is_weight_only()

builder = Builder()
builder_config = builder.create_builder_config(
name=args.model,
Expand Down Expand Up @@ -1011,6 +1029,10 @@ def enc_dec_build_helper(component, config, args):
has_token_type_embedding=False, # by default
strongly_typed=False, # by default
gather_all_token_logits=False, # by default
int8=(quant_mode.has_act_and_weight_quant()
or quant_mode.is_int8_weight_only()),
quant_mode=quant_mode,
n_mels=n_mels,
)

# build engine
Expand All @@ -1024,34 +1046,45 @@ def enc_dec_build_helper(component, config, args):
fp16_clamping = (args.dtype == 'float16') and ('t5' in family)

if component == 'encoder':
tllm_model = tensorrt_llm.models.EncoderModel(
num_layers=config['num_layers'],
num_heads=config['num_heads'],
num_kv_heads=config['num_heads'],
head_size=config['head_size'],
hidden_size=config['hidden_size'],
ffn_hidden_size=config['ffn_hidden_size'],
vocab_size=config['vocab_size'],
max_position_embeddings=config.get('n_positions', 0),
has_position_embedding=has_position_embedding,
relative_attention=relative_attention,
max_distance=config.get('max_distance', 0),
num_buckets=config.get('num_buckets', 0),
has_embedding_layernorm=has_embedding_layernorm,
has_embedding_scale=config.get('has_embedding_scale', False),
q_scaling=q_scaling,
has_attention_qkvo_bias=has_attention_qkvo_bias,
has_mlp_bias=has_mlp_bias,
has_model_final_layernorm=has_model_final_layernorm,
layernorm_eps=1e-6,
layernorm_position=layernorm_position,
layernorm_type=layernorm_type,
hidden_act=config['hidden_act'],
dtype=dtype,
use_parallel_embedding=False, # by default
embedding_sharding_dim=0, # by default
mapping=mapping,
fp16_clamping=fp16_clamping)
if family == 'whisper':
tllm_model = tensorrt_llm.models.WhisperEncoder(
n_mels=config['n_mels'],
n_ctx=1500, # n_audio_ctx
n_state=config['hidden_size'],
n_head=config['num_heads'],
n_layer=config['num_layers'],
dtype=dtype)
if use_weight_only:
tllm_model = quantize_model(tllm_model, quant_mode)
else:
tllm_model = tensorrt_llm.models.EncoderModel(
num_layers=config['num_layers'],
num_heads=config['num_heads'],
num_kv_heads=config['num_heads'],
head_size=config['head_size'],
hidden_size=config['hidden_size'],
ffn_hidden_size=config['ffn_hidden_size'],
vocab_size=config['vocab_size'],
max_position_embeddings=config.get('n_positions', 0),
has_position_embedding=has_position_embedding,
relative_attention=relative_attention,
max_distance=config.get('max_distance', 0),
num_buckets=config.get('num_buckets', 0),
has_embedding_layernorm=has_embedding_layernorm,
has_embedding_scale=config.get('has_embedding_scale', False),
q_scaling=q_scaling,
has_attention_qkvo_bias=has_attention_qkvo_bias,
has_mlp_bias=has_mlp_bias,
has_model_final_layernorm=has_model_final_layernorm,
layernorm_eps=1e-6,
layernorm_position=layernorm_position,
layernorm_type=layernorm_type,
hidden_act=config['hidden_act'],
dtype=dtype,
use_parallel_embedding=False, # by default
embedding_sharding_dim=0, # by default
mapping=mapping,
fp16_clamping=fp16_clamping)
elif component == 'decoder':
tllm_model = tensorrt_llm.models.DecoderModel(
num_layers=config['num_layers'],
Expand Down Expand Up @@ -1084,8 +1117,10 @@ def enc_dec_build_helper(component, config, args):
embedding_sharding_dim=0, # by default
mapping=mapping,
rescale_before_lm_head=rescale_before_lm_head,
logits_dtype='float32', # by default
logits_dtype=logits_dtype, # by default
fp16_clamping=fp16_clamping)
if use_weight_only and family == 'whisper':
tllm_model = quantize_model(tllm_model, quant_mode)

# Module -> Network
engine_name = get_engine_name(args.model, args.dtype, world_size,
Expand All @@ -1099,6 +1134,12 @@ def enc_dec_build_helper(component, config, args):
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
if use_weight_only:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype=args.dtype)
elif args.mode == 'ootb-except-mha':
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)

if world_size > 1:
network.plugin_config.set_nccl_plugin(
Expand All @@ -1110,18 +1151,31 @@ def enc_dec_build_helper(component, config, args):

# Forward
if component == 'encoder':
inputs = tllm_model.prepare_inputs(
max_batch_size=config['max_batch_size'],
max_input_len=config['max_encoder_input_len'],
)
if family == 'whisper':
inputs = tllm_model.prepare_inputs(
max_batch_size=config['max_batch_size'], )
else:
inputs = tllm_model.prepare_inputs(
max_batch_size=config['max_batch_size'],
max_input_len=config['max_encoder_input_len'],
)
elif component == 'decoder':
inputs = tllm_model.prepare_inputs(
max_batch_size=config['max_batch_size'],
max_beam_width=config['max_beam_width'],
max_decoder_input_len=config['max_decoder_input_len'],
max_new_tokens=config['max_output_len'],
max_encoder_input_len=config['max_encoder_input_len'],
)
if family == 'whisper':
inputs = tllm_model.prepare_inputs(
max_batch_size=config['max_batch_size'],
max_beam_width=config['max_beam_width'],
max_decoder_input_len=config['max_decoder_input_len'],
max_new_tokens=config['max_output_len'],
max_encoder_input_len=1500, # n_audio_ctx
)
else:
inputs = tllm_model.prepare_inputs(
max_batch_size=config['max_batch_size'],
max_beam_width=config['max_beam_width'],
max_decoder_input_len=config['max_decoder_input_len'],
max_new_tokens=config['max_output_len'],
max_encoder_input_len=config['max_encoder_input_len'],
)

tllm_model(*inputs)

Expand Down
Loading

0 comments on commit 655524d

Please sign in to comment.