Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def moe_gemm1(
by = gpu.block_id("x") # tile along inter_dim (N)
bx_persist = gpu.block_id("y") # persistent WG index

if xcd_swizzle > 0:
if const_expr(xcd_swizzle > 0):
_NUM_XCDS_S1 = 8
_c1_sw = arith.constant(1, index=True)
_c_tn_sw = arith.constant(tile_n, index=True)
Expand Down Expand Up @@ -579,7 +579,7 @@ def moe_gemm1(
_lds_out_elem_type = (
T.f32 if _need_quant else (T.bf16 if out_is_bf16 else T.f16)
)
if _split_lds_out and _use_cshuffle_epilog:
if const_expr(_split_lds_out and _use_cshuffle_epilog):
_half_out_elems = int(tile_m) * (int(tile_n) // 2)
lds_out = SmemPtr(
base_ptr_pong,
Expand Down Expand Up @@ -1296,7 +1296,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3):
)
for ikxdl in range_constexpr(pack_K):
k_idx = ku128 * pack_K + ikxdl
if k_idx < ku_count:
if const_expr(k_idx < ku_count):
gate_bp0, gate_bp1 = gate_b_tile_in[k_idx]
if const_expr(not _single_b):
up_bp0, up_bp1 = up_b_tile_in[k_idx]
Expand Down Expand Up @@ -1863,7 +1863,7 @@ def _interleaved_half(
prefetch_x_to_lds(k_tail1, lds_x_ping)
else:
x_regs_ping = load_x_tile(k_tail1)
if _pad_ku_skip > 0:
if const_expr(_pad_ku_skip > 0):
gate_w_ping, up_w_ping = load_b_tile(
k_tail1 // arith.constant(2, index=True),
ku_limit=_tail_ku,
Expand Down Expand Up @@ -1893,7 +1893,7 @@ def _interleaved_half(
store_x_tile_to_lds(x_regs_ping, lds_x_ping)
rocdl.s_waitcnt(0)
_barrier()
if _pad_ku_skip > 0:
if const_expr(_pad_ku_skip > 0):
a_tile_ping = prefetch_full_a_from_lds(
lds_x_ping, ku_limit=_tail_ku
)
Expand Down Expand Up @@ -1974,7 +1974,7 @@ def _swiglu_mul_vec4(gate_v4, up_v4):

def _act_vec4(gate_v4, up_v4):
"""Dispatch activation based on `act` parameter."""
if act == "swiglu":
if const_expr(act == "swiglu"):
return _swiglu_mul_vec4(gate_v4, up_v4)
else:
return _silu_mul_vec4(gate_v4, up_v4)
Expand Down Expand Up @@ -2320,7 +2320,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag):
_w,
)
out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx)
if _e_vec == 2:
if const_expr(_e_vec == 2):
store_val = arith.TruncIOp(T.i16, packed_i32)
store_raw = (
store_val._value
Expand Down Expand Up @@ -4512,7 +4512,7 @@ def launch_mixed_moe_gemm2(
gx = (
n_in - _model_dim_pad_idx + _tile_n_idx - arith.constant(1, index=True)
) / _tile_n_idx
if _persistent:
if const_expr(_persistent):
gy = arith.constant(_cu_num, index=True)
else:
_c_pm_l = arith.constant(persist_m, index=True)
Expand Down
2 changes: 1 addition & 1 deletion aiter/ops/flydsl/kernels/preshuffle_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3):
for imxdl in range_constexpr(_fp4_pack_M):
mi_idx = mi_p * _fp4_pack_M + imxdl
curr_row_a_lds = row_a_lds + (mi_idx * 16)
if (
if const_expr(
(a0_prefetch is not None)
and (k_idx == 0)
and (mi_idx == 0)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ requires = [
"psutil",
"ninja",
"pandas",
"flydsl==0.1.4.2"
"flydsl==0.1.5.dev504"
]

[tool.setuptools_scm]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ pyyaml
einops
pybind11>=3.0.1
ninja
flydsl==0.1.4.2
flydsl==0.1.5.dev504
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
OPT_COMPILER_CONFIG = os.path.join(this_dir, "aiter", "jit", "optCompilerConfig.json")
PACKAGE_NAME = "amd-aiter"

FLYDSL_VERSION = "flydsl==0.1.4.2"
FLYDSL_VERSION = "flydsl==0.1.5.dev504"

BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto")
PREBUILD_KERNELS = int(os.environ.get("PREBUILD_KERNELS", 0))
Expand Down
Loading