Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified .pre-commit-config.yaml
100644 β†’ 100755
Empty file.
4 changes: 2 additions & 2 deletions flashinfer/jit/attention/fmha_v2/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3711,10 +3711,10 @@ def generate_files(specs_names):
]
if "CUDA_PATH" in os.environ:
cmd[0] = os.environ["CUDA_PATH"] + "/bin/" + cmd[0]
print('Running command "{}" to build "bin/print_traits.exe":'.format(" ".join(cmd)))
# print('Running command "{}" to build "bin/print_traits.exe":'.format(" ".join(cmd)))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these changes relevant to the PR?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really, just commenting these out from the original trtllm script to clean up the stdout

process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
output, error = process.communicate()
print('Running "bin/print_traits.exe":')
# print('Running "bin/print_traits.exe":')
Comment on lines +3714 to +3717
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of commenting out these print statements, consider using the logging module. This allows for more flexible control over verbosity (e.g., via log levels like INFO or DEBUG) and is a better practice for maintainability. The information about the commands being run is valuable for debugging the kernel generation process.

process = subprocess.Popen(
"bin/print_traits.exe", stdin=subprocess.PIPE, stdout=subprocess.PIPE
)
Expand Down
3 changes: 2 additions & 1 deletion flashinfer/jit/attention/modules.py
100644 β†’ 100755
Original file line number Diff line number Diff line change
Expand Up @@ -1901,9 +1901,10 @@ def gen_trtllm_fmha_v2_module() -> JitSpec:
source_paths = kernel_paths + [binding_source_path]

nvcc_flags = current_compilation_context.get_nvcc_flags_list(
supported_major_versions=[10, 11, 12]
supported_major_versions=[12]
)
nvcc_flags.append(f"-I{jit_env.FLASHINFER_CSRC_DIR / 'fmha_v2'}")
nvcc_flags.append("-Wno-deprecated-gpu-targets")

return gen_jit_spec(
uri,
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3603,6 +3603,8 @@ def fmha_v2_prefill_deepseek(
If return_lse is True, the output will be a tuple of two tensors, the first is the output tensor, the second is the lse tensor.
If return_lse is False, the output will be a single tensor.
"""
if not is_sm120a_supported(query.device):
raise ValueError("fmha_v2_prefill_deepseek is only supported on SM120 GPUs.")
assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, (
"currently only support deepseek r1 192 query and 128 value"
)
Expand Down
3 changes: 3 additions & 0 deletions tests/attention/test_fmha_v2_prefill_deepseek.py
100644 β†’ 100755
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from flashinfer.prefill import fmha_v2_prefill_deepseek
from tests.utils_fp8 import to_float8
from flashinfer.utils import is_sm120a_supported


def attention_ref(
Expand Down Expand Up @@ -56,6 +57,8 @@ def attention_ref(
def test_fmha_v2_prefill_deepseek(
batch_size, num_heads, head_dim_qk, head_dim_v, seq_len, qkv_dtype, o_dtype
):
if not is_sm120a_supported(torch.device("cuda")):
pytest.skip("fmha_v2_prefill_deepseek is only supported on SM120 GPUs.")
torch.manual_seed(42)

def initialize_tensors(batch_size, num_heads, head_dim_qk, head_dim_v, seq_len):
Expand Down