Skip to content

Commit 5205cfe

Browse files
authored
Merge branch 'release/1.1' into user/yunruis/fix_bug_5606268
2 parents 32ab7a6 + a1d9126 commit 5205cfe

File tree

18 files changed

+137
-43
lines changed

18 files changed

+137
-43
lines changed

cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,6 @@ void moeSetSignalForCpuStageForTest(MoeLoadBalanceSingleLayerSignal* signal)
131131
template <typename TYPE>
132132
__global__ void zeroExpertTokenCountKernel(MoeLoadBalanceMetaInfo metaInfo, int* const enabled, int* expertTokenCount)
133133
{
134-
if (*enabled == 0)
135-
{
136-
return;
137-
}
138134
TYPE oldExpertTokenCount = {0};
139135
int* expertTokenCountPtr = expertTokenCount + metaInfo.expertCount * blockIdx.x;
140136
TYPE* typedExpertTokenCountPtr = reinterpret_cast<TYPE*>(expertTokenCountPtr);

cpp/tensorrt_llm/thop/fp8BlockScalingGemm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ torch::Tensor fp8_block_scaling_bmm(torch::Tensor const& mat1, torch::Tensor con
343343

344344
TORCH_LIBRARY_FRAGMENT(trtllm, m)
345345
{
346-
m.def("fp8_block_scaling_gemm(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale) -> Tensor");
346+
m.def("fp8_block_scaling_gemm_impl(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale) -> Tensor");
347347
m.def(
348348
"fp8_block_scaling_bmm(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale, ScalarType? "
349349
"out_dtype=None) -> Tensor");
@@ -357,7 +357,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
357357

358358
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
359359
{
360-
m.impl("fp8_block_scaling_gemm", &torch_ext::fp8_block_scaling_gemm);
360+
m.impl("fp8_block_scaling_gemm_impl", &torch_ext::fp8_block_scaling_gemm);
361361
m.impl("fp8_block_scaling_bmm", &torch_ext::fp8_block_scaling_bmm);
362362
m.impl("fp8_block_scaling_bmm_out", &torch_ext::fp8_block_scaling_bmm_out);
363363
m.impl("fp8_block_scaling_moe_gemm", &torch_ext::fp8_block_scaling_moe_gemm);

docker/Dockerfile.multi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ COPY cpp cpp
127127
COPY scripts scripts
128128
COPY tensorrt_llm tensorrt_llm
129129
COPY 3rdparty 3rdparty
130-
COPY .gitmodules setup.py requirements.txt requirements-dev.txt constraints.txt ./
130+
COPY .gitmodules setup.py requirements.txt requirements-dev.txt constraints.txt README.md ./
131131

132132
# Create cache directories for pip and ccache
133133
RUN mkdir -p /root/.cache/pip /root/.cache/ccache

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,7 @@ def is_nvfp4_output_kernel_available(
536536
@dataclass(kw_only=True)
537537
class TrtllmAttentionMetadata(AttentionMetadata):
538538
workspace: Optional[torch.Tensor] = None
539+
cuda_graph_workspace: Optional[torch.Tensor] = None
539540

540541
# TrtllmAttention needs to know the beam width to access to the cache indirection buffer,
541542
# when beam search is enabled.
@@ -693,6 +694,14 @@ def get_empty_like(like_tensor: torch.Tensor,
693694
device='cuda',
694695
dtype=torch.int8,
695696
)
697+
698+
if self.cuda_graph_workspace is None:
699+
self.cuda_graph_workspace = torch.empty(
700+
(0, ),
701+
device='cuda',
702+
dtype=torch.int8,
703+
)
704+
696705
if self.kv_cache_manager is not None:
697706
self.kv_cache_block_offsets = get_empty(
698707
[
@@ -1276,8 +1285,9 @@ def forward(
12761285
host_kv_cache_pool_pointers=metadata.host_kv_cache_pool_pointers,
12771286
host_kv_cache_pool_mapping=metadata.host_kv_cache_pool_mapping,
12781287
block_ids_per_seq=metadata.block_ids_per_seq,
1279-
workspace=metadata.
1280-
workspace, # re-enable it, if pass None to it, fp8 mla will encounter invalid cuda free issue.
1288+
# re-enable it, if pass None to it, fp8 mla will encounter invalid cuda free issue.
1289+
workspace=metadata.workspace
1290+
if not metadata.is_cuda_graph else metadata.cuda_graph_workspace,
12811291
cache_indirection=metadata.cache_indirection,
12821292
kv_scale_orig_quant=self.kv_scale_orig_quant,
12831293
kv_scale_quant_orig=self.kv_scale_quant_orig,

tensorrt_llm/_torch/autotuner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -410,22 +410,24 @@ def choose_one(
410410
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
411411
)
412412
else:
413-
logger.warning(
413+
logger.warning_once(
414414
f"[Autotuner] No valid runner/tactic was found for custom_op={custom_op}, input_shapes={input_shapes}. "
415415
f"At least one valid (runner, tactic) pair is required. "
416416
f"If get_valid_tactics is intended to return empty list, please ensure that this profile is not valid for the custom_op "
417-
f"and should not occurs during the inference stage, or fallback tactic is implemented. Otherwise, the the tuning process will crash."
417+
f"and should not occurs during the inference stage, or fallback tactic is implemented. Otherwise, the the tuning process will crash.",
418+
key=custom_op,
418419
)
419420
new_tuning_failure_occured = new_tuning_failure_occured or has_tuning_failure_occured
420421

421422
# If failed profiling tactics occurs, log the error.
422423
if new_tuning_failure_occured:
423-
logger.warning(
424+
logger.warning_once(
424425
f"[Autotuner] New tuning error occurs:"
425426
f"Total failed profiling tactics occurs: {len(self.stats.failed_profiling_count[custom_op])} for custom_op={custom_op}. "
426427
f"This will not block the tuning process. "
427428
f"Please set TLLM_LOG_LEVEL=WARNING to find out when the tactic profiling fails. "
428-
f"Set TLLM_LOG_LEVEL=DEBUG to get more details of the failures."
429+
f"Set TLLM_LOG_LEVEL=DEBUG to get more details of the failures.",
430+
key=custom_op,
429431
)
430432

431433
# Get the best runner and tactic from cache

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def _(scores, scores_with_bias, n_group, topk_group, topk,
150150
def _(input, force_applying_finalize):
151151
return torch.empty_like(input)
152152

153-
@torch.library.register_fake("trtllm::fp8_block_scaling_gemm")
153+
@torch.library.register_fake("trtllm::fp8_block_scaling_gemm_impl")
154154
def _(a, b, a_scale, b_scale):
155155
m = a.shape[0]
156156
n = b.shape[0]

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,7 @@ def _(
903903
return input.new_empty((M, N), dtype=output_dtype)
904904

905905

906-
def fp8_swap_ab_gen_tuning_buckets(x: int):
906+
def deep_gemm_gen_tuning_buckets(x: int):
907907
buckets = tuple(range(8, 128, 8))
908908
if x >= 128:
909909
buckets += tuple(range(128, x, 128))
@@ -913,7 +913,7 @@ def fp8_swap_ab_gen_tuning_buckets(x: int):
913913
class fp8SwapABGemmRunner(TunableRunner):
914914
tuning_config = TuningConfig(
915915
dynamic_tensor_specs=(DynamicTensorSpec(
916-
0, 0, fp8_swap_ab_gen_tuning_buckets), ),
916+
0, 0, deep_gemm_gen_tuning_buckets), ),
917917
tune_max_num_tokens=4096,
918918
)
919919

@@ -992,6 +992,78 @@ def _(
992992
return input.new_empty((input.size(0), weight.size(0)), dtype=output_dtype)
993993

994994

995+
# The runner is used to trigger deepgemm jit during autotune.
996+
class Fp8BlockScalingGemmRunner(TunableRunner):
997+
tuning_config = TuningConfig(
998+
dynamic_tensor_specs=(DynamicTensorSpec(
999+
0, 0, deep_gemm_gen_tuning_buckets), ),
1000+
tune_max_num_tokens=4096,
1001+
)
1002+
1003+
def get_valid_tactics(
1004+
self,
1005+
inputs: List[torch.Tensor],
1006+
profile: OptimizationProfile,
1007+
) -> List[int]:
1008+
return [0]
1009+
1010+
def forward(
1011+
self,
1012+
inputs: List[torch.Tensor],
1013+
tactic: int = -1,
1014+
) -> torch.Tensor:
1015+
a, b, a_scale, b_scale = inputs
1016+
return torch.ops.trtllm.fp8_block_scaling_gemm_impl(
1017+
a, b, a_scale, b_scale)
1018+
1019+
1020+
def get_fp8_block_scaling_gemm_constraint_spec():
1021+
# The implementation aligns with the fp8_quantize_1x128 custom op.
1022+
def fp8_quantize_1x128_sm90_constrant(inputs: List[List[int]]):
1023+
pad_m = fp4_utils.pad_up(inputs[0][0], 4)
1024+
blocked_n = (inputs[0][1] + 127) // 128
1025+
return fp4_utils.pad_up(pad_m * blocked_n * 4, 128) // 4
1026+
1027+
if get_sm_version() >= 100:
1028+
return (ConstraintSpec(2, 1, lambda inputs: inputs[0][0]), )
1029+
else:
1030+
return (ConstraintSpec(2, 0, fp8_quantize_1x128_sm90_constrant), )
1031+
1032+
1033+
@torch.library.custom_op("trtllm::fp8_block_scaling_gemm", mutates_args=())
1034+
def fp8_block_scaling_gemm(
1035+
a: torch.Tensor,
1036+
b: torch.Tensor,
1037+
a_scale: torch.Tensor,
1038+
b_scale: torch.Tensor,
1039+
tune_max_num_tokens: int = 4096,
1040+
) -> torch.Tensor:
1041+
tuner = AutoTuner.get()
1042+
fp8_block_scaling_gemm_runner = Fp8BlockScalingGemmRunner()
1043+
Fp8BlockScalingGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
1044+
1045+
Fp8BlockScalingGemmRunner.tuning_config.constraint_specs = get_fp8_block_scaling_gemm_constraint_spec(
1046+
)
1047+
1048+
_, best_tactic = tuner.choose_one(
1049+
"trtllm::fp8_block_scaling_gemm",
1050+
[fp8_block_scaling_gemm_runner],
1051+
Fp8BlockScalingGemmRunner.tuning_config,
1052+
[a, b, a_scale, b_scale],
1053+
)
1054+
return fp8_block_scaling_gemm_runner(
1055+
inputs=[a, b, a_scale, b_scale],
1056+
tactic=best_tactic,
1057+
)
1058+
1059+
1060+
@fp8_block_scaling_gemm.register_fake
1061+
def _(a, b, a_scale, b_scale, tune_max_num_tokens=4096):
1062+
m = a.shape[0]
1063+
n = b.shape[0]
1064+
return a.new_empty((m, n), dtype=torch.bfloat16)
1065+
1066+
9951067
@torch.library.custom_op("trtllm::silu_and_mul", mutates_args=())
9961068
def silu_and_mul(x: torch.Tensor,
9971069
scale: Optional[torch.Tensor] = None,

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -531,14 +531,17 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
531531
return result
532532

533533
def get_warmup_request(num_tokens: int,
534-
num_gen_tokens: int,
534+
num_gen_requests: int,
535535
least_requests: bool = True):
536536
available_tokens = kv_cache_manager.get_num_available_tokens(
537537
self.runtime_draft_len)
538538
available_blocks = kv_cache_manager.get_num_free_blocks()
539539
if num_tokens > self.max_num_tokens or num_tokens > available_tokens:
540540
return None
541-
if num_gen_tokens > self.batch_size:
541+
if num_gen_requests > self.batch_size:
542+
return None
543+
num_gen_tokens = num_gen_requests * (1 + self.runtime_draft_len)
544+
if num_gen_tokens > self.max_num_tokens:
542545
return None
543546

544547
num_extra_decoding_steps = get_num_extra_decoding_steps()
@@ -548,7 +551,8 @@ def get_warmup_request(num_tokens: int,
548551
# during warmup.
549552
return None
550553

551-
num_ctx_tokens = num_tokens - num_gen_tokens
554+
num_ctx_tokens = num_tokens - num_gen_requests * (
555+
1 + self.runtime_draft_len)
552556
num_ctx_requests = 0
553557
ctx_requests = []
554558
gen_requests = []
@@ -557,7 +561,7 @@ def get_warmup_request(num_tokens: int,
557561
num_full_seqs = 0
558562
num_left_over_tokens = 0
559563

560-
max_context_requests = self.batch_size - num_gen_tokens
564+
max_context_requests = self.batch_size - num_gen_requests
561565
if max_context_requests * max_seq_len < num_ctx_tokens:
562566
return None
563567

@@ -572,7 +576,7 @@ def get_warmup_request(num_tokens: int,
572576

573577
else:
574578
max_bs = min(num_ctx_tokens,
575-
self.batch_size - num_gen_tokens)
579+
self.batch_size - num_gen_requests)
576580
if num_ctx_tokens % max_bs == 0:
577581
num_full_seqs = max_bs
578582
else:
@@ -583,13 +587,13 @@ def get_warmup_request(num_tokens: int,
583587
> 0 else 0)
584588

585589
# We do not have enough batch to fill the request
586-
if num_ctx_requests + num_gen_tokens > self.batch_size:
590+
if num_ctx_requests + num_gen_requests > self.batch_size:
587591
return None
588592

589593
blocks_to_use = num_full_seqs * math.ceil(
590594
max_seq_len / kv_cache_manager.tokens_per_block) + math.ceil(
591595
num_left_over_tokens /
592-
kv_cache_manager.tokens_per_block) + num_gen_tokens
596+
kv_cache_manager.tokens_per_block) + num_gen_requests
593597

594598
if blocks_to_use > available_blocks:
595599
return None
@@ -604,25 +608,29 @@ def get_warmup_request(num_tokens: int,
604608
token_nums=ctx_token_nums,
605609
is_gen=False,
606610
max_num_draft_tokens=self.runtime_draft_len,
607-
use_mrope=self.use_mrope)
611+
use_mrope=self.use_mrope,
612+
max_beam_width=self.max_beam_width,
613+
num_extra_decoding_steps=num_extra_decoding_steps)
608614

609615
if spec_resource_manager is not None:
610616
spec_resource_manager.add_dummy_requests(
611617
request_ids=list(range(num_ctx_requests)))
612618

613-
if num_gen_tokens > 0:
619+
if num_gen_requests > 0:
614620
gen_requests = kv_cache_manager.add_dummy_requests(
615621
list(
616622
range(num_ctx_requests,
617-
num_ctx_requests + num_gen_tokens)),
618-
token_nums=[1] * num_gen_tokens,
623+
num_ctx_requests + num_gen_requests)),
624+
token_nums=[1] * num_gen_requests,
619625
is_gen=True,
620626
max_num_draft_tokens=self.max_draft_len,
621-
use_mrope=self.use_mrope)
627+
use_mrope=self.use_mrope,
628+
max_beam_width=self.max_beam_width,
629+
num_extra_decoding_steps=num_extra_decoding_steps)
622630
if spec_resource_manager is not None:
623631
spec_resource_manager.add_dummy_requests(request_ids=list(
624632
range(num_ctx_requests, num_ctx_requests +
625-
num_gen_tokens)))
633+
num_gen_requests)))
626634

627635
result = ScheduledRequests()
628636
result.context_requests = ctx_requests
@@ -655,15 +663,18 @@ def release_batch(result: ScheduledRequests | None):
655663
return
656664

657665
def general_warmup(reverse: bool = False):
666+
max_batch_size = min(
667+
self.batch_size,
668+
curr_max_num_tokens // (1 + self.runtime_draft_len))
658669
warmup_requests = set([
659670
(1, 1), # Specialize for 1 token.
660-
(self.batch_size,
661-
self.batch_size), # max_batch_size, pure generation
671+
(max_batch_size,
672+
max_batch_size), # max_batch_size, pure generation
662673
(2, 0), # Non-one, pure context
663674
(curr_max_num_tokens, 0), # max_num_tokens, pure context
664675
])
665-
if reverse:
666-
warmup_requests = sorted(list(warmup_requests), reverse=reverse)
676+
677+
warmup_requests = sorted(list(warmup_requests), reverse=reverse)
667678

668679
for warmup_num_tokens, warmup_num_gen_tokens in warmup_requests:
669680
with release_batch(
@@ -817,6 +828,7 @@ def _update_draft_inference_state(is_first_draft: bool,
817828
# Also, we run a general warmup from large to small to make sure that blocks are allocated well.
818829
# The cudagraph and piecewise cuda graph capture calls torch.cuda.empty_cache() and block may already
819830
# be freed even we calls general_warmup for torch compile.
831+
# Also the additional warmup helps trigger the runtime jit to avoid runtime jit overhead.
820832
general_warmup(reverse=True)
821833

822834
# Set the value back to the original value

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,9 +470,9 @@ def _update_target_inputs_with_draft_tokens(
470470
continue
471471

472472
# Get the index of the draft/target tokens in the device tensor
473-
draft_idx = req_idx if self.use_static_draft_loop else request.py_batch_idx
473+
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
474474
target_idx = req_id_to_old_request[
475-
request.py_request_id].py_batch_idx
475+
request.py_request_id].py_seq_slot
476476
target_inputs.new_tokens[draft_position + 1:draft_position +
477477
draft_length + 1, target_idx,
478478
0] = draft_tensors[0:draft_length,

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,6 @@ def test_auto_dtype(self, disable_overlap_scheduler):
404404
task.evaluate(llm)
405405

406406
@pytest.mark.skip_less_device(2)
407-
@skip_pre_hopper
408407
def test_ngram(self):
409408
speculative_decoding_config = {
410409
"decoding_type": "NGram",

0 commit comments

Comments
 (0)