6868#include < aclnnop/aclnn_grouped_matmul_v3.h>
6969#include < aclnnop/aclnn_fused_infer_attention_score_v2.h>
7070#include < aclnnop/aclnn_zero.h>
71+ #include < aclnnop/aclnn_index_copy.h>
72+ #include < aclnnop/aclnn_index_select.h>
7173#include < float.h>
7274
7375#include < cmath>
@@ -1614,50 +1616,97 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
16141616}
16151617
16161618/* *
1617- * @brief Performs embedding operation on a 4D tensor using the CANN backend.
1619+ * @brief Performs index select operation on a 4D tensor using the CANN backend.
16181620 *
1619- * This function extracts slices from the source tensor (`src_buffer`),
1620- * index tensor (`index`), and destination tensor (`dst`), and performs an
1621- * embedding operation on them. The embedding operation is applied by iterating
1622- * over the last two dimensions of the source tensor, creating the necessary
1623- * tensors for the source, index, and output, and executing the embedding operation .
1621+ * This function applies the `IndexSelect` operation along a specific dimension
1622+ * of the source tensor (`src_buffer`) using the indices from the index tensor (`index`).
1623+ * It iterates over the last two dimensions of the source tensor, creates the corresponding
1624+ * CANN tensors for the source, index, and output slices, and executes the `IndexSelect`
1625+ * operation for each slice .
16241626 *
16251627 * @param ctx The context for CANN backend operations.
1626- * @param src_buffer The source buffer holding the data for the source tensor.
1628+ * @param src_buffer The source buffer containing the 4D input tensor data .
16271629 * @param src_ne The dimensions of the source tensor.
16281630 * @param src_nb The strides (byte offsets) of the source tensor.
1629- * @param index The index tensor used in the embedding operation.
1630- * @param dst The destination tensor where the result will be stored.
1631+ * @param dst_buffer The destination buffer where the output tensor data will be written.
1632+ * @param dst_ne The dimensions of the destination tensor.
1633+ * @param dst_nb The strides (byte offsets) of the destination tensor.
1634+ * @param index The index tensor specifying the indices to select from the source tensor.
1635+ * @param type The data type of the source and destination tensors.
16311636 */
1632- static void aclnn_embedding_4d (ggml_backend_cann_context& ctx, void * src_buffer,
1633- int64_t * src_ne, size_t * src_nb, ggml_tensor* index,
1634- ggml_tensor* dst) {
1637+ static void aclnn_index_select_4d (ggml_backend_cann_context& ctx,
1638+ void * src_buffer,int64_t * src_ne, size_t * src_nb,
1639+ void * dst_buffer, int64_t * dst_ne, size_t * dst_nb,
1640+ ggml_tensor* index, ggml_type type) {
16351641 for (int64_t i = 0 ; i < src_ne[3 ]; i++) {
16361642 for (int64_t j = 0 ; j < src_ne[2 ]; j++) {
16371643 // src
1638- int64_t acl_src_ne[2 ] = {src_ne[0 ], src_ne[1 ]};
1639- size_t acl_src_nb[2 ] = {src_nb[0 ], src_nb[1 ]};
16401644 aclTensor* acl_src_tensor = ggml_cann_create_tensor (
16411645 (char *)src_buffer + i * src_nb[3 ] + j * src_nb[2 ],
1642- ggml_cann_type_mapping (dst-> type ), ggml_element_size (dst ),
1643- acl_src_ne, acl_src_nb , 2 );
1646+ ggml_cann_type_mapping (type), ggml_type_size (type ),
1647+ src_ne, src_nb , 2 );
16441648
16451649 // index
1646- int64_t acl_index_ne[1 ] = {index->ne [0 ]};
1647- size_t acl_index_nb[1 ] = {index->nb [0 ]};
16481650 aclTensor* acl_index = ggml_cann_create_tensor (
1649- (char *)index->data + i * index->nb [2 ] + j * index->nb [1 ],
1651+ (char *)index->data + (i % index-> ne [ 2 ]) * index->nb [2 ] + (j % index-> ne [ 1 ]) * index->nb [1 ],
16501652 ggml_cann_type_mapping (index->type ), ggml_element_size (index),
1651- acl_index_ne, acl_index_nb , 1 );
1653+ index-> ne , index-> nb , 1 );
16521654
16531655 // out
1654- int64_t acl_out_ne[2 ] = {dst->ne [0 ], dst->ne [1 ]};
1655- size_t acl_out_nb[2 ] = {dst->nb [0 ], dst->nb [1 ]};
16561656 aclTensor* acl_out = ggml_cann_create_tensor (
1657- (char *)dst->data + i * dst->nb [3 ] + j * dst->nb [2 ],
1658- ggml_cann_type_mapping (dst->type ), ggml_element_size (dst),
1659- acl_out_ne, acl_out_nb, 2 );
1660- GGML_CANN_CALL_ACLNN_OP (ctx, Embedding, acl_src_tensor, acl_index, acl_out);
1657+ (char *)dst_buffer + i * dst_nb[3 ] + j * dst_nb[2 ],
1658+ ggml_cann_type_mapping (type), ggml_type_size (type),
1659+ dst_ne, dst_nb, 2 );
1660+ GGML_CANN_CALL_ACLNN_OP (ctx, IndexSelect, acl_src_tensor, 0 , acl_index, acl_out);
1661+ ggml_cann_release_resources (ctx, acl_src_tensor, acl_index, acl_out);
1662+ }
1663+ }
1664+ }
1665+
1666+ /* *
1667+ * @brief Performs inplace index copy operation on a 4D tensor using the CANN backend.
1668+ *
1669+ * This function applies the `IndexCopy` operation along a specific dimension of the
1670+ * destination tensor (`dst_buffer`) by copying elements from the source tensor (`src_buffer`)
1671+ * to positions specified by the index tensor (`index`).
1672+ * It iterates over the last two dimensions of the tensors, creates the corresponding
1673+ * CANN tensors for source, index, and destination slices, and performs the index copy
1674+ * operation for each slice.
1675+ *
1676+ * @param ctx The context for CANN backend operations.
1677+ * @param src_buffer The source buffer containing the 4D input tensor data to be copied.
1678+ * @param src_ne The dimensions of the source tensor.
1679+ * @param src_nb The strides (byte offsets) of the source tensor.
1680+ * @param dst_buffer The destination buffer where values will be copied to.
1681+ * @param dst_ne The dimensions of the destination tensor.
1682+ * @param dst_nb The strides (byte offsets) of the destination tensor.
1683+ * @param index The index tensor specifying target positions in the destination tensor.
1684+ * @param type The data type of the source and destination tensors.
1685+ */
1686+ static void aclnn_index_copy_4d (ggml_backend_cann_context& ctx,
1687+ void * src_buffer,int64_t * src_ne, size_t * src_nb,
1688+ void * dst_buffer, int64_t * dst_ne, size_t * dst_nb,
1689+ ggml_tensor* index, ggml_type type) {
1690+ for (int64_t i = 0 ; i < src_ne[3 ]; i++) {
1691+ for (int64_t j = 0 ; j < src_ne[2 ]; j++) {
1692+ // src
1693+ aclTensor* acl_src_tensor = ggml_cann_create_tensor (
1694+ (char *)src_buffer + i * src_nb[3 ] + j * src_nb[2 ],
1695+ ggml_cann_type_mapping (type), ggml_type_size (type),
1696+ src_ne, src_nb, 2 );
1697+
1698+ // index
1699+ aclTensor* acl_index = ggml_cann_create_tensor (
1700+ (char *)index->data + (i % index->ne [2 ]) * index->nb [2 ] + (j % index->ne [1 ]) * index->nb [1 ],
1701+ ggml_cann_type_mapping (index->type ), ggml_element_size (index),
1702+ index->ne , index->nb , 1 );
1703+
1704+ // out
1705+ aclTensor* acl_out = ggml_cann_create_tensor (
1706+ (char *)dst_buffer + i * dst_nb[3 ] + j * dst_nb[2 ],
1707+ ggml_cann_type_mapping (type), ggml_type_size (type),
1708+ dst_ne, dst_nb, 2 );
1709+ GGML_CANN_CALL_ACLNN_OP (ctx, InplaceIndexCopy, acl_out, 0 , acl_index, acl_src_tensor);
16611710 ggml_cann_release_resources (ctx, acl_src_tensor, acl_index, acl_out);
16621711 }
16631712 }
@@ -1669,8 +1718,9 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
16691718
16701719 switch (src0->type ) {
16711720 case GGML_TYPE_F32: {
1672- aclnn_embedding_4d (ctx, src0->data , src0->ne , src0->nb , src1,
1673- dst);
1721+ aclnn_index_select_4d (ctx, src0->data , src0->ne , src0->nb ,
1722+ dst->data , dst->ne , dst->nb ,
1723+ src1, dst->type );
16741724 break ;
16751725 }
16761726 case GGML_TYPE_F16: {
@@ -1687,8 +1737,9 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
16871737 src_trans_buffer, ACL_FLOAT, ggml_type_size (dst->type ),
16881738 src0->ne , src_trans_nb, GGML_MAX_DIMS);
16891739 aclnn_cast (ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping (dst->type ));
1690- aclnn_embedding_4d (ctx, src_trans_buffer, src0->ne ,
1691- src_trans_nb, src1, dst);
1740+ aclnn_index_select_4d (ctx, src_trans_buffer, src0->ne , src_trans_nb,
1741+ dst->data , dst->ne , dst->nb ,
1742+ src1, dst->type );
16921743 ggml_cann_release_resources (ctx, acl_src0, src_trans_tensor);
16931744 break ;
16941745 }
@@ -1748,8 +1799,10 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
17481799 dequant_nb[i] = dequant_nb[i - 1 ] * src0->ne [i - 1 ];
17491800 }
17501801
1751- aclnn_embedding_4d (ctx, dequant_buffer_allocator.get (),
1752- dequant_ne, dequant_nb, src1, dst);
1802+ aclnn_index_select_4d (ctx, dequant_buffer_allocator.get (),
1803+ dequant_ne, dequant_nb,
1804+ dst->data , dst->ne , dst->nb ,
1805+ src1, dst->type );
17531806
17541807 ggml_cann_release_resources (ctx, dequant_tensor);
17551808 break ;
@@ -1760,6 +1813,43 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
17601813 }
17611814}
17621815
1816+ void ggml_cann_set_rows (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1817+ ggml_tensor* src0 = dst->src [0 ]; // src
1818+ ggml_tensor* src1 = dst->src [1 ]; // index
1819+
1820+ switch (dst->type ) {
1821+ case GGML_TYPE_F32: {
1822+ aclnn_index_copy_4d (ctx, src0->data , src0->ne , src0->nb ,
1823+ dst->data , dst->ne , dst->nb ,
1824+ src1, dst->type );
1825+ break ;
1826+ }
1827+ case GGML_TYPE_F16: {
1828+ aclTensor* acl_src0 = ggml_cann_create_tensor (src0);
1829+ ggml_cann_pool_alloc src_buffer_allocator (
1830+ ctx.pool (), ggml_nelements (src0) * sizeof (float16_t ));
1831+ void * src_trans_buffer = src_buffer_allocator.get ();
1832+ size_t src_trans_nb[GGML_MAX_DIMS];
1833+ src_trans_nb[0 ] = sizeof (float16_t );
1834+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
1835+ src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
1836+ }
1837+ aclTensor* src_trans_tensor = ggml_cann_create_tensor (
1838+ src_trans_buffer, ACL_FLOAT16, ggml_type_size (dst->type ),
1839+ src0->ne , src_trans_nb, GGML_MAX_DIMS);
1840+ aclnn_cast (ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping (dst->type ));
1841+ aclnn_index_copy_4d (ctx, src_trans_buffer, src0->ne , src_trans_nb,
1842+ dst->data , dst->ne , dst->nb ,
1843+ src1, dst->type );
1844+ ggml_cann_release_resources (ctx, acl_src0, src_trans_tensor);
1845+ break ;
1846+ }
1847+ default :
1848+ GGML_ABORT (" Unsupported tensor type for GGML_OP_SET_ROWS" );
1849+ break ;
1850+ }
1851+ }
1852+
17631853/* *
17641854 * @brief Repeats elements of a tensor along a specified dimension.
17651855 *
0 commit comments