Skip to content

Commit b0e30f4

Browse files
committed
[Xe] Re-implement FlashAttention with new atoms
1 parent 22993de commit b0e30f4

File tree

10 files changed

+2037
-0
lines changed

10 files changed

+2037
-0
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/***************************************************************************************************
2+
* Copyright (C) 2025 Intel Corporation, All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*
30+
**************************************************************************************************/
31+
32+
#pragma once
33+
34+
namespace cute {
35+
36+
/* Flat copies */
37+
template <class SrcEngine, class SrcLayout,
38+
class DstEngine, class DstLayout>
39+
CUTE_HOST_DEVICE
40+
void
41+
copy_block_r2s(Tensor<SrcEngine, SrcLayout> const& src,
42+
Tensor<DstEngine, DstLayout> & dst)
43+
{
44+
static_assert(is_rmem_v<SrcEngine> && is_smem_v<DstEngine>, "Expected rmem->smem copy");
45+
46+
auto atom_r2s = Copy_Atom<XE_1D_STSM<float>, float>{}; // TODO: larger block messages
47+
48+
auto atom_shape = make_shape(_1{}, size(src));
49+
auto src_v = src.compose(make_layout(atom_shape));
50+
auto dst_v = dst.compose(make_layout(atom_shape, Stride<_0, _16>{}));
51+
52+
copy(atom_r2s, src_v, dst_v);
53+
}
54+
55+
template <class SrcEngine, class SrcLayout,
56+
class DstEngine, class DstLayout>
57+
CUTE_HOST_DEVICE
58+
void
59+
copy_block_s2r(Tensor<SrcEngine, SrcLayout> const& src,
60+
Tensor<DstEngine, DstLayout> & dst)
61+
{
62+
static_assert(is_smem_v<SrcEngine> && is_rmem_v<DstEngine>, "Expected smem->rmem copy");
63+
64+
auto atom_s2r = Copy_Atom<XE_1D_LDSM<float>, float>{};
65+
66+
auto atom_shape = make_shape(_1{}, size(dst));
67+
auto src_v = src.compose(make_layout(atom_shape, Stride<_0, _16>{}));
68+
auto dst_v = dst.compose(make_layout(atom_shape));
69+
70+
copy(atom_s2r, src_v, dst_v);
71+
}
72+
73+
/* Coordinate-aware copies */
74+
template <class SrcEngine, class SrcLayout, class SrcCoordLayout,
75+
class DstEngine, class DstLayout, class DstCoordLayout>
76+
CUTE_HOST_DEVICE
77+
void
78+
copy_block_r2s(SubgroupTensor<SrcEngine, SrcLayout, SrcCoordLayout> const& src,
79+
Tensor<DstEngine, DstLayout> & dst,
80+
DstCoordLayout const& dst_c)
81+
{
82+
static_assert(is_rmem_v<SrcEngine> && is_smem_v<DstEngine>, "Expected rmem->smem copy");
83+
84+
auto atom_r2s = Copy_Atom<XE_1D_STSM<float>, float>{}; // TODO: larger block messages
85+
86+
auto atom_shape = make_shape(_1{}, size(SrcLayout{}));
87+
88+
auto src_c_wi0 = composition(project_strides(SrcCoordLayout{}), make_layout(atom_shape, Stride<_0, _16>{}));
89+
auto rlayout = composition(right_inverse(project_strides(dst_c)), src_c_wi0);
90+
91+
auto src_v = src.compose(make_layout(atom_shape));
92+
auto dst_v = dst.compose(rlayout);
93+
94+
copy(atom_r2s, src_v, dst_v);
95+
}
96+
97+
template <class SrcEngine, class SrcLayout, class SrcCoordLayout,
98+
class DstEngine, class DstLayout, class DstCoordLayout>
99+
CUTE_HOST_DEVICE
100+
void
101+
copy_block_s2r(Tensor<SrcEngine, SrcLayout> const& src,
102+
SrcCoordLayout const& src_c,
103+
SubgroupTensor<DstEngine, DstLayout, DstCoordLayout> & dst)
104+
{
105+
static_assert(is_smem_v<SrcEngine> && is_rmem_v<DstEngine>, "Expected smem->rmem copy");
106+
107+
auto atom_s2r = Copy_Atom<XE_1D_LDSM<float>, float>{};
108+
109+
auto atom_shape = make_shape(_1{}, size(DstLayout{}));
110+
111+
auto dst_c_wi0 = composition(project_strides(DstCoordLayout{}), make_layout(atom_shape, Stride<_0, _16>{}));
112+
auto rlayout = composition(right_inverse(project_strides(src_c)), dst_c_wi0);
113+
114+
auto src_v = src.compose(rlayout);
115+
auto dst_v = dst.compose(make_layout(atom_shape));
116+
117+
copy(atom_s2r, src_v, dst_v);
118+
}
119+
120+
/* Variants accepting rvalue dst */
121+
template <class SrcEngine, class SrcLayout,
122+
class DstEngine, class DstLayout>
123+
CUTE_HOST_DEVICE
124+
void
125+
copy_block_r2s(Tensor<SrcEngine, SrcLayout> const& src,
126+
Tensor<DstEngine, DstLayout> && dst)
127+
{
128+
return copy_block_r2s(src, dst);
129+
}
130+
131+
template <class SrcEngine, class SrcLayout,
132+
class DstEngine, class DstLayout>
133+
CUTE_HOST_DEVICE
134+
void
135+
copy_block_s2r(Tensor<SrcEngine, SrcLayout> const& src,
136+
Tensor<DstEngine, DstLayout> && dst)
137+
{
138+
return copy_block_s2r(src, dst);
139+
}
140+
141+
template <class SrcEngine, class SrcLayout, class SrcCoordLayout,
142+
class DstEngine, class DstLayout, class DstCoordLayout>
143+
CUTE_HOST_DEVICE
144+
void
145+
copy_block_r2s(SubgroupTensor<SrcEngine, SrcLayout, SrcCoordLayout> const& src,
146+
Tensor<DstEngine, DstLayout> && dst,
147+
DstCoordLayout const& dst_c)
148+
{
149+
return copy_block_r2s(src, dst, dst_c);
150+
}
151+
152+
template <class SrcEngine, class SrcLayout, class SrcCoordLayout,
153+
class DstEngine, class DstLayout, class DstCoordLayout>
154+
CUTE_HOST_DEVICE
155+
void
156+
copy_block_s2r(Tensor<SrcEngine, SrcLayout> const& src,
157+
SrcCoordLayout const& src_c,
158+
SubgroupTensor<DstEngine, DstLayout, DstCoordLayout> && dst)
159+
{
160+
return copy_block_s2r(src, dst);
161+
}
162+
163+
} /* namespace cute */

0 commit comments

Comments
 (0)