Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,21 @@ def _create_scale_bmm2_d_tensor(
"""
if data_dtype == torch.float16:
# Create int32 buffer on device, write FP16 value to lower 16 bits via view
result = torch.zeros(1, dtype=torch.int32, device=device)
result.view(torch.float16)[0] = scale_bmm2
return result
return (
torch.full((1,), scale_bmm2, dtype=torch.float16, device=device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this still causes a tiny Fill kernel.

If we want to eliminate this kernel as well, the solution would be to accept bmm1_scale and bmm2_scale as a torch.Tensor so that the framework (like SGLang) can provide the scales as device tensors directly (and framework can cache that across decoding steps).

https://github.com/akhilg-nv/flashinfer/blob/bdf29115facde5097b050c5ffdf60f0eae9826f9/flashinfer/prefill.py#L4088-L4089

See this as an example: https://github.com/akhilg-nv/flashinfer/blob/bdf29115facde5097b050c5ffdf60f0eae9826f9/flashinfer/prefill.py#L3725-L3726

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, Jimmy has this draft PR up which allows the scale to be Union[float, torch.Tensor]. I think we will still want to keep the logic in this PR for the case where the input is a float, but perhaps it may be better to force it to be a tensor?

.view(torch.uint16)
.to(torch.int32)
)
elif data_dtype == torch.bfloat16:
# Create int32 buffer on device, write BF16 value to lower 16 bits via view
result = torch.zeros(1, dtype=torch.int32, device=device)
result.view(torch.bfloat16)[0] = scale_bmm2
return result
return (
torch.full((1,), scale_bmm2, dtype=torch.bfloat16, device=device)
.view(torch.uint16)
.to(torch.int32)
)
else:
# FP8, INT8, etc. use FP32 accumulation - create FP32 tensor and view as int32
return torch.tensor([scale_bmm2], dtype=torch.float32, device=device).view(
return torch.full((1,), scale_bmm2, dtype=torch.float32, device=device).view(
torch.int32
)

Expand Down
Loading