1+ /* **************************************************************************************************
2+ * Copyright (C) 2025 Intel Corporation, All rights reserved.
3+ * SPDX-License-Identifier: BSD-3-Clause
4+ *
5+ * Redistribution and use in source and binary forms, with or without
6+ * modification, are permitted provided that the following conditions are met:
7+ *
8+ * 1. Redistributions of source code must retain the above copyright notice, this
9+ * list of conditions and the following disclaimer.
10+ *
11+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12+ * this list of conditions and the following disclaimer in the documentation
13+ * and/or other materials provided with the distribution.
14+ *
15+ * 3. Neither the name of the copyright holder nor the names of its
16+ * contributors may be used to endorse or promote products derived from
17+ * this software without specific prior written permission.
18+ *
19+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+ *
30+ **************************************************************************************************/
31+
32+ #pragma once
33+
34+ namespace cute {
35+
36+ /* Flat copies */
37+ template <class SrcEngine , class SrcLayout ,
38+ class DstEngine , class DstLayout >
39+ CUTE_HOST_DEVICE
40+ void
41+ copy_block_r2s (Tensor<SrcEngine, SrcLayout> const & src,
42+ Tensor<DstEngine, DstLayout> & dst)
43+ {
44+ static_assert (is_rmem_v<SrcEngine> && is_smem_v<DstEngine>, " Expected rmem->smem copy" );
45+
46+ auto atom_r2s = Copy_Atom<XE_1D_STSM<float >, float >{}; // TODO: larger block messages
47+
48+ auto atom_shape = make_shape (_1{}, size (src));
49+ auto src_v = src.compose (make_layout (atom_shape));
50+ auto dst_v = dst.compose (make_layout (atom_shape, Stride<_0, _16>{}));
51+
52+ copy (atom_r2s, src_v, dst_v);
53+ }
54+
55+ template <class SrcEngine , class SrcLayout ,
56+ class DstEngine , class DstLayout >
57+ CUTE_HOST_DEVICE
58+ void
59+ copy_block_s2r (Tensor<SrcEngine, SrcLayout> const & src,
60+ Tensor<DstEngine, DstLayout> & dst)
61+ {
62+ static_assert (is_smem_v<SrcEngine> && is_rmem_v<DstEngine>, " Expected smem->rmem copy" );
63+
64+ auto atom_s2r = Copy_Atom<XE_1D_LDSM<float >, float >{};
65+
66+ auto atom_shape = make_shape (_1{}, size (dst));
67+ auto src_v = src.compose (make_layout (atom_shape, Stride<_0, _16>{}));
68+ auto dst_v = dst.compose (make_layout (atom_shape));
69+
70+ copy (atom_s2r, src_v, dst_v);
71+ }
72+
73+ /* Coordinate-aware copies */
74+ template <class SrcEngine , class SrcLayout , class SrcCoordLayout ,
75+ class DstEngine , class DstLayout , class DstCoordLayout >
76+ CUTE_HOST_DEVICE
77+ void
78+ copy_block_r2s (SubgroupTensor<SrcEngine, SrcLayout, SrcCoordLayout> const & src,
79+ Tensor<DstEngine, DstLayout> & dst,
80+ DstCoordLayout const & dst_c)
81+ {
82+ static_assert (is_rmem_v<SrcEngine> && is_smem_v<DstEngine>, " Expected rmem->smem copy" );
83+
84+ auto atom_r2s = Copy_Atom<XE_1D_STSM<float >, float >{}; // TODO: larger block messages
85+
86+ auto atom_shape = make_shape (_1{}, size (SrcLayout{}));
87+
88+ auto src_c_wi0 = composition (project_strides (SrcCoordLayout{}), make_layout (atom_shape, Stride<_0, _16>{}));
89+ auto rlayout = composition (right_inverse (project_strides (dst_c)), src_c_wi0);
90+
91+ auto src_v = src.compose (make_layout (atom_shape));
92+ auto dst_v = dst.compose (rlayout);
93+
94+ copy (atom_r2s, src_v, dst_v);
95+ }
96+
97+ template <class SrcEngine , class SrcLayout , class SrcCoordLayout ,
98+ class DstEngine , class DstLayout , class DstCoordLayout >
99+ CUTE_HOST_DEVICE
100+ void
101+ copy_block_s2r (Tensor<SrcEngine, SrcLayout> const & src,
102+ SrcCoordLayout const & src_c,
103+ SubgroupTensor<DstEngine, DstLayout, DstCoordLayout> & dst)
104+ {
105+ static_assert (is_smem_v<SrcEngine> && is_rmem_v<DstEngine>, " Expected smem->rmem copy" );
106+
107+ auto atom_s2r = Copy_Atom<XE_1D_LDSM<float >, float >{};
108+
109+ auto atom_shape = make_shape (_1{}, size (DstLayout{}));
110+
111+ auto dst_c_wi0 = composition (project_strides (DstCoordLayout{}), make_layout (atom_shape, Stride<_0, _16>{}));
112+ auto rlayout = composition (right_inverse (project_strides (src_c)), dst_c_wi0);
113+
114+ auto src_v = src.compose (rlayout);
115+ auto dst_v = dst.compose (make_layout (atom_shape));
116+
117+ copy (atom_s2r, src_v, dst_v);
118+ }
119+
120+ /* Variants accepting rvalue dst */
121+ template <class SrcEngine , class SrcLayout ,
122+ class DstEngine , class DstLayout >
123+ CUTE_HOST_DEVICE
124+ void
125+ copy_block_r2s (Tensor<SrcEngine, SrcLayout> const & src,
126+ Tensor<DstEngine, DstLayout> && dst)
127+ {
128+ return copy_block_r2s (src, dst);
129+ }
130+
131+ template <class SrcEngine , class SrcLayout ,
132+ class DstEngine , class DstLayout >
133+ CUTE_HOST_DEVICE
134+ void
135+ copy_block_s2r (Tensor<SrcEngine, SrcLayout> const & src,
136+ Tensor<DstEngine, DstLayout> && dst)
137+ {
138+ return copy_block_s2r (src, dst);
139+ }
140+
141+ template <class SrcEngine , class SrcLayout , class SrcCoordLayout ,
142+ class DstEngine , class DstLayout , class DstCoordLayout >
143+ CUTE_HOST_DEVICE
144+ void
145+ copy_block_r2s (SubgroupTensor<SrcEngine, SrcLayout, SrcCoordLayout> const & src,
146+ Tensor<DstEngine, DstLayout> && dst,
147+ DstCoordLayout const & dst_c)
148+ {
149+ return copy_block_r2s (src, dst, dst_c);
150+ }
151+
152+ template <class SrcEngine , class SrcLayout , class SrcCoordLayout ,
153+ class DstEngine , class DstLayout , class DstCoordLayout >
154+ CUTE_HOST_DEVICE
155+ void
156+ copy_block_s2r (Tensor<SrcEngine, SrcLayout> const & src,
157+ SrcCoordLayout const & src_c,
158+ SubgroupTensor<DstEngine, DstLayout, DstCoordLayout> && dst)
159+ {
160+ return copy_block_s2r (src, dst);
161+ }
162+
163+ } /* namespace cute */
0 commit comments