1818from enum import IntEnum
1919from types import SimpleNamespace
2020from typing import Any , Dict , List , Optional , Tuple , Union
21- from typing_extensions import deprecated
2221import torch
2322
2423from ..autotuner import (
@@ -1656,9 +1655,6 @@ def _fake_trtllm_fp4_block_scale_moe(
16561655 )
16571656
16581657
1659- @deprecated (
1660- "tile_tokens_dim is deprecated and will be removed in trtllm_fp8_per_tensor_scale_moe after v0.5.0"
1661- )
16621658def trtllm_fp8_per_tensor_scale_moe (
16631659 routing_logits : torch .Tensor ,
16641660 routing_bias : Optional [torch .Tensor ],
@@ -1708,6 +1704,12 @@ def trtllm_fp8_per_tensor_scale_moe(
17081704 Returns:
17091705 torch.Tensor: Output tensor of shape [seq_len, hidden_size]
17101706 """
1707+ if tile_tokens_dim is not None :
1708+ logger .warning_once (
1709+ "tile_tokens_dim in trtllm_fp8_per_tensor_scale_moe is planned for deprecation "
1710+ "in a future release. Please remove it from your code as tile_tokens_dim will no "
1711+ "longer be supported after v0.5.0."
1712+ )
17111713 return get_trtllm_moe_sm100_module ().trtllm_fp8_per_tensor_scale_moe (
17121714 routing_logits ,
17131715 routing_bias ,
@@ -1731,9 +1733,6 @@ def trtllm_fp8_per_tensor_scale_moe(
17311733 )
17321734
17331735
1734- @deprecated (
1735- "tile_tokens_dim is deprecated and will be removed in trtllm_fp8_block_scale_moe after v0.5.0"
1736- )
17371736def trtllm_fp8_block_scale_moe (
17381737 routing_logits : torch .Tensor ,
17391738 routing_bias : Optional [torch .Tensor ],
@@ -1782,6 +1781,12 @@ def trtllm_fp8_block_scale_moe(
17821781 Returns:
17831782 torch.Tensor: Output tensor of shape [seq_len, hidden_size]
17841783 """
1784+ if tile_tokens_dim is not None :
1785+ logger .warning_once (
1786+ "tile_tokens_dim in trtllm_fp8_block_scale_moe is planned for deprecation "
1787+ "in a future release. Please remove it from your code as tile_tokens_dim will no "
1788+ "longer be supported after v0.5.0."
1789+ )
17851790 output = torch .empty (
17861791 hidden_states .shape , dtype = torch .bfloat16 , device = hidden_states .device
17871792 )
@@ -1810,9 +1815,6 @@ def trtllm_fp8_block_scale_moe(
18101815 )
18111816
18121817
1813- @deprecated (
1814- "tile_tokens_dim is deprecated and will be removed in trtllm_fp4_block_scale_moe after v0.5.0"
1815- )
18161818def trtllm_fp4_block_scale_moe (
18171819 routing_logits : torch .Tensor ,
18181820 routing_bias : Optional [torch .Tensor ],
@@ -1908,7 +1910,12 @@ def trtllm_fp4_block_scale_moe(
19081910 List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output.
19091911 Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing.
19101912 """
1911-
1913+ if tile_tokens_dim is not None :
1914+ logger .warning_once (
1915+ "tile_tokens_dim in trtllm_fp4_block_scale_moe is planned for deprecation "
1916+ "in a future release. Please remove it from your code as tile_tokens_dim will no "
1917+ "longer be supported after v0.5.0."
1918+ )
19121919 return get_trtllm_moe_sm100_module ().trtllm_fp4_block_scale_moe (
19131920 routing_logits ,
19141921 None ,
@@ -1945,9 +1952,6 @@ def trtllm_fp4_block_scale_moe(
19451952 )
19461953
19471954
1948- @deprecated (
1949- "tile_tokens_dim is deprecated and will be removed in trtllm_fp4_block_scale_routed_moe after v0.5.0"
1950- )
19511955def trtllm_fp4_block_scale_routed_moe (
19521956 topk_ids : torch .Tensor ,
19531957 routing_bias : Optional [torch .Tensor ],
@@ -2046,7 +2050,7 @@ def trtllm_fp4_block_scale_routed_moe(
20462050 Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing.
20472051 """
20482052 if tile_tokens_dim is not None :
2049- logger .info (
2053+ logger .warning_once (
20502054 "tile_tokens_dim in trtllm_fp4_block_scale_routed_moe is planned for deprecation "
20512055 "in a future release. Please remove it from your code as tile_tokens_dim will no "
20522056 "longer be supported after v0.5.0."
0 commit comments