Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/thop/fp8BlockScalingGemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ TRTLLM_NAMESPACE_END

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("fp8_block_scaling_gemm(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale) -> Tensor");
m.def("fp8_block_scaling_gemm_impl(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale) -> Tensor");
m.def(
"fp8_block_scaling_bmm(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale, ScalarType? "
"out_dtype=None) -> Tensor");
Expand All @@ -425,7 +425,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("fp8_block_scaling_gemm", &tensorrt_llm::torch_ext::fp8_block_scaling_gemm);
m.impl("fp8_block_scaling_gemm_impl", &tensorrt_llm::torch_ext::fp8_block_scaling_gemm);
m.impl("fp8_block_scaling_bmm", &tensorrt_llm::torch_ext::fp8_block_scaling_bmm);
m.impl("fp8_block_scaling_bmm_out", &tensorrt_llm::torch_ext::fp8_block_scaling_bmm_out);
m.impl("fp8_block_scaling_moe_gemm", &tensorrt_llm::torch_ext::fp8_block_scaling_moe_gemm);
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _(logits, seq_lens, indices, next_n, index_topk):
def _(input, force_applying_finalize):
return torch.empty_like(input)

@torch.library.register_fake("trtllm::fp8_block_scaling_gemm")
@torch.library.register_fake("trtllm::fp8_block_scaling_gemm_impl")
def _(a, b, a_scale, b_scale):
m = a.shape[0]
n = b.shape[0]
Expand Down
76 changes: 74 additions & 2 deletions tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,7 @@ def _(
return input.new_empty((M, N), dtype=output_dtype)


def fp8_swap_ab_gen_tuning_buckets(x: int):
def deep_gemm_gen_tuning_buckets(x: int):
buckets = tuple(range(8, 128, 8))
if x >= 128:
buckets += tuple(range(128, x, 128))
Expand All @@ -1451,7 +1451,7 @@ def fp8_swap_ab_gen_tuning_buckets(x: int):
class fp8SwapABGemmRunner(TunableRunner):
tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, fp8_swap_ab_gen_tuning_buckets), ),
0, 0, deep_gemm_gen_tuning_buckets), ),
tune_max_num_tokens=4096,
)

Expand Down Expand Up @@ -1536,6 +1536,78 @@ def _(
return input.new_empty((input.size(0), weight.size(0)), dtype=output_dtype)


# The runner is used to trigger deepgemm jit during autotune.
class Fp8BlockScalingGemmRunner(TunableRunner):
tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, deep_gemm_gen_tuning_buckets), ),
tune_max_num_tokens=4096,
)

def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
return [0]

def forward(
self,
inputs: List[torch.Tensor],
tactic: int = -1,
) -> torch.Tensor:
a, b, a_scale, b_scale = inputs
return torch.ops.trtllm.fp8_block_scaling_gemm_impl(
a, b, a_scale, b_scale)


def get_fp8_block_scaling_gemm_constraint_spec():
# The implementation aligns with the fp8_quantize_1x128 custom op.
def fp8_quantize_1x128_sm90_constrant(inputs: List[List[int]]):
pad_m = fp4_utils.pad_up(inputs[0][0], 4)
blocked_n = (inputs[0][1] + 127) // 128
return fp4_utils.pad_up(pad_m * blocked_n * 4, 128) // 4

if get_sm_version() >= 100:
return (ConstraintSpec(2, 1, lambda inputs: inputs[0][0]), )
else:
return (ConstraintSpec(2, 0, fp8_quantize_1x128_sm90_constrant), )


@torch.library.custom_op("trtllm::fp8_block_scaling_gemm", mutates_args=())
def fp8_block_scaling_gemm(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
tune_max_num_tokens: int = 4096,
) -> torch.Tensor:
tuner = AutoTuner.get()
fp8_block_scaling_gemm_runner = Fp8BlockScalingGemmRunner()
Fp8BlockScalingGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens

Fp8BlockScalingGemmRunner.tuning_config.constraint_specs = get_fp8_block_scaling_gemm_constraint_spec(
)

_, best_tactic = tuner.choose_one(
"trtllm::fp8_block_scaling_gemm",
[fp8_block_scaling_gemm_runner],
Fp8BlockScalingGemmRunner.tuning_config,
[a, b, a_scale, b_scale],
)
return fp8_block_scaling_gemm_runner(
inputs=[a, b, a_scale, b_scale],
tactic=best_tactic,
)


@fp8_block_scaling_gemm.register_fake
def _(a, b, a_scale, b_scale, tune_max_num_tokens=4096):
m = a.shape[0]
n = b.shape[0]
return a.new_empty((m, n), dtype=torch.bfloat16)


@torch.library.custom_op("trtllm::silu_and_mul", mutates_args=())
def silu_and_mul(x: torch.Tensor,
scale: Optional[torch.Tensor] = None,
Expand Down
43 changes: 37 additions & 6 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,26 @@ def _create_dummy_mm_context_request(
dummy_mm_prompt = input_processor.get_dummy_prompt(input_seq_len)

if dummy_mm_prompt is not None:
prompt_token_ids, extra_processed_inputs = self._model_engine.input_processor(
prompt_token_ids, extra_processed_inputs = self._model_engine.input_processor_with_hash(
dummy_mm_prompt, sampling_params=None)

multimodal_input = extra_processed_inputs.get(
'multimodal_input')
multimodal_data = extra_processed_inputs.get('multimodal_data')
req_mm_input = trtllm.MultimodalInput(
multimodal_hashes=multimodal_input.multimodal_hashes,
multimodal_positions=multimodal_input.multimodal_positions,
multimodal_lengths=multimodal_input.multimodal_lengths
) if multimodal_input else None

request = trtllm.Request(prompt_token_ids,
max_tokens=1,
streaming=False,
sampling_config=trtllm.SamplingConfig(
beam_width=max_beam_width, ),
output_config=trtllm.OutputConfig(),
end_id=-1)
end_id=-1,
multimodal_input=req_mm_input)
request.py_multimodal_data = multimodal_data
else:
# Fall back to text-only prompt when we could not find the small image size.
Expand Down Expand Up @@ -266,9 +275,29 @@ def _get_token_num_for_estimation(self) -> int:
# Requests cannot share KV cache blocks. Round up to nearest integer multiple of block size.
num_cache_blocks += (num_req_tokens + self._tokens_per_block -
1) // self._tokens_per_block

# Max cuda graph warmup required tokens
max_cuda_graph_bs = min(self._model_engine.batch_size,
self._model_engine._max_cuda_graph_batch_size)
cuda_graph_warmup_block = (
self._model_engine.max_seq_len +
1) // self._tokens_per_block + max_cuda_graph_bs - 1
num_cache_blocks = max(cuda_graph_warmup_block, num_cache_blocks)

# This is the minimal blocks required to run with max bs
# If not able to allocate self._model_engine.batch_size blocks, the max batch size should be adjusted.
num_cache_blocks = max(num_cache_blocks, self._model_engine.batch_size)

free_mem, total_mem = torch.cuda.mem_get_info()
max_memory = self._kv_cache_config.free_gpu_memory_fraction * free_mem
max_num_tokens_in_memory = max_memory // self._get_kv_size_per_token(
) // self._tokens_per_block * self._tokens_per_block

# Multiply by beam width, to prevent rescaling of the max_seq_len caused by the influence of beam width during the preparation for kv_cache_estimation
return num_cache_blocks * self._tokens_per_block * self._dummy_reqs[
0].sampling_config.beam_width
return min(
num_cache_blocks * self._tokens_per_block *
self._dummy_reqs[0].sampling_config.beam_width,
max_num_tokens_in_memory)

def try_prepare_estimation(self) -> bool:
"""Prepare for possible KV cache capacity estimation.
Expand All @@ -279,8 +308,10 @@ def try_prepare_estimation(self) -> bool:
estimating_kv_cache = False
if 'cp_type' not in self._mapping.cp_config:
estimating_kv_cache = True
self._kv_cache_config.max_tokens = self._get_token_num_for_estimation(
)
estimate_max_tokens = self._get_token_num_for_estimation()
self._kv_cache_config.max_tokens = min(
estimate_max_tokens, self._kv_cache_config.max_tokens
) if self._kv_cache_config.max_tokens is not None else estimate_max_tokens
model_config = self._model_engine.model.model_config
if model_config.attn_backend == "VANILLA":
logger.info(
Expand Down
Loading