Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 50 additions & 8 deletions python/sglang/jit_kernel/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def _jit_activation_module(dtype: torch.dtype) -> Module:
extra_cuda_cflags=_fast_math_flags(),
cuda_wrappers=[
("run_activation", f"ActivationKernel<{args}>::run_activation"),
(
"run_activation_filtered",
f"ActivationKernel<{args}>::run_activation_filtered",
),
],
)

Expand All @@ -56,30 +60,68 @@ def _run_activation_inplace(
module.run_activation(input_2d, out_2d, op_name)


@register_custom_op(mutates_args=["out"])
def _run_activation_filtered_inplace(
op_name: str,
input: torch.Tensor,
out: torch.Tensor,
expert_ids: torch.Tensor,
expert_step: int,
) -> None:
hidden_size = input.shape[-1] // 2
module = _jit_activation_module(input.dtype)
input_2d = input.view(-1, hidden_size * 2)
out_2d = out.view(-1, hidden_size)
module.run_activation_filtered(input_2d, out_2d, expert_ids, expert_step, op_name)


def run_activation(
op_name: str, input: torch.Tensor, out: Optional[torch.Tensor]
op_name: str,
input: torch.Tensor,
out: Optional[torch.Tensor],
expert_ids: Optional[torch.Tensor] = None,
expert_step: int = 1,
) -> torch.Tensor:
"""Apply ``op_name`` activation followed by element-wise multiplication.

When ``expert_ids`` is provided, output rows are skipped for tokens whose
routed expert id is ``-1``. ``expert_step`` is 1 for per-token routing and
``BLOCK_SIZE_M`` for sorted/TMA routing — i.e. ``expert_ids[token_id //
expert_step]`` is consulted before computing each row.
"""
assert op_name in SUPPORTED_ACTIVATIONS, f"Unsupported activation: {op_name}"
hidden_size = input.shape[-1] // 2
if out is None:
out = input.new_empty(*input.shape[:-1], hidden_size)
_run_activation_inplace(op_name, input, out)
if expert_ids is None:
_run_activation_inplace(op_name, input, out)
else:
_run_activation_filtered_inplace(op_name, input, out, expert_ids, expert_step)
return out


def silu_and_mul(
input: torch.Tensor, out: Optional[torch.Tensor] = None
input: torch.Tensor,
out: Optional[torch.Tensor] = None,
expert_ids: Optional[torch.Tensor] = None,
expert_step: int = 1,
) -> torch.Tensor:
return run_activation("silu", input, out)
return run_activation("silu", input, out, expert_ids, expert_step)


def gelu_and_mul(
input: torch.Tensor, out: Optional[torch.Tensor] = None
input: torch.Tensor,
out: Optional[torch.Tensor] = None,
expert_ids: Optional[torch.Tensor] = None,
expert_step: int = 1,
) -> torch.Tensor:
return run_activation("gelu", input, out)
return run_activation("gelu", input, out, expert_ids, expert_step)


def gelu_tanh_and_mul(
input: torch.Tensor, out: Optional[torch.Tensor] = None
input: torch.Tensor,
out: Optional[torch.Tensor] = None,
expert_ids: Optional[torch.Tensor] = None,
expert_step: int = 1,
) -> torch.Tensor:
return run_activation("gelu_tanh", input, out)
return run_activation("gelu_tanh", input, out, expert_ids, expert_step)
71 changes: 71 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,76 @@ def f():
return run_benchmark(f, scale=NUM_LAYERS)


FILTER_OPS = ["silu", "gelu"]
FILTER_BS = get_benchmark_range(
full_range=[64, 256, 1024, 4096, 16384], ci_range=[1024]
)
FILTER_DIMS = get_benchmark_range(full_range=[1024, 4096, 8192], ci_range=[4096])
FILTER_RATIOS = get_benchmark_range(full_range=[0.0, 0.25, 0.5], ci_range=[0.25])
FILTER_CONFIGS = list(
itertools.product(FILTER_OPS, FILTER_DIMS, FILTER_BS, FILTER_RATIOS)
)


def _make_expert_ids(num_tokens: int, skip_ratio: float) -> torch.Tensor:
expert_ids = torch.randint(
low=0, high=8, size=(num_tokens,), dtype=torch.int32, device=DEFAULT_DEVICE
)
if skip_ratio > 0:
skip = torch.rand(num_tokens, device=DEFAULT_DEVICE) < skip_ratio
expert_ids[skip] = -1
return expert_ids


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["op_name", "dim", "batch_size", "skip_ratio"],
x_vals=FILTER_CONFIGS,
line_arg="provider",
line_vals=["unfiltered", "filtered"],
line_names=["JIT (no filter_expert)", "JIT (with expert_ids)"],
styles=[("blue", "--"), ("orange", "-")],
ylabel="us",
plot_name="activation-filter-expert",
args={},
)
)
def benchmark_filter(
op_name: str, dim: int, batch_size: int, skip_ratio: float, provider: str
):
x = torch.randn(
NUM_LAYERS,
batch_size,
2 * dim,
dtype=DEFAULT_DTYPE,
device=DEFAULT_DEVICE,
)
out = torch.empty(
NUM_LAYERS,
batch_size,
dim,
dtype=DEFAULT_DTYPE,
device=DEFAULT_DEVICE,
)
expert_ids = _make_expert_ids(batch_size, skip_ratio)

jit_fn = silu_and_mul_jit if op_name == "silu" else gelu_and_mul_jit

if provider == "unfiltered":

def f():
for i in range(NUM_LAYERS):
jit_fn(x[i], out[i])

else: # filtered

def f():
for i in range(NUM_LAYERS):
jit_fn(x[i], out[i], expert_ids=expert_ids, expert_step=1)

return run_benchmark(f, scale=NUM_LAYERS)


if __name__ == "__main__":
benchmark.run(print_data=True)
benchmark_filter.run(print_data=True)
77 changes: 60 additions & 17 deletions python/sglang/jit_kernel/csrc/elementwise/activation.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,14 @@ struct ActivationParams {
void* __restrict__ out;
uint32_t hidden_dim;
uint32_t num_tokens;
// Optional MoE expert filtering: when expert_ids != nullptr, a token is
// skipped if expert_ids[token_id / expert_step] == -1. expert_step is 1
// for per-token routing and BLOCK_SIZE_M for sorted/TMA routing.
const int32_t* __restrict__ expert_ids;
uint32_t expert_step;
};

template <typename T, ActivationKind kAct, bool kUsePDL>
template <typename T, ActivationKind kAct, bool kUsePDL, bool kFilterExpert>
__global__ void act_and_mul_kernel(const __grid_constant__ ActivationParams params) {
using namespace device;
constexpr auto kVecSize = kMaxVecBytes / sizeof(T);
Expand All @@ -56,6 +61,9 @@ __global__ void act_and_mul_kernel(const __grid_constant__ ActivationParams para
const auto token_id = tid / num_vecs;

if (token_id >= params.num_tokens) return;
if constexpr (kFilterExpert) {
if (params.expert_ids[token_id / params.expert_step] == -1) return;
}
const auto offset = tid % num_vecs;
const auto input_offset = token_id * (num_vecs * 2) + offset;
const auto output_offset = tid;
Expand All @@ -78,11 +86,33 @@ struct ActivationKernel {
static constexpr auto kVecSize = device::kMaxVecBytes / sizeof(T);
static constexpr auto kBlockSize = 256u;

template <ActivationKind kAct>
static constexpr auto activation_kernel = act_and_mul_kernel<T, kAct, kUsePDL>;
template <ActivationKind kAct, bool kFilterExpert>
static constexpr auto activation_kernel = act_and_mul_kernel<T, kAct, kUsePDL, kFilterExpert>;

static_assert(device::kMaxVecBytes % sizeof(T) == 0, "unsupported data type");
static void run_activation(const tvm::ffi::TensorView input, const tvm::ffi::TensorView out, std::string type) {

template <bool kFilterExpert>
static auto select_kernel(const std::string& type)
-> decltype(activation_kernel<ActivationKind::kSiLU, kFilterExpert>) {
using namespace host;
if (type == "silu") {
return activation_kernel<ActivationKind::kSiLU, kFilterExpert>;
} else if (type == "gelu") {
return activation_kernel<ActivationKind::kGELU, kFilterExpert>;
} else if (type == "gelu_tanh") {
return activation_kernel<ActivationKind::kGELUTanh, kFilterExpert>;
} else {
Panic("unsupported activation type: ", type);
}
return nullptr;
}

static void launch(
const tvm::ffi::TensorView& input,
const tvm::ffi::TensorView& out,
const std::string& type,
const int32_t* expert_ids,
uint32_t expert_step) {
using namespace host;

auto N = SymbolicSize{"num_tokens"};
Expand All @@ -106,18 +136,6 @@ struct ActivationKernel {
if (num_tokens == 0) return;
RuntimeCheck(hidden_size * 2 == D_in.unwrap(), "invalid activation dimension");
RuntimeCheck(hidden_size % kVecSize == 0, "hidden size must be divisible by vector size");
const auto kernel = [&]() -> decltype(activation_kernel<ActivationKind::kSiLU>) {
if (type == "silu") {
return activation_kernel<ActivationKind::kSiLU>;
} else if (type == "gelu") {
return activation_kernel<ActivationKind::kGELU>;
} else if (type == "gelu_tanh") {
return activation_kernel<ActivationKind::kGELUTanh>;
} else {
Panic("unsupported activation type: ", type);
}
return nullptr;
}();
// only get once to avoid overhead
const auto num_total_items = num_tokens * (hidden_size / kVecSize);
RuntimeCheck(num_total_items <= std::numeric_limits<uint32_t>::max(), "too many items for 32-bit indexing");
Expand All @@ -127,8 +145,33 @@ struct ActivationKernel {
.out = out.data_ptr(),
.hidden_dim = hidden_size,
.num_tokens = num_tokens,
.expert_ids = expert_ids,
.expert_step = expert_step,
};
LaunchKernel(num_blocks, kBlockSize, device).enable_pdl(kUsePDL)(kernel, params);
if (expert_ids != nullptr) {
RuntimeCheck(expert_step > 0, "expert_step must be positive");
const auto kernel = select_kernel<true>(type);
LaunchKernel(num_blocks, kBlockSize, device).enable_pdl(kUsePDL)(kernel, params);
} else {
const auto kernel = select_kernel<false>(type);
LaunchKernel(num_blocks, kBlockSize, device).enable_pdl(kUsePDL)(kernel, params);
}
}

static void run_activation(const tvm::ffi::TensorView input, const tvm::ffi::TensorView out, std::string type) {
launch(input, out, type, /*expert_ids=*/nullptr, /*expert_step=*/1);
}

static void run_activation_filtered(
const tvm::ffi::TensorView input,
const tvm::ffi::TensorView out,
const tvm::ffi::TensorView expert_ids,
int64_t expert_step,
std::string type) {
using namespace host;
RuntimeCheck(is_type<int32_t>(expert_ids.dtype()), "expert_ids must have dtype int32");
RuntimeCheck(expert_step >= 1, "expert_step must be positive");
launch(input, out, type, static_cast<const int32_t*>(expert_ids.data_ptr()), static_cast<uint32_t>(expert_step));
Comment on lines +171 to +174
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The expert_ids tensor should be verified to be on the same device as the input and out tensors. Accessing a CPU tensor's data pointer from a CUDA kernel will lead to a segmentation fault or illegal memory access. Additionally, verifying that expert_ids is a 1D tensor ensures the indexing logic in the kernel remains valid.

    using namespace host;
    RuntimeCheck(is_type<int32_t>(expert_ids.dtype()), "expert_ids must have dtype int32");
    RuntimeCheck(expert_ids.device().device_type == input.device().device_type &&
                 expert_ids.device().device_id == input.device().device_id,
                 "expert_ids must be on the same device as input");
    RuntimeCheck(expert_ids.ndim() == 1, "expert_ids must be a 1D tensor");
    RuntimeCheck(expert_step >= 1, "expert_step must be positive");
    launch(input, out, type, static_cast<const int32_t*>(expert_ids.data_ptr()), static_cast<uint32_t>(expert_step));

}
};

Expand Down
80 changes: 80 additions & 0 deletions python/sglang/jit_kernel/tests/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,85 @@ def test_activation_out_param(
torch.testing.assert_close(out, expected, atol=atol, rtol=rtol)


FILTER_SHAPES = get_ci_test_range(
full_range=[(83, 1024), (256, 2048), (1024, 4096)],
ci_range=[(83, 1024)],
)
EXPERT_STEPS = [1, 16]


@pytest.mark.parametrize("op_name", OPS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", FILTER_SHAPES)
@pytest.mark.parametrize("expert_step", EXPERT_STEPS)
def test_activation_filter_expert(
op_name: str,
dtype: torch.dtype,
shape: tuple[int, int],
expert_step: int,
) -> None:
"""expert_ids[token // expert_step] == -1 must leave the output row untouched."""
num_tokens = shape[0]
x = torch.randn(shape, dtype=dtype, device="cuda")
# Pre-fill out with a sentinel so we can detect untouched rows.
sentinel = float("nan")
out = torch.full(
shape[:-1] + (shape[-1] // 2,),
sentinel,
dtype=dtype,
device="cuda",
)

num_groups = (num_tokens + expert_step - 1) // expert_step
expert_ids = torch.randint(
low=0, high=8, size=(num_groups,), dtype=torch.int32, device="cuda"
)
skip_mask = torch.rand(num_groups, device="cuda") < 0.4
expert_ids[skip_mask] = -1

result = run_activation(op_name, x, out, expert_ids, expert_step)
assert result is out

token_skip = skip_mask[torch.arange(num_tokens, device="cuda") // expert_step]
expected = _reference(op_name, x)
atol, rtol = _tolerances(dtype)

kept = ~token_skip
if kept.any():
torch.testing.assert_close(out[kept], expected[kept], atol=atol, rtol=rtol)
if token_skip.any():
assert torch.isnan(
out[token_skip]
).all(), "filter_expert kernel touched rows whose expert_id is -1"


@pytest.mark.parametrize("op_name", OPS)
def test_activation_filter_expert_all_skipped(op_name: str) -> None:
"""If every expert id is -1, the output must be left entirely untouched."""
shape = (32, 512)
x = torch.randn(shape, dtype=torch.bfloat16, device="cuda")
out = torch.full(
shape[:-1] + (shape[-1] // 2,),
float("nan"),
dtype=torch.bfloat16,
device="cuda",
)
expert_ids = torch.full((shape[0],), -1, dtype=torch.int32, device="cuda")
run_activation(op_name, x, out, expert_ids, 1)
assert torch.isnan(out).all()


@pytest.mark.parametrize("op_name", OPS)
def test_activation_filter_expert_none_skipped(op_name: str) -> None:
"""No -1 in expert_ids must yield bit-identical output to the unfiltered path."""
shape = (64, 512)
dtype = torch.bfloat16
x = torch.randn(shape, dtype=dtype, device="cuda")
expert_ids = torch.zeros((shape[0],), dtype=torch.int32, device="cuda")
out_filtered = run_activation(op_name, x, None, expert_ids, 1)
out_unfiltered = run_activation(op_name, x, None)
torch.testing.assert_close(out_filtered, out_unfiltered, atol=0.0, rtol=0.0)


if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-v", "-s"]))
Loading
Loading