Replies: 1 comment 1 reply
-
Thanks for putting this together @krzysz00! Just to confirm, |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
What we're getting rid of
The old fusion flow leant heavily on Tosa and Linalg fusion operations to create complicated
linalg.generic
operations that we'd then fix up so they could be aligned with our GEMM's structure. This had evolved into a complicated and brittle pile of hacks, and we ( @sjw36 and @krzysz00 ) realized there was a better way after a whole bunch of iterations.So, therefore, the new fusion system.
Overview
tosa.reshape
andtosa.transpose
into corresponding calls torock.transform
. This moves more of the kernel into our map language earlier in the flow, and protects against general-purpose rewrites that we'd need to undo later.FoldTranspose
andCopyOpt
) and replace them with a new scheme. This gets rid of a lot of special-case rewrites that broke very easily and created a lot of bugs. The main component of this new fusion system will be the (tentatively named) "global alignment" pass, which aims to rearrange all the code to create something that's easier to fuse with. This alignment pass goes right before--rock-gridwise-gemm-to-blockwise
rock.threadwise_write_all
, allowing us to defer generating the actual store loop until after fusion took place.A running example will follow after the general overview.
Global regularization
The goal of the regularization pass consists of a pair of transformations: pulling views onto writers and pushing views onto readers, which run in that order.
Pulling writes onto writers
The goal of pulling views onto writers is to transform code like this
with this
Note that this transformation doesn't change what type any of the computations are operating on.
The reason we need this rewrite is that it turns a general graph of views applied to the "main result" (which starts at the gemm and goes through any generics and copies) into a linear sequence of transformations applied when an operation writes to a new temporary. This will make it possible to do the next normalization step, which is pushing the transforms of a writer to its readers.
In other words, the problem with the pre-transformation code is that it's hard to generate the sequence of functions from gemm coordinates to the final output buffer (which is the only one we'll actually be writing to), if you have these intermediate views being applied when the gemm's output is being read.
Or, for another perspective, views applied to
%tmp
that that don't participate in the GEMM coordinates -> buffer transformation stack are ones that can't be applied to the values each thread holds in registers in any consistent way.Pushing writers' transforms to readers
The purpose of this transformation is to ensure consistency and type-safety while we fuse. In general, during a fusion, we'll be replacing the sequence "write gemm results (through view A) to temporary; compute function on that temporary; write those results (through view B) to the next buffer" with "compute function on gemm results, write gemm results to next buffer through view (A and then B)". This only works if we can transparently replace the temporary output with the memref that the fused function (or
memref.copy
) was outputting to. And, since gridwise gemm has a map from (bid, tid, register #) to GEMM coordinates,, we need all the new write results to be viewed with GEMM coordinates.In addition, when doing a fusion, we may encounter additional arguments. We need to load the values from those arguments that would be paired with each result in the temporary buffer. However, when we're accessing the results in each thread's registers, all we have is a way to go from (bid, tid, register#) to GEMM coordinates. This means that, when we have additional arguments, we need to put views on them to make them have GEMM coordinates.
As an example
needs to be rewritten to
Then, if the output of
computation_2
is used by some other operation (another elementwise function, amemref.copy
, or so on), we take the transformations on%next_view_push
and replicate them onto its reader is the same way.Fusion
After these transformations, we run
-rock-gridwise-gemm-to-blockwise
and we will have created arock.thrieadwise_write_all
markerThe fusion pass itself then will, if the target of a
rock.threadwise_write_all [map](%src) -> %dest
is a temporary buffer - that is,%dest = f(%buf)
(wheref
is a transform sequence), do the following, and
memref.copy %dest to %out_view, replace
%destwith
%out_view`la.generic(f(%buf), g(%otherArg)) -> %nextDest
doing an elementwise operation, we will%otherArg
, loadinglen(%src)
values from%otherArg
into registers%other_arg_reg
usingmap ; g
as the function(bid, tid, register#) -> (other_arg_coords)
%fusion_result_reg
linalg.generic
to operate on%src, %other_arg_reg
and write to%fusion_result_reg
rock.threadwise_write_all
after thelinalg.generic
and replace%src
with%fusion_result_reg
and%dest
with%nextDest
We will repeat this process until there are no more
linalg.generic
ormemref.copy
operations to fold in.Running example
This example aims to show off most of the complexities of these rewrites
We begin with
%arg2 : memref<8> // intermediate
%arg3 : memref<64x32> // output
%tmp1 = memref.alloc() : memref<32x64>
%c = rock.transform %tmp1 by add_batch : memref<32x64> to memref<1x32x64>
rock.gridwise_gemm %c = ... : memref<1x32x64> = ...
%tmp1_split = rock.transform %tmp1 by reshape : memref<32x64> to memref<32x8x8>
%arg2_bcast = rock.transform %arg2 by broadcast : memref<8> to memref<32x8x8>
%tmp2 = rock.alloc() : memref<32x8x8>
la.generic (%tmp1_split, %arg2_bcast) -> %tmp2 { elementwise add } : memref<32x8x8>, memref<32x8x8>, memref<32x8x8>
%tmp2_concat = rock.transform %tmp2 by concat : memref<32x8x8> to memref<32x64>
%tmp2_tr = rock.transform %tmp2_concat by transpose : memref<32x64> to memref<64x32>
memref.copy %tmp2_tr -> %arg3 : memref<64x32>
Pull readers' transforms to writers
Now, we will be applying the "pull up" rewrite, also know as "pull reader views to writer" or what have you
I'll show each application separately, and these can happen in either order since they're rather local
First, on the memref copy
%arg2 : memref<8> // intermediate
%arg3 : memref<64x32> // output
%tmp1 = memref.alloc() : memref<32x64>
%c = rock.transform %tmp1 by add_batch : memref<32x64> to memref<1x32x64>
rock.gridwise_gemm %c = ... : memref<1x32x64> = ...
%tmp1_split = rock.transform %tmp1 by reshape : memref<32x64> to memref<32x8x8>
%arg2_bcast = rock.transform %arg2 by broadcast : memref<8> to memref<32x8x8>
%tmp2 = rock.alloc() : memref<64x32>
%tmp2_tr_inv = rock.transform %tmp2 by inv(transpose) : memref<64x32> to memref<32x64>
%tmp2_concat_inv = rock.transform %tmp2_tr_inv by inv(concat) : memref<32x64> to memref<32x8x8>
la.generic (%tmp1_split, %arg2_bcast) -> %tmp2_concat_inv { elementwise add } : memref<32x8x8>, memref<32x8x8>, memref<32x8x8>
memref.copy %tmp2 -> %arg3 : memref<64x32>
And then on the addition
%arg2 : memref<8> // intermediate ,
%arg3 : memref<64x32> // output
%tmp1 = memref.alloc() : memref<32x8x8>
%tmp1_split_inv = rock.transform %tmp1 by inv(reshape) : memref<32x8x8> to memref<32x64>
%c = rock.transform %tmp1_split_inv by add_batch : memref<32x64> to memref<1x32x64>
rock.gridwise_gemm %c = ... : memref<1x32x64> = ...
%arg2_bcast = rock.transform %arg2 by broadcast : memref<8> to memref<32x8x8>
%tmp2 = rock.alloc() : memref<64x32>
%tmp2_tr_inv = rock.transform %tmp2 by inv(transpose) : memref<64x32> to memref<32x64>
%tmp2_concat_inv = rock.transform %tmp2_tr_inv by inv(concat) : memref<32x64> to memref<32x8x8>
la.generic (%tmp1, %arg2_bcast) -> %tmp2_concat_inv { elementwise add } : memref<32x8x8>, memref<32x8x8>, memref<32x8x8>
memref.copy %tmp2 -> %arg3 : memref<64x32>
Pushing writers' transforms to readers
And now the push down step. For this, we will need to work top to bottom, starting at the gemm (and, as a side note, for stuff like this, it can make sense not to use the pattern infrastructure, a lot of passes upstream like, say, the loop unroller, don't)
Starting at the previous code, we first move the gemm's transforms to its readers.
%arg2 : memref<8> // intermediate
%arg3 : memref<64x32> // output
%tmp1 = memref.alloc() : memref<32x8x8>
%tmp1_split_inv = rock.transform %tmp1 by inv(reshape) : memref<32x8x8> to memref<32x64>
%c = rock.transform %tmp1_split_inv by add_batch : memref<32x64> to memref<1x32x64>
rock.gridwise_gemm %c = ... : memref<1x32x64> = ...
// Note, I've CSE'd the transform stack on %tmp1
%arg2_bcast = rock.transform %arg2 by broadcast : memref<8> to memref<32x8x8>
%arg2_split_inv = rock.transform %arg2_bcast by inv(reshape) : memref<32x8x8> to memref<32x64>
%arg2_add_batch = rock.transform %arg2_split_inv by add_batch : memref<32x64> to memref<1x32x64>
%tmp2 = rock.alloc() : memref<64x32>
%tmp2_tr_inv = rock.transform %tmp2 by inv(transpose) : memref<64x32> to memref<32x64>
%tmp2_concat_inv = rock.transform %tmp2_tr_inv by inv(concat) : memref<32x64> to memref<32x8x8>
%tmp2_split_inv = rock.transform %tmp2_concat_inv by inv(reshape) : memref<32x8x8> to memref<32x64>
%tmp2_add_batch = rock.transform %tmp2_split_inv by add_batch : memref<32x64> to memref<1x32x64>
la.generic (%c, %arg2_add_batch) -> %tmp2_add_batch { elementwise add } : memref<1x32x64>, memref<1x32x64>, memref<1x32x64>
memref.copy %tmp2 -> %arg3 : memref<64x32>
And then we will propagate the
la.generic
's output's transforms to thememref.copy
%arg2 : memref<8> // intermediate
%arg3 : memref<64x32> // output
%tmp1 = memref.alloc() : memref<32x8x8>
%tmp1_split_inv = rock.transform %tmp1 by inv(reshape) : memref<32x8x8> to memref<32x64>
%c = rock.transform %tmp1_split_inv by add_batch : memref<32x64> to memref<1x32x64>
rock.gemm %c = ... : memref<1x32x64> = ...
// Note, I've CSE'd the transform stack on %tmp1
%arg2_bcast = rock.transform %arg2 by broadcast : memref<8> to memref<32x8x8>
%arg2_split_inv = rock.transform %arg2_bcast by inv(reshape) : memref<32x8x8> to memref<32x64>
%arg2_add_batch = rock.transform %arg2_split_inv by add_batch : memref<32x64> to memref<1x32x64>
%tmp2 = rock.alloc() : memref<64x32>
%tmp2_tr_inv = rock.transform %tmp2 by inv(transpose) : memref<64x32> to memref<32x64>
%tmp2_concat_inv = rock.transform %tmp2_tr_inv by inv(concat) : memref<32x64> to memref<32x8x8>
%tmp2_split_inv = rock.transform %tmp2_concat_inv by inv(reshape) : memref<32x8x8> to memref<32x64>
%tmp2_add_batch = rock.transform %tmp2_split_inv by add_batch : memref<32x64> to memref<1x32x64>
la.generic (%c, %arg2_add_batch) -> %tmp2_add_batch { elementwise add } : memref<1x32x64>, memref<1x32x64>, memref<1x32x64>
%arg3_tr_inv = rock.transform %arg3 by inv(transpose) : memref<64x32> to memref<32x64>
%arg3_concat_inv = rock.transform %arg3_tr_inv by inv(concat) : memref<32x64> to memref<32x8x8>
%arg3_split_inv = rock.transform %arg3_concat_inv by inv(reshape) : memref<32x8x8> to memref<32x64>
%arg3_add_batch = rock.transform %arg3_split_inv by add_batch : memref<32x64> to memref<1x32x64>
memref.copy %tmp2_add_batch -> %arg3_add_batch : memref<1x32x64>
Note that now the types of all operations in the chain from the GEMMto the
memref copy
are aligned to GEMM space. This means that, when we want to load extra arguments, we will be able to use GEMM coordinates as a starting point, and that we don't have to worry about needing to rearrange our in-registers values as we do elementwise things to themFusion
With that running example, the fusion rewrites, which happen after gridwise-gemm-to-blockwise, go like this.
We start with the code after we've added a
rock.threadwise_write_all
%arg2 : memref<8> // intermediate
%arg3 : memref<64x32> // output
%tmp1 = memref.alloc() : memref<32x8x8>
%tmp1_split_inv = rock.transform %tmp1 by inv(reshape) : memref<32x8x8> to memref<32x64>
%c = rock.transform %tmp1_split_inv by add_batch : memref<32x64> to memref<1x32x64>
%c_reg = alloca() : memref<16, 5>
... // the actual gemm
rock.threadwise_write_all id2mat -> %c : memref<16, 5> -> memref<1x32x64>
%arg2_bcast = rock.transform %arg2 by broadcast : memref<8> to memref<32x8x8>
%arg2_split_inv = rock.transform %arg2_bcast by inv(reshape) : memref<32x8x8> to memref<32x64>
%arg2_add_batch = rock.transform %arg2_split_inv by add_batch : memref<32x64> to memref<1x32x64>
%tmp2 = rock.alloc() : memref<64x32>
%tmp2_tr_inv = rock.transform %tmp2 by inv(transpose) : memref<64x32> to memref<32x64>
%tmp2_concat_inv = rock.transform %tmp2_tr_inv by inv(concat) : memref<32x64> to memref<32x8x8>
%tmp2_split_inv = rock.transform %tmp2_concat_inv by inv(reshape) : memref<32x8x8> to memref<32x64>
%tmp2_add_batch = rock.transform %tmp2_split_inv by add_batch : memref<32x64> to memref<1x32x64>
la.generic (%c, %arg2_add_batch) -> %tmp2_add_batch { elementwise add } : memref<1x32x64>, memref<1x32x64>, memref<1x32x64>
%arg3_tr_inv = rock.transform %arg3 by inv(transpose) : memref<64x32> to memref<32x64>
%arg3_concat_inv = rock.transform %arg3_tr_inv by inv(concat) : memref<32x64> to memref<32x8x8>
%arg3_split_inv = rock.transform %arg3_concat_inv by inv(reshape) : memref<32x8x8> to memref<32x64>
%arg3_add_batch = rock.transform %arg3_split_inv by add_batch : memref<32x64> to memref<1x32x64>
memref.copy %tmp2_add_batch -> %arg3_add_batch : memref<1x32x64>
Then, we fuse in the linalg.generic, adding a load loop
%arg2 : memref<8> // intermediate
%arg3 : memref<64x32> // output
%tmp1 = memref.alloc() : memref<32x8x8>%tmp1_split_inv = rock.transform %tmp1 by inv(reshape) : memref<32x8x8> to memref<32x64>
%c = rock.transform %tmp1_split_inv by add_batch : memref<32x64> to memref<1x32x64>
%c_reg = alloca() : memref<16, 5>
... // the actual gemm
%arg2_reg = alloca() : memref<16, 5>
transforming_for (%mem) = [id2mat, add_batch, inv(reshape), broadcast](%bid, %tid, %c0), (%, %, %i) = [](%c0, %c0, %c0) bounds = [1, 1, 16], strides = [1, 1, V_arg2] { // or equivalent register access transforms
%val = rock.global_load {...OOB...} %arg2[%mem] : memref<8> -> vector
rock.insert_slice %val, %arg2_reg[i] : vector -> memref<16, 5>
}
%tmp1_reg = alloca() : memref<16, 5>
%arg2_bcast = rock.transform %arg2 by broadcast : memref<8> to memref<32x8x8>%arg2_split_inv = rock.transform %arg2_bcast by inv(reshape) : memref<32x8x8> to memref<32x64>
%arg2_add_batch = rock.transform %arg2_split_inv by add_batch : memref<32x64> to memref<1x32x64>
%tmp2 = rock.alloc() : memref<64x32>
%tmp2_tr_inv = rock.transform %tmp2 by inv(transpose) : memref<64x32> to memref<32x64>
%tmp2_concat_inv = rock.transform %tmp2_tr_inv by inv(concat) : memref<32x64> to memref<32x8x8>
%tmp2_split_inv = rock.transform %tmp2_concat_inv by inv(reshape) : memref<32x8x8> to memref<32x64>
%tmp2_add_batch = rock.transform %tmp2_split_inv by add_batch : memref<32x64> to memref<1x32x64>
la.generic (%c_reg, %arg2_reg) -> %tmp1_reg { elementwise add } : memref<16, 5>, memref<16, 5>, memref<16, 5>
// Moved after the fusion, or, equivalently, the fusion was moved before it
rock.threadwise_write_all id2mat -> %tmp2_add_batch : memref<16, 5> -> memref<1x32x64> // Note, we have replaced %tmp1 with the result of the thing we fused in, namely %tmp2_add_batch
%arg3_tr_inv = rock.transform %arg3 by inv(transpose) : memref<64x32> to memref<32x64>
%arg3_concat_inv = rock.transform %arg3_tr_inv by inv(concat) : memref<32x64> to memref<32x8x8>
%arg3_split_inv = rock.transform %arg3_concat_inv by inv(reshape) : memref<32x8x8> to memref<32x64>
%arg3_add_batch = rock.transform %arg3_split_inv by add_batch : memref<32x64> to memref<1x32x64>
memref.copy %tmp2_add_batch -> %arg3_add_batch : memref<1x32x64>
And then we remove the copy by observing that we're copying f(tmp) to f(y), which can be replaced by writing to f(y) directly
%arg2 : memref<8> // intermediate
%arg3 : memref<64x32> // output
%c_reg = alloca() : memref<16, 5>
... // the actual gemm
%arg2_reg = alloca() : memref<16, 5>
transforming_for (%mem) = [id2mat, add_batch, inv(reshape), broadcast](%bid, %tid, %c0), (%, %, %i) = [](%c0, %c0, %c0) bounds = [1, 1, 16], strides = [1, 1, V_arg2] { // or equivalent register access transforms
%val = rock.global_load {...OOB...} %arg2[%mem] : memref<8> -> vector
rock.insert_slice %val, %arg2_reg[i] : vector -> memref<16, 5>
}
%tmp1_reg = alloca() : memref<16, 5>
%tmp2 = rock.alloc() : memref<64x32>%tmp2_tr_inv = rock.transform %tmp2 by inv(transpose) : memref<64x32> to memref<32x64>
%tmp2_concat_inv = rock.transform %tmp2_tr_inv by inv(concat) : memref<32x64> to memref<32x8x8>
%tmp2_split_inv = rock.transform %tmp2_concat_inv by inv(reshape) : memref<32x8x8> to memref<32x64>
%tmp2_add_batch = rock.transform %tmp2_split_inv by add_batch : memref<32x64> to memref<1x32x64>
la.generic(%c_reg, %arg2_reg) -> %tmp1_reg : memref<16, 5>, memref<16, 5> memref<16, 5>
%arg3_tr_inv = rock.transform %arg3 by inv(transpose) : memref<64x32> to memref<32x64>
%arg3_concat_inv = rock.transform %arg3_tr_inv by inv(concat) : memref<32x64> to memref<32x8x8>
%arg3_split_inv = rock.transform %arg3_concat_inv by inv(reshape) : memref<32x8x8> to memref<32x64>
%arg3_add_batch = rock.transform %arg3_split_inv by add_batch : memref<32x64> to memref<1x32x64>
rock.threadwise_write_all id2mat -> %arg3_add_batch : memref<16, 5> -> memref<1x32x64>
memref.copy %tmp2_add_batch -> %arg3_add_batch : memref<1x32x64>Note that this pass was just taking
la.generic
s and moving them above the writeback, adding load loops for extra arguments as needed (which was easy because we already had a full transform stack we could work with applied to said arguments) and optimizing out copiesSummary / let's have one more section header to break things up
In summary, we're replacing a bunch of hacks by actively impeding the rewrites the linalg dialect does and then using that lack of interfierence to make a rather conceptually clean fusion system.
P.S. Impacts on other code
Because
rock.threadwise_write_all
moves the vectorization decision further down the lowering pipeline, the impacts from that rewrite is minimal.However,
gridwise_gemm_v2
still needs to do a full vectorization analysis in order to determine if we should swizzle. To do that, we will introduce afind_final_output(Value)
that traces through thisalloc
+ computation chain in order to get the full transform stack off the terminatingmemref.copy
,linalg.generic
or whatever comes last.Beta Was this translation helpful? Give feedback.
All reactions