Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions flashinfer/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,13 @@ def choose_one(
# Record the total configs to try
self.stats.tuned_op_total_configs[custom_op] = len(profiles)

# Pre-compute runner arg names to avoid calling inspect.signature in the loop
runner_arg_names_map = {}
for r in runners:
runner_arg_names_map[r] = {
param.name for param in inspect.signature(r.forward).parameters.values()
}

for p in profiles:
tensors = self._prepare_input_tensors(p, inputs)
is_cache_hit, runner_id, tactic, _ = self.search_cache(
Expand All @@ -470,9 +477,7 @@ def choose_one(
for r_id, r in enumerate(runners):
# TODO: use FakeTensor here.
valid_tactics = r.get_valid_tactics(tensors, p)
runner_arg_names = {
p.name for p in inspect.signature(r.forward).parameters.values()
}
runner_arg_names = runner_arg_names_map[r]
if "do_preparation" in runner_arg_names and len(valid_tactics) > 0:
r(tensors, tactic=-1, do_preparation=True, **kwargs)
for tac in valid_tactics:
Expand Down
120 changes: 74 additions & 46 deletions flashinfer/gemm/gemm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,25 @@ def forward(
)


_FP8_GEMM_SM100_TUNING_CONFIG = TuningConfig(
dynamic_tensor_specs=(
DynamicTensorSpec(
(0,), # a_tensor_index
(-2,),
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
),
),
constraint_specs=(
ConstraintSpec(
4, # out_tensor_index
-2,
lambda shapes: shapes[0][-2],
),
),
)


def fp8_gemm_sm100(
a: torch.Tensor,
b: torch.Tensor,
Expand All @@ -376,29 +395,12 @@ def fp8_gemm_sm100(
runners.append(_cudnn_gemm_fp8_runner())
assert runners, "No suitable runners found"
tuner = AutoTuner.get()
a_tensor_index = 0
out_tensor_index = 4
tuning_config = TuningConfig(
dynamic_tensor_specs=(
DynamicTensorSpec(
(a_tensor_index,),
(-2,),
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
),
),
constraint_specs=(
ConstraintSpec(
out_tensor_index, -2, lambda shapes: shapes[a_tensor_index][-2]
),
),
)

inputs = [a, b, scale_a, scale_b, out, workspace_buffer]
runner, tactic = tuner.choose_one(
"fp8_gemm",
runners,
tuning_config,
_FP8_GEMM_SM100_TUNING_CONFIG,
inputs,
)

Expand Down Expand Up @@ -2019,6 +2021,58 @@ def _heuristic_func_mm_fp4(
return [c for c in candidate_backends if c in suitable_backends]


def _pad_up(x, y):
return ((x + y - 1) // y) * y


_MM_FP4_TUNING_CONFIG_8x4 = TuningConfig(
dynamic_tensor_specs=(
DynamicTensorSpec(
(0,), # a_tensor_index
(0,),
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
),
),
constraint_specs=(
ConstraintSpec(
2, # a_scale_tensor_index
0,
lambda shapes: _pad_up(shapes[0][0], 8),
),
ConstraintSpec(
6, # out_tensor_index
0,
lambda shapes: shapes[0][0],
),
),
)


_MM_FP4_TUNING_CONFIG_128x4 = TuningConfig(
dynamic_tensor_specs=(
DynamicTensorSpec(
(0,), # a_tensor_index
(0,),
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
),
),
constraint_specs=(
ConstraintSpec(
2, # a_scale_tensor_index
0,
lambda shapes: _pad_up(shapes[0][0], 128),
),
ConstraintSpec(
6, # out_tensor_index
0,
lambda shapes: shapes[0][0],
),
),
)


@backend_requirement(
{
"cudnn": _cudnn_gemm_fp4_requirement,
Expand Down Expand Up @@ -2138,34 +2192,8 @@ def mm_fp4(
# Now we have a list of runners for desired & supported backends.
tuner = AutoTuner.get()

a_tensor_index = 0
a_scale_tensor_index = 2
out_tensor_index = 6

def pad_up(x, y):
return ((x + y - 1) // y) * y

tuning_config = TuningConfig(
dynamic_tensor_specs=(
DynamicTensorSpec(
(a_tensor_index,),
(0,),
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
),
),
constraint_specs=(
ConstraintSpec(
a_scale_tensor_index,
0,
lambda shapes: pad_up(
shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128
),
),
ConstraintSpec(
out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0]
),
),
tuning_config = (
_MM_FP4_TUNING_CONFIG_8x4 if use_8x4_sf_layout else _MM_FP4_TUNING_CONFIG_128x4
)

inputs = [
Expand Down