diff --git a/iree/turbine/kernel/boo/op_exports/layer_norm.py b/iree/turbine/kernel/boo/op_exports/layer_norm.py index d8d2e0408..980473f7a 100644 --- a/iree/turbine/kernel/boo/op_exports/layer_norm.py +++ b/iree/turbine/kernel/boo/op_exports/layer_norm.py @@ -78,7 +78,7 @@ def output_shape(self) -> list[int]: @property def force_single_dispatch(self) -> bool: - return True + return self.mode == Mode.FORWARD @staticmethod def get( diff --git a/iree/turbine/kernel/boo/runtime/launch.py b/iree/turbine/kernel/boo/runtime/launch.py index 5a5734acb..8c0f068e7 100644 --- a/iree/turbine/kernel/boo/runtime/launch.py +++ b/iree/turbine/kernel/boo/runtime/launch.py @@ -231,6 +231,9 @@ def default_compiler_flags_callback(device: Device, cache_dir: Path) -> list[str flags.append( "--iree-dispatch-creation-enable-fuse-padding-into-linalg-consumer-ops" ) + flags.append("--iree-dispatch-creation-enable-split-reduction") + # Temporary flags to transpose filter layout without fusing into the computation dispatch. + flags.append("--iree-global-opt-experimental-enable-edge-reshape-propagation") flags.append( "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-convert-conv-filter-to-channels-last)" )