Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flashinfer/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ def register_default_modules() -> int:
return len(jit_specs)


def main():
def main() -> None:
parser = argparse.ArgumentParser(
description="Ahead-of-Time (AOT) build all modules"
)
Expand Down
8 changes: 6 additions & 2 deletions flashinfer/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,11 +558,15 @@ class AutoTuner:
_instance = None
_class_lock = threading.Lock()

def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
def __init__(
self, warmup: int = 3, repeat: int = 10, stream_delay_micro_secs: int = 1000
) -> None:
self.repeat = repeat
self.warmup = warmup
self.stream_delay_micro_secs = stream_delay_micro_secs
self.profiling_cache = {}
self.profiling_cache: Dict[
Tuple[Any, ...], Tuple[int, Any, OptimizationProfile]
] = {}
self.is_tuning_mode = False
self._active_tuning_contexts = 0

Expand Down
13 changes: 11 additions & 2 deletions flashinfer/comm/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def allreduce_fusion(
expanded_idx_to_permuted_idx: Optional[torch.Tensor] = None,
expert_scale_factor: Optional[torch.Tensor] = None,
shared_expert_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
"""
AllReduce + RMSNorm fusion operation.
Expand Down Expand Up @@ -548,6 +549,7 @@ def allreduce_fusion(
output row. Shape [token_num, top_k], dtype int32.
expert_scale_factor: Router weights for each selected expert [token_num, top_k]
shared_expert_output: Optional shared expert output to add [token_num, hidden_dim]
routed_scaling_factor: Optional scaling factor forwarded to MoE finalize fusion

Returns:
Output tensor (typically norm_out for fusion cases, output otherwise)
Expand Down Expand Up @@ -726,10 +728,17 @@ def allreduce_fusion(
eps=rms_eps,
shared_expert_output=shared_expert_output,
expert_scale_factor=expert_scale_factor,
routed_scaling_factor=None,
routed_scaling_factor=routed_scaling_factor,
)

return norm_out
if norm_out is not None:
return norm_out
elif quant_out is not None:
return quant_out
elif residual_out is not None:
return residual_out
else:
return input
Comment on lines +734 to +741
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if-elif-else chain for returning outputs is redundant and can be simplified using a list or tuple of potential outputs to return the first non-None value.

for out in [norm_out, quant_out, residual_out]:
                if out is not None:
                    return out
            return input


# ---- Standard patterns (0-5) ----
# Extract shape from 2D input
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def status(self) -> str:
class JitSpecRegistry:
"""Global registry to track all JitSpecs"""

def __init__(self):
def __init__(self) -> None:
self._specs: Dict[str, JitSpec] = {}
self._creation_times: Dict[str, datetime] = {}

Expand Down
Loading