Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/target/codegen_cutedsl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ std::string CodeGenTileLangCuTeDSL::CanonicalizeFastmathFunctionName_(
{"logf", "tl.log"}, {"log2", "tl.log2"}, {"log2f", "tl.log2"},
{"log10", "tl.log10"}, {"tan", "tl.tan"}, {"cos", "tl.cos"},
{"sin", "tl.sin"}, {"sqrt", "tl.sqrt"}, {"sqrtf", "tl.sqrt"},
{"tanh", "tl.tanh"}, {"tanhf", "tl.tanh"},
};

auto it = kFastMathMap.find(func_name);
Expand Down
28 changes: 27 additions & 1 deletion tilelang/contrib/cutedsl/math.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,35 @@
import cutlass.cute as cute
from cutlass.cute.typing import Union, Numeric
from cutlass.cute.tensor import TensorSSA
from cutlass._mlir.dialects import arith
from cutlass._mlir.dialects import arith, math
from cutlass.cute.math import exp, exp2, log, log2, log10, tan, cos, sin, sqrt # noqa: F401

from cutlass._mlir.dialects import llvm
from cutlass.base_dsl.typing import Float32
from cutlass.cutlass_dsl import T, dsl_user_op


def divf(x: Union[TensorSSA, Numeric], y: Union[TensorSSA, Numeric], fastmath: bool = False) -> Union[TensorSSA, Numeric]:
return cute.math._math_op(arith.divf, fastmath, x, y)


@dsl_user_op
def __tanhf(x: Union[float, Float32], *, fastmath, loc=None, ip=None) -> Float32:
return Float32(
llvm.inline_asm(
T.f32(),
[Float32(x).ir_value()],
"tanh.approx.f32 $0, $1;",
"=f,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)
)


def tanh(x: Union[TensorSSA, Numeric], fastmath: bool = False) -> Union[TensorSSA, Numeric]:
tanh_op = __tanhf if fastmath else math.tanh
return cute.math._math_op(tanh_op, False, x)
Loading