diff --git a/ggml/src/ggml-qnn/npu/device/device.cpp b/ggml/src/ggml-qnn/npu/device/device.cpp index ff2819bae65e5..db987217fa4c5 100644 --- a/ggml/src/ggml-qnn/npu/device/device.cpp +++ b/ggml/src/ggml-qnn/npu/device/device.cpp @@ -4,7 +4,6 @@ #include #include -#include #include "graph.hpp" #include "hexagon_npu.h" @@ -69,20 +68,28 @@ struct npu_device_context { } }; -inline hexagon::tensor * tensor_from_handle(npu_device_graph_handle_t h) { +inline hexagon::tensor * tensor_from_handle(npu_device_tensor_handle_t h) { + if (h == npu_device_INVALID_DEVICE_TENSOR_HANDLE) { + return nullptr; + } + return reinterpret_cast(h); } -inline npu_device_graph_handle_t tensor_to_handle(hexagon::tensor * tensor) { - return reinterpret_cast(tensor); +inline npu_device_tensor_handle_t tensor_to_handle(hexagon::tensor * tensor) { + return reinterpret_cast(tensor); } -inline hexagon::graph * graph_from_handle(npu_device_tensor_handle_t h) { +inline hexagon::graph * graph_from_handle(npu_device_graph_handle_t h) { + if (h == npu_device_INVALID_DEVICE_GRAPH_HANDLE) { + return nullptr; + } + return reinterpret_cast(h); } -inline npu_device_tensor_handle_t graph_to_handle(hexagon::graph * graph) { - return reinterpret_cast(graph); +inline npu_device_graph_handle_t graph_to_handle(hexagon::graph * graph) { + return reinterpret_cast(graph); } inline npu_device_context * device_context_from_handle(remote_handle64 h) { @@ -93,12 +100,7 @@ inline npu_device_context * device_context_from_handle(remote_handle64 h) { int npu_device_open(const char * uri, remote_handle64 * h) { // TODO: should we have a device context here? - auto * context = new (std::nothrow) npu_device_context(); - if (!context) { - DEVICE_LOG_ERROR("Failed to allocate memory for the npu_device_context"); - return AEE_ENOMEMORY; - } - + auto * context = new npu_device_context(); if (!context->init()) { DEVICE_LOG_ERROR("Failed to initialize npu_device_context"); delete context; @@ -144,12 +146,7 @@ AEEResult npu_device_device_support_op(remote_handle64 _h, npu_device_tensor_op AEEResult npu_device_tensor_init(remote_handle64 _h, const npu_device_tensor_config * info, npu_device_tensor_handle_t * tensor_handle) { NPU_UNUSED(_h); - auto * tensor = new (std::nothrow) hexagon::tensor(*info); - if (!tensor) { - DEVICE_LOG_ERROR("Failed to allocate memory for the tensor"); - return AEE_ENOMEMORY; - } - + auto * tensor = new hexagon::tensor(*info); *tensor_handle = tensor_to_handle(tensor); return AEE_SUCCESS; } @@ -177,13 +174,29 @@ AEEResult npu_device_tensor_free(remote_handle64 _h, npu_device_tensor_handle_t return AEE_SUCCESS; } -AEEResult npu_device_graph_init(remote_handle64 _h, npu_device_graph_handle_t * graph_handle) { +AEEResult npu_device_tensors_free(remote_handle64 _h, const npu_device_tensor_handle_t * tensor_handles, + int tensor_handlesLen) { NPU_UNUSED(_h); - auto * graph = new (std::nothrow) hexagon::graph(); - if (!graph) { - return AEE_ENOMEMORY; + if (!tensor_handles || tensor_handlesLen < 0) { + DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid arguments"); + return AEE_EINVARGS; + } + + for (int i = 0; i < tensor_handlesLen; ++i) { + auto * tensor = tensor_from_handle(tensor_handles[i]); + if (tensor) { + delete tensor; + } else { + DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid tensor handle at index %d", i); + } } + return AEE_SUCCESS; +} + +AEEResult npu_device_graph_init(remote_handle64 _h, npu_device_graph_handle_t * graph_handle) { + NPU_UNUSED(_h); + auto * graph = new hexagon::graph(); *graph_handle = graph_to_handle(graph); return AEE_SUCCESS; } diff --git a/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp b/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp index 9c264654c1c9e..5beea614a308c 100644 --- a/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp +++ b/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp @@ -13,10 +13,19 @@ inline float f16_to_f32(const npu_device_fp16_t src) { } // From: ggml/src/ggml-cpu/ops.cpp +template void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hexagon::tensor * k, const hexagon::tensor * v, const hexagon::tensor * mask, hexagon::compute_params * params) { static_assert(3 <= hexagon::kMaxParamsCount, "flash_attn op params count exceeds max params count"); + constexpr const npu_device_tensor_data_type kKvDataType = _IsKvF16 ? NPU_DATA_TYPE_F16 : NPU_DATA_TYPE_F32; + + if (k->get_type() != kKvDataType || v->get_type() != k->get_type()) { + DEVICE_LOG_ERROR("flash_attn_impl: k and v must have same type, got k: %s, v: %s\n", + hexagon::get_type_name(k->get_type()), hexagon::get_type_name(v->get_type())); + return; + } + float scale = out->get_op_param(0); const float max_bias = out->get_op_param(1); const float logit_softcap = out->get_op_param(2); @@ -37,9 +46,11 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex const float m0 = powf(2.0f, -(max_bias) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - const auto q_to_vec_dot = hexagon::get_type_traits(k->get_type()).from_float; // TODO: fix this - const auto kq_vec_dot = hexagon::get_type_traits(k->get_type()).vec_dot; - if (!q_to_vec_dot || !kq_vec_dot) { + const auto & k_type_traits = hexagon::get_type_traits(kKvDataType); + const auto q_to_vec_dot = k_type_traits.from_float; + constexpr const auto kq_vec_dot = _IsKvF16 ? hexagon::type_erase_dot_func : + hexagon::type_erase_dot_func; + if (!q_to_vec_dot) { DEVICE_LOG_ERROR("flash_attn_impl: unsupported data type for q, k, or v\n"); return; } @@ -50,12 +61,12 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex const auto DK = k->get_ne(0); const auto DV = v->get_ne(0); const auto row_bytes_q = q->get_ne(0) * hexagon::get_type_traits(q->get_type()).type_size; - const auto row_bytes_k = DK * hexagon::get_type_traits(k->get_type()).type_size; + const auto row_bytes_k = DK * k_type_traits.type_size; const auto row_bytes_v = DV * hexagon::get_type_traits(v->get_type()).type_size; - constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float); - const auto aligned_dk = (DK + kFloatsPerVector - 1) / kFloatsPerVector * kFloatsPerVector; - const auto aligned_dv = (DV + kFloatsPerVector - 1) / kFloatsPerVector * kFloatsPerVector; + constexpr const size_t kFloatsPerVectorPair = hexagon::kBytesPerVector * 2 / sizeof(float); + const auto aligned_dk = (DK + kFloatsPerVectorPair - 1) / kFloatsPerVectorPair * kFloatsPerVectorPair; + const auto aligned_dv = (DV + kFloatsPerVectorPair - 1) / kFloatsPerVectorPair * kFloatsPerVectorPair; size_t total_cache_size = sizeof(float) * (aligned_dk + 2 * aligned_dv); auto * cache_ptr = params->get_vtcm_cache(total_cache_size); if (!cache_ptr) { @@ -64,11 +75,10 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex } // loop over n_batch and n_head - const auto rows_per_batch = q->get_ne(2) * q->get_ne(1); - const auto out_rows_per_batch = out->get_ne(2) * out->get_ne(1); - const bool is_v_f16 = - v->get_type() == NPU_DATA_TYPE_F16; // check if V is in FP16 format, otherwise it is in FP32 format - uint8_t * dst_ptr = out->get_write_buffer(); + constexpr bool is_v_f16 = _IsKvF16; // check if V is in FP16 format, otherwise it is in FP32 format + const auto rows_per_batch = q->get_ne(2) * q->get_ne(1); + const auto out_rows_per_batch = out->get_ne(2) * out->get_ne(1); + uint8_t * dst_ptr = out->get_write_buffer(); if (!dst_ptr) { DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out, hexagon::get_type_name(out->get_type())); @@ -80,6 +90,10 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex const uint8_t * k_ptr = k->get_read_buffer(); const uint8_t * v_ptr = v->get_read_buffer(); const uint8_t * mask_ptr = mask ? mask->get_read_buffer() : nullptr; + float * VKQ32 = reinterpret_cast(cache_ptr); // FP32 VKQ accumulator + auto * VKQ16 = reinterpret_cast(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator + auto * Q_q = reinterpret_cast( + VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16 for (auto ir = start_end_row.first; ir < start_end_row.second; ++ir) { // q indices const auto iq3 = ir / rows_per_batch; @@ -90,15 +104,13 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; - float S = 0.0f; // sum - float M = -INFINITY; // maximum KQ value + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value - float * VKQ32 = reinterpret_cast(cache_ptr); // FP32 VKQ accumulator - auto * VKQ16 = reinterpret_cast(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator - auto * Q_q = reinterpret_cast( - VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16 + const auto * q_data = q_ptr + (iq1 * q->get_nb(1) + iq2 * q->get_nb(2) + iq3 * q->get_nb(3)); + hexagon::l2fetch_row(q_data, row_bytes_q); - if (is_v_f16) { + if constexpr (is_v_f16) { memset(VKQ16, 0, DV * sizeof(npu_device_fp16_t)); } else { memset(VKQ32, 0, DV * sizeof(float)); @@ -117,16 +129,13 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex const int iv3 = iq3 / rv3; const int iv2 = iq2 / rv2; - const auto * q_data = q_ptr + (iq1 * q->get_nb(1) + iq2 * q->get_nb(2) + iq3 * q->get_nb(3)); - if (iq1 < q->get_ne(1) - 1) { - hexagon::l2fetch_row(q_data + q->get_nb(1), row_bytes_q); - } - q_to_vec_dot(reinterpret_cast(q_data), Q_q, DK); // online softmax / attention // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf + const auto * k_plane_ptr = k_ptr + ik2 * k->get_nb(2) + ik3 * k->get_nb(3); + const auto * v_plane_ptr = v_ptr + iv2 * v->get_nb(2) + iv3 * v->get_nb(3); for (int64_t ic = 0; ic < k->get_ne(1); ++ic) { DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 0, loop); float mv = mp ? (slope * f16_to_f32(mp[ic])) : 0.0f; @@ -137,7 +146,7 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex float s = 0.f; { DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 1, kq_dot); - const auto * k_data = k_ptr + (ic * k->get_nb(1) + ik2 * k->get_nb(2) + ik3 * k->get_nb(3)); + const auto * k_data = k_plane_ptr + ic * k->get_nb(1); if (ic < k->get_ne(1) - 1) { hexagon::l2fetch_row(k_data + k->get_nb(1), row_bytes_k); } @@ -156,12 +165,12 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value float vs = 1.0f; // post-softmax KQ value, expf(s - M) - const auto * v_data = v_ptr + (ic * v->get_nb(1) + iv2 * v->get_nb(2) + iv3 * v->get_nb(3)); + const auto * v_data = v_plane_ptr + ic * v->get_nb(1); if (ic < v->get_ne(1)) { hexagon::l2fetch_row(v_data, row_bytes_v); } - if (is_v_f16) { + if constexpr (is_v_f16) { if (s > M) { // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; @@ -201,7 +210,7 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex S = S * ms + vs; // scale and increment sum with partial sum } - if (is_v_f16) { + if constexpr (is_v_f16) { // TODO: use a more efficient conversion for (int64_t d = 0; d < DV; ++d) { VKQ32[d] = f16_to_f32(VKQ16[d]); @@ -218,7 +227,10 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex const int i3 = iq3; // permute(0, 2, 1, 3) - memcpy(dst_ptr + (i3 * out_rows_per_batch + i2 + i1 * out->get_ne(1)) * out->get_nb(1), VKQ32, out->get_nb(1)); + hexagon::vec_cpy_f32( + reinterpret_cast(VKQ32), + reinterpret_cast(dst_ptr + (i3 * out_rows_per_batch + i2 + i1 * out->get_ne(1)) * out->get_nb(1)), + out->get_ne(0)); } out->release_write_buffer(); // mark the output tensor as modified @@ -244,7 +256,11 @@ bool flash_attn_f32(tensor * out, compute_params * params) { return false; } - flash_attn_impl(out, q, k, v, mask, params); + if (k->get_type() == NPU_DATA_TYPE_F16) { + flash_attn_impl(out, q, k, v, mask, params); + } else { + flash_attn_impl(out, q, k, v, mask, params); + } return true; } diff --git a/ggml/src/ggml-qnn/npu/device/op_impl.cpp b/ggml/src/ggml-qnn/npu/device/op_impl.cpp index 6f89f454598ba..a794a8b750138 100644 --- a/ggml/src/ggml-qnn/npu/device/op_impl.cpp +++ b/ggml/src/ggml-qnn/npu/device/op_impl.cpp @@ -12,64 +12,10 @@ namespace { -template -inline void vec_op_impl(const _TyData * src0, const _TyData * src1, size_t count, _TyData * dst) { - constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData); - - HVX_Vector * iptr0 = ((HVX_Vector *) src0); - HVX_Vector * const iptr0_end = ((HVX_Vector *) src0) + (count / kElementsPerVector); - HVX_Vector * iptr1 = ((HVX_Vector *) src1); - HVX_Vector * optr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned - HVX_Vector prev0 = *iptr0++; - HVX_Vector prev1 = *iptr1++; - - while (iptr0 < iptr0_end) { - HVX_Vector curr0 = *iptr0++; - HVX_Vector curr1 = *iptr1++; - HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - *optr++ = _OpIntrinsic(s0, s1); - prev0 = curr0; - prev1 = curr1; - } - - const size_t leftover = count % kElementsPerVector; - if ((iptr0_end - ((HVX_Vector *) src0)) > 0) { - // handle the last vector - // see also: - // https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147 - // or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c - bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(iptr0); - bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(iptr1); - HVX_Vector curr0 = should_fetch_src0 ? *iptr0 : prev0; - HVX_Vector curr1 = should_fetch_src1 ? *iptr1 : prev1; - iptr0 += should_fetch_src0 ? 1 : 0; - iptr1 += should_fetch_src1 ? 1 : 0; - HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - *optr++ = _OpIntrinsic(s0, s1); - prev0 = curr0; - prev1 = curr1; - } - - const size_t leftover_bytes = leftover * sizeof(_TyData); - if (leftover > 0) { - // handle the leftover elements - HVX_Vector curr0 = - (leftover_bytes + hexagon::unaligned_bytes(iptr0) > hexagon::kBytesPerVector) ? *iptr0 : prev0; - curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - - HVX_Vector curr1 = - (leftover_bytes + hexagon::unaligned_bytes(iptr1) > hexagon::kBytesPerVector) ? *iptr1 : prev1; - curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - - hexagon::q6op_vstu_variable_ARV(optr, leftover_bytes, _OpIntrinsic(curr0, curr1)); - } -} - -template +template inline void vec_op_f32_f32(const float * src0, const float * src1, size_t count, float * dst) { - vec_op_impl<_OpIntrinsic, float>(src0, src1, count, dst); + using namespace hexagon::vec; + vec_trans_op_impl<_OpBinaryTransform, float>(src0, src1, count, dst); } inline HVX_Vector vadd_f32_f32(HVX_Vector a, HVX_Vector b) { @@ -84,10 +30,11 @@ inline HVX_Vector vmul_f32_f32(HVX_Vector a, HVX_Vector b) { return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)); } -template +template inline void vec_op_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count, npu_device_fp16_t * dst) { - vec_op_impl<_OpIntrinsic, npu_device_fp16_t>(src0, src1, count, dst); + using namespace hexagon::vec; + vec_trans_op_impl<_OpBinaryTransform, npu_device_fp16_t>(src0, src1, count, dst); } inline HVX_Vector vadd_f16_f16(HVX_Vector a, HVX_Vector b) { @@ -252,10 +199,10 @@ void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) { prev = curr; } - const size_t leftover_bytes = leftover * sizeof(float); if (leftover > 0) { // handle the leftover elements - HVX_Vector curr = + const size_t leftover_bytes = leftover * sizeof(float); + HVX_Vector curr = (leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev; curr = Q6_V_valign_VVR(curr, prev, (size_t) src); sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, diff --git a/ggml/src/ggml-qnn/npu/device/op_mul_mat.cpp b/ggml/src/ggml-qnn/npu/device/op_mul_mat.cpp index ff1335ace2731..e7ca2ea4404c5 100644 --- a/ggml/src/ggml-qnn/npu/device/op_mul_mat.cpp +++ b/ggml/src/ggml-qnn/npu/device/op_mul_mat.cpp @@ -6,26 +6,37 @@ namespace { -template struct get_data_type {}; +template struct get_data_type {}; -template -struct get_data_type { - using data_type0 = _TyData0; - using data_type1 = _TyData1; +template +struct get_data_type { + using data_type0 = _TData0; + using data_type1 = _TData1; }; -template +template struct convert_vector {}; + +template <> struct convert_vector { + static float convert(HVX_Vector vec) { return hexagon::get_flt0_from_fltv(Q6_Vsf_equals_Vqf32(vec)); } +}; + +template <> struct convert_vector { + static float convert(HVX_Vector vec) { + HVX_Vector vect = Q6_Vhf_equals_Vqf16(vec); + uint16_t i = (vect[0] & 0xffff); + return reinterpret_cast<__fp16 &>(i); + } +}; + +template void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tensor * dst, hexagon::compute_params * params) { using data_type0 = typename get_data_type::data_type0; using data_type1 = typename get_data_type::data_type1; - static_assert(!_IsQuantized || std::is_same_v, - "data_type0 must be the same as hexagon::dequant_target_type"); - const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0); auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float; - if (_IsQuantized && dequantize_row_func == nullptr) { + if (_ShouldCacheSrc0 && dequantize_row_func == nullptr) { DEVICE_LOG_ERROR("Unsupported quantized src0 type: %d, dequantize_row_func is null\n", src0->get_type()); return; } @@ -61,7 +72,7 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso size_t src0_plane_cache_size = 0; uint8_t * src0_plane_cache_ptr = nullptr; const uint8_t * last_cached_plane_ptr = nullptr; - if constexpr (_IsQuantized) { + if constexpr (_ShouldCacheSrc0) { src0_plane_slice_row_count = std::min(params->get_vtcm_quota_size() / src0_actual_row_size, src0_plane_slice_row_count); src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count; @@ -78,11 +89,12 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso DEVICE_LOG_DEBUG( "mul_mat_impl src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, is_quantized: %d, vtcm_mem: " "%p(%zu)\n", - src0_actual_row_size, src0_plane_slice_row_count, _IsQuantized, (void *) src0_plane_cache_ptr, + src0_actual_row_size, src0_plane_slice_row_count, _ShouldCacheSrc0, (void *) src0_plane_cache_ptr, src0_plane_cache_size); const size_t valid_row0_bytes = src0->get_ne(0) * sizeof(data_type0); - const size_t valid_row1_bytes = src1->get_ne(0) * sizeof(data_type1); + const size_t valid_row1_bytes = + src0->get_ne(0) * sizeof(data_type1); // src0 and src1 should have the same element count in the 1st dimension DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat); uint8_t * dst_ptr = dst->get_write_buffer(); @@ -92,7 +104,7 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso return; } - constexpr bool should_fetch_src0_row = !_IsQuantized; + constexpr bool should_fetch_src0_row = !_ShouldCacheSrc0; const uint8_t * src0_ptr = src0->get_read_buffer(); const uint8_t * src1_ptr = src1->get_read_buffer(); for (int64_t ip = start_end_plane.first; ip < start_end_plane.second; ip++) { @@ -102,24 +114,24 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso auto * dst_plane = dst_ptr + i3 * dst->get_nb(3) + i2 * dst->get_nb(2); for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second; col_idx += src0_plane_slice_row_count) { - const auto actual_row_count = + const int64_t actual_row_count = std::min(src0_plane_slice_row_count, start_end_element.second - col_idx); // number of rows in this slice const uint8_t * src0_plane = src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2) + col_idx * src0->get_nb(1); - if constexpr (_IsQuantized) { + if constexpr (_ShouldCacheSrc0) { if (last_cached_plane_ptr != src0_plane) { DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant); - for (int64_t ir = 0; ir < (int64_t) actual_row_count; ir++) { + hexagon::l2fetch_row(src0_plane, src0->get_nb(1)); + for (int64_t ir = 0; ir < actual_row_count; ir++) { auto * src0_row = src0_plane + ir * src0->get_nb(1); if (ir + 1 < actual_row_count) { hexagon::l2fetch_row(src0_row + src0->get_nb(1), src0->get_nb(1)); } - auto * dst_row = reinterpret_cast(src0_plane_cache_ptr + - ir * src0_actual_row_size); - dequantize_row_func(src0_row, reinterpret_cast(dst_row), + auto * cached_row_ptr = src0_plane_cache_ptr + ir * src0_actual_row_size; + dequantize_row_func(src0_row, reinterpret_cast(cached_row_ptr), src0->get_ne(0)); } @@ -138,34 +150,45 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso auto * src1_row = src1_plane + i1 * src1->get_nb(1); auto * dst_row = reinterpret_cast(dst_plane + i1 * dst->get_nb(1)) + col_idx; int64_t i0 = 0; - for (; i0 + 1 < (int64_t) actual_row_count; i0 += 2) { + for (; i0 + 1 < actual_row_count; i0 += 2) { auto * src0_row = src0_plane + i0 * src0_actual_row_size; if constexpr (should_fetch_src0_row) { hexagon::l2fetch_row(src0_row + src0_actual_row_size, valid_row0_bytes); } // TODO: figure dst how to handle a entire row - dst_row[i0] = _DotFunc(reinterpret_cast(src0_row), - reinterpret_cast(src1_row), (size_t) src0->get_ne(0)); + auto res0 = _DotFunc(reinterpret_cast(src0_row), + reinterpret_cast(src1_row), (size_t) src0->get_ne(0)); + + { + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store); + dst_row[i0] = convert_vector::convert(res0); + } - if (should_fetch_src0_row && i0 + 2 < (int64_t) actual_row_count) { + if (should_fetch_src0_row && i0 + 2 < actual_row_count) { hexagon::l2fetch_row(src0_row + src0_actual_row_size + src0_actual_row_size, valid_row0_bytes); } // TODO: figure dst how to handle a entire row - dst_row[i0 + 1] = - _DotFunc(reinterpret_cast(src0_row + src0_actual_row_size), - reinterpret_cast(src1_row), (size_t) src0->get_ne(0)); + auto res1 = _DotFunc(reinterpret_cast(src0_row + src0_actual_row_size), + reinterpret_cast(src1_row), (size_t) src0->get_ne(0)); + + { + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store); + dst_row[i0 + 1] = convert_vector::convert(res1); + } } if (ip + 1 < start_end_plane.second) { hexagon::l2fetch_row(src1_row + src1->get_nb(1), valid_row1_bytes); } - if (i0 < (int64_t) actual_row_count) { + if (i0 < actual_row_count) { auto * src0_row = src0_plane + i0 * src0_actual_row_size; - dst_row[i0] = _DotFunc(reinterpret_cast(src0_row), + auto res = _DotFunc(reinterpret_cast(src0_row), reinterpret_cast(src1_row), (size_t) src0->get_ne(0)); + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, store); + dst_row[i0] = convert_vector::convert(res); } } } @@ -174,6 +197,25 @@ void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tenso dst->release_write_buffer(); // mark the output tensor as modified } +bool is_row_size_cacheable(const npu_device_tensor_spec & src) { + const auto & type_traits = hexagon::get_type_traits(src.type); + if (type_traits.to_float == nullptr) { + DEVICE_LOG_DEBUG("[MUL_MAT]src.type(%s) cannot be cached, to_float is null\n", + hexagon::get_type_name(src.type)); + return false; + } + + const size_t type_size = type_traits.is_quantized ? sizeof(hexagon::dequant_target_type) : type_traits.type_size; + const auto vtcm_thread_quota_size = hexagon::default_thread_pool::get_per_thread_vtcm_quota(); + if (src.ne[0] * type_size > vtcm_thread_quota_size) { + DEVICE_LOG_DEBUG("[MUL_MAT]src.type(%s) ne[0] is too large: %ld, vtcm_thread_quota_size: %zu\n", + hexagon::get_type_name(src.type), (long) src.ne[0], vtcm_thread_quota_size); + return false; + } + + return true; +} + bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const npu_device_tensor_spec & src1) { if (src1.type != NPU_DATA_TYPE_F32 && src1.type != NPU_DATA_TYPE_F16) { DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) and src1.type(%s) mismatch and src1 is not F32\n", @@ -194,10 +236,7 @@ bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const n return false; } - const auto vtcm_thread_quota_size = hexagon::default_thread_pool::get_per_thread_vtcm_quota(); - if (src0.ne[0] * sizeof(hexagon::dequant_target_type) > vtcm_thread_quota_size) { - DEVICE_LOG_DEBUG("[MUL_MAT]src0.type(%s) ne[0] is too large: %ld, vtcm_thread_quota_size: %zu\n", - hexagon::get_type_name(src0.type), (long) src0.ne[0], vtcm_thread_quota_size); + if (!is_row_size_cacheable(src0)) { return false; } @@ -208,9 +247,8 @@ bool is_quantized_mul_mat_supported(const npu_device_tensor_spec & src0, const n bool is_mul_mat_f16_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::tensor * src1, bool is_src0_quantized) { const auto * src1_ptr = src1->get_read_buffer_as(); - const auto * src0_ptr = is_src0_quantized ? - src1->get_read_buffer_as() : - src0->get_read_buffer_as(); // skip src0 for quantized tensors + const auto * src0_ptr = + is_src0_quantized ? nullptr : src0->get_read_buffer_as(); // skip src0 for quantized tensors if (!hexagon::is_f16_f32_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) { DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0)); @@ -223,13 +261,23 @@ bool is_mul_mat_f16_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::ten bool is_mul_mat_f16_f16_src_tensors_aligned(hexagon::tensor * src0, hexagon::tensor * src1, bool is_src0_quantized) { const auto * src1_ptr = src1->get_read_buffer_as(); - const auto * src0_ptr = is_src0_quantized ? src1_ptr : src0->get_read_buffer_as(); + const auto * src0_ptr = is_src0_quantized ? nullptr : src0->get_read_buffer_as(); if (!hexagon::is_f16_f16_dot_product_aligned(src0_ptr, src1_ptr, src0->get_ne(0))) { DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_unaligned: ne[0]: %ld\n", (long) src0->get_ne(0)); return false; } + if (!is_src0_quantized && !hexagon::is_size_aligned(src0->get_nb(1))) { + DEVICE_LOG_DEBUG("[MUL_MAT]src0 tensor nb[1] is not aligned: %zu\n", src0->get_nb(1)); + return false; + } + + if (!hexagon::is_size_aligned(src1->get_nb(1))) { + DEVICE_LOG_DEBUG("[MUL_MAT]src1 tensor nb[1] is not aligned: %zu\n", src1->get_nb(1)); + return false; + } + DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_aligned: ne[0]: %ld\n", (long) src0->get_ne(0)); return true; } @@ -243,6 +291,16 @@ bool is_mul_mat_f32_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::ten return false; } + if (!hexagon::is_size_aligned(src0->get_nb(1))) { + DEVICE_LOG_DEBUG("[MUL_MAT]src0 tensor nb[1] is not aligned: %zu\n", src0->get_nb(1)); + return false; + } + + if (!hexagon::is_size_aligned(src1->get_nb(1))) { + DEVICE_LOG_DEBUG("[MUL_MAT]src1 tensor nb[1] is not aligned: %zu\n", src1->get_nb(1)); + return false; + } + DEVICE_LOG_DEBUG("[MUL_MAT]src_tensors_aligned: ne[0]: %ld\n", (long) src0->get_ne(0)); return true; } @@ -250,30 +308,32 @@ bool is_mul_mat_f32_f32_src_tensors_aligned(hexagon::tensor * src0, hexagon::ten typedef void (*mul_mat_func_type)(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tensor * dst, hexagon::compute_params * params); -constexpr const mul_mat_func_type kMulMatF16F32Funcs[2][2] = { - { - // non-quantized - mul_mat_impl, // F32 * F32 unaligned - mul_mat_impl, // F32 * F32 aligned - }, - { - // quantized - mul_mat_impl, // F32 * F32 quantized unaligned - mul_mat_impl, // F32 * F32 quantized aligned - }, +constexpr const mul_mat_func_type kMulMatF32F32CachedFuncs[2] = { + // quantized and non-quantized + mul_mat_impl, // F32 * F32 quantized unaligned + mul_mat_impl, // F32 * F32 quantized aligned +}; + +constexpr const mul_mat_func_type kMulMatF32F32Funcs[2] = { + // quantized and non-quantized + mul_mat_impl, // F32 * F32 quantized unaligned + mul_mat_impl, // F32 * F32 quantized aligned +}; + +constexpr const mul_mat_func_type kMulMatF16CachedFuncs[2] = { + mul_mat_impl, // F16 * F16 quantized unaligned + mul_mat_impl, // F16 * F16 quantized aligned }; -constexpr const mul_mat_func_type kMulMatF16Funcs[2][2] = { - { - // non-quantized - mul_mat_impl, // F16 * F16 unaligned - mul_mat_impl, // F16 * F16 aligned - }, - { - // quantized - mul_mat_impl, // F16 * F16 quantized unaligned - mul_mat_impl, // F16 * F16 quantized aligned - }, +constexpr const mul_mat_func_type kMulMatF16Funcs[2] = { + mul_mat_impl, // F16 * F16 quantized unaligned + mul_mat_impl, // F16 * F16 quantized aligned +}; + +constexpr const mul_mat_func_type kMulMatF16F32Funcs[2] = { + // quantized and non-quantized + mul_mat_impl, // F32 * F32 quantized unaligned + mul_mat_impl, // F32 * F32 quantized aligned }; } // namespace @@ -297,22 +357,26 @@ bool mul_mat_f32(hexagon::tensor * out, compute_params * params) { } const bool is_src0_quantized = is_quantized_type(src0->get_type()); + const bool should_cache_src0 = is_src0_quantized || src1->get_ne(1) > 1; switch (src1->get_type()) { case NPU_DATA_TYPE_F32: if (is_src0_quantized || src0->get_type() == NPU_DATA_TYPE_F16) { - kMulMatF16F32Funcs[is_src0_quantized][is_mul_mat_f16_f32_src_tensors_aligned( - src0, src1, is_src0_quantized)](src0, src1, out, params); + kMulMatF16F32Funcs[is_mul_mat_f16_f32_src_tensors_aligned(src0, src1, is_src0_quantized)](src0, src1, + out, params); + } else if (should_cache_src0) { + kMulMatF32F32CachedFuncs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1)](src0, src1, out, params); } else { - if (is_mul_mat_f32_f32_src_tensors_aligned(src0, src1)) { - mul_mat_impl(src0, src1, out, params); - } else { - mul_mat_impl(src0, src1, out, params); - } + kMulMatF32F32Funcs[is_mul_mat_f32_f32_src_tensors_aligned(src0, src1)](src0, src1, out, params); } return true; case NPU_DATA_TYPE_F16: - kMulMatF16Funcs[is_src0_quantized][is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized)]( - src0, src1, out, params); + if (should_cache_src0) { + kMulMatF16CachedFuncs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized)]( + src0, src1, out, params); + } else { + kMulMatF16Funcs[is_mul_mat_f16_f16_src_tensors_aligned(src0, src1, is_src0_quantized)](src0, src1, out, + params); + } return true; default: break; diff --git a/ggml/src/ggml-qnn/npu/device/op_rope.cpp b/ggml/src/ggml-qnn/npu/device/op_rope.cpp index 514c445290ef2..34bd0409db90e 100644 --- a/ggml/src/ggml-qnn/npu/device/op_rope.cpp +++ b/ggml/src/ggml-qnn/npu/device/op_rope.cpp @@ -270,8 +270,9 @@ bool rope_impl(hexagon::tensor * out, hexagon::compute_params * params) { } } else { // fill the remain channels with data from src tensor - memcpy(dst_row + n_dims * out->get_nb(0), src0_row + n_dims * src0->get_nb(0), - (out->get_ne(0) - n_dims) * sizeof(float)); + hexagon::vec_cpy_f32(reinterpret_cast(src0_row + n_dims * src0->get_nb(0)), + reinterpret_cast(dst_row + n_dims * out->get_nb(0)), + out->get_ne(0) - n_dims); } } } diff --git a/ggml/src/ggml-qnn/npu/device/tensor.hpp b/ggml/src/ggml-qnn/npu/device/tensor.hpp index 3bf834f826f4c..c6a7fb10779dc 100644 --- a/ggml/src/ggml-qnn/npu/device/tensor.hpp +++ b/ggml/src/ggml-qnn/npu/device/tensor.hpp @@ -60,7 +60,8 @@ class tensor { memcpy(_op_params, config.params, sizeof(_op_params)); for (size_t i = 0; i < DEVICE_TENSOR_MAX_SRC; ++i) { auto src_handle = config.src_handles[i]; - _src[i] = (src_handle ? reinterpret_cast(src_handle) : nullptr); + _src[i] = (src_handle != npu_device_INVALID_DEVICE_TENSOR_HANDLE ? reinterpret_cast(src_handle) : + nullptr); } } diff --git a/ggml/src/ggml-qnn/npu/device/type_traits.cpp b/ggml/src/ggml-qnn/npu/device/type_traits.cpp index 704607167fec5..31377f6e55219 100644 --- a/ggml/src/ggml-qnn/npu/device/type_traits.cpp +++ b/ggml/src/ggml-qnn/npu/device/type_traits.cpp @@ -28,22 +28,26 @@ inline npu_device_fp16_t to_fp16(const float src) { return reinterpret_cast(f16_value); } +template inline HVX_Vector load_into_vector(const _TStruct * src) { + static_assert(hexagon::kBytesPerVector >= sizeof(_TStruct) * _Count, "_TStruct too large for vector load"); + + const HVX_Vector * qs0 = reinterpret_cast(&(src->*_MemberPtr)); + HVX_Vector prev = *qs0; + HVX_Vector curr = hexagon::is_addr_aligned(qs0) ? Q6_V_vzero() : *(qs0 + 1); + return Q6_V_valign_VVR(curr, prev, (size_t) qs0); +} + template inline HVX_Vector load_block_generic(const _TBlock & src) { static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock), "wrong q4_0 block size/padding"); - - const HVX_Vector * qs0 = reinterpret_cast(src.qs); - const HVX_Vector * qs1 = qs0 + 1; - return Q6_V_valign_VVR(*qs1, *qs0, (size_t) src.qs); + return load_into_vector<_TBlock, 1, &_TBlock::qs>(&src); } template inline HVX_Vector load_dual_block_generic(const _TBlock * srcs) { static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 2, "wrong q4_0 block size/padding"); constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs); - const HVX_Vector * qs0 = reinterpret_cast(srcs->qs); - const HVX_Vector * qs1 = qs0 + 1; - HVX_Vector blocks = Q6_V_valign_VVR(*qs1, *qs0, (size_t) srcs->qs); - HVX_Vector block1 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock)); + HVX_Vector blocks = load_into_vector<_TBlock, 2, &_TBlock::qs>(srcs); + HVX_Vector block1 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock)); return Q6_V_lo_W(Q6_W_vshuff_VVR(block1, blocks, kSizeOfQs)); } @@ -51,15 +55,14 @@ template inline HVX_Vector load_qual_block_generic(const _TBl static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 4, "wrong q4_0 block size/padding"); constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs); - const HVX_Vector * qs0 = reinterpret_cast(srcs->qs); - const HVX_Vector * qs1 = qs0 + 1; - HVX_Vector blocks = Q6_V_valign_VVR(*qs1, *qs0, (size_t) srcs->qs); - HVX_Vector block1 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock)); - HVX_Vector block2 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock) * 2); - HVX_Vector block3 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock) * 3); + HVX_Vector blocks = load_into_vector<_TBlock, 4, &_TBlock::qs>(srcs); + HVX_Vector block1 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock)); + HVX_VectorPair qp0 = Q6_W_vshuff_VVR(block1, blocks, kSizeOfQs); + + HVX_Vector block2 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock) * 2); + HVX_Vector block3 = Q6_V_valign_VVR(Q6_V_vzero(), blocks, sizeof(_TBlock) * 3); + HVX_VectorPair qp1 = Q6_W_vshuff_VVR(block3, block2, kSizeOfQs); - HVX_VectorPair qp0 = Q6_W_vshuff_VVR(block1, blocks, kSizeOfQs); - HVX_VectorPair qp1 = Q6_W_vshuff_VVR(block3, block2, kSizeOfQs); return Q6_V_lo_W(Q6_W_vshuff_VVR(Q6_V_lo_W(qp1), Q6_V_lo_W(qp0), kSizeOfQs * 2)); } @@ -381,17 +384,22 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2 + 4)); q_lo = Q6_Vb_vsub_VbVb(Q6_V_lo_W(qp0), minus); qp0 = Q6_Wh_vunpack_Vb(q_lo); - q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(qp0)); - q_hi = Q6_Vhf_equals_Vh(Q6_V_hi_W(qp0)); - q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scales01); - q_hi = Q6_Vqf16_vmpy_VhfVhf(q_hi, scales23); + + q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(qp0)); + q_hi = Q6_Vhf_equals_Vh(Q6_V_hi_W(qp0)); + + q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scales01); + q_lo = Q6_Vhf_equals_Vqf16(q_lo); + + q_hi = Q6_Vqf16_vmpy_VhfVhf(q_hi, scales23); + q_hi = Q6_Vhf_equals_Vqf16(q_hi); if constexpr (_IsDstAligned) { - reinterpret_cast(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(q_lo); - reinterpret_cast(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(q_hi); + reinterpret_cast(dst_ptr)[0] = q_lo; + reinterpret_cast(dst_ptr)[1] = q_hi; } else { - reinterpret_cast(dst_ptr)[0] = Q6_Vhf_equals_Vqf16(q_lo); - reinterpret_cast(dst_ptr)[1] = Q6_Vhf_equals_Vqf16(q_hi); + reinterpret_cast(dst_ptr)[0] = q_lo; + reinterpret_cast(dst_ptr)[1] = q_hi; } dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_target_type) * 2; @@ -412,11 +420,12 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d qp0 = Q6_Wh_vunpack_Vb(q_lo); q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(qp0)); q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scales01); + q_lo = Q6_Vhf_equals_Vqf16(q_lo); if constexpr (_IsDstAligned) { - *reinterpret_cast(dst_ptr) = Q6_Vhf_equals_Vqf16(q_lo); + *reinterpret_cast(dst_ptr) = q_lo; } else { - *reinterpret_cast(dst_ptr) = Q6_Vhf_equals_Vqf16(q_lo); + *reinterpret_cast(dst_ptr) = q_lo; } dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_target_type); @@ -434,12 +443,12 @@ void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_target_type * d qp0 = Q6_Wh_vunpack_Vb(q_lo); q_lo = Q6_Vhf_equals_Vh(Q6_V_lo_W(qp0)); q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scales); + q_lo = Q6_Vhf_equals_Vqf16(q_lo); + if constexpr (_IsDstAligned) { - hexagon::q6op_vstu_variable_aligned(dst_ptr, Q6_Vhf_equals_Vqf16(q_lo)); + hexagon::q6op_vstu_variable_aligned(dst_ptr, q_lo); } else { - hexagon::q6op_vstu_variable_ARV( - dst_ptr, - Q6_Vhf_equals_Vqf16(q_lo)); // TODO: opt the store + hexagon::q6op_vstu_variable_ARV(dst_ptr, q_lo); } } } @@ -488,26 +497,24 @@ void dequantize_row_q4_K(const void * src, hexagon::dequant_target_type * dst, s } } -template struct dot_func_traits {}; - -template struct dot_func_traits { - using param_type = std::remove_const_t>; -}; - -template float wrap_dot_func(const void * src0, const void * src1, size_t count) { - using param_type = typename dot_func_traits::param_type; +void copy_row_f16(const void * src, hexagon::dequant_target_type * dst, size_t count) { + hexagon::vec_cpy_f16(reinterpret_cast(src), dst, count); +} - auto * src0_typed = reinterpret_cast(src0); - auto * src1_typed = reinterpret_cast(src1); - return _DotFunc(src0_typed, src1_typed, count); +void copy_row_f32(const void * src, hexagon::dequant_target_type * dst, size_t count) { + hexagon::vec_cpy_f32(reinterpret_cast(src), reinterpret_cast(dst), count); } constexpr const hexagon::device_type_traits kDeviceTypeTraits[] = { - { NPU_DATA_TYPE_F32, "F32", 1, sizeof(float), false, nullptr, nullptr, - wrap_dot_func }, - { NPU_DATA_TYPE_F16, "F16", 1, sizeof(npu_device_fp16_t), false, nullptr, quantize_row_fp16, - wrap_dot_func }, - { NPU_DATA_TYPE_I32, "I32", 1, sizeof(int32_t), false, nullptr, nullptr, nullptr }, + { NPU_DATA_TYPE_F32, "F32", 1, sizeof(float), false, copy_row_f32, nullptr, + hexagon::type_erase_dot_func, + hexagon::type_erase_dot_func, + hexagon::type_erase_dot_func }, + { NPU_DATA_TYPE_F16, "F16", 1, sizeof(npu_device_fp16_t), false, copy_row_f16, quantize_row_fp16, + hexagon::type_erase_dot_func, + hexagon::type_erase_dot_func, + hexagon::type_erase_dot_func }, + { NPU_DATA_TYPE_I32, "I32", 1, sizeof(int32_t), false }, { NPU_DATA_TYPE_Q8_0, "Q8_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q8_0), true, dequantize_row_q8_0, quantize_row_q8_0 }, { NPU_DATA_TYPE_Q4_0, "Q4_0", QUANT_BLOCK_SIZE, sizeof(npu_device_block_q4_0), true, dequantize_row_q4_0, @@ -552,4 +559,14 @@ const device_type_traits & get_type_traits(npu_device_tensor_data_type type) { return kDeviceTypeTraits[type]; } +size_t get_dequantized_row_size(const tensor * tensor) { + if (!is_quantized_type(tensor->get_type())) { + return tensor->get_nb(1); // for f32 and f16 + } + + auto row_elems_count = tensor->get_ne(0); + return hexagon::get_aligned_size( + row_elems_count * sizeof(dequant_target_type)); // dequant_target_type is currently restricted to f32 +} + } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/type_traits.hpp b/ggml/src/ggml-qnn/npu/device/type_traits.hpp index aa6e7d11ed500..645101a676c60 100644 --- a/ggml/src/ggml-qnn/npu/device/type_traits.hpp +++ b/ggml/src/ggml-qnn/npu/device/type_traits.hpp @@ -12,6 +12,7 @@ bool init_f16_f32_table(float * table, size_t count); typedef void (*quantize_row_type)(const float * src, void * dst, size_t count); typedef void (*dequantize_row_type)(const void * src, dequant_target_type * dst, size_t count); typedef float (*vec_dot_type)(const void * src0, const void * src1, size_t count); +typedef bool (*can_use_aligned_vec_dot_type)(const void * src0, const void * src1, size_t count); struct device_type_traits { npu_device_tensor_data_type type; @@ -20,9 +21,11 @@ struct device_type_traits { size_t type_size; bool is_quantized; - dequantize_row_type to_float; - quantize_row_type from_float; - vec_dot_type vec_dot; + dequantize_row_type to_float; + quantize_row_type from_float; + vec_dot_type vec_dot; + vec_dot_type vec_dot_aligned; + can_use_aligned_vec_dot_type can_use_aligned_vec_dot; }; const device_type_traits & get_type_traits(npu_device_tensor_data_type type); @@ -31,14 +34,7 @@ inline bool is_quantized_type(npu_device_tensor_data_type type) { return get_type_traits(type).is_quantized; } -inline size_t get_dequantized_row_size(const tensor * tensor) { - if (!is_quantized_type(tensor->get_type())) { - return tensor->get_nb(1); // for f32 and f16 - } - - auto row_elems_count = tensor->get_ne(0); - return row_elems_count * sizeof(dequant_target_type); // currently only f32 is supported -} +size_t get_dequantized_row_size(const tensor * tensor); inline const char * get_type_name(npu_device_tensor_data_type type) { return get_type_traits(type).type_name; diff --git a/ggml/src/ggml-qnn/npu/device/util.hpp b/ggml/src/ggml-qnn/npu/device/util.hpp index 86da92b9a3130..4fdcc786bacc3 100644 --- a/ggml/src/ggml-qnn/npu/device/util.hpp +++ b/ggml/src/ggml-qnn/npu/device/util.hpp @@ -344,8 +344,10 @@ inline auto make_scoped_perf_timer(const char * format, ...) { } // namespace hexagon #ifdef GGML_HEXAGON_ENABLE_PERFORMANCE_TRACKING +# define _MAKE_VARIABLE_NAME2(name, postfix) name##postfix +# define _MAKE_VARIABLE_NAME(name, postfix) _MAKE_VARIABLE_NAME2(name, postfix) # define DEVICE_SCOPED_PERFORMANCE_TRACKER(fmt, ...) \ - auto __npu_timer_##__LINE__ = hexagon::make_scoped_perf_timer(fmt, __VA_ARGS__) + auto _MAKE_VARIABLE_NAME(__npu_timer_, __LINE__) = hexagon::make_scoped_perf_timer(fmt, __VA_ARGS__) #else # define DEVICE_SCOPED_PERFORMANCE_TRACKER(fmt, ...) ((void) 0) #endif diff --git a/ggml/src/ggml-qnn/npu/device/vec_ops.cpp b/ggml/src/ggml-qnn/npu/device/vec_ops.cpp deleted file mode 100644 index 4375bb7d5b7ae..0000000000000 --- a/ggml/src/ggml-qnn/npu/device/vec_ops.cpp +++ /dev/null @@ -1,321 +0,0 @@ -#include "vec_ops.hpp" - -#include "util.hpp" - -namespace { - -template -inline float vec_dot_product_impl(const _TElem * src0, const _TElem * src1, size_t count) { - constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem); - - HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0); - HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector; - HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1); - HVX_Vector prev0 = *src0_vec_ptr++; - HVX_Vector prev1 = *src1_vec_ptr++; - HVX_Vector sum = Q6_V_vzero(); - HVX_Vector sum0 = Q6_V_vzero(); - HVX_Vector sum1 = Q6_V_vzero(); - - while (src0_vec_ptr_end - src0_vec_ptr > 1) { - HVX_VectorPair curr0 = reinterpret_cast(src0_vec_ptr)[0]; - HVX_VectorPair curr1 = reinterpret_cast(src1_vec_ptr)[0]; - - HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0); - HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1); - HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0); - HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1); - prev0 = Q6_V_hi_W(curr0); - prev1 = Q6_V_hi_W(curr1); - src0_vec_ptr += 2; - src1_vec_ptr += 2; - - sum0 = _AddFunc(_MpyFunc(l0, l1), sum0); - sum1 = _AddFunc(_MpyFunc(h0, h1), sum1); - } - - sum = _AddFunc(sum0, sum1); - if (src0_vec_ptr_end - src0_vec_ptr > 0) { - HVX_Vector curr0 = *src0_vec_ptr++; - HVX_Vector curr1 = *src1_vec_ptr++; - HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - prev0 = curr0; - prev1 = curr1; - - sum = _AddFunc(_MpyFunc(s0, s1), sum); - } - - const size_t leftover = count % kElementsPerVector; - if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) { - // handle the last vector - // see also: - // https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147 - // or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c - bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr); - bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(src1_vec_ptr); - HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0; - HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1; - src0_vec_ptr += should_fetch_src0 ? 1 : 0; - src1_vec_ptr += should_fetch_src1 ? 1 : 0; - HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - prev0 = curr0; - prev1 = curr1; - - sum = _AddFunc(_MpyFunc(s0, s1), sum); - } - - const size_t leftover_bytes = leftover * sizeof(_TElem); - if (leftover > 0) { - // handle the leftover elements - HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ? - *src0_vec_ptr : - prev0; - curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - - HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ? - *src1_vec_ptr : - prev1; - curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - - sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes), sum); - } - - return _ReduceFunc(sum); -} - -template -inline float vec_dot_product_aligned_impl(const _TElem * src0, const _TElem * src1, size_t count) { - constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem); - - HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0); - HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector; - HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1); - HVX_Vector sum0 = Q6_V_vzero(); - HVX_Vector sum1 = Q6_V_vzero(); - - while (src0_vec_ptr_end - src0_vec_ptr > 1) { - HVX_VectorPair curr0 = reinterpret_cast(src0_vec_ptr)[0]; - HVX_VectorPair curr1 = reinterpret_cast(src1_vec_ptr)[0]; - src0_vec_ptr += 2; - src1_vec_ptr += 2; - - sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr0), Q6_V_lo_W(curr1)), sum0); - sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr0), Q6_V_hi_W(curr1)), sum1); - } - - if (src0_vec_ptr_end - src0_vec_ptr > 0) { - HVX_Vector curr0 = src0_vec_ptr[0]; - HVX_Vector curr1 = src1_vec_ptr[0]; - - sum0 = _AddFunc(_MpyFunc(curr0, curr1), sum0); - } - - return _ReduceFunc(_AddFunc(sum0, sum1)); -} - -inline HVX_Vector vec_mpy_qf32(HVX_Vector src0, HVX_Vector src1) { - return Q6_Vqf32_vmpy_VsfVsf(src0, src1); -} - -inline HVX_Vector vec_add_qf32(HVX_Vector sum, HVX_Vector result) { - return Q6_Vqf32_vadd_Vqf32Vqf32(sum, result); -} - -inline HVX_Vector vec_mpy_qf16(HVX_Vector src0, HVX_Vector src1) { - return Q6_Vqf16_vmpy_VhfVhf(src0, src1); -} - -inline HVX_Vector vec_add_qf16(HVX_Vector sum, HVX_Vector result) { - return Q6_Vqf16_vadd_Vqf16Vqf16(sum, result); -} - -template -inline float vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) { - static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1"); - static_assert((sizeof(_TElem1) / sizeof(_TElem0)) == 2, - "Element size mismatch: _TElem1 must be twice the size of _TElem0"); - static_assert((sizeof(_TElem1) % sizeof(_TElem0)) == 0, - "Element size mismatch: _TElem1 must be a multiple of _TElem0"); - - constexpr const size_t kElementsPerVector0 = hexagon::kBytesPerVector / sizeof(_TElem0); - constexpr const size_t kElementsPerVector1 = hexagon::kBytesPerVector / sizeof(_TElem1); - - constexpr const __fp16 kOne = 1.0f; - const HVX_Vector kOneV = Q6_Vh_vsplat_R(reinterpret_cast(kOne)); - - const _TElem0 * const src0_ptr_end = src0 + count; - HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0); - HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1); - HVX_Vector * const src1_vec_ptr_end = ((HVX_Vector *) src1) + count / kElementsPerVector1; - HVX_Vector prev0 = *src0_vec_ptr++; - HVX_Vector prev1 = *src1_vec_ptr++; - HVX_Vector sum = Q6_V_vzero(); - HVX_Vector sum0 = Q6_V_vzero(); - HVX_Vector sum1 = Q6_V_vzero(); - - while (src1_vec_ptr_end - src1_vec_ptr > 1) { - HVX_Vector curr0 = src0_vec_ptr[0]; - HVX_VectorPair curr1 = reinterpret_cast(src1_vec_ptr)[0]; - - HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1); - HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1); - HVX_VectorPair s0_pair = _ExpandFunc(s0, kOneV); - prev0 = curr0; - prev1 = Q6_V_hi_W(curr1); - src0_vec_ptr++; - src1_vec_ptr += 2; - - sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(s0_pair), l1), sum0); - sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(s0_pair), h1), sum1); - } - - sum = _AddFunc(sum0, sum1); - const size_t leftover1 = count % kElementsPerVector1; - if ((src1_vec_ptr_end - ((HVX_Vector *) src1)) > 0) { - // handle the last vector - const bool should_fetch_src0 = - reinterpret_cast(hexagon::align_down(src0_vec_ptr)) < src0_ptr_end; - HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0; - src0_vec_ptr += should_fetch_src0 ? 1 : 0; - - HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - HVX_VectorPair s0_pair = _ExpandFunc(s0, kOneV); - - const bool has_remaining_src1_vector = src1_vec_ptr_end - src1_vec_ptr > 0; - if (has_remaining_src1_vector) { - HVX_Vector curr1 = *src1_vec_ptr++; - HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - prev1 = curr1; - - // should_handle_last_vector will be always true here - sum = _AddFunc(_MpyFunc(Q6_V_lo_W(s0_pair), s1), sum); - } - - bool should_fetch_src1 = leftover1 != 0 || !hexagon::is_addr_aligned(src1_vec_ptr); - HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1; - src1_vec_ptr += should_fetch_src1 ? 1 : 0; - HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - prev0 = curr0; - prev1 = curr1; - - sum = _AddFunc(_MpyFunc(has_remaining_src1_vector ? Q6_V_hi_W(s0_pair) : Q6_V_lo_W(s0_pair), s1), sum); - } - - const size_t leftover0 = count % kElementsPerVector0; - const size_t leftover_bytes1 = leftover1 * sizeof(_TElem1); - if (leftover1 > 0) { - // handle the leftover elements - HVX_Vector curr0 = - reinterpret_cast(hexagon::align_down(src0_vec_ptr)) < src0_ptr_end ? *src0_vec_ptr : prev0; - HVX_Vector curr1 = (leftover_bytes1 + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ? - *src1_vec_ptr : - prev1; - curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); - curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); - HVX_VectorPair curr0_pair = _ExpandFunc(curr0, kOneV); - - curr0 = leftover1 == leftover0 ? Q6_V_lo_W(curr0_pair) : Q6_V_hi_W(curr0_pair); - sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes1), sum); - } - - return _ReduceFunc(sum); -} - -template -inline float vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) { - static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1"); - static_assert((sizeof(_TElem1) / sizeof(_TElem0)) == 2, - "Element size mismatch: _TElem1 must be twice the size of _TElem0"); - static_assert((sizeof(_TElem1) % sizeof(_TElem0)) == 0, - "Element size mismatch: _TElem1 must be a multiple of _TElem0"); - - constexpr const size_t kElementsPerVector1 = hexagon::kBytesPerVector / sizeof(_TElem1); - - constexpr const __fp16 kOne = 1.0f; - const HVX_Vector kOneV = Q6_Vh_vsplat_R(reinterpret_cast(kOne)); - - HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0); - HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1); - HVX_Vector * const src1_vec_ptr_end = ((HVX_Vector *) src1) + count / kElementsPerVector1; - HVX_Vector sum0 = Q6_V_vzero(); - HVX_Vector sum1 = Q6_V_vzero(); - - { - HVX_Vector sum2 = Q6_V_vzero(); - HVX_Vector sum3 = Q6_V_vzero(); - - while (src1_vec_ptr_end - src1_vec_ptr > 3) { - HVX_VectorPair curr0 = reinterpret_cast(src0_vec_ptr)[0]; - HVX_VectorPair curr10 = reinterpret_cast(src1_vec_ptr)[0]; - HVX_VectorPair curr11 = reinterpret_cast(src1_vec_ptr)[1]; - - HVX_VectorPair curr00 = _ExpandFunc(Q6_V_lo_W(curr0), kOneV); - HVX_VectorPair curr01 = _ExpandFunc(Q6_V_hi_W(curr0), kOneV); - src0_vec_ptr += 2; - src1_vec_ptr += 4; - - sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr00), Q6_V_lo_W(curr10)), sum0); - sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr00), Q6_V_hi_W(curr10)), sum1); - sum2 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr01), Q6_V_lo_W(curr11)), sum2); - sum3 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr01), Q6_V_hi_W(curr11)), sum3); - } - - sum0 = _AddFunc(sum0, sum2); - sum1 = _AddFunc(sum1, sum3); - } - - if (src1_vec_ptr_end - src1_vec_ptr > 1) { - HVX_Vector curr0 = src0_vec_ptr[0]; - HVX_VectorPair curr1 = reinterpret_cast(src1_vec_ptr)[0]; - - HVX_VectorPair s0_pair = _ExpandFunc(curr0, kOneV); - - sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(s0_pair), Q6_V_lo_W(curr1)), sum0); - sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(s0_pair), Q6_V_hi_W(curr1)), sum1); - } - - return _ReduceFunc(_AddFunc(sum0, sum1)); -} - -} // namespace - -namespace hexagon { - -float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count) { - return vec_dot_product_impl(src0, src1, count); -} - -float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count) { - return vec_dot_product_aligned_impl(src0, src1, count); -} - -float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) { - return vec_dot_product_impl(src0, src1, - count); -} - -float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) { - return vec_dot_product_aligned_impl( - src0, src1, count); -} - -float vec_dot_product_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) { - return vec_dot_product_mixed_impl(src0, src1, count); -} - -float vec_dot_product_aligned_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) { - return vec_dot_product_mix_aligned_impl(src0, src1, count); -} - -} // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/vec_ops.hpp b/ggml/src/ggml-qnn/npu/device/vec_ops.hpp index 220dc8f77c02d..051255c9b76a7 100644 --- a/ggml/src/ggml-qnn/npu/device/vec_ops.hpp +++ b/ggml/src/ggml-qnn/npu/device/vec_ops.hpp @@ -31,6 +31,10 @@ inline bool is_addr_aligned(const void * addr) { return unaligned_bytes(addr) == 0; } +inline bool is_size_aligned(size_t size) { + return (size & kAlignMask) == 0; +} + inline float get_flt0_from_fltv(HVX_Vector vect) { static_assert(sizeof(vect[0]) == sizeof(float), "vect[0] should be a float"); int32_t i = vect[0]; @@ -157,31 +161,25 @@ inline HVX_VectorPair hvx_vqf32_convert_vhf(HVX_Vector vxl) { return qhmath_hvx_vqf32_convert_vqf16(qhmath_hvx_vqf16_convert_vhf(vxl)); } -inline HVX_VectorPair hvx_vsf_convert_vhf(HVX_Vector vxl, HVX_Vector one) { - HVX_VectorPair res = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vxl), one); - HVX_Vector vxl_w = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(res)); - HVX_Vector vxh_w = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(res)); - return Q6_W_vcombine_VV(vxh_w, vxl_w); +using HVX_Vector_Dual = std::pair; + +inline HVX_Vector_Dual hvx_vsf_convert_vhf(HVX_Vector vxl, HVX_Vector one) { + HVX_VectorPair res = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vxl), one); + return { + Q6_Vsf_equals_Vqf32(Q6_V_lo_W(res)), + Q6_Vsf_equals_Vqf32(Q6_V_hi_W(res)), + }; } inline HVX_Vector vec_reduction_qf32(HVX_Vector sums) { constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float); - static_assert(kFloatsPerVector == 32 || kFloatsPerVector == 16, "kFloatsPerVector should be 16 or 32"); - - // TODO: do we have a better way to do the reduction? - switch (kFloatsPerVector) { - default: - case 32: - sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 16 * sizeof(float))); - // fallthrough - case 16: - sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 8 * sizeof(float))); - sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 4 * sizeof(float))); - sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 2 * sizeof(float))); - sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, sizeof(float))); - break; - } + static_assert(kFloatsPerVector == 32, "kFloatsPerVector should be 32"); + sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 16 * sizeof(float))); + sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 8 * sizeof(float))); + sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 4 * sizeof(float))); + sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, 2 * sizeof(float))); + sums = Q6_Vqf32_vadd_Vqf32Vqf32(sums, Q6_V_vror_VR(sums, sizeof(float))); return sums; } @@ -191,23 +189,14 @@ inline float vec_reduction_f32_qf32(HVX_Vector sums) { inline HVX_Vector vec_reduction_qf16(HVX_Vector sums) { constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(npu_device_fp16_t); - static_assert(kFloatsPerVector == 64 || kFloatsPerVector == 32, "kFloatsPerVector should be 32 or 64"); - - // TODO: do we have a better way to do the reduction? - switch (kFloatsPerVector) { - default: - case 64: - sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 32 * sizeof(npu_device_fp16_t))); - // fallthrough - case 32: - sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 16 * sizeof(npu_device_fp16_t))); - sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 8 * sizeof(npu_device_fp16_t))); - sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 4 * sizeof(npu_device_fp16_t))); - sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 2 * sizeof(npu_device_fp16_t))); - sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, sizeof(npu_device_fp16_t))); - break; - } - + static_assert(kFloatsPerVector == 64, "kFloatsPerVector should be 64"); + + sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 32 * sizeof(npu_device_fp16_t))); + sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 16 * sizeof(npu_device_fp16_t))); + sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 8 * sizeof(npu_device_fp16_t))); + sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 4 * sizeof(npu_device_fp16_t))); + sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, 2 * sizeof(npu_device_fp16_t))); + sums = Q6_Vqf16_vadd_Vqf16Vqf16(sums, Q6_V_vror_VR(sums, sizeof(npu_device_fp16_t))); return sums; } @@ -221,62 +210,6 @@ inline HVX_Vector hvx_scale_f32(float scale) { return Q6_V_vsplat_R(reinterpret_cast(scale)); } -template -inline void vec_scale_impl(const _TParam * src, float scale, _TParam * dst, size_t count) { - constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TParam); - - HVX_Vector * src_vec_ptr = ((HVX_Vector *) src); - HVX_Vector * const src_vec_end = ((HVX_Vector *) src) + (count / kElementsPerVector); - HVX_UVector * dst_vec_ptr = ((HVX_UVector *) dst); // TODO: opt the unaligned case? - HVX_Vector prev = *src_vec_ptr++; - const size_t leftover = count % kElementsPerVector; - const size_t leftover_bytes = leftover * sizeof(_TParam); - - HVX_Vector scale_vec = _FuncScaleConvert(scale); - - while (src_vec_end - src_vec_ptr > 1) { - HVX_VectorPair curr = reinterpret_cast(src_vec_ptr)[0]; - src_vec_ptr += 2; - - HVX_Vector lo = Q6_V_valign_VVR(Q6_V_lo_W(curr), prev, (size_t) src); - HVX_Vector hi = Q6_V_valign_VVR(Q6_V_hi_W(curr), Q6_V_lo_W(curr), (size_t) src); - - dst_vec_ptr[0] = _Func(lo, dst_vec_ptr, scale_vec); - dst_vec_ptr[1] = _Func(hi, dst_vec_ptr + 1, scale_vec); - - dst_vec_ptr += 2; - prev = Q6_V_hi_W(curr); - } - - if (src_vec_end - src_vec_ptr > 0) { - HVX_Vector curr = *src_vec_ptr++; - HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src); - dst_vec_ptr[0] = _Func(s0, dst_vec_ptr, scale_vec); - dst_vec_ptr++; - prev = curr; - } - - if ((src_vec_end - ((HVX_Vector *) src)) > 0) { - // handle the last vector - bool should_fetch_next = leftover == 0 && hexagon::is_addr_aligned(src_vec_ptr); - HVX_Vector curr = should_fetch_next ? prev : *src_vec_ptr; - src_vec_ptr = should_fetch_next ? src_vec_ptr : src_vec_ptr + 1; - HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src); - dst_vec_ptr[0] = _Func(s0, dst_vec_ptr, scale_vec); - dst_vec_ptr++; - prev = curr; - } - - if (leftover > 0) { - // handle the leftover elements - HVX_Vector curr = - (leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev; - curr = Q6_V_valign_VVR(curr, prev, (size_t) src); - q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _Func(curr, dst_vec_ptr, scale_vec)); - } -} - inline HVX_Vector hvx_vec_scale_f32_f32(HVX_Vector src, HVX_UVector *, HVX_Vector scale_vec) { return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(src, scale_vec)); } @@ -288,14 +221,6 @@ inline HVX_Vector hvx_vec_mad_f32_f32(HVX_Vector src, HVX_UVector * dst_ptr, HVX return Q6_Vsf_equals_Vqf32(src); } -inline void vec_scale_f32(const float * src, float scale, float * dst, size_t count) { - vec_scale_impl(src, scale, dst, count); -} - -inline void vec_mad_f32(const float * src, float scale, float * dst, size_t count) { - vec_scale_impl(src, scale, dst, count); -} - inline HVX_Vector hvx_scale_f16(float scale) { __fp16 f16_scale = scale; return Q6_Vh_vsplat_R(reinterpret_cast(f16_scale)); @@ -312,19 +237,65 @@ inline HVX_Vector hvx_vec_mad_f16_f16(HVX_Vector src, HVX_UVector * dst_ptr, HVX return Q6_Vhf_equals_Vqf16(result); } +inline HVX_Vector hvx_nop(float scale) { + return HVX_Vector(); +} + +inline HVX_Vector hvx_passthru(HVX_Vector src, HVX_UVector *, HVX_Vector) { + return src; +} + +} // namespace hexagon + +#include "vec_ops.inl" + +namespace hexagon { + +inline void vec_scale_f32(const float * src, float scale, float * dst, size_t count) { + using namespace hexagon::vec; + vec_scale_impl(src, scale, dst, count); +} + +inline void vec_mad_f32(const float * src, float scale, float * dst, size_t count) { + using namespace hexagon::vec; + vec_scale_impl(src, scale, dst, count); +} + +inline void vec_cpy_f32(const float * src, float * dst, size_t count) { + using namespace hexagon::vec; + vec_scale_impl(src, 0, dst, count); +} + +inline void vec_zero_f32(float * src, size_t count) { + using namespace hexagon::vec; + vec_zero_impl(src, count); +} + inline void vec_scale_f16(const npu_device_fp16_t * src, float scale, npu_device_fp16_t * dst, size_t count) { + using namespace hexagon::vec; vec_scale_impl(src, scale, dst, count); } inline void vec_mad_f16(const npu_device_fp16_t * src, float scale, npu_device_fp16_t * dst, size_t count) { + using namespace hexagon::vec; vec_scale_impl(src, scale, dst, count); } +inline void vec_cpy_f16(const npu_device_fp16_t * src, npu_device_fp16_t * dst, size_t count) { + using namespace hexagon::vec; + vec_scale_impl(src, 0, dst, count); +} + +inline void vec_zero_f16(npu_device_fp16_t * src, size_t count) { + using namespace hexagon::vec; + vec_zero_impl(src, count); +} + template inline bool is_dot_product_aligned(const _TElem0 * src0, const _TElem1 * src1, size_t count) { static_assert(sizeof(_TElem0) <= sizeof(_TElem1), "src0 should be smaller than src1"); - if (!hexagon::is_addr_aligned(src0) || !hexagon::is_addr_aligned(src1)) { + if ((src0 && !hexagon::is_addr_aligned(src0)) || (src1 && !hexagon::is_addr_aligned(src1))) { return false; } @@ -335,26 +306,107 @@ inline bool is_dot_product_aligned(const _TElem0 * src0, const _TElem1 * src1, s return true; } -float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count); -float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count); +inline HVX_Vector vec_dot_product_vqf32_f32_f32(const float * src0, const float * src1, size_t count) { + using namespace hexagon::vec; + return vec_dot_product_impl(src0, src1, count); +} + +inline HVX_Vector vec_dot_product_aligned_vqf32_f32_f32(const float * src0, const float * src1, size_t count) { + using namespace hexagon::vec; + return vec_dot_product_aligned_impl(src0, src1, + count); +} + +inline float vec_dot_product_f32_f32(const float * src0, const float * src1, size_t count) { + using namespace hexagon::vec; + return vec_dot_product_impl(src0, src1, count); +} + +inline float vec_dot_product_aligned_f32_f32(const float * src0, const float * src1, size_t count) { + using namespace hexagon::vec; + return vec_dot_product_aligned_impl(src0, src1, + count); +} inline bool is_f32_f32_dot_product_aligned(const float * src0, const float * src1, size_t count) { return is_dot_product_aligned(src0, src1, count); } -float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count); -float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count); +inline HVX_Vector vec_dot_product_vqf16_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, + size_t count) { + using namespace hexagon::vec; + return vec_dot_product_impl( + src0, src1, count); +} + +inline HVX_Vector vec_dot_product_aligned_vqf16_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, + size_t count) { + using namespace hexagon::vec; + return vec_dot_product_aligned_impl( + src0, src1, count); +} + +inline float vec_dot_product_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) { + using namespace hexagon::vec; + return vec_dot_product_impl( + src0, src1, count); +} + +inline float vec_dot_product_aligned_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, + size_t count) { + using namespace hexagon::vec; + return vec_dot_product_aligned_impl( + src0, src1, count); +} inline bool is_f16_f16_dot_product_aligned(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count) { return is_dot_product_aligned(src0, src1, count); } -float vec_dot_product_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count); -float vec_dot_product_aligned_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count); +inline HVX_Vector vec_dot_product_vqf32_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) { + using namespace hexagon::vec; + return vec_dot_product_mixed_impl(src0, src1, count); +} + +inline HVX_Vector vec_dot_product_aligned_vqf32_f16_f32(const npu_device_fp16_t * src0, const float * src1, + size_t count) { + using namespace hexagon::vec; + return vec_dot_product_mix_aligned_impl(src0, src1, count); +} + +inline float vec_dot_product_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) { + using namespace hexagon::vec; + return vec_dot_product_mixed_impl(src0, src1, count); +} + +inline float vec_dot_product_aligned_f16_f32(const npu_device_fp16_t * src0, const float * src1, size_t count) { + using namespace hexagon::vec; + return vec_dot_product_mix_aligned_impl(src0, src1, count); +} inline bool is_f16_f32_dot_product_aligned(const npu_device_fp16_t * src0, const float * src1, size_t count) { return is_dot_product_aligned(src0, src1, count); } +template struct dot_func_traits {}; + +template struct dot_func_traits<_TReturn (*)(_TData, _TData, size_t)> { + using param_type = std::remove_const_t>; + using return_type = _TReturn; +}; + +template ::return_type> +_TReturn type_erase_dot_func(const void * src0, const void * src1, size_t count) { + using param_type = typename dot_func_traits::param_type; + + auto * src0_typed = reinterpret_cast(src0); + auto * src1_typed = reinterpret_cast(src1); + return _DotFunc(src0_typed, src1_typed, count); +} + } // namespace hexagon diff --git a/ggml/src/ggml-qnn/npu/device/vec_ops.inl b/ggml/src/ggml-qnn/npu/device/vec_ops.inl new file mode 100644 index 0000000000000..f21d6b06d9e46 --- /dev/null +++ b/ggml/src/ggml-qnn/npu/device/vec_ops.inl @@ -0,0 +1,499 @@ +#pragma once + +#include + +#include + +#include "hexagon_npu.h" + +namespace hexagon::vec { + +template +inline _TRet vec_dot_product_impl(const _TElem * src0, const _TElem * src1, size_t count) { + constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem); + + HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0); + HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector; + HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1); + HVX_Vector prev0 = *src0_vec_ptr++; + HVX_Vector prev1 = *src1_vec_ptr++; + HVX_Vector sum = Q6_V_vzero(); + + if (src0_vec_ptr_end - src0_vec_ptr > 1) { + HVX_Vector sum0 = Q6_V_vzero(); + HVX_Vector sum1 = Q6_V_vzero(); + + do { + HVX_VectorPair curr0 = reinterpret_cast(src0_vec_ptr)[0]; + HVX_VectorPair curr1 = reinterpret_cast(src1_vec_ptr)[0]; + + HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0); + HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1); + sum0 = _AddFunc(_MpyFunc(l0, l1), sum0); + + HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0); + HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1); + sum1 = _AddFunc(_MpyFunc(h0, h1), sum1); + + prev0 = Q6_V_hi_W(curr0); + prev1 = Q6_V_hi_W(curr1); + src0_vec_ptr += 2; + src1_vec_ptr += 2; + } while (src0_vec_ptr_end - src0_vec_ptr > 1); + + sum = _AddFunc(sum0, sum1); + } + + if (src0_vec_ptr_end - src0_vec_ptr > 0) { + HVX_Vector curr0 = *src0_vec_ptr++; + HVX_Vector curr1 = *src1_vec_ptr++; + HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + prev0 = curr0; + prev1 = curr1; + + sum = _AddFunc(_MpyFunc(s0, s1), sum); + } + + const size_t leftover = count % kElementsPerVector; + if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) { + // handle the last vector + // see also: + // https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147 + // or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c + bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr); + bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(src1_vec_ptr); + HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0; + HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1; + src0_vec_ptr += should_fetch_src0 ? 1 : 0; + src1_vec_ptr += should_fetch_src1 ? 1 : 0; + HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + prev0 = curr0; + prev1 = curr1; + + sum = _AddFunc(_MpyFunc(s0, s1), sum); + } + + if (leftover > 0) { + // handle the leftover elements + const size_t leftover_bytes = leftover * sizeof(_TElem); + HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ? + *src0_vec_ptr : + prev0; + curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + + HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ? + *src1_vec_ptr : + prev1; + curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + + sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes), sum); + } + + return _ReduceFunc(sum); +} + +template +inline _TRet vec_dot_product_aligned_impl(const _TElem * src0, const _TElem * src1, size_t count) { + constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TElem); + + HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0); + HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector; + HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1); + HVX_Vector sum = Q6_V_vzero(); + + { + HVX_Vector sum0 = Q6_V_vzero(); + HVX_Vector sum1 = Q6_V_vzero(); + if (src0_vec_ptr_end - src0_vec_ptr > 3) { + HVX_Vector sum2 = Q6_V_vzero(); + HVX_Vector sum3 = Q6_V_vzero(); + + do { + HVX_VectorPair curr00 = reinterpret_cast(src0_vec_ptr)[0]; + HVX_VectorPair curr10 = reinterpret_cast(src1_vec_ptr)[0]; + sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr00), Q6_V_lo_W(curr10)), sum0); + sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr00), Q6_V_hi_W(curr10)), sum1); + + HVX_VectorPair curr01 = reinterpret_cast(src0_vec_ptr)[1]; + HVX_VectorPair curr11 = reinterpret_cast(src1_vec_ptr)[1]; + sum2 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr01), Q6_V_lo_W(curr11)), sum2); + sum3 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr01), Q6_V_hi_W(curr11)), sum3); + + src0_vec_ptr += 4; + src1_vec_ptr += 4; + } while (src0_vec_ptr_end - src0_vec_ptr > 3); + + sum0 = _AddFunc(sum2, sum0); + sum1 = _AddFunc(sum3, sum1); + } + + if (src0_vec_ptr_end - src0_vec_ptr > 1) { + HVX_VectorPair curr0 = reinterpret_cast(src0_vec_ptr)[0]; + HVX_VectorPair curr1 = reinterpret_cast(src1_vec_ptr)[0]; + src0_vec_ptr += 2; + src1_vec_ptr += 2; + + sum0 = _AddFunc(_MpyFunc(Q6_V_lo_W(curr0), Q6_V_lo_W(curr1)), sum0); + sum1 = _AddFunc(_MpyFunc(Q6_V_hi_W(curr0), Q6_V_hi_W(curr1)), sum1); + } + + sum = _AddFunc(sum0, sum1); + } + + if (src0_vec_ptr_end - src0_vec_ptr > 0) { + HVX_Vector curr0 = src0_vec_ptr[0]; + HVX_Vector curr1 = src1_vec_ptr[0]; + + sum = _AddFunc(_MpyFunc(curr0, curr1), sum); + } + + return _ReduceFunc(sum); +} + +inline HVX_Vector vec_mpy_qf32(HVX_Vector src0, HVX_Vector src1) { + return Q6_Vqf32_vmpy_VsfVsf(src0, src1); +} + +inline HVX_Vector vec_add_qf32(HVX_Vector sum, HVX_Vector result) { + return Q6_Vqf32_vadd_Vqf32Vqf32(sum, result); +} + +inline HVX_Vector vec_mpy_qf16(HVX_Vector src0, HVX_Vector src1) { + return Q6_Vqf16_vmpy_VhfVhf(src0, src1); +} + +inline HVX_Vector vec_add_qf16(HVX_Vector sum, HVX_Vector result) { + return Q6_Vqf16_vadd_Vqf16Vqf16(sum, result); +} + +template +inline _TRet vec_dot_product_mixed_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) { + static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1"); + static_assert((sizeof(_TElem1) / sizeof(_TElem0)) == 2, + "Element size mismatch: _TElem1 must be twice the size of _TElem0"); + static_assert((sizeof(_TElem1) % sizeof(_TElem0)) == 0, + "Element size mismatch: _TElem1 must be a multiple of _TElem0"); + + constexpr const size_t kElementsPerVector0 = hexagon::kBytesPerVector / sizeof(_TElem0); + constexpr const size_t kElementsPerVector1 = hexagon::kBytesPerVector / sizeof(_TElem1); + + constexpr const __fp16 kOne = 1.0f; + const HVX_Vector kOneV = Q6_Vh_vsplat_R(reinterpret_cast(kOne)); + + const _TElem0 * const src0_ptr_end = src0 + count; + HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0); + HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1); + HVX_Vector * const src1_vec_ptr_end = ((HVX_Vector *) src1) + count / kElementsPerVector1; + HVX_Vector prev0 = *src0_vec_ptr++; + HVX_Vector prev1 = *src1_vec_ptr++; + HVX_Vector sum = Q6_V_vzero(); + + if (src1_vec_ptr_end - src1_vec_ptr > 1) { + HVX_Vector sum0 = Q6_V_vzero(); + HVX_Vector sum1 = Q6_V_vzero(); + + do { + HVX_Vector curr0 = src0_vec_ptr[0]; + HVX_VectorPair curr1 = reinterpret_cast(src1_vec_ptr)[0]; + + HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + HVX_Vector_Dual s0_pair = _ExpandFunc(s0, kOneV); + + HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1); + sum0 = _AddFunc(_MpyFunc(s0_pair.first, l1), sum0); + + HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1); + sum1 = _AddFunc(_MpyFunc(s0_pair.second, h1), sum1); + + prev0 = curr0; + prev1 = Q6_V_hi_W(curr1); + src0_vec_ptr++; + src1_vec_ptr += 2; + } while (src1_vec_ptr_end - src1_vec_ptr > 1); + + sum = _AddFunc(sum0, sum1); + } + + const size_t leftover1 = count % kElementsPerVector1; + if ((src1_vec_ptr_end - ((HVX_Vector *) src1)) > 0) { + // handle the last vector + const bool should_fetch_src0 = + reinterpret_cast(hexagon::align_down(src0_vec_ptr)) < src0_ptr_end; + HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0; + src0_vec_ptr += should_fetch_src0 ? 1 : 0; + + HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + HVX_Vector_Dual s0_pair = _ExpandFunc(s0, kOneV); + + const bool has_remaining_src1_vector = src1_vec_ptr_end - src1_vec_ptr > 0; + if (has_remaining_src1_vector) { + HVX_Vector curr1 = *src1_vec_ptr++; + HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + sum = _AddFunc(_MpyFunc(s0_pair.first, s1), sum); + prev1 = curr1; + } + + bool should_fetch_src1 = leftover1 != 0 || !hexagon::is_addr_aligned(src1_vec_ptr); + HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1; + src1_vec_ptr += should_fetch_src1 ? 1 : 0; + HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + prev0 = curr0; + prev1 = curr1; + + sum = _AddFunc(_MpyFunc(has_remaining_src1_vector ? s0_pair.second : s0_pair.first, s1), sum); + } + + if (leftover1 > 0) { + // handle the leftover elements + const size_t leftover0 = count % kElementsPerVector0; + const size_t leftover_bytes1 = leftover1 * sizeof(_TElem1); + HVX_Vector curr0 = + reinterpret_cast(hexagon::align_down(src0_vec_ptr)) < src0_ptr_end ? *src0_vec_ptr : prev0; + curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + + HVX_Vector curr1 = (leftover_bytes1 + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ? + *src1_vec_ptr : + prev1; + curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + + HVX_Vector_Dual curr0_pair = _ExpandFunc(curr0, kOneV); + + curr0 = leftover1 == leftover0 ? curr0_pair.first : curr0_pair.second; + sum = _AddFunc(Q6_V_valign_VVR(_MpyFunc(curr0, curr1), Q6_V_vzero(), leftover_bytes1), sum); + } + + return _ReduceFunc(sum); +} + +template +inline _TRet vec_dot_product_mix_aligned_impl(const _TElem0 * src0, const _TElem1 * src1, size_t count) { + static_assert(sizeof(_TElem0) < sizeof(_TElem1), "Element size mismatch: _TElem0 must be smaller than _TElem1"); + static_assert((sizeof(_TElem1) / sizeof(_TElem0)) == 2, + "Element size mismatch: _TElem1 must be twice the size of _TElem0"); + static_assert((sizeof(_TElem1) % sizeof(_TElem0)) == 0, + "Element size mismatch: _TElem1 must be a multiple of _TElem0"); + + constexpr const size_t kElementsPerVector1 = hexagon::kBytesPerVector / sizeof(_TElem1); + + constexpr const __fp16 kOne = 1.0f; + const HVX_Vector kOneV = Q6_Vh_vsplat_R(reinterpret_cast(kOne)); + + HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0); + HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1); + HVX_Vector * const src1_vec_ptr_end = ((HVX_Vector *) src1) + count / kElementsPerVector1; + HVX_Vector sum0 = Q6_V_vzero(); + HVX_Vector sum1 = Q6_V_vzero(); + + if (src1_vec_ptr_end - src1_vec_ptr > 3) { + HVX_Vector sum2 = Q6_V_vzero(); + HVX_Vector sum3 = Q6_V_vzero(); + + do { + HVX_VectorPair curr0 = reinterpret_cast(src0_vec_ptr)[0]; + HVX_Vector_Dual curr00 = _ExpandFunc(Q6_V_lo_W(curr0), kOneV); + HVX_VectorPair curr10 = reinterpret_cast(src1_vec_ptr)[0]; + sum0 = _AddFunc(_MpyFunc(curr00.first, Q6_V_lo_W(curr10)), sum0); + sum1 = _AddFunc(_MpyFunc(curr00.second, Q6_V_hi_W(curr10)), sum1); + + HVX_Vector_Dual curr01 = _ExpandFunc(Q6_V_hi_W(curr0), kOneV); + HVX_VectorPair curr11 = reinterpret_cast(src1_vec_ptr)[1]; + sum2 = _AddFunc(_MpyFunc(curr01.first, Q6_V_lo_W(curr11)), sum2); + sum3 = _AddFunc(_MpyFunc(curr01.second, Q6_V_hi_W(curr11)), sum3); + + src0_vec_ptr += 2; + src1_vec_ptr += 4; + } while (src1_vec_ptr_end - src1_vec_ptr > 3); + + sum0 = _AddFunc(sum0, sum2); + sum1 = _AddFunc(sum1, sum3); + } + + if (src1_vec_ptr_end - src1_vec_ptr > 1) { + HVX_Vector curr0 = src0_vec_ptr[0]; + HVX_Vector_Dual s0_pair = _ExpandFunc(curr0, kOneV); + + HVX_VectorPair curr1 = reinterpret_cast(src1_vec_ptr)[0]; + sum0 = _AddFunc(_MpyFunc(s0_pair.first, Q6_V_lo_W(curr1)), sum0); + sum1 = _AddFunc(_MpyFunc(s0_pair.second, Q6_V_hi_W(curr1)), sum1); + } + + return _ReduceFunc(_AddFunc(sum0, sum1)); +} + +template +inline void vec_scale_impl(const _TParam * src, float scale, _TParam * dst, size_t count) { + constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TParam); + + HVX_Vector * src_vec_ptr = ((HVX_Vector *) src); + HVX_Vector * const src_vec_end = ((HVX_Vector *) src) + (count / kElementsPerVector); + HVX_UVector * dst_vec_ptr = ((HVX_UVector *) dst); // TODO: opt the unaligned case? + HVX_Vector prev = *src_vec_ptr++; + const size_t leftover = count % kElementsPerVector; + + HVX_Vector scale_vec = _FuncScaleConvert(scale); + + while (src_vec_end - src_vec_ptr > 1) { + HVX_VectorPair curr = reinterpret_cast(src_vec_ptr)[0]; + src_vec_ptr += 2; + + HVX_Vector lo = Q6_V_valign_VVR(Q6_V_lo_W(curr), prev, (size_t) src); + dst_vec_ptr[0] = _Func(lo, dst_vec_ptr, scale_vec); + + HVX_Vector hi = Q6_V_valign_VVR(Q6_V_hi_W(curr), Q6_V_lo_W(curr), (size_t) src); + dst_vec_ptr[1] = _Func(hi, dst_vec_ptr + 1, scale_vec); + + dst_vec_ptr += 2; + prev = Q6_V_hi_W(curr); + } + + if (src_vec_end - src_vec_ptr > 0) { + HVX_Vector curr = *src_vec_ptr++; + HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src); + dst_vec_ptr[0] = _Func(s0, dst_vec_ptr, scale_vec); + dst_vec_ptr++; + prev = curr; + } + + if ((src_vec_end - ((HVX_Vector *) src)) > 0) { + // handle the last vector + bool should_fetch_next = leftover == 0 && hexagon::is_addr_aligned(src_vec_ptr); + HVX_Vector curr = should_fetch_next ? prev : *src_vec_ptr; + src_vec_ptr = should_fetch_next ? src_vec_ptr : src_vec_ptr + 1; + HVX_Vector s0 = Q6_V_valign_VVR(curr, prev, (size_t) src); + dst_vec_ptr[0] = _Func(s0, dst_vec_ptr, scale_vec); + dst_vec_ptr++; + prev = curr; + } + + if (leftover > 0) { + // handle the leftover elements + const size_t leftover_bytes = leftover * sizeof(_TParam); + HVX_Vector curr = + (leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev; + curr = Q6_V_valign_VVR(curr, prev, (size_t) src); + q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _Func(curr, dst_vec_ptr, scale_vec)); + } +} + +template inline void vec_zero_impl(_TData * src, size_t count) { + constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TData); + + HVX_UVector * src_vec_ptr = ((HVX_UVector *) src); + HVX_UVector * const src_vec_end = ((HVX_UVector *) src) + (count / kElementsPerVector); + + while (src_vec_end - src_vec_ptr > 1) { + src_vec_ptr[0] = Q6_V_vzero(); + src_vec_ptr[1] = Q6_V_vzero(); + src_vec_ptr += 2; + } + + if (src_vec_end - src_vec_ptr > 0) { + src_vec_ptr[0] = Q6_V_vzero(); + src_vec_ptr++; + } + + const size_t leftover = count % kElementsPerVector; + if (leftover > 0) { + // handle the leftover elements + const size_t leftover_bytes = leftover * sizeof(_TData); + q6op_vstu_variable_ARV(src_vec_ptr, leftover_bytes, Q6_V_vzero()); + } +} + +template +inline void vec_trans_op_impl(const _TyData * src0, const _TyData * src1, size_t count, _TyData * dst) { + constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData); + + HVX_Vector * src0_vec_ptr = ((HVX_Vector *) src0); + HVX_Vector * const src0_vec_ptr_end = ((HVX_Vector *) src0) + count / kElementsPerVector; + HVX_Vector * src1_vec_ptr = ((HVX_Vector *) src1); + HVX_Vector * dst_vec_ptr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned + HVX_Vector prev0 = *src0_vec_ptr++; + HVX_Vector prev1 = *src1_vec_ptr++; + + { + while (src0_vec_ptr_end - src0_vec_ptr > 1) { + HVX_VectorPair curr0 = reinterpret_cast(src0_vec_ptr)[0]; + HVX_VectorPair curr1 = reinterpret_cast(src1_vec_ptr)[0]; + + HVX_Vector l0 = Q6_V_valign_VVR(Q6_V_lo_W(curr0), prev0, (size_t) src0); + HVX_Vector l1 = Q6_V_valign_VVR(Q6_V_lo_W(curr1), prev1, (size_t) src1); + dst_vec_ptr[0] = _OpBinaryTransform(l0, l1); + + HVX_Vector h0 = Q6_V_valign_VVR(Q6_V_hi_W(curr0), Q6_V_lo_W(curr0), (size_t) src0); + HVX_Vector h1 = Q6_V_valign_VVR(Q6_V_hi_W(curr1), Q6_V_lo_W(curr1), (size_t) src1); + dst_vec_ptr[1] = _OpBinaryTransform(h0, h1); + + prev0 = Q6_V_hi_W(curr0); + prev1 = Q6_V_hi_W(curr1); + src0_vec_ptr += 2; + src1_vec_ptr += 2; + dst_vec_ptr += 2; + } + } + + if (src0_vec_ptr_end - src0_vec_ptr > 0) { + HVX_Vector curr0 = *src0_vec_ptr++; + HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + + HVX_Vector curr1 = *src1_vec_ptr++; + HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + + dst_vec_ptr[0] = _OpBinaryTransform(s0, s1); + + prev0 = curr0; + prev1 = curr1; + dst_vec_ptr++; + } + + const size_t leftover = count % kElementsPerVector; + if ((src0_vec_ptr_end - ((HVX_Vector *) src0)) > 0) { + // handle the last vector + // see also: + // https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147 + // or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c + bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(src0_vec_ptr); + bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(src1_vec_ptr); + + HVX_Vector curr0 = should_fetch_src0 ? *src0_vec_ptr : prev0; + HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + + HVX_Vector curr1 = should_fetch_src1 ? *src1_vec_ptr : prev1; + HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + + dst_vec_ptr[0] = _OpBinaryTransform(s0, s1); + + src0_vec_ptr += should_fetch_src0 ? 1 : 0; + src1_vec_ptr += should_fetch_src1 ? 1 : 0; + prev0 = curr0; + prev1 = curr1; + dst_vec_ptr++; + } + + if (leftover > 0) { + // handle the leftover elements + const size_t leftover_bytes = leftover * sizeof(_TyData); + HVX_Vector curr0 = (leftover_bytes + hexagon::unaligned_bytes(src0_vec_ptr) > hexagon::kBytesPerVector) ? + *src0_vec_ptr : + prev0; + curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0); + + HVX_Vector curr1 = (leftover_bytes + hexagon::unaligned_bytes(src1_vec_ptr) > hexagon::kBytesPerVector) ? + *src1_vec_ptr : + prev1; + curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1); + + q6op_vstu_variable_ARV(dst_vec_ptr, leftover_bytes, _OpBinaryTransform(curr0, curr1)); + } +} + +} // namespace hexagon::vec diff --git a/ggml/src/ggml-qnn/npu/host/buffer.cpp b/ggml/src/ggml-qnn/npu/host/buffer.cpp index c7482f8b590e6..3eeb611f1d712 100644 --- a/ggml/src/ggml-qnn/npu/host/buffer.cpp +++ b/ggml/src/ggml-qnn/npu/host/buffer.cpp @@ -3,6 +3,7 @@ #include #include "host_device.hpp" +#include "profiler.hpp" #include "tensor.hpp" namespace { @@ -78,6 +79,8 @@ void backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { void backend_buffer_reset(ggml_backend_buffer_t buffer) { auto * buffer_obj = get_buffer_object(buffer); GGML_ASSERT(buffer_obj != nullptr); + + SCOPED_PERFORMANCE_TRACKER("[hexagon-npu][%p]backend_buffer_reset", (void *) buffer_obj); buffer_obj->clear_tensors(); } @@ -199,8 +202,8 @@ std::shared_ptr host_buffer::init_tensor(ggml_tensor * tensor, remo } void host_buffer::clear_tensors() { - _tensors.clear(); LOG_DEBUG("clear host_buffer(%p) tensors\n", (void *) _data); + host_tensor::destroy_tensors(_tensors); } host_buffer_type::host_buffer_type(ggml_backend_dev_t dev, const std::string & name, common::rpc_mem_ptr rpc_mem) : diff --git a/ggml/src/ggml-qnn/npu/host/graph.cpp b/ggml/src/ggml-qnn/npu/host/graph.cpp index 1d40fe0dd5176..526191173dd17 100644 --- a/ggml/src/ggml-qnn/npu/host/graph.cpp +++ b/ggml/src/ggml-qnn/npu/host/graph.cpp @@ -57,10 +57,10 @@ bool host_graph::update(ggml_cgraph * cgraph) { _tensor_handles.push_back(tensor_obj->get_device_tensor_handle()); _tensor_update_configs.push_back(tensor_obj->update_hosts_params_only(node)); - PROFILER_LOG_DEBUG("node[%d]%s(%s), addr(%p), %s_%ldx%ldx%ldx%ld, handle(%p)\n", i, ggml_get_name(node), - ggml_op_desc(node), (void *) tensor_obj, ggml_type_name(node->type), - (long) tensor_obj->get_ne(0), (long) tensor_obj->get_ne(1), (long) tensor_obj->get_ne(2), - (long) tensor_obj->get_ne(3), (void *) tensor_obj->get_device_tensor_handle()); + PROFILER_LOG_DEBUG("node[%d]%s(%s), addr(%p), %ldx%ldx%ldx%ld%s, handle(%p)\n", i, ggml_get_name(node), + ggml_op_desc(node), (void *) tensor_obj, (long) tensor_obj->get_ne(0), + (long) tensor_obj->get_ne(1), (long) tensor_obj->get_ne(2), (long) tensor_obj->get_ne(3), + ggml_type_name(node->type), (void *) tensor_obj->get_device_tensor_handle()); } GGML_ASSERT(_tensor_handles.size() == _tensor_update_configs.size()); diff --git a/ggml/src/ggml-qnn/npu/host/graph.hpp b/ggml/src/ggml-qnn/npu/host/graph.hpp index b871c125563f2..0f8efe1079785 100644 --- a/ggml/src/ggml-qnn/npu/host/graph.hpp +++ b/ggml/src/ggml-qnn/npu/host/graph.hpp @@ -22,7 +22,7 @@ class host_graph { private: remote_handle64 _device_handle = 0; - npu_device_graph_handle_t _graph_handle = 0; + npu_device_graph_handle_t _graph_handle = npu_device_INVALID_DEVICE_GRAPH_HANDLE; std::vector _tensor_handles; std::vector _tensor_update_configs; diff --git a/ggml/src/ggml-qnn/npu/host/tensor.hpp b/ggml/src/ggml-qnn/npu/host/tensor.hpp index 7e8ee8f34cc09..f70526bf25dff 100644 --- a/ggml/src/ggml-qnn/npu/host/tensor.hpp +++ b/ggml/src/ggml-qnn/npu/host/tensor.hpp @@ -1,6 +1,8 @@ #pragma once +#include #include +#include #include "common.hpp" #include "ggml-impl.h" @@ -42,7 +44,7 @@ class host_tensor { auto status = npu_device_tensor_init(_device_handle, &_info, &_device_tensor_handle); if (status != AEE_SUCCESS) { LOG_ERROR("Failed to init tensor: %d", (int) status); - _device_tensor_handle = 0; + _device_tensor_handle = npu_device_INVALID_DEVICE_TENSOR_HANDLE; return; } @@ -66,6 +68,27 @@ class host_tensor { } } + static void destroy_tensors(std::list> & tensors) { + std::vector handles; + + handles.reserve(tensors.size()); + remote_handle64 device_handle = 0; + + for (auto tensor : tensors) { + if (tensor && tensor->_device_tensor_handle != npu_device_INVALID_DEVICE_TENSOR_HANDLE) { + handles.push_back(tensor->_device_tensor_handle); + tensor->_device_tensor_handle = npu_device_INVALID_DEVICE_TENSOR_HANDLE; // prevent double free + device_handle = tensor->_device_handle; + } + } + + if (!handles.empty()) { + npu_device_tensors_free(device_handle, handles.data(), handles.size()); + } + + tensors.clear(); + } + npu_device_tensor_handle_t get_device_tensor_handle() const { return _device_tensor_handle; } void update_params(ggml_tensor * ggml_tensor) { @@ -157,7 +180,7 @@ class host_tensor { return _info_update; } - bool is_valid() const { return _device_tensor_handle != 0; } + bool is_valid() const { return _device_tensor_handle != npu_device_INVALID_DEVICE_TENSOR_HANDLE; } int64_t get_ne(size_t index) const { if (index >= DEVICE_TENSOR_MAX_DIMS) { diff --git a/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl b/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl index 0aa8d8e8ab48b..513b69d88a25b 100644 --- a/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl +++ b/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl @@ -20,6 +20,9 @@ interface npu_device : remote_handle64{ typedef uint64_t tensor_handle_t; typedef uint64_t graph_handle_t; + const graph_handle_t INVALID_DEVICE_GRAPH_HANDLE = 0; + const tensor_handle_t INVALID_DEVICE_TENSOR_HANDLE = 0; + typedef uint16_t fp16_t; struct block_q4_0 { @@ -107,6 +110,10 @@ interface npu_device : remote_handle64{ in tensor_handle_t tensor_handle ); + AEEResult tensors_free( + in sequence tensor_handles + ); + AEEResult graph_init( rout graph_handle_t graph_handle );