diff --git a/ggml/src/ggml-qnn/npu/device/dma_transfer.cpp b/ggml/src/ggml-qnn/npu/device/dma_transfer.cpp index 4778566a3d1dd..27bef6c7191d9 100644 --- a/ggml/src/ggml-qnn/npu/device/dma_transfer.cpp +++ b/ggml/src/ggml-qnn/npu/device/dma_transfer.cpp @@ -12,7 +12,7 @@ dma_transfer::dma_transfer() { dma_desc_set_next(_dma_1d_desc0, 0); dma_desc_set_dstate(_dma_1d_desc0, DESC_DSTATE_INCOMPLETE); dma_desc_set_desctype(_dma_1d_desc0, DMA_DESC_TYPE_1D); - dma_desc_set_order(_dma_1d_desc0, DESC_ORDER_ORDER); + dma_desc_set_order(_dma_1d_desc0, DESC_ORDER_NOORDER); dma_desc_set_bypasssrc(_dma_1d_desc0, DESC_BYPASS_ON); // for dram dma_desc_set_bypassdst(_dma_1d_desc0, DESC_BYPASS_OFF); // for vtcm dma_desc_set_length(_dma_1d_desc0, 0); @@ -20,7 +20,7 @@ dma_transfer::dma_transfer() { dma_desc_set_next(_dma_1d_desc1, 0); dma_desc_set_dstate(_dma_1d_desc1, DESC_DSTATE_INCOMPLETE); dma_desc_set_desctype(_dma_1d_desc1, DMA_DESC_TYPE_1D); - dma_desc_set_order(_dma_1d_desc1, DESC_ORDER_ORDER); + dma_desc_set_order(_dma_1d_desc1, DESC_ORDER_NOORDER); dma_desc_set_bypasssrc(_dma_1d_desc1, DESC_BYPASS_ON); // for dram dma_desc_set_bypassdst(_dma_1d_desc1, DESC_BYPASS_OFF); // for vtcm dma_desc_set_length(_dma_1d_desc1, 0); @@ -28,7 +28,7 @@ dma_transfer::dma_transfer() { dma_desc_set_next(_dma_2d_desc0, 0); dma_desc_set_dstate(_dma_2d_desc0, DESC_DSTATE_INCOMPLETE); dma_desc_set_desctype(_dma_2d_desc0, DMA_DESC_TYPE_2D); - dma_desc_set_order(_dma_2d_desc0, DESC_ORDER_ORDER); + dma_desc_set_order(_dma_2d_desc0, DESC_ORDER_NOORDER); dma_desc_set_bypasssrc(_dma_2d_desc0, DESC_BYPASS_ON); // for dram dma_desc_set_bypassdst(_dma_2d_desc0, DESC_BYPASS_OFF); // for vtcm dma_desc_set_cachealloc(_dma_2d_desc0, DESC_CACHEALLOC_NONE); diff --git a/ggml/src/ggml-qnn/npu/device/op/op_mul_mat.cpp b/ggml/src/ggml-qnn/npu/device/op/op_mul_mat.cpp index e7fc4bcfdf646..063e8d105d7a0 100644 --- a/ggml/src/ggml-qnn/npu/device/op/op_mul_mat.cpp +++ b/ggml/src/ggml-qnn/npu/device/op/op_mul_mat.cpp @@ -20,6 +20,12 @@ template <> struct convert_vector { static float convert(HVX_Vector vec) { return hexagon::get_flt0_from_fltv(Q6_Vsf_equals_Vqf32(vec)); } }; +inline std::pair unflatten_i3_i2(int64_t idx, const hexagon::tensor * t) { + const auto i3 = idx / t->get_ne(2); + const auto i2 = idx - i3 * t->get_ne(2); + return { i3, i2 }; +} + template <> struct convert_vector { static float convert(HVX_Vector vec) { HVX_Vector vect = Q6_Vhf_equals_Vqf16(vec); @@ -28,6 +34,27 @@ template <> struct convert_vector { } }; +template +inline bool init_dma_transfer(hexagon::compute_params * params, + const uint8_t * src, + uint8_t * dst, + size_t width, + size_t height, + size_t src_stride, + size_t dst_stride) { + if constexpr (_IsQuantized) { + if (!params->initiate_dma_row_transfer(src, dst, src_stride * height)) { + return false; + } + } else { + if (!params->initiate_dma_plane_transfer(src, dst, width, height, src_stride, dst_stride)) { + return false; + } + } + + return true; +} + template inline void batched_row_dot(const uint8_t * src0_plane, const size_t src0_ne0, @@ -75,9 +102,10 @@ inline void mul_mat_impl(hexagon::tensor * src0, hexagon::tensor * src1, hexagon::tensor * dst, hexagon::compute_params * params) { + using data_type0 = typename get_data_type::data_type0; using data_type1 = typename get_data_type::data_type1; - const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0); + const auto src0_actual_row_stride = hexagon::get_dequantized_row_size(src0); auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float; auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table; if (_IsSrcQuantized && dequantize_row_func == nullptr) { @@ -113,66 +141,70 @@ inline void mul_mat_impl(hexagon::tensor * src0, const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation - // cache the src0 plane in VTCM - size_t src0_plane_slice_row_count = start_end_element.second - start_end_element.first; - size_t src0_plane_cache_size = 0; - uint8_t * src0_plane_read_cache_ptr = nullptr; - uint8_t * src0_plane_write_cache_ptr = nullptr; - const uint8_t * last_write_cached_plane_ptr = nullptr; - const uint8_t * last_read_cached_plane_ptr = nullptr; - if constexpr (_IsSrcQuantized) { - src0_plane_slice_row_count = - std::min(params->get_vtcm_quota_size() / src0_actual_row_size, src0_plane_slice_row_count); - src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count; - src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size); - if (src0_plane_read_cache_ptr == nullptr) { - DEVICE_LOG_ERROR( - "mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, " - "src0_actual_row_size: %zu, will fallback to mem cache\n", - src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_size); - return; - } - } else { - src0_plane_slice_row_count = - std::min(params->get_vtcm_quota_size() / (src0_actual_row_size * 2), src0_plane_slice_row_count); - src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count; - src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2); - if (src0_plane_read_cache_ptr == nullptr) { + // cache the src0 plane in VTCM + const size_t valid_src0_row_bytes = _IsSrcQuantized ? src0->get_nb(1) : (src0->get_ne(0) * sizeof(data_type0)); + const size_t src1_actual_row_stride = hexagon::get_aligned_size(src1->get_nb(1)); + + // TODO: figure out why we have to add padding after src0 plane cache + const size_t src0_plane_slice_row_count = + std::min((params->get_vtcm_quota_size() - src1_actual_row_stride) / (src0_actual_row_stride * 2), + start_end_element.second - start_end_element.first); + uint8_t * src0_plane_read_cache_ptr = nullptr; + uint8_t * src0_plane_write_cache_ptr = nullptr; + size_t src0_plane_write_cache_offset = 0; + const uint8_t * last_write_cached_plane_ptr = nullptr; + const uint8_t * last_read_cached_plane_ptr = nullptr; + + { + const size_t src0_plane_cache_size = src0_actual_row_stride * src0_plane_slice_row_count; + src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2); + if (!src0_plane_read_cache_ptr) { DEVICE_LOG_ERROR( "mul_mat_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, " - "src0_actual_row_size: %zu, will fallback to mem cache\n", - src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_size); + "src0_actual_row_stride: %zu, will fallback to mem cache\n", + src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_stride); return; } src0_plane_write_cache_ptr = src0_plane_read_cache_ptr + src0_plane_cache_size; + if constexpr (_IsSrcQuantized) { + src0_plane_write_cache_offset = + src0_plane_cache_size - size_t(src0->get_nb(1) * src0_plane_slice_row_count); + } + + DEVICE_LOG_DEBUG( + "[%d]mul_mat_impl, src0_actual_row_stride:%zu, valid_src0_row_bytes:%zu, src_nb0:%zu, " + "slice_row_count:%zu, write_cache_offset: %zu, " + "total_planes:%lld, planes:[%d,%d), rows:[%d,%d), elems:[%d,%d), is_quant:%d, " + "vtcm_mem:%p(%zu)\n", + (int) params->get_thread_index(), src0_actual_row_stride, valid_src0_row_bytes, (size_t) src0->get_nb(1), + src0_plane_slice_row_count, src0_plane_write_cache_offset, total_planes, (int) start_end_plane.first, + (int) start_end_plane.second, (int) start_end_row.first, (int) start_end_row.second, + (int) start_end_element.first, (int) start_end_element.second, _IsSrcQuantized, + (void *) src0_plane_read_cache_ptr, params->get_vtcm_quota_size()); + } - const auto i3 = start_end_plane.first / dst->get_ne(2); - const auto i2 = start_end_plane.first - i3 * dst->get_ne(2); + { + const auto [i3, i2] = unflatten_i3_i2(start_end_plane.first, dst); const uint8_t * src0_plane = src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2) + start_end_element.first * src0->get_nb(1); - const int64_t next_row_count = - std::min(src0_plane_slice_row_count, - start_end_element.second - start_end_element.first); // number of rows in this slice - if (!params->initiate_dma_plane_transfer(src0_plane, src0_plane_write_cache_ptr, - src0_actual_row_size, // TODO: reduce to aligned valid_row0_bytes? - next_row_count, src0_actual_row_size, src0_actual_row_size)) { - DEVICE_LOG_ERROR("mul_mat_impl: failed to continue dma transfer for src0 plane\n"); + const size_t next_row_count = + std::min(src0_plane_slice_row_count, + start_end_element.second - start_end_element.first); // number of rows in this slice + if (!init_dma_transfer<_IsSrcQuantized>( + params, src0_plane, src0_plane_write_cache_ptr + src0_plane_write_cache_offset, valid_src0_row_bytes, + next_row_count, src0->get_nb(1), src0->get_nb(1))) { + DEVICE_LOG_ERROR("mul_mat_impl: failed to continue dma transfer for src0 plane, is_quant: %d\n", + (int) _IsSrcQuantized); return; } + DEVICE_LOG_DEBUG("mul_mat_impl: [i2,i3]:[%d,%d], src0_plane:%p, row_count:%zu\n", (int) i2, (int) i3, + (void *) src0_plane, next_row_count); + last_write_cached_plane_ptr = src0_plane; } - DEVICE_LOG_DEBUG( - "[%d]mul_mat_impl src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, total_planes: %lld, " - "start_end_plane: " - "[%d,%d), start_end_row: [%d,%d), start_end_element: [%d,%d), is_quantized: %d, vtcm_mem: %p(%zu)\n", - (int) params->get_thread_index(), src0_actual_row_size, src0_plane_slice_row_count, total_planes, - (int) start_end_plane.first, (int) start_end_plane.second, (int) start_end_row.first, - (int) start_end_row.second, (int) start_end_element.first, (int) start_end_element.second, _IsSrcQuantized, - (void *) src0_plane_read_cache_ptr, params->get_vtcm_quota_size()); - const size_t valid_row1_bytes = src0->get_ne(0) * sizeof(data_type1); // src0 and src1 should have the same element count in the 1st dimension DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat); @@ -184,11 +216,10 @@ inline void mul_mat_impl(hexagon::tensor * src0, return; } - auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector(); + const auto dequant_table = load_dequant_table_func ? load_dequant_table_func() : HVX_Vector(); const uint8_t * src1_ptr = src1->get_read_buffer(); for (int64_t ip = start_end_plane.first; ip < start_end_plane.second; ip++) { - const auto i3 = ip / dst->get_ne(2); - const auto i2 = ip - i3 * dst->get_ne(2); + const auto [i3, i2] = unflatten_i3_i2(ip, dst); const auto * src1_plane = src1_ptr + i3 * src1->get_nb(3) + i2 * src1->get_nb(2); auto * dst_plane = dst_ptr + i3 * dst->get_nb(3) + i2 * dst->get_nb(2); const uint8_t * src0_plane_base = src0_ptr + i3 / r03 * src0->get_nb(3) + i2 / r02 * src0->get_nb(2); @@ -198,57 +229,37 @@ inline void mul_mat_impl(hexagon::tensor * src0, const int64_t actual_row_count = std::min(src0_plane_slice_row_count, start_end_element.second - col_idx); // number of rows in this slice - if constexpr (_IsSrcQuantized) { - if (last_write_cached_plane_ptr != src0_plane) { - DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant); - - for (int64_t ir = 0; ir < actual_row_count; ir++) { - auto * src0_row = src0_plane + ir * src0->get_nb(1); - if (ir + 1 < actual_row_count) { - hexagon::l2fetch_row(src0_row + src0->get_nb(1), src0->get_nb(1)); - } - - auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_actual_row_size; - dequantize_row_func(src0_row, reinterpret_cast(cached_row_ptr), - src0->get_ne(0), dequant_table); - } - - last_write_cached_plane_ptr = src0_plane; - } - } else { - if (last_read_cached_plane_ptr != src0_plane) { - std::swap(src0_plane_read_cache_ptr, src0_plane_write_cache_ptr); - last_read_cached_plane_ptr = src0_plane; - params->wait_for_dma(); - } + { const uint8_t * src0_next_plane = last_write_cached_plane_ptr; - int64_t next_row_count = 0; + size_t next_row_count = 0; if (col_idx + src0_plane_slice_row_count < start_end_element.second) { const auto next_col_idx = col_idx + src0_plane_slice_row_count; - src0_next_plane = src0_plane_base + next_col_idx * src0_actual_row_size; + src0_next_plane = src0_plane_base + next_col_idx * src0->get_nb(1); next_row_count = - std::min(src0_plane_slice_row_count, - start_end_element.second - next_col_idx); // number of rows in this slice + std::min(src0_plane_slice_row_count, + start_end_element.second - next_col_idx); // number of rows in this slice } else if (ip + 1 < start_end_plane.second) { // prefetch the next plane's first slice - const auto ip_next = ip + 1; - const auto i3_next = ip_next / dst->get_ne(2); - const auto i2_next = ip_next - i3_next * dst->get_ne(2); + const auto [i3_next, i2_next] = unflatten_i3_i2(ip + 1, dst); const uint8_t * src0_next_plane_base = src0_ptr + i3_next / r03 * src0->get_nb(3) + i2_next / r02 * src0->get_nb(2); - src0_next_plane = src0_next_plane_base + start_end_element.first * src0_actual_row_size; - next_row_count = std::min( + src0_next_plane = src0_next_plane_base + start_end_element.first * src0->get_nb(1); + next_row_count = std::min( src0_plane_slice_row_count, start_end_element.second - start_end_element.first); // number of rows in this slice } + if (last_read_cached_plane_ptr != src0_plane) { + std::swap(src0_plane_read_cache_ptr, src0_plane_write_cache_ptr); + params->wait_for_dma(); + } + if (last_write_cached_plane_ptr != src0_next_plane) { - DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dma); - if (!params->initiate_dma_plane_transfer( - src0_next_plane, src0_plane_write_cache_ptr, - src0_actual_row_size, // TODO: reduce to aligned valid_row0_bytes? - next_row_count, src0_actual_row_size, src0_actual_row_size)) { + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, dma); + if (!init_dma_transfer<_IsSrcQuantized>( + params, src0_next_plane, src0_plane_write_cache_ptr + src0_plane_write_cache_offset, + valid_src0_row_bytes, next_row_count, src0->get_nb(1), src0->get_nb(1))) { DEVICE_LOG_ERROR("mul_mat_impl: failed to continue dma transfer for src0 plane\n"); return; } @@ -257,15 +268,30 @@ inline void mul_mat_impl(hexagon::tensor * src0, } } + if constexpr (_IsSrcQuantized) { + if (last_read_cached_plane_ptr != src0_plane) { + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant); + const uint8_t * src0_quant_plane = src0_plane_read_cache_ptr + src0_plane_write_cache_offset; + for (int64_t ir = 0; ir < actual_row_count; ir++) { + auto * src0_row = src0_quant_plane + ir * src0->get_nb(1); + auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_actual_row_stride; + dequantize_row_func(src0_row, reinterpret_cast(cached_row_ptr), + src0->get_ne(0), dequant_table); + } + } + } + + last_read_cached_plane_ptr = src0_plane; + if (start_end_row.second > start_end_row.first) { hexagon::l2fetch_row(src1_plane + start_end_row.first * src1->get_nb(1), valid_row1_bytes); } for (int64_t i1 = start_end_row.first; i1 < start_end_row.second; i1++) { - DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, vec_dot); + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, dot); auto * src1_row = src1_plane + i1 * src1->get_nb(1); auto * dst_row = reinterpret_cast(dst_plane + i1 * dst->get_nb(1)) + col_idx; - batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_actual_row_size, src1_row, + batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_actual_row_stride, src1_row, src1->get_nb(1), dst_row, actual_row_count, (ip + 1 < start_end_plane.second) ? valid_row1_bytes : 0); } @@ -283,7 +309,7 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0, using data_type0 = typename get_data_type::data_type0; using data_type1 = typename get_data_type::data_type1; - const auto src0_actual_row_size = hexagon::get_dequantized_row_size(src0); + const auto src0_actual_row_stride = hexagon::get_dequantized_row_size(src0); auto * dequantize_row_func = hexagon::get_type_traits(src0->get_type()).to_float; auto * load_dequant_table_func = hexagon::get_type_traits(src0->get_type()).load_dequant_table; if (_IsSrcQuantized && dequantize_row_func == nullptr) { @@ -309,52 +335,45 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0, return; } - const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation - const size_t valid_row0_bytes = src0->get_ne(0) * sizeof(data_type0); + const uint8_t * src0_ptr = src0->get_read_buffer(true); // TODO: avoid invalidation + const size_t valid_src0_row_bytes = _IsSrcQuantized ? src0->get_nb(1) : (src0->get_ne(0) * sizeof(data_type0)); // cache the src0 plane in VTCM - size_t src0_plane_slice_row_count = start_end_element.second - start_end_element.first; - size_t src0_plane_cache_size = 0; - uint8_t * src0_plane_read_cache_ptr = nullptr; - uint8_t * src0_plane_write_cache_ptr = nullptr; - const auto src1_actual_row_size = hexagon::get_aligned_size(src1->get_nb(1)); - uint8_t * src1_row_cache_ptr = nullptr; - if constexpr (_IsSrcQuantized) { - src0_plane_slice_row_count = std::min( - (params->get_vtcm_quota_size() - src1_actual_row_size) / src0_actual_row_size, src0_plane_slice_row_count); - src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count; - src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size + src1_actual_row_size); - if (src0_plane_read_cache_ptr == nullptr) { + const size_t src1_actual_row_stride = hexagon::get_aligned_size(src1->get_nb(1)); + const size_t src0_plane_slice_row_count = + std::min((params->get_vtcm_quota_size() - src1_actual_row_stride) / (src0_actual_row_stride * 2), + start_end_element.second - start_end_element.first); + + uint8_t * src0_plane_read_cache_ptr = nullptr; + uint8_t * src0_plane_write_cache_ptr = nullptr; + size_t src0_plane_write_cache_offset = 0; + uint8_t * src1_row_cache_ptr = nullptr; + + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat); + { + const size_t src0_plane_cache_size = src0_actual_row_stride * src0_plane_slice_row_count; + src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2 + src1_actual_row_stride); + if (!src0_plane_read_cache_ptr) { DEVICE_LOG_ERROR( "mul_mat_gemv_impl: failed to get VTCM cache for src0, size: %zu, src0_plane_slice_row_count: %zu, " - "src0_actual_row_size: %zu, will fallback to mem cache\n", - src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_size); - return; - } - - src1_row_cache_ptr = src0_plane_read_cache_ptr + src0_plane_cache_size; - } else { - src0_plane_slice_row_count = - std::min((params->get_vtcm_quota_size() - src1_actual_row_size) / (src0_actual_row_size * 2), - src0_plane_slice_row_count); - src0_plane_cache_size = src0_actual_row_size * src0_plane_slice_row_count; - src0_plane_read_cache_ptr = params->get_vtcm_cache(src0_plane_cache_size * 2 + src1_actual_row_size); - if (src0_plane_read_cache_ptr == nullptr) { - DEVICE_LOG_ERROR("mul_mat_gemv_impl: failed to get VTCM cache for src1, size: %zu\n", src1_actual_row_size); + "src0_actual_row_stride: %zu, will fallback to mem cache\n", + src0_plane_cache_size, src0_plane_slice_row_count, src0_actual_row_stride); return; } src0_plane_write_cache_ptr = src0_plane_read_cache_ptr + src0_plane_cache_size; src1_row_cache_ptr = src0_plane_write_cache_ptr + src0_plane_cache_size; - } - DEVICE_LOG_DEBUG( - "mul_mat_gemv_impl: src0_actual_row_size: %zu, src0_plane_slice_row_count: %zu, is_quantized: %d, vtcm_mem: " - "%p(%zu)\n", - src0_actual_row_size, src0_plane_slice_row_count, _IsSrcQuantized, (void *) src0_plane_read_cache_ptr, - src0_plane_cache_size); + if constexpr (_IsSrcQuantized) { + src0_plane_write_cache_offset = src0_plane_cache_size - (src0->get_nb(1) * src0_plane_slice_row_count); + } - DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(dst, params->get_thread_index(), mul_mat); + DEVICE_LOG_DEBUG( + "mul_mat_gemv_impl: src0_actual_row_stride: %zu, src0_plane_slice_row_count: %zu, " + "src0_plane_write_cache_offset: %zu, src0.nb[1]: %d, is_quantized: %d, vtcm_mem: %p(%zu)\n", + src0_actual_row_stride, src0_plane_slice_row_count, src0_plane_write_cache_offset, int(src0->get_nb(1)), + _IsSrcQuantized, (void *) src0_plane_read_cache_ptr, src0_plane_cache_size); + } uint8_t * dst_ptr = dst->get_write_buffer(); if (!dst_ptr) { @@ -372,66 +391,62 @@ inline void mul_mat_gemv_impl(hexagon::tensor * src0, return; } - if constexpr (!_IsSrcQuantized) { - const uint8_t * src0_plane = src0_ptr + start_end_element.first * src0_actual_row_size; - const int64_t next_row_count = - std::min(src0_plane_slice_row_count, - start_end_element.second - start_end_element.first); // number of rows in this slice - params->wait_for_dma(); - if (!params->initiate_dma_plane_transfer(src0_plane, src0_plane_write_cache_ptr, valid_row0_bytes, - next_row_count, src0_actual_row_size, src0_actual_row_size)) { - DEVICE_LOG_ERROR("mul_mat_gemv_impl: failed to initiate dma transfer for src0 plane\n"); - return; - } - } else { - params->wait_for_dma(); + const uint8_t * src0_plane = src0_ptr + start_end_element.first * src0->get_nb(1); + const int64_t next_row_count = + std::min(src0_plane_slice_row_count, + start_end_element.second - start_end_element.first); // number of rows in this slice + params->wait_for_dma(); + + if (!init_dma_transfer<_IsSrcQuantized>( + params, src0_plane, src0_plane_write_cache_ptr + src0_plane_write_cache_offset, valid_src0_row_bytes, + next_row_count, src0->get_nb(1), src0->get_nb(1))) { + DEVICE_LOG_ERROR("mul_mat_gemv_impl: failed to initiate dma plane transfer for src0 plane, is_quant: %d\n", + (int) _IsSrcQuantized); + return; } } { for (int64_t col_idx = start_end_element.first; col_idx < start_end_element.second; col_idx += src0_plane_slice_row_count) { - const uint8_t * src0_plane = src0_ptr + col_idx * src0->get_nb(1); - const int64_t actual_row_count = + const int64_t actual_row_count = std::min(src0_plane_slice_row_count, start_end_element.second - col_idx); // number of rows in this slice + const auto next_col_idx = col_idx + src0_plane_slice_row_count; + std::swap(src0_plane_read_cache_ptr, src0_plane_write_cache_ptr); + params->wait_for_dma(); + + if (next_col_idx < start_end_element.second) { + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 2, dma); + const uint8_t * src0_next_plane = src0_ptr + next_col_idx * src0->get_nb(1); + const int64_t next_row_count = + std::min(src0_plane_slice_row_count, + start_end_element.second - next_col_idx); // number of rows in this slice + if (!init_dma_transfer<_IsSrcQuantized>( + params, src0_next_plane, src0_plane_write_cache_ptr + src0_plane_write_cache_offset, + valid_src0_row_bytes, next_row_count, src0->get_nb(1), src0->get_nb(1))) { + DEVICE_LOG_ERROR( + "mul_mat_gemv_impl: failed to continue dma plane transfer for src0 plane, is_quant: %d\n", + (int) _IsSrcQuantized); + return; + } + } + if constexpr (_IsSrcQuantized) { DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dequant); - + const uint8_t * src0_quant_plane = src0_plane_read_cache_ptr + src0_plane_write_cache_offset; for (int64_t ir = 0; ir < actual_row_count; ir++) { - auto * src0_row = src0_plane + ir * src0->get_nb(1); - if (ir + 1 < actual_row_count) { - hexagon::l2fetch_row(src0_row + src0->get_nb(1), src0->get_nb(1)); - } - - auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_actual_row_size; + auto * src0_row = src0_quant_plane + ir * src0->get_nb(1); + auto * cached_row_ptr = src0_plane_read_cache_ptr + ir * src0_actual_row_stride; dequantize_row_func(src0_row, reinterpret_cast(cached_row_ptr), src0->get_ne(0), dequant_table); } - } else { - DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 0, dma); - std::swap(src0_plane_read_cache_ptr, src0_plane_write_cache_ptr); - params->wait_for_dma(); - - const auto next_col_idx = col_idx + src0_plane_slice_row_count; - if (next_col_idx < start_end_element.second) { - const uint8_t * src0_next_plane = src0_ptr + next_col_idx * src0_actual_row_size; - const int64_t next_row_count = - std::min(src0_plane_slice_row_count, - start_end_element.second - next_col_idx); // number of rows in this slice - if (!params->initiate_dma_plane_transfer(src0_next_plane, src0_plane_write_cache_ptr, - valid_row0_bytes, next_row_count, src0_actual_row_size, - src0_actual_row_size)) { - DEVICE_LOG_ERROR("mul_mat_gemv_impl: failed to continue dma transfer for src0 plane\n"); - return; - } - } } { - DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, vec_dot); + DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(mul_mat, 1, dot); auto * dst_row = reinterpret_cast(dst_ptr) + col_idx; - batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_actual_row_size, + batched_row_dot<_DotFunc>(src0_plane_read_cache_ptr, src0->get_ne(0), src0_actual_row_stride, src1_row_cache_ptr, src1->get_nb(1), dst_row, actual_row_count, 0); } } diff --git a/ggml/src/ggml-qnn/npu/device/type_traits.cpp b/ggml/src/ggml-qnn/npu/device/type_traits.cpp index f3af0a6408c6c..cc0ca77b38c2b 100644 --- a/ggml/src/ggml-qnn/npu/device/type_traits.cpp +++ b/ggml/src/ggml-qnn/npu/device/type_traits.cpp @@ -44,7 +44,7 @@ template inline HVX_Vector load_block_generic(const _TBlock & } template inline HVX_Vector make_scale_load_mask() { - static_assert(sizeof(_TBlock) < 32, "wrong block size/padding"); + static_assert(sizeof(_TBlock) + sizeof(npu_device_fp16_t) < 32, "wrong block size/padding"); static_assert(sizeof(_TBlock::qs) == 16 || sizeof(_TBlock::qs) == 32, "wrong quantization block size"); constexpr const size_t kScaleBlockSize = QUANT_BLOCK_SIZE * sizeof(hexagon::dequant_output_type); @@ -83,12 +83,13 @@ inline hexagon::HVX_Vector_x2 load_dual_block_generic(const _TBlock * srcs, hexagon::HVX_Vector_x2 result; - HVX_Vector blocks = load_into_vector<_TBlock, 2, &_TBlock::qs>(srcs); + const HVX_Vector blocks = load_struct_into_vector<_TBlock, 2>(srcs); - HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale); + HVX_Vector block0 = Q6_V_vror_VR(blocks, kSizeOfScale); + HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale * 2); HVX_Vector scale01 = Q6_Vb_vshuff_Vb(blocks); - result.val[0] = Q6_V_vmux_QVV(mask, blocks, block1); + result.val[0] = Q6_V_vmux_QVV(mask, block0, block1); result.val[1] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale01, 0); return result; @@ -143,6 +144,50 @@ inline hexagon::HVX_Vector_x3 load_qual_block_generic(const _TBlock * return result; } +template +inline hexagon::HVX_Vector_x5 load_hexa_block_generic(const _TBlock * srcs, + const hexagon::HVX_VectorPred_x3 mask, + const HVX_Vector scale_indices) { + static_assert(hexagon::kBytesPerVector >= sizeof(_TBlock) * 6, "wrong block size/padding"); + constexpr const uint32_t kSizeOfQs = sizeof(_TBlock::qs); + constexpr const uint32_t kSizeOfScale = sizeof(_TBlock) - kSizeOfQs; + + const HVX_Vector blocks = load_struct_into_vector<_TBlock, 6>(srcs); + + hexagon::HVX_Vector_x5 result; + { + HVX_Vector block0 = Q6_V_vror_VR(blocks, kSizeOfScale); + HVX_Vector block1 = Q6_V_vror_VR(blocks, kSizeOfScale * 2); + + HVX_Vector block2 = Q6_V_vror_VR(blocks, kSizeOfScale * 3); + HVX_Vector block3 = Q6_V_vror_VR(blocks, kSizeOfScale * 4); + + HVX_Vector block4 = Q6_V_vror_VR(blocks, kSizeOfScale + sizeof(_TBlock) * 4); + HVX_Vector block5 = Q6_V_vror_VR(blocks, kSizeOfScale * 2 + sizeof(_TBlock) * 4); + + HVX_Vector block01 = Q6_V_vmux_QVV(mask.val[0], block0, block1); + HVX_Vector block23 = Q6_V_vmux_QVV(mask.val[1], block2, block3); + + result.val[0] = Q6_V_vmux_QVV(mask.val[2], block01, block23); + result.val[3] = Q6_V_vmux_QVV(mask.val[0], block4, block5); // block45 + } + + { + HVX_Vector scale23 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 2); + HVX_Vector scale45 = Q6_V_vror_VR(blocks, sizeof(_TBlock) * 4); + + HVX_Vector scale01 = Q6_Vb_vshuff_Vb(blocks); + scale23 = Q6_Vb_vshuff_Vb(scale23); + scale45 = Q6_Vb_vshuff_Vb(scale45); + + result.val[1] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale01, 0); + result.val[2] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale23, 0); + result.val[4] = Q6_Vb_vlut32_VbVbR_nomatch(scale_indices, scale45, 0); + } + + return result; +} + inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) { // TODO: use intrinsics if (j < 4) { @@ -442,71 +487,100 @@ void dequantize_row_q8_0(const void * src, hexagon::dequant_output_type * dst, s } template -void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * dst, size_t count, HVX_Vector table) { - constexpr const int qk = QUANT_BLOCK_SIZE; - static_assert(qk % 2 == 0, "qk must be even"); - static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float)); +inline void dequantize_row_q4_0_2blocks(HVX_Vector qs, + HVX_Vector scale01, + HVX_Vector table, + hexagon::dequant_output_type * dst_ptr) { constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs); - static const auto load_masks = make_quad_block_mask(); - alignas(hexagon::kBytesPerVector) static const HVX_Vector scale_indices = - make_scale_load_mask(); + HVX_Vector q_lo = qs; + HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4); + HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2)); - const int nb = count / qk; - const auto * src_ptr = reinterpret_cast(src); - hexagon::dequant_output_type * dst_ptr = dst; // TODO: opt for aligned access + q_lo = Q6_V_lo_W(qp0); + q_lo = Q6_Vb_vshuff_Vb(q_lo); + qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0); - int i = 0; - for (; i + 3 < nb; i += 4) { - auto qs = load_qual_block_generic(src_ptr + i, load_masks, scale_indices); + q_lo = Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(qp0), scale01); + q_lo = Q6_Vhf_equals_Vqf16(q_lo); - HVX_Vector q_lo = qs.val[0]; - HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs.val[0], 4); + if constexpr (_IsDstAligned) { + *reinterpret_cast(dst_ptr) = q_lo; + } else { + *reinterpret_cast(dst_ptr) = q_lo; + } +} - HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2 + 4)); +template +inline void dequantize_row_q4_0_4blocks(HVX_Vector qs, + HVX_Vector scale01, + HVX_Vector scale23, + HVX_Vector table, + hexagon::dequant_output_type * dst_ptr) { + constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs); - q_lo = Q6_Vb_vshuff_Vb(Q6_V_lo_W(qp0)); - qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0); + HVX_Vector q_lo = qs; + HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs, 4); - q_lo = Q6_V_lo_W(qp0); - q_hi = Q6_V_hi_W(qp0); + HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2 + 4)); - q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, qs.val[1]); - q_hi = Q6_Vqf16_vmpy_VhfVhf(q_hi, qs.val[2]); + q_lo = Q6_V_lo_W(qp0); + q_lo = Q6_Vb_vshuff_Vb(q_lo); + qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0); - q_lo = Q6_Vhf_equals_Vqf16(q_lo); - q_hi = Q6_Vhf_equals_Vqf16(q_hi); + q_lo = Q6_V_lo_W(qp0); + q_hi = Q6_V_hi_W(qp0); - if constexpr (_IsDstAligned) { - reinterpret_cast(dst_ptr)[0] = q_lo; - reinterpret_cast(dst_ptr)[1] = q_hi; - } else { - reinterpret_cast(dst_ptr)[0] = q_lo; - reinterpret_cast(dst_ptr)[1] = q_hi; - } + q_lo = Q6_Vqf16_vmpy_VhfVhf(q_lo, scale01); + q_hi = Q6_Vqf16_vmpy_VhfVhf(q_hi, scale23); + + q_lo = Q6_Vhf_equals_Vqf16(q_lo); + q_hi = Q6_Vhf_equals_Vqf16(q_hi); - dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_output_type) * 2; + if constexpr (_IsDstAligned) { + reinterpret_cast(dst_ptr)[0] = q_lo; + reinterpret_cast(dst_ptr)[1] = q_hi; + } else { + reinterpret_cast(dst_ptr)[0] = q_lo; + reinterpret_cast(dst_ptr)[1] = q_hi; } +} - for (; i + 1 < nb; i += 2) { - auto qs = load_dual_block_generic(src_ptr + i, load_masks.val[0], scale_indices); - HVX_Vector q_lo = qs.val[0]; - HVX_Vector q_hi = Q6_Vub_vlsr_VubR(qs.val[0], 4); - HVX_VectorPair qp0 = Q6_W_vshuff_VVR(q_hi, q_lo, kSizeOfQs * (1 + 2)); +template +void dequantize_row_q4_0_impl(const void * src, hexagon::dequant_output_type * dst, size_t count, HVX_Vector table) { + constexpr const size_t kElemsPerVec = hexagon::kBytesPerVector / sizeof(hexagon::dequant_output_type); + constexpr const uint32_t kSizeOfQs = sizeof(npu_device_block_q4_0::qs); + constexpr const int qk = QUANT_BLOCK_SIZE; + static_assert(qk % 2 == 0, "qk must be even"); + static_assert(QUANT_BLOCK_SIZE == hexagon::kBytesPerVector / sizeof(float)); - q_lo = Q6_Vb_vshuff_Vb(Q6_V_lo_W(qp0)); - qp0 = Q6_Wh_vlut16_VbVhR_nomatch(q_lo, table, 0); + static const auto load_masks = make_quad_block_mask(); + alignas(hexagon::kBytesPerVector) static const HVX_Vector scale_indices = + make_scale_load_mask(); - q_lo = Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(qp0), qs.val[1]); - q_lo = Q6_Vhf_equals_Vqf16(q_lo); + const int nb = count / qk; + const auto * src_ptr = reinterpret_cast(src); - if constexpr (_IsDstAligned) { - *reinterpret_cast(dst_ptr) = q_lo; - } else { - *reinterpret_cast(dst_ptr) = q_lo; - } + hexagon::dequant_output_type * dst_ptr = dst; // TODO: opt for aligned access + + int i = 0; + for (; i + 5 < nb; i += 6) { + auto qs = load_hexa_block_generic(src_ptr + i, load_masks, scale_indices); + dequantize_row_q4_0_4blocks<_IsDstAligned>(qs.val[0], qs.val[1], qs.val[2], table, dst_ptr); + dequantize_row_q4_0_2blocks<_IsDstAligned>(qs.val[3], qs.val[4], table, dst_ptr + kElemsPerVec * 2); + dst_ptr += kElemsPerVec * 3; + } - dst_ptr += hexagon::kBytesPerVector / sizeof(hexagon::dequant_output_type); + for (; i + 3 < nb; i += 4) { + auto qs = load_qual_block_generic(src_ptr + i, load_masks, scale_indices); + dequantize_row_q4_0_4blocks<_IsDstAligned>(qs.val[0], qs.val[1], qs.val[2], table, dst_ptr); + dst_ptr += kElemsPerVec * 2; + } + + for (; i + 1 < nb; i += 2) { + auto qs = load_dual_block_generic(src_ptr + i, load_masks.val[0], scale_indices); + dequantize_row_q4_0_2blocks<_IsDstAligned>(qs.val[0], qs.val[1], table, dst_ptr); + dst_ptr += kElemsPerVec; } if (i < nb) { diff --git a/ggml/src/ggml-qnn/npu/device/vec_ops.hpp b/ggml/src/ggml-qnn/npu/device/vec_ops.hpp index e286aebbb569b..f72e1c37c0fd3 100644 --- a/ggml/src/ggml-qnn/npu/device/vec_ops.hpp +++ b/ggml/src/ggml-qnn/npu/device/vec_ops.hpp @@ -18,6 +18,7 @@ template struct HEXAGON_pack { using HVX_Vector_x2 = HEXAGON_pack; using HVX_Vector_x3 = HEXAGON_pack; using HVX_Vector_x4 = HEXAGON_pack; +using HVX_Vector_x5 = HEXAGON_pack; using HVX_VectorPair_x4 = HEXAGON_pack; using HVX_VectorPred_x3 = HEXAGON_pack;