diff --git a/flashinfer/aot.py b/flashinfer/aot.py index dfb05150a8..95b1619b77 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -917,7 +917,7 @@ def register_default_modules() -> int: return len(jit_specs) -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="Ahead-of-Time (AOT) build all modules" ) diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index 36f37b1725..a22575abd9 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -558,11 +558,15 @@ class AutoTuner: _instance = None _class_lock = threading.Lock() - def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000): + def __init__( + self, warmup: int = 3, repeat: int = 10, stream_delay_micro_secs: int = 1000 + ) -> None: self.repeat = repeat self.warmup = warmup self.stream_delay_micro_secs = stream_delay_micro_secs - self.profiling_cache = {} + self.profiling_cache: Dict[ + Tuple[Any, ...], Tuple[int, Any, OptimizationProfile] + ] = {} self.is_tuning_mode = False self._active_tuning_contexts = 0 diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index ef048b5e54..07e7075ce5 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -480,6 +480,7 @@ def allreduce_fusion( expanded_idx_to_permuted_idx: Optional[torch.Tensor] = None, expert_scale_factor: Optional[torch.Tensor] = None, shared_expert_output: Optional[torch.Tensor] = None, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: """ AllReduce + RMSNorm fusion operation. @@ -548,6 +549,7 @@ def allreduce_fusion( output row. Shape [token_num, top_k], dtype int32. expert_scale_factor: Router weights for each selected expert [token_num, top_k] shared_expert_output: Optional shared expert output to add [token_num, hidden_dim] + routed_scaling_factor: Optional scaling factor forwarded to MoE finalize fusion Returns: Output tensor (typically norm_out for fusion cases, output otherwise) @@ -726,10 +728,17 @@ def allreduce_fusion( eps=rms_eps, shared_expert_output=shared_expert_output, expert_scale_factor=expert_scale_factor, - routed_scaling_factor=None, + routed_scaling_factor=routed_scaling_factor, ) - return norm_out + if norm_out is not None: + return norm_out + elif quant_out is not None: + return quant_out + elif residual_out is not None: + return residual_out + else: + return input # ---- Standard patterns (0-5) ---- # Extract shape from 2D input diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 495b006716..04cf7f0f1d 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -160,7 +160,7 @@ def status(self) -> str: class JitSpecRegistry: """Global registry to track all JitSpecs""" - def __init__(self): + def __init__(self) -> None: self._specs: Dict[str, JitSpec] = {} self._creation_times: Dict[str, datetime] = {}