Skip to content

Commit

Permalink
[TKW] Fix sympy expr lowering and add some more igemm test shapes (#184)
Browse files Browse the repository at this point in the history
* Rework how we are lowering `rational` sympy expressions, instead of
delayed materialization via lambdas introduce `_Rational` type and
propagate `numerator/denominator` values independently. Division will
only be materialized on explicit `sympy.floor/ceiling` op.
* Rework how igemm test cases are generated and introduce few real
shapes.
* Use custom pytest markers to separate perf/non-perf tests

---------

Signed-off-by: Ivan Butygin <[email protected]>
  • Loading branch information
Hardcode84 authored Oct 3, 2024
1 parent 553e929 commit 9ed388a
Show file tree
Hide file tree
Showing 3 changed files with 346 additions and 79 deletions.
202 changes: 130 additions & 72 deletions shark_turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dataclasses import dataclass
import torch.fx as fx
import torch.utils._pytree as pytree
from collections import namedtuple

from ..compiler.ir import (
Attribute,
Expand Down Expand Up @@ -200,76 +201,139 @@ def add_emitter_subs(emitter: WaveEmitter) -> dict[IndexSymbol, Any]:
return dynamics


_Rational = namedtuple("_Rational", ["numerator", "denominator"])


def gen_sympy_index(dynamics: dict[IndexSymbol, Any], expr: sympy.Expr) -> OpResult:
stack: list[OpResult] = []

def _broadcast(a, b):
if not isinstance(a, (Value, OpResult)):
a = a.result
def _get_ir_value(arg):
if not isinstance(arg, (Value, OpResult)):
arg = arg.result

return arg

if not isinstance(b, (Value, OpResult)):
b = b.result
def _check_vec_scalar(a, b):
return isinstance(a.type, VectorType) and a.type.element_type == b.type

def _broadcast(a, b):
a = _get_ir_value(a)
b = _get_ir_value(b)

if a.type == b.type:
return a, b

if isinstance(a.type, VectorType) and isinstance(
b.type, (IndexType, IntegerType)
):
assert a.type.element_type == b.type
if _check_vec_scalar(a, b):
b = vector_d.splat(a.type, b)
return a, b

if isinstance(a.type, (IndexType, IntegerType)) and isinstance(
b.type, VectorType
):
assert b.type.element_type == a.type
if _check_vec_scalar(b, a):
a = vector_d.splat(b.type, a)
return a, b

raise CodegenError(f"Cannot broadcast {a.type} and {b.type}")

def _process_mul_add_ops(term, is_mul):
args = []
callables = []
for _ in range(len(term.args)):
val = stack.pop()
if callable(val):
callables.append(val)
else:
args.append(val)
operation = None
for arg in args:
if operation is None:
operation = arg
continue
def get_const_val(arg):
if isinstance(arg, OpResult):
arg = arg.owner.opview

if is_mul:
operation = arith_d.MulIOp(*_broadcast(operation, arg))
else:
operation = arith_d.AddIOp(*_broadcast(operation, arg))
if isinstance(arg, arith_d.ConstantOp):
value = arg.attributes["value"]
if isinstance(value, IntegerAttr):
return int(value)

for arg in callables:
operation = arg(operation, is_mul)
return None

stack.append(operation)
def muli_fold(lhs, rhs):
if get_const_val(lhs) == 1:
return rhs

if get_const_val(rhs) == 1:
return lhs

return arith_d.muli(lhs, rhs)

# `x + (a/b)` transformed into `(x*b + a) / b`
def _add(lhs, rhs):
is_rational_lhs = isinstance(lhs, _Rational)
is_rational_rhs = isinstance(rhs, _Rational)
if is_rational_lhs and not is_rational_rhs:
numerator = muli_fold(*_broadcast(lhs.denominator, rhs))
numerator = arith_d.addi(*_broadcast(numerator, lhs.numerator))
return _Rational(numerator, lhs.denominator)
elif not is_rational_lhs and is_rational_rhs:
numerator = muli_fold(*_broadcast(lhs, rhs.denominator))
numerator = arith_d.addi(*_broadcast(numerator, rhs.numerator))
return _Rational(numerator, rhs.denominator)
elif is_rational_lhs and is_rational_rhs:
lhs_numerator = muli_fold(*_broadcast(lhs.numerator, rhs.denominator))
rhs_numerator = muli_fold(*_broadcast(rhs.numerator, lhs.denominator))
numerator = arith_d.addi(*_broadcast(lhs_numerator, rhs_numerator))
denominator = muli_fold(*_broadcast(lhs.denominator, rhs.denominator))
return _Rational(numerator, denominator)
else:
return arith_d.addi(*_broadcast(lhs, rhs))

# `x * (a/b)` transformed into `(x * a) / b`
def _mul(lhs, rhs):
is_rational_lhs = isinstance(lhs, _Rational)
is_rational_rhs = isinstance(rhs, _Rational)
if is_rational_lhs and not is_rational_rhs:
numerator = muli_fold(*_broadcast(lhs.numerator, rhs))
return _Rational(numerator, lhs.denominator)
elif not is_rational_lhs and is_rational_rhs:
numerator = muli_fold(*_broadcast(lhs, rhs.numerator))
return _Rational(numerator, rhs.denominator)
elif is_rational_lhs and is_rational_rhs:
numerator = muli_fold(*_broadcast(lhs.numerator, rhs.numerator))
denominator = muli_fold(*_broadcast(lhs.denominator, rhs.denominator))
return _Rational(numerator, denominator)
else:
return muli_fold(*_broadcast(lhs, rhs))

def _get_mul(numerator):
return lambda x: arith_d.MulIOp(*_broadcast(x, numerator))
def _floor(value):
if isinstance(value, _Rational):
value = arith_d.divsi(*_broadcast(value.numerator, value.denominator))

def _get_add(numerator, denominator):
return lambda x: arith_d.AddIOp(
*_broadcast(arith_d.MulIOp(*_broadcast(x, denominator)), numerator)
)
return value

def _get_div(mul, add, denominator):
return lambda x, is_mul: arith_d.DivSIOp(
*_broadcast(mul(x) if is_mul else add(x), denominator)
)
def _ceiling(value):
if isinstance(value, _Rational):
value = arith_d.ceildivsi(*_broadcast(value.numerator, value.denominator))

return value

def _group_rationals(stack, count):
"""Group rationals and non-rationals args into 2 contiguous sets.
This allows to mul/add all non-rationals first, reducing total number of ops.
"""
rationals = []
non_rationals = []
for _ in range(count):
val = stack.pop()
if isinstance(val, _Rational):
rationals.append(val)
else:
non_rationals.append(val)

return non_rationals + rationals

def _apply(args, func):
assert len(args) > 0
value = args[0]
for val in args[1:]:
value = func(value, val)

return value

def _enforce_non_rational(val, term):
if isinstance(val, _Rational):
raise CodegenError(f"Rational is not supported yet in '{type(term)}'")

def _get_const(val):
if isinstance(val, int):
return arith_d.constant(IndexType.get(), res)
return arith_d.constant(IndexType.get(), val)

if isinstance(val, (tuple, list)):
vec_type = VectorType.get([len(val)], IndexType.get())
Expand All @@ -296,56 +360,50 @@ def _get_const(val):
else:
raise CodegenError(f"Unknown symbol {term}")
case sympy.Integer():
stack.append(arith_d.constant(IndexType.get(), int(term)))
stack.append(_get_const(int(term)))
case sympy.Mul():
_process_mul_add_ops(term, is_mul=True)
args = _group_rationals(stack, len(term.args))
stack.append(_apply(args, _mul))
case sympy.Add():
_process_mul_add_ops(term, is_mul=False)
args = _group_rationals(stack, len(term.args))
stack.append(_apply(args, _add))
case sympy.Mod():
rhs = stack.pop()
lhs = stack.pop()
mod = arith_d.RemSIOp(*_broadcast(lhs, rhs))
_enforce_non_rational(rhs, term)
_enforce_non_rational(lhs, term)
mod = arith_d.remsi(*_broadcast(lhs, rhs))
stack.append(mod)
case sympy.floor():
# TODO: Since divsi rounds to zero, this seems to work.
# But check whether floordivsi is needed.
stack.append(stack.pop())
stack.append(_floor(stack.pop()))
case sympy.ceiling():
stack.append(_ceiling(stack.pop()))
case sympy.Rational():
# `x * (a/b)` transformed into `(x * a) / b`
# `x + (a/b)` transformed into `(x*b + a) / b`
numerator = arith_d.constant(IndexType.get(), abs(term.p))
denominator = arith_d.constant(IndexType.get(), abs(term.q))
# Assumes that the negative term is always carried on the numerator
if abs(term.p) > term.p:
zero = arith_d.constant(IndexType.get(), int(0))
numerator = arith_d.SubIOp(*_broadcast(zero, numerator))
mul = lambda x: x
if abs(term.p) != 1:
mul = _get_mul(numerator)
add = _get_add(numerator, denominator)
operation = _get_div(mul, add, denominator)
stack.append(operation)
numerator = _get_const(term.p)
denominator = _get_const(term.q)
stack.append(_Rational(numerator, denominator))
case sympy.StrictLessThan():
rhs = stack.pop()
lhs = stack.pop()
_enforce_non_rational(rhs, term)
_enforce_non_rational(lhs, term)
res = arith_d.cmpi(arith_d.CmpIPredicate.slt, *_broadcast(lhs, rhs))
stack.append(res)
case sympy.And():
rhs = stack.pop()
lhs = stack.pop()
_enforce_non_rational(rhs, term)
_enforce_non_rational(lhs, term)
res = arith_d.andi(*_broadcast(lhs, rhs))
stack.append(res)
case sympy.ceiling():
value = stack.pop()
if not isinstance(value, arith_d.DivSIOp):
raise CodegenError(f"Cannot handle ceil({value}) yet")
stack.append(arith_d.CeilDivSIOp(value.lhs, value.rhs))
case sympy.UnevaluatedExpr():
continue
case _:
raise CodegenError(f"Can not handle {type(term)} : {term}")
if len(stack) != 1:

if len(stack) != 1 or isinstance(stack[0], _Rational):
raise CodegenError(f"Expected single result, got {len(stack)}")

return stack[0]


Expand Down
31 changes: 31 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import pytest


def pytest_addoption(parser):
parser.addoption(
"--runperf", action="store_true", default=False, help="run performance tests"
)


def pytest_configure(config):
config.addinivalue_line(
"markers", "perf_only: performace test, runs only with '--runperf'"
)


def pytest_collection_modifyitems(config, items):
run_perf = config.getoption("--runperf")
for item in items:
is_perf_only = next(item.iter_markers("perf_only"), None) is not None
if run_perf:
if not is_perf_only:
item.add_marker(pytest.mark.skip("skip non-perf test"))
else:
if is_perf_only:
item.add_marker(pytest.mark.skip("skip perf test"))
Loading

0 comments on commit 9ed388a

Please sign in to comment.