Skip to content

Commit 9901ca0

Browse files
committed
wip
1 parent f0ca3e7 commit 9901ca0

File tree

2 files changed

+40
-11
lines changed

2 files changed

+40
-11
lines changed

ggml/src/ggml-qnn/npu/device/op_glu.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ inline float dummy_load_coeff() {
2020
return 0;
2121
}
2222

23-
template <typename _TyData> inline float expf_fix(float x) {
24-
// Avoid overflow for large values, f32: log(3.4028234664e+38), f16: log(65504)
25-
constexpr float kMaxExp = std::is_same_v<_TyData, float> ? 88.02f : 11.0898664f;
23+
inline float expf_f16_guard_inf(float x) {
24+
// Avoid overflow for large values, f16: log(65504)
25+
constexpr float kMaxExp = 11.0898664f;
2626

2727
if (x >= kMaxExp) {
2828
// Avoid overflow for large values
@@ -33,14 +33,13 @@ template <typename _TyData> inline float expf_fix(float x) {
3333
return std::expf(x);
3434
}
3535

36-
template <typename _TyData>
37-
inline void glu_vec_op_impl(const _TyData * src0, const _TyData * src1, _TyData * dst, size_t count, float coeff) {
36+
inline void glu_vec_op_f16_f16(const __fp16 * src0, const __fp16 * src1, __fp16 * dst, size_t count, float coeff) {
3837
// TODO: use simd version, for some input hexagon intrinsics will generate nan instead of inf.
3938
for (uint32_t i = 0; i < count; ++i) {
4039
float x = src0[i];
4140
float g = src1[i];
4241

43-
dst[i] = (x / (1.0f + expf_fix<_TyData>(-x))) * g;
42+
dst[i] = (x / (1.0f + expf_f16_guard_inf(-x))) * g;
4443
}
4544
}
4645

@@ -153,7 +152,7 @@ bool glu_compute(hexagon::tensor * out, hexagon::compute_params * params) {
153152
if constexpr (_DataType == NPU_DATA_TYPE_F32) {
154153
return glu_impl<glu_vec_op_f32_f32, qhmath_load_div_sf_ltu>(out, params);
155154
} else if constexpr (_DataType == NPU_DATA_TYPE_F16) {
156-
return glu_impl<glu_vec_op_impl<__fp16>, dummy_load_coeff>(out, params);
155+
return glu_impl<glu_vec_op_f16_f16, dummy_load_coeff>(out, params);
157156
}
158157

159158
DEVICE_LOG_ERROR("Unsupported GLU data type: %s\n", hexagon::get_type_name(out->get_type()));

ggml/src/ggml-qnn/npu/device/vec_math.inl

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,28 +1127,58 @@ inline HVX_Vector_x2 hvx_vsf_convert_vhf(HVX_Vector vxl, HVX_Vector one) {
11271127
return ret;
11281128
}
11291129

1130+
/**
1131+
* @brief Calculates exponential (e^x) for vector elements with infinity guard
1132+
*
1133+
* This function computes the exponential value for each element in the input vector.
1134+
* For input values greater than kMaxExp (88.02f), the function returns the provided
1135+
* infinity value instead of attempting to calculate an exponential that would overflow.
1136+
*
1137+
* @param sline The input vector containing values to compute exponential for
1138+
* @param inf The vector containing the infinity representation to use for guarded values
1139+
* @return HVX_Vector containing exponential values, with values > kMaxExp replaced by inf
1140+
*
1141+
* @note Input values greater than 88.02f will return the specified infinity value
1142+
*/
11301143
inline HVX_Vector qhmath_hvx_exp_vf_guard_inf(HVX_Vector sline, const HVX_Vector inf) {
11311144
constexpr float kMaxExp = 88.02f;
11321145
const HVX_Vector max_exp = Q6_V_vsplat_R(reinterpret_cast<const uint32_t &>(kMaxExp));
11331146

1134-
HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(sline, max_exp);
1147+
HVX_VectorPred pred_gt_max_exp = Q6_Q_vcmp_gt_VsfVsf(sline, max_exp);
11351148

11361149
HVX_Vector out = qhmath_hvx_exp_vf(sline);
11371150

1138-
out = Q6_V_vmux_QVV(pred0, inf, out);
1151+
out = Q6_V_vmux_QVV(pred_gt_max_exp, inf, out);
11391152
return out;
11401153
}
11411154

1155+
/**
1156+
* @brief Vectorized division with guard for infinite denominators on HVX.
1157+
*
1158+
* Performs element-wise division num/denom using qhmath_hvx_div_vf and then
1159+
* masks out lanes where denom equals the provided inf value, forcing those
1160+
* lanes of the result to zero. This is a temporary guard until proper INF
1161+
* handling is implemented in the underlying division routine.
1162+
*
1163+
* @param num Numerator vector (per-lane).
1164+
* @param denom Denominator vector (per-lane); lanes equal to inf are zeroed in the output.
1165+
* @param coeffs Coefficients used by qhmath_hvx_div_vf for the reciprocal/division approximation.
1166+
* @param inf Lane value representing +INF to compare against denom.
1167+
* @return Vector of num/denom with lanes set to zero where denom == inf.
1168+
*
1169+
* @note NaNs, negative infinity, zero denominators, and subnormals are not explicitly handled.
1170+
* @see qhmath_hvx_div_vf
1171+
*/
11421172
inline HVX_Vector qhmath_hvx_div_vf_guard_inf(HVX_Vector num,
11431173
HVX_Vector denom,
11441174
HVX_VectorPair_x4 coeffs,
11451175
const HVX_Vector inf) {
1146-
HVX_VectorPred pred0 = Q6_Q_vcmp_eq_VwVw(denom, inf);
1176+
HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(denom, inf);
11471177

11481178
// TODO: fix the inf in div
11491179
HVX_Vector out = qhmath_hvx_div_vf(num, denom, coeffs);
11501180

1151-
out = Q6_V_vmux_QVV(pred0, Q6_V_vzero(), out);
1181+
out = Q6_V_vmux_QVV(pred_inf, Q6_V_vzero(), out);
11521182
return out;
11531183
}
11541184

0 commit comments

Comments
 (0)