Skip to content

Commit 460d34a

Browse files
committed
[Xe] Re-implement FlashAttention with new atoms
1 parent cc6d10d commit 460d34a

File tree

10 files changed

+2022
-0
lines changed

10 files changed

+2022
-0
lines changed
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/***************************************************************************************************
2+
* Copyright (C) 2025 Intel Corporation, All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*
30+
**************************************************************************************************/
31+
32+
#pragma once
33+
34+
namespace cute {
35+
36+
/* Flat copies */
37+
template <class SrcEngine, class SrcLayout,
38+
class DstEngine, class DstLayout>
39+
CUTE_HOST_DEVICE
40+
void
41+
copy_block_r2s(Tensor<SrcEngine, SrcLayout> const& src,
42+
Tensor<DstEngine, DstLayout> & dst)
43+
{
44+
static_assert(is_rmem_v<SrcEngine> && is_smem_v<DstEngine>, "Expected rmem->smem copy");
45+
46+
auto atom_r2s = Copy_Atom<XE_1D_STSM<float>, float>{}; // TODO: larger block messages
47+
48+
auto atom_shape = make_shape(_1{}, size(src));
49+
auto src_v = src.compose(make_layout(atom_shape));
50+
auto dst_v = dst.compose(make_layout(atom_shape, Stride<_0, _16>{}));
51+
52+
copy(atom_r2s, src_v, dst_v);
53+
}
54+
55+
template <class SrcEngine, class SrcLayout,
56+
class DstEngine, class DstLayout>
57+
CUTE_HOST_DEVICE
58+
void
59+
copy_block_s2r(Tensor<SrcEngine, SrcLayout> const& src,
60+
Tensor<DstEngine, DstLayout> & dst)
61+
{
62+
static_assert(is_smem_v<SrcEngine> && is_rmem_v<DstEngine>, "Expected smem->rmem copy");
63+
64+
auto atom_s2r = Copy_Atom<XE_1D_LDSM<float>, float>{};
65+
66+
auto atom_shape = make_shape(_1{}, size(dst));
67+
auto src_v = src.compose(make_layout(atom_shape, Stride<_0, _16>{}));
68+
auto dst_v = dst.compose(make_layout(atom_shape));
69+
70+
copy(atom_s2r, src_v, dst_v);
71+
}
72+
73+
/* Coordinate-aware copies */
74+
template <class SrcEngine, class SrcLayout, class SrcCoordLayout,
75+
class DstEngine, class DstLayout, class DstCoordLayout>
76+
CUTE_HOST_DEVICE
77+
void
78+
copy_block_r2s(SubgroupTensor<SrcEngine, SrcLayout, SrcCoordLayout> const& src,
79+
Tensor<DstEngine, DstLayout> & dst,
80+
DstCoordLayout const& dst_c)
81+
{
82+
using _SG = intel::_SGSize;
83+
84+
static_assert(is_rmem_v<SrcEngine> && is_smem_v<DstEngine>, "Expected rmem->smem copy");
85+
static_assert(sizeof_bits_v<typename SrcEngine::value_type> == 32, "Only 32-bit data supported");
86+
87+
auto atom_r2s = Copy_Atom<XE_1D_STSM<float>, float>{}; // TODO: larger block messages
88+
89+
auto atom_shape = make_shape(_1{}, size(SrcLayout{}));
90+
91+
auto src_c_wi0 = composition(project_strides(SrcCoordLayout{}), make_layout(atom_shape, Stride<_0, _SG>{}));
92+
auto rlayout = composition(right_inverse(project_strides(dst_c)), src_c_wi0);
93+
94+
auto src_v = src.compose(make_layout(atom_shape));
95+
auto dst_v = dst.compose(rlayout);
96+
97+
copy(atom_r2s, src_v, dst_v);
98+
}
99+
100+
template <class SrcEngine, class SrcLayout, class SrcCoordLayout,
101+
class DstEngine, class DstLayout, class DstCoordLayout>
102+
CUTE_HOST_DEVICE
103+
void
104+
copy_block_s2r(Tensor<SrcEngine, SrcLayout> const& src,
105+
SrcCoordLayout const& src_c,
106+
SubgroupTensor<DstEngine, DstLayout, DstCoordLayout> & dst)
107+
{
108+
using _SG = intel::_SGSize;
109+
110+
static_assert(is_smem_v<SrcEngine> && is_rmem_v<DstEngine>, "Expected smem->rmem copy");
111+
static_assert(sizeof_bits_v<typename SrcEngine::value_type> == 32, "Only 32-bit data supported");
112+
113+
auto atom_s2r = Copy_Atom<XE_1D_LDSM<float>, float>{};
114+
115+
auto atom_shape = make_shape(_1{}, size(DstLayout{}));
116+
117+
auto dst_c_wi0 = composition(project_strides(DstCoordLayout{}), make_layout(atom_shape, Stride<_0, _SG>{}));
118+
auto rlayout = composition(right_inverse(project_strides(src_c)), dst_c_wi0);
119+
120+
auto src_v = src.compose(rlayout);
121+
auto dst_v = dst.compose(make_layout(atom_shape));
122+
123+
copy(atom_s2r, src_v, dst_v);
124+
}
125+
126+
/* Variants accepting rvalue dst */
127+
template <class SrcEngine, class SrcLayout,
128+
class DstEngine, class DstLayout>
129+
CUTE_HOST_DEVICE
130+
void
131+
copy_block_r2s(Tensor<SrcEngine, SrcLayout> const& src,
132+
Tensor<DstEngine, DstLayout> && dst)
133+
{
134+
return copy_block_r2s(src, dst);
135+
}
136+
137+
template <class SrcEngine, class SrcLayout,
138+
class DstEngine, class DstLayout>
139+
CUTE_HOST_DEVICE
140+
void
141+
copy_block_s2r(Tensor<SrcEngine, SrcLayout> const& src,
142+
Tensor<DstEngine, DstLayout> && dst)
143+
{
144+
return copy_block_s2r(src, dst);
145+
}
146+
147+
template <class SrcEngine, class SrcLayout, class SrcCoordLayout,
148+
class DstEngine, class DstLayout, class DstCoordLayout>
149+
CUTE_HOST_DEVICE
150+
void
151+
copy_block_r2s(SubgroupTensor<SrcEngine, SrcLayout, SrcCoordLayout> const& src,
152+
Tensor<DstEngine, DstLayout> && dst,
153+
DstCoordLayout const& dst_c)
154+
{
155+
return copy_block_r2s(src, dst, dst_c);
156+
}
157+
158+
template <class SrcEngine, class SrcLayout, class SrcCoordLayout,
159+
class DstEngine, class DstLayout, class DstCoordLayout>
160+
CUTE_HOST_DEVICE
161+
void
162+
copy_block_s2r(Tensor<SrcEngine, SrcLayout> const& src,
163+
SrcCoordLayout const& src_c,
164+
SubgroupTensor<DstEngine, DstLayout, DstCoordLayout> && dst)
165+
{
166+
return copy_block_s2r(src, dst);
167+
}
168+
169+
} /* namespace cute */

0 commit comments

Comments
 (0)