diff --git a/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py b/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py index 81d7bfeb4e..75089a0286 100644 --- a/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py +++ b/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py @@ -520,7 +520,7 @@ def moe_gemm1( by = gpu.block_id("x") # tile along inter_dim (N) bx_persist = gpu.block_id("y") # persistent WG index - if xcd_swizzle > 0: + if const_expr(xcd_swizzle > 0): _NUM_XCDS_S1 = 8 _c1_sw = arith.constant(1, index=True) _c_tn_sw = arith.constant(tile_n, index=True) @@ -579,7 +579,7 @@ def moe_gemm1( _lds_out_elem_type = ( T.f32 if _need_quant else (T.bf16 if out_is_bf16 else T.f16) ) - if _split_lds_out and _use_cshuffle_epilog: + if const_expr(_split_lds_out and _use_cshuffle_epilog): _half_out_elems = int(tile_m) * (int(tile_n) // 2) lds_out = SmemPtr( base_ptr_pong, @@ -1296,7 +1296,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): ) for ikxdl in range_constexpr(pack_K): k_idx = ku128 * pack_K + ikxdl - if k_idx < ku_count: + if const_expr(k_idx < ku_count): gate_bp0, gate_bp1 = gate_b_tile_in[k_idx] if const_expr(not _single_b): up_bp0, up_bp1 = up_b_tile_in[k_idx] @@ -1863,7 +1863,7 @@ def _interleaved_half( prefetch_x_to_lds(k_tail1, lds_x_ping) else: x_regs_ping = load_x_tile(k_tail1) - if _pad_ku_skip > 0: + if const_expr(_pad_ku_skip > 0): gate_w_ping, up_w_ping = load_b_tile( k_tail1 // arith.constant(2, index=True), ku_limit=_tail_ku, @@ -1893,7 +1893,7 @@ def _interleaved_half( store_x_tile_to_lds(x_regs_ping, lds_x_ping) rocdl.s_waitcnt(0) _barrier() - if _pad_ku_skip > 0: + if const_expr(_pad_ku_skip > 0): a_tile_ping = prefetch_full_a_from_lds( lds_x_ping, ku_limit=_tail_ku ) @@ -1974,7 +1974,7 @@ def _swiglu_mul_vec4(gate_v4, up_v4): def _act_vec4(gate_v4, up_v4): """Dispatch activation based on `act` parameter.""" - if act == "swiglu": + if const_expr(act == "swiglu"): return _swiglu_mul_vec4(gate_v4, up_v4) else: return _silu_mul_vec4(gate_v4, up_v4) @@ -2320,7 +2320,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): _w, ) out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) - if _e_vec == 2: + if const_expr(_e_vec == 2): store_val = arith.TruncIOp(T.i16, packed_i32) store_raw = ( store_val._value @@ -4512,7 +4512,7 @@ def launch_mixed_moe_gemm2( gx = ( n_in - _model_dim_pad_idx + _tile_n_idx - arith.constant(1, index=True) ) / _tile_n_idx - if _persistent: + if const_expr(_persistent): gy = arith.constant(_cu_num, index=True) else: _c_pm_l = arith.constant(persist_m, index=True) diff --git a/aiter/ops/flydsl/kernels/preshuffle_gemm.py b/aiter/ops/flydsl/kernels/preshuffle_gemm.py index 932b9c001a..e043ac088c 100644 --- a/aiter/ops/flydsl/kernels/preshuffle_gemm.py +++ b/aiter/ops/flydsl/kernels/preshuffle_gemm.py @@ -907,7 +907,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): for imxdl in range_constexpr(_fp4_pack_M): mi_idx = mi_p * _fp4_pack_M + imxdl curr_row_a_lds = row_a_lds + (mi_idx * 16) - if ( + if const_expr( (a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0) diff --git a/pyproject.toml b/pyproject.toml index 082b4c37cc..7fdb43384d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ "psutil", "ninja", "pandas", - "flydsl==0.1.4.2" + "flydsl==0.1.5.dev504" ] [tool.setuptools_scm] diff --git a/requirements.txt b/requirements.txt index 8a863d7694..b986f2909d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ pyyaml einops pybind11>=3.0.1 ninja -flydsl==0.1.4.2 +flydsl==0.1.5.dev504 diff --git a/setup.py b/setup.py index bf78f9b177..6a296edfb1 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ OPT_COMPILER_CONFIG = os.path.join(this_dir, "aiter", "jit", "optCompilerConfig.json") PACKAGE_NAME = "amd-aiter" -FLYDSL_VERSION = "flydsl==0.1.4.2" +FLYDSL_VERSION = "flydsl==0.1.5.dev504" BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto") PREBUILD_KERNELS = int(os.environ.get("PREBUILD_KERNELS", 0))