Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions vllm/model_executor/layers/batch_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,29 +930,29 @@ def enable_batch_invariant_mode():
_batch_invariant_MODE = True
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")

if current_platform.is_device_capability_family(
100
) or current_platform.is_device_capability_family(80):
# For PyTorch 2.9, B200 uses GEMV for bs=1
# Requires https://github.com/pytorch/pytorch/pull/166735
if current_platform.is_device_capability_family(80):
# SM80 (Ampere) cannot rely on cuBLASLt-only determinism; install the
# triton persistent matmul overrides for mm/addmm/matmul/linear.
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")

# Query the shared memory size and set block size
# accordingly to avoid triton OutOfResources
_fp16_block_size_n = 256 if get_max_shared_memory_bytes() > 106496 else 128
else:
# Only source of batch invariance for Hopper is split-k, can disable through
# cuBLAS workspace config
# Hopper (SM90) and Blackwell (SM100): the only source of batch
# variance is split-k, which we disable via the cuBLAS workspace
# config.
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
_original_cublaslt_workspace_size = os.environ.get(
"CUBLASLT_WORKSPACE_SIZE", None
)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"

# Triton bmm/persistent-matmul kernels read this for the FP16 N-tile size;
# set unconditionally because bmm is overridden on all CUDA platforms.
if current_platform.is_cuda():
_fp16_block_size_n = 256 if get_max_shared_memory_bytes() > 106496 else 128

_batch_invariant_LIB.impl(
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
)
Expand Down
Loading