-
Notifications
You must be signed in to change notification settings - Fork 74
WIP: ScanOp #4211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
WIP: ScanOp #4211
Conversation
|
Review updated until commit 302015c Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
Still need exclusive+inclusive scan to get that all the way implemented
There is still an overallocation problem, but it's getting closer to optimal.
| auto index_info = computeIndex(expr, indexed_ids, for_loops); | ||
| for (const auto& [indexed_id, index] : override_index) { | ||
| index_info.index_map.emplace(traversalGraph().toGroup(indexed_id), index); | ||
| index_info.index_map[traversalGraph().toGroup(indexed_id)] = index; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is a bug since emplace will simply fail and return false if there is a pre-existing entry, whereas we need to update the index here instead.
| std::unordered_map<IterDomain*, Val*> override_index_ids; | ||
| for (auto& [pos, idx] : override_index) { | ||
| override_index_ids.emplace( | ||
| consumer->getMaybeAllocationDomain().at((size_t)pos), idx); | ||
| } | ||
| index = GpuLower::current()->tensorIndexer().getLinearIndex( | ||
| consumer, consumer->definition(), loops); | ||
| consumer, consumer->definition(), loops, override_index_ids); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is necessary for getConsumerIndex to respect override_index.
| // We don't inline the scans past the scan dimension | ||
| std::unordered_set<IterDomain*> uninlineable_ids; | ||
| for (TensorView* tv : {m, m_prev, denoms}) { | ||
| for (IterDomain* id : tv->getLoopDomain()) { | ||
| uninlineable_ids.insert(id); | ||
| } | ||
| } | ||
|
|
||
| inlineMost(uninlineable_ids); | ||
|
|
||
| // These TVs are not inlined, but instead we set computeWith on them | ||
| for (TensorView* tv : {m, m_prev, denoms}) { | ||
| tv->computeWith(-1); | ||
| for (Val* v : tv->definition()->inputs()) { | ||
| v->as<TensorView>()->computeWith(-1); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scheduling inlining with scan
The above schedule results in this fusion IR:
Inputs:
T0_g_float[iS0{32}]
Outputs:
T8_g_float[]
%kernel {
T9_l_float[iS11{32}] compute_with( 1 )
= Set( T0_g_float[iS0{32}], cache_op=Streaming )
T1_l_float[cS2{32}] compute_with( 1 ) maybe_produce_pos( 1 ),
T2_l_float[cS3{32}] compute_with( 1 ) maybe_produce_pos( 1 )
= scan(fmax,
T9_l_float[iS11{32}] compute_with( 1 ),
dim=0,
init=double(-inf))
T3_l_float[iS4{32}] ca_pos( 1 ) maybe_produce_pos( 1 )
= T9_l_float[iS11{32}] compute_with( 1 )
- T1_l_float[cS2{32}] compute_with( 1 ) maybe_produce_pos( 1 );
T4_l_float[iS5{32}] compute_with( 1 ) produce_pos( 1 )
= expf(T3_l_float[iS4{32}] ca_pos( 1 ) maybe_produce_pos( 1 ));
T5_l_float[iS6{32}] ca_pos( 1 ) maybe_produce_pos( 1 )
= T2_l_float[cS3{32}] compute_with( 1 ) maybe_produce_pos( 1 )
- T1_l_float[cS2{32}] compute_with( 1 ) maybe_produce_pos( 1 );
T6_l_float[iS7{32}] compute_with( 1 ) produce_pos( 1 )
= expf(T5_l_float[iS6{32}] ca_pos( 1 ) maybe_produce_pos( 1 ));
T7_l_float[cS9{32}] compute_with( 1 ) maybe_produce_pos( 1 )
= scan(add,
T4_l_float[iS5{32}] compute_with( 1 ) produce_pos( 1 ),
dim=0,
discount_factor=T6_l_float[iS7{32}] compute_with( 1 ) produce_pos( 1 ),
init=float(0))
T10_l_float[rS10{32}] maybe_produce_pos( 1 )
= reduction( T7_l_float[cS9{32}] compute_with( 1 ) maybe_produce_pos( 1 ), op = rhs, initial value = float(0), allreduce = false )
T8_g_float[]
= Set( T10_l_float[rS10{32}] maybe_produce_pos( 1 ), cache_op=Streaming )
and the following kernel:
// Codegen generated code
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 1, 1> T0, Tensor<float, 0, 0> T8) {
NVFUSER_DEFINE_MAGIC_ZERO;
Array<float, 32, 1> T9;
#pragma unroll
for(nvfuser_index_t i0 = 0; i0 < 32; ++i0) {
T9[i0] = 0;
}
NVFUSER_UPDATE_MAGIC_ZERO;
Array<float, 1, 1> T1;
Array<float, 1, 1> T2;
Array<float, 32, 1> T6;
Array<float, 32, 1> T4;
Array<float, 1, 1> T7;
Array<float, 1, 1> T10;
T10[0] = 0.000000000e+00f;
#pragma unroll
for(nvfuser_index_t i0 = 0; i0 < 32; ++i0) {
T9[i0]
= T0[(T0.alloc_stride[0LL] * (i0 + nvfuser_zero))];
T2[0] = ((i0 > 0) ? (T1[0]) : NEG_INFINITY);
T1[0] = fmax(
((i0 > 0) ? (T1[0]) : NEG_INFINITY),
(T9[i0]));
Array<float, 1, 1> T5;
T5[0]
= T2[0]
- T1[0];
T6[i0]
= expf(T5[0]);
Array<float, 1, 1> T3;
T3[0]
= T9[i0]
- T1[0];
T4[i0]
= expf(T3[0]);
T7[0]
= (T6[i0] * ((i0 > 0) ? (T7[0]) : 0.000000000e+00f))
+ (T4[i0]);
T10[0] = rhs(
T10[0],
T7[0]);
}
NVFUSER_UPDATE_MAGIC_ZERO;
T8[0]
= T10[0];
}I don't yet understand how I can get T9, T4, and T6 to be allocated with size 1 instead of 32. The ID model maps all IDs together in the exact map but not in the loop map (before resolveComputeWith):
IterDomainGraphs {
IdGraph exact{
Disjoint Ids:
(idgs){
idg{0 2 3 4 5 6 7 9 10 11}
}
Disjoint Expression groups:
(exprgs){
}
} IdGraph
IdGraph almost_exact{
Disjoint Ids:
(idgs){
idg{0 2 3 4 5 6 7 9 10 11}
}
Disjoint Expression groups:
(exprgs){
}
} IdGraph
IdGraph broadcast{
Disjoint Ids:
(idgs){
idg{0 2 3 4 5 6 7 9 10 11}
}
Disjoint Expression groups:
(exprgs){
}
} IdGraph
IdGraph permissive{
Disjoint Ids:
(idgs){
idg{0 2 3 4 5 6 7 9 10 11}
}
Disjoint Expression groups:
(exprgs){
}
} IdGraph
IdGraph loop{
Disjoint Ids:
(idgs){
idg{0}
idg{2 3}
idg{4 5}
idg{6 7}
idg{9}
idg{10}
idg{11}
}
Disjoint Expression groups:
(exprgs){
}
} IdGraph
} IterDomainGraphs
I am wondering:
why the loop graph is finer than exactthis is just because ca_pos is 0 for a lot of these tensors due touninlineable_ids.- how can I control inlining such that producers of a scan are inlined past the scan dimension so that I don't need to use computeWith on them. I think
uninlineable_idsprevents inlining with both consumers and producers but I'd like to still inline the producers with it. - Whether there is a better way to force allocation at a higher scope but still inline the expression
- What exactly is "storeAt" referred to in the code comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't yet understand how I can get T9, T4, and T6 to be allocated with size 1 instead of 32.
Because their computeAt positions are all 0, they are allocated entirely. The computeWith position only changes the loop where the expression appears.
why the loop graph is finer than exact
Loop only maps when inlined, so by definition, it should be finer.
how can I control inlining such that producers of a scan are inlined past the scan dimension so that I don't need to use computeWith on them.
I'm not sure as I'm not sure how the scan is implemented, but if it's like reduction, don't we need to update MaxPosInliner?
Whether there is another way to force allocation at a higher scope but still inline the expression
This is exactly what computeWith is supposed to address.
What exactly is "storeAt" referred to in the code comments
Probably the same concept as store_at in Halide.
Currently doesn't inline two of the loops properly.
This test runs. I'm not sure if it's correct yet, but there is one main loop so that's nice
| // | ||
| // Dao et al. 2022. FlashAttention: Fast and Memory-Efficient Exact Attention | ||
| // with IO-Awareness. https://arxiv.org/abs/2205.14135 | ||
| TEST_F(ScanTest, FlashAttentionNoMma) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As of deaff1f this generates the following code:
// Codegen generated code
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 8, 8> T0, Tensor<float, 8, 8> T1, Tensor<float, 8, 8> T2, Tensor<float, 8, 8> T31) {
NVFUSER_DEFINE_MAGIC_ZERO;
nvfuser_index_t i0;
i0 = T0.alloc_stride[0LL] * ((nvfuser_index_t)blockIdx.x);
nvfuser_index_t i1;
i1 = T2.alloc_stride[3LL] * ((nvfuser_index_t)blockIdx.y);
nvfuser_index_t i2;
i2 = (27 * ((nvfuser_index_t)blockIdx.y)) + (216 * ((nvfuser_index_t)blockIdx.x));
Array<float, 12, 1> T8;
Array<float, 12, 1> T9;
Array<float, 12, 1> T19;
Array<float, 12, 1> T20;
Array<float, 108, 1> T28;
Array<float, 27, 1> T30;
#pragma unroll
for(nvfuser_index_t i3 = 0; i3 < 3; ++i3) {
nvfuser_index_t i4;
i4 = 9 * i3;
#pragma unroll
for(nvfuser_index_t i5 = 0; i5 < 9; ++i5) {
T30[(i4 + i5)] = 0.000000000e+00f;
}
}
NVFUSER_UPDATE_MAGIC_ZERO;
#pragma unroll
for(nvfuser_index_t i6 = 0; i6 < 4; ++i6) {
nvfuser_index_t i7;
i7 = T1.alloc_stride[1LL] * i6;
nvfuser_index_t i8;
i8 = 3 * i6;
nvfuser_index_t i9;
i9 = i1 + (T2.alloc_stride[1LL] * i6);
nvfuser_index_t i10;
i10 = 27 * i6;
#pragma unroll
for(nvfuser_index_t i11 = 0; i11 < 6; ++i11) {
nvfuser_index_t i12;
i12 = i7 + (T1.alloc_stride[2LL] * i11);
nvfuser_index_t i13;
i13 = i0 + (T0.alloc_stride[2LL] * i11);
#pragma unroll
for(nvfuser_index_t i3 = 0; i3 < 3; ++i3) {
nvfuser_index_t i14;
i14 = i13 + (T0.alloc_stride[4LL] * i3);
nvfuser_index_t i15;
i15 = i8 + i3;
nvfuser_index_t i16;
i16 = 9 * i3;
nvfuser_index_t i17;
i17 = i10 + i16;
Array<float, 5, 1> T5;
#pragma unroll
for(nvfuser_index_t i18 = 0; i18 < 5; ++i18) {
nvfuser_index_t i19;
i19 = i12 + (T1.alloc_stride[5LL] * i18);
Array<float, 1, 1> T4;
T4[0] = 0.000000000e+00f;
#pragma unroll
for(nvfuser_index_t i20 = 0; i20 < 7; ++i20) {
nvfuser_index_t i21;
i21 = i20 + nvfuser_zero;
Array<float, 1, 1> T33;
T33[0] = 0;
T33[0]
= T1[(i19 + (T1.alloc_stride[6LL] * i21))];
Array<float, 1, 1> T32;
T32[0] = 0;
T32[0]
= T0[(i14 + (T0.alloc_stride[6LL] * i21))];
Array<float, 1, 1> T3;
T3[0]
= T32[0]
* T33[0];
T4[0]
= T4[0]
+ T3[0];
}
T5[i18]
= T4[0];
}
Array<float, 1, 1> T6;
T6[0] = NEG_INFINITY;
#pragma unroll
for(nvfuser_index_t i22 = 0; i22 < 5; ++i22) {
T6[0] = fmax(
T6[0],
T5[i22]);
}
Array<float, 1, 1> T7;
T7[0]
= T6[0];
T9[i15] = ((i11 > 0) ? (T8[i15]) : NEG_INFINITY);
T8[i15] = fmax(
T9[i15],
(T7[0]));
Array<float, 1, 1> T14;
T14[0]
= T9[i15]
- T8[i15];
Array<float, 1, 1> T15;
T15[0]
= expf(T14[0]);
Array<float, 1, 1> T12;
T12[0] = 0.000000000e+00f;
Array<float, 9, 1> T24;
#pragma unroll
for(nvfuser_index_t i23 = 0; i23 < 9; ++i23) {
T24[i23] = 0.000000000e+00f;
}
#pragma unroll
for(nvfuser_index_t i24 = 0; i24 < 5; ++i24) {
nvfuser_index_t i25;
i25 = i9 + (T2.alloc_stride[5LL] * i24);
Array<float, 1, 1> T10;
T10[0]
= T5[i24]
- T7[0];
Array<float, 1, 1> T11;
T11[0]
= expf(T10[0]);
T12[0]
= T12[0]
+ T11[0];
#pragma unroll
for(nvfuser_index_t i23 = 0; i23 < 9; ++i23) {
Array<float, 1, 1> T34;
T34[0] = 0;
T34[0]
= T2[(i25 + (T2.alloc_stride[7LL] * (i23 + nvfuser_zero)))];
Array<float, 1, 1> T23;
T23[0]
= T11[0]
* T34[0];
T24[i23]
= T24[i23]
+ T23[0];
}
}
Array<float, 1, 1> T13;
T13[0]
= T12[0];
Array<float, 1, 1> T16;
T16[0]
= T7[0]
- T8[i15];
Array<float, 1, 1> T17;
T17[0]
= expf(T16[0]);
Array<float, 1, 1> T18;
T18[0]
= T17[0]
* T13[0];
T20[i15] = ((i11 > 0) ? (T19[i15]) : 0.000000000e+00f);
T19[i15]
= (T15[0] * T20[i15])
+ (T18[0]);
Array<float, 1, 1> T21;
T21[0]
= T20[i15]
/ T19[i15];
Array<float, 1, 1> T22;
T22[0]
= T21[0]
* T15[0];
Array<float, 1, 1> T26;
T26[0]
= T17[0]
/ T19[i15];
#pragma unroll
for(nvfuser_index_t i5 = 0; i5 < 9; ++i5) {
nvfuser_index_t i26;
i26 = i17 + i5;
Array<float, 1, 1> T25;
T25[0]
= T24[i5];
Array<float, 1, 1> T27;
T27[0]
= T26[0]
* T25[0];
T28[i26]
= (T22[0] * ((i11 > 0) ? (T28[i26]) : 0.000000000e+00f))
+ (T27[0]);
Array<float, 1, 1> T29;
T29[0]
= T28[i26];
T30[(i16 + i5)] = rhs(
T30[(i16 + i5)],
T29[0]);
}
}
}
}
NVFUSER_UPDATE_MAGIC_ZERO;
#pragma unroll
for(nvfuser_index_t i27 = 0; i27 < 3; ++i27) {
nvfuser_index_t i28;
i28 = 9 * i27;
nvfuser_index_t i29;
i29 = i2 + i28;
#pragma unroll
for(nvfuser_index_t i30 = 0; i30 < 9; ++i30) {
Array<float, 1, 1> T35;
T35[0]
= T30[(i28 + i30)];
T31[(i29 + (i30 + nvfuser_zero))]
= T35[0];
}
}
NVFUSER_UPDATE_MAGIC_ZERO;
}The goal here is to approximate the original FlashAttention pattern:

I am not sure if this is correct yet (it might be too far inlined in the main loop). But it does seem to compute something without requiring multiple top-level loops so that is encouraging...
Note that the original algorithm calls for caching l, m, and O to gmem between outer loop iterations, which I don't currently know how best to do. I suppose we should just set their MemoryTypes and cache them before using in downstream expressions. Avoiding redundant gmem use for the output O is a challenge here but might be enabled if we can get ScanOp to properly return a reduction tensor (I've had trouble with scheduling/lowering such an approach so far but the interface is there already).
ScanTest.OnlineSoftmaxcurrently passes and generates this code (as of db5798c)This is the softmax denominator only.
Next I will:
Fix inlining of intermediates as mentioned aboveAdd validation