Skip to content

Commit 167501f

Browse files
authored
fix softmax arm fp16s sum error, fix #5340 (#5393)
1 parent 6595743 commit 167501f

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/layer/arm/softmax_arm_asimdhp.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ int Softmax_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt)
255255
float16x8_t _ss01 = vpaddq_f16(_p0, _p1);
256256
float16x8_t _ss23 = vpaddq_f16(_p2, _p3);
257257
float16x8_t _ss2 = vpaddq_f16(_ss01, _ss23);
258-
_sum = vadd_f16(_sum, vpmax_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
258+
_sum = vadd_f16(_sum, vpadd_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
259259
vst1_f16(sumptr, _sum);
260260
ptr += 32;
261261
maxptr += 4;
@@ -292,7 +292,7 @@ int Softmax_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt)
292292
vst1q_f16(ptr, _p0);
293293
vst1q_f16(ptr + 8, _p1);
294294
float16x8_t _ss2 = vpaddq_f16(_p0, _p1);
295-
_sum = vadd_f16(_sum, vpmax_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
295+
_sum = vadd_f16(_sum, vpadd_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
296296
vst1_f16(sumptr, _sum);
297297
ptr += 16;
298298
maxptr += 4;
@@ -743,7 +743,7 @@ int Softmax_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt)
743743
float16x8_t _ss01 = vpaddq_f16(_p0, _p1);
744744
float16x8_t _ss23 = vpaddq_f16(_p2, _p3);
745745
float16x8_t _ss2 = vpaddq_f16(_ss01, _ss23);
746-
_sum = vadd_f16(_sum, vpmax_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
746+
_sum = vadd_f16(_sum, vpadd_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
747747
vst1_f16(sumptr, _sum);
748748
ptr += 32;
749749
sumptr += 4;
@@ -768,7 +768,7 @@ int Softmax_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt)
768768
float16x8_t _p1 = vld1q_f16(ptr + 8);
769769
float16x4_t _sum = vld1_f16(sumptr);
770770
float16x8_t _ss2 = vpaddq_f16(_p0, _p1);
771-
_sum = vadd_f16(_sum, vpmax_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
771+
_sum = vadd_f16(_sum, vpadd_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
772772
vst1_f16(sumptr, _sum);
773773
ptr += 16;
774774
sumptr += 4;

0 commit comments

Comments
 (0)