Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions flash_attn/ops/triton/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@
import triton
import triton.language as tl

def triton_autotune_configs():
# Return configs with a valid warp count for the current device
configs=[]
# Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
max_threads_per_block=1024
# Default to warp size 32 if not defined by device
warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
# Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
warp_count=1
while warp_count*warp_size <= max_threads_per_block:
configs.append(triton.Config({}, num_warps=warp_count))
warp_count*=2
return configs

def layer_norm_ref(
x,
Expand Down Expand Up @@ -126,14 +139,7 @@ def rms_norm_ref(


@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
configs=triton_autotune_configs(),
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
Expand Down Expand Up @@ -393,14 +399,7 @@ def _layer_norm_fwd(


@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
configs=triton_autotune_configs(),
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
Expand Down