Skip to content

Commit c4be849

Browse files
committed
add log_once and update deprecation log
Signed-off-by: jiahanc <[email protected]>
1 parent 64298f3 commit c4be849

File tree

2 files changed

+48
-16
lines changed

2 files changed

+48
-16
lines changed

flashinfer/fused_moe/core.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from enum import IntEnum
1919
from types import SimpleNamespace
2020
from typing import Any, Dict, List, Optional, Tuple, Union
21-
from typing_extensions import deprecated
2221
import torch
2322

2423
from ..autotuner import (
@@ -1656,9 +1655,6 @@ def _fake_trtllm_fp4_block_scale_moe(
16561655
)
16571656

16581657

1659-
@deprecated(
1660-
"tile_tokens_dim is deprecated and will be removed in trtllm_fp8_per_tensor_scale_moe after v0.5.0"
1661-
)
16621658
def trtllm_fp8_per_tensor_scale_moe(
16631659
routing_logits: torch.Tensor,
16641660
routing_bias: Optional[torch.Tensor],
@@ -1708,6 +1704,12 @@ def trtllm_fp8_per_tensor_scale_moe(
17081704
Returns:
17091705
torch.Tensor: Output tensor of shape [seq_len, hidden_size]
17101706
"""
1707+
if tile_tokens_dim is not None:
1708+
logger.warning_once(
1709+
"tile_tokens_dim in trtllm_fp8_per_tensor_scale_moe is planned for deprecation "
1710+
"in a future release. Please remove it from your code as tile_tokens_dim will no "
1711+
"longer be supported after v0.5.0."
1712+
)
17111713
return get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe(
17121714
routing_logits,
17131715
routing_bias,
@@ -1731,9 +1733,6 @@ def trtllm_fp8_per_tensor_scale_moe(
17311733
)
17321734

17331735

1734-
@deprecated(
1735-
"tile_tokens_dim is deprecated and will be removed in trtllm_fp8_block_scale_moe after v0.5.0"
1736-
)
17371736
def trtllm_fp8_block_scale_moe(
17381737
routing_logits: torch.Tensor,
17391738
routing_bias: Optional[torch.Tensor],
@@ -1782,6 +1781,12 @@ def trtllm_fp8_block_scale_moe(
17821781
Returns:
17831782
torch.Tensor: Output tensor of shape [seq_len, hidden_size]
17841783
"""
1784+
if tile_tokens_dim is not None:
1785+
logger.warning_once(
1786+
"tile_tokens_dim in trtllm_fp8_block_scale_moe is planned for deprecation "
1787+
"in a future release. Please remove it from your code as tile_tokens_dim will no "
1788+
"longer be supported after v0.5.0."
1789+
)
17851790
output = torch.empty(
17861791
hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device
17871792
)
@@ -1810,9 +1815,6 @@ def trtllm_fp8_block_scale_moe(
18101815
)
18111816

18121817

1813-
@deprecated(
1814-
"tile_tokens_dim is deprecated and will be removed in trtllm_fp4_block_scale_moe after v0.5.0"
1815-
)
18161818
def trtllm_fp4_block_scale_moe(
18171819
routing_logits: torch.Tensor,
18181820
routing_bias: Optional[torch.Tensor],
@@ -1908,7 +1910,12 @@ def trtllm_fp4_block_scale_moe(
19081910
List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output.
19091911
Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing.
19101912
"""
1911-
1913+
if tile_tokens_dim is not None:
1914+
logger.warning_once(
1915+
"tile_tokens_dim in trtllm_fp4_block_scale_moe is planned for deprecation "
1916+
"in a future release. Please remove it from your code as tile_tokens_dim will no "
1917+
"longer be supported after v0.5.0."
1918+
)
19121919
return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe(
19131920
routing_logits,
19141921
None,
@@ -1945,9 +1952,6 @@ def trtllm_fp4_block_scale_moe(
19451952
)
19461953

19471954

1948-
@deprecated(
1949-
"tile_tokens_dim is deprecated and will be removed in trtllm_fp4_block_scale_routed_moe after v0.5.0"
1950-
)
19511955
def trtllm_fp4_block_scale_routed_moe(
19521956
topk_ids: torch.Tensor,
19531957
routing_bias: Optional[torch.Tensor],
@@ -2046,7 +2050,7 @@ def trtllm_fp4_block_scale_routed_moe(
20462050
Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing.
20472051
"""
20482052
if tile_tokens_dim is not None:
2049-
logger.info(
2053+
logger.warning_once(
20502054
"tile_tokens_dim in trtllm_fp4_block_scale_routed_moe is planned for deprecation "
20512055
"in a future release. Please remove it from your code as tile_tokens_dim will no "
20522056
"longer be supported after v0.5.0."

flashinfer/jit/core.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import dataclasses
2+
import functools
23
import logging
34
import os
45
from contextlib import nullcontext
56
from datetime import datetime
67
from pathlib import Path
7-
from typing import Dict, List, Optional, Sequence, Union
8+
from typing import Dict, List, Optional, Sequence, Union, Hashable
89

910
import tvm_ffi
1011
from filelock import FileLock
@@ -60,6 +61,33 @@ def __init__(self, name):
6061
)
6162
)
6263

64+
def debug_once(self, msg: str, *args: Hashable) -> None:
65+
"""
66+
As [`debug`][logging.Logger.debug], but subsequent calls with
67+
the same message are silently dropped.
68+
"""
69+
self._print_once(self.debug, msg, *args)
70+
71+
def info_once(self, msg: str, *args: Hashable) -> None:
72+
"""
73+
As [`info`][logging.Logger.info], but subsequent calls with
74+
the same message are silently dropped.
75+
"""
76+
self._print_once(self.info, msg, *args)
77+
78+
def warning_once(self, msg: str, *args: Hashable) -> None:
79+
"""
80+
As [`warning`][logging.Logger.warning], but subsequent calls with
81+
the same message are silently dropped.
82+
"""
83+
self._print_once(self.warning, msg, *args)
84+
85+
@functools.lru_cache(maxsize=None)
86+
def _print_once(self, log_method, msg: str, *args: Hashable) -> None:
87+
"""Helper method to log messages only once per unique (msg, args) combination."""
88+
# Note: stacklevel=3 to show the caller's location, not this helper method
89+
log_method(msg, *args, stacklevel=3)
90+
6391

6492
logger = FlashInferJITLogger("flashinfer.jit")
6593

0 commit comments

Comments
 (0)