diff --git a/python/triton/language/extra/hip/libdevice.py b/python/triton/language/extra/hip/libdevice.py index a9b23c0ea5ef..02e5d2d0b211 100644 --- a/python/triton/language/extra/hip/libdevice.py +++ b/python/triton/language/extra/hip/libdevice.py @@ -57,6 +57,15 @@ def exp2(arg0, _builder=None): }, is_pure=True, _builder=_builder) +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_exp_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_exp_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + @core.extern def fast_dividef(arg0, arg1, _builder=None): return core.extern_elementwise("", "", [arg0, arg1], { @@ -295,6 +304,15 @@ def atanh(arg0, _builder=None): }, is_pure=True, _builder=_builder) +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_expm1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_expm1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + @core.extern def hypot(arg0, arg1, _builder=None): return core.extern_elementwise( @@ -313,6 +331,15 @@ def j0(arg0, _builder=None): }, is_pure=True, _builder=_builder) +@core.extern +def j1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_j1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_j1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + @core.extern def y0(arg0, _builder=None): return core.extern_elementwise(