diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index d3c125dbc3d..3d1c9da8329 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2865,6 +2865,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_UNARY_OP_NEG: return HTP_OP_UNARY_NEG; case GGML_UNARY_OP_EXP: return HTP_OP_UNARY_EXP; case GGML_UNARY_OP_SOFTPLUS: return HTP_OP_UNARY_SOFTPLUS; + case GGML_UNARY_OP_TANH: return HTP_OP_UNARY_TANH; default: break; } @@ -3335,6 +3336,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_TANH: supp = ggml_hexagon_supported_unary(sess, op); break; case GGML_UNARY_OP_SILU: diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 6203e3848b9..98db864dd42 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -62,6 +62,7 @@ enum htp_op_code { HTP_OP_UNARY_EXP, HTP_OP_UNARY_NEG, HTP_OP_UNARY_SOFTPLUS, + HTP_OP_UNARY_TANH, HTP_OP_GLU_SWIGLU, HTP_OP_GLU_SWIGLU_OAI, HTP_OP_GLU_GEGLU, diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index fa1e0698f4a..883a31d6163 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -542,6 +542,7 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_UNARY_SIGMOID: case HTP_OP_UNARY_NEG: case HTP_OP_UNARY_EXP: + case HTP_OP_UNARY_TANH: case HTP_OP_L2_NORM: return op_unary(octx); diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 26a0e0bd793..d4ae89ee6f0 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -373,6 +373,21 @@ static void l2_norm_f32(const float * restrict src, } } +static void tanh_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_tanh_f32_aa(dst_local, src_local, row_elems); + } +} + static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { const struct htp_unary_context * uctx = (const struct htp_unary_context *) data; struct htp_ops_context * octx = uctx->octx; @@ -477,6 +492,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_UNARY_SOFTPLUS: softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_UNARY_TANH: + tanh_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; case HTP_OP_L2_NORM: l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; @@ -547,10 +565,12 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_UNARY_SOFTPLUS: op_type = "softplus-f32"; break; + case HTP_OP_UNARY_TANH: + op_type = "tanh-f32"; + break; case HTP_OP_L2_NORM: op_type = "l2norm-f32"; break; - default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); return HTP_STATUS_NO_SUPPORT;