[Backend][AMD] Introduce stream pipeliner v2#4148
[Backend][AMD] Introduce stream pipeliner v2#4148pawelszczerbuk merged 38 commits intotriton-lang:mainfrom
Conversation
|
Thanks for doing that, when you are ready please have @pawelszczerbuk review it :) |
|
Thanks @sjw36! Some high level comments before reviewing detailed implementation--can we have a separate pull request for the NFC moving passes? Basically commit e0bd4d8. It's easier to review that way and if later we need to revert the changes to AMD part for whatever reason, we also don't need to revert the NFC code shuffling. |
c678fd8 to
f586572
Compare
|
@ThomasRaoux: @sjw36 and I chatted a bit offline. This pull request is great at showing the global picture. But we want to break it into smaller pieces to make it easier for review and restructure a bit. Overall the direction is increase reuse without abstracting too much; so we will expose some useful functions like |
69b4536 to
9517277
Compare
…structure
- Copied scheduler from MatmulLoopPipeline (much could be consolidated)
- Enable register buffering (even though may increases register pressure)
- Enable num_stages=2+, including multi-buffering, and make `2` the default
- updated tutorial for new tuning default
- added lit tests
- Also move independent(from loop-carried buffer) `triton_gpu.local_store` as early as possible
- check for last atomic (sync?) - also check for other accesses to the source
…replaced with loop fusion * Reorder will not move loads/local_stores over loops
4eeb8cc to
faf95cb
Compare
* Added TRITONAMD_OLD_STREAM_PIPELINER env variable to temporarily select old pipeliner
* update test
|
|
||
| // Create a cluster for the prefetches. It may end up being empty, but this | ||
| // is OK. | ||
| tt::CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); |
There was a problem hiding this comment.
Prefetch cluster is needed to push the copies to the end of the loop, so they work well with prefetching, that is needed for nvidia A100. I'm not sure you need it?
There was a problem hiding this comment.
Yes we also need to prefetch for AMD GPUs. The most naive pipelining we want should have the following structure:
S = <alloc-shared-memory>
R(0) = load Global(0)
store R(0) to S
for i = 0 .. N-1
barrier
R(i+1) = load Gloal(i+1)
dot (load S) (load S)
barrier
store R(i+1) to S
There was a problem hiding this comment.
Here is computation without pipelining
# load from HBM to SRAM
Load R0 : Read global0 (global -> registers)
Store R0: (registers -> SMEM)
# compute on SRAM
Load Si : (Si(SMEM) -> registers Ri)
Compute
Store Si : (registers -> Si)
# store data on SRAM
Store Rn : write SRAM data back to global
@pawelszczerbuk @antiagainst I have a question. Can we load from global to SRAM directly?
My question is "if we load data from global to register first" why dont't we compute it then store it back to SRAM:
# load from HBM
R0_0 load
R1_1 load
# instead of store it back to SMEM
compute R0, R1 (R0_0 + R1_0)
store SMEM O0
# store back to global
store
There was a problem hiding this comment.
No direct global to shared support in normal global load in mi300. (buffer load supports that but we are. not using that, yet.)
There was a problem hiding this comment.
No direct global to shared support in normal global load in mi300. (buffer load supports that but we are. not using that, yet.)
Hi @antiagainst , does buffer load look like "cp.async"? Is there any reason the pipeline doesn't use buffer load? Thanks!
There was a problem hiding this comment.
We just added support for buffer ops in #4966. Enabling it still takes more iterations to stablize though.
There was a problem hiding this comment.
We just added support for buffer ops in #4966. Enabling it still takes more iterations to stablize though.
Thanks very much for your reply. I would take investagation on the corresponding PRs.
This PR first promotes common infrastructure in `lib/Dialect/TritonGPU/Transforms/Pipeliner` to enable inclusion by other target backends. No other changes have been made to the lib/include directories. Second, the `tritonamdgpu-stream-pipeline` pass has been completely revamped based on code from `lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp` using similar scheduling passes to compute multi-stage pipelines. Some of this code could be consolidated further in the CoarseSchedule class (or perhaps a derived LoopScheduler class). This modulo scheduler collects `tt.load` ops and generates local_storage and management ops for the ramp-up stage (stage-0), then collecting all uses of the loads for stage-1. Multi-buffering is introduced when num_stages exceeds the max distance between load and uses. Buffering may be in Shared memory for `tt.dot` uses or Registers for all other uses. This current implement does not support peeling the last iteration if the loop is dynamic. Lastly, the `tritonamdgpu-reorder-instructions` pass has been enhanced to move `tt.load` ops as early as possible in its region. This includes loop bodies as well as func entry blocks for the case of ramp-up. This pass will also move `triton_gpu.local_store` ops as early as possible if their source is not directly from a `tt.load`. In this way, a multi-buffered pipeline will overlap in this order: 1. tt.load buffer+2 2. tg.local_store buffer+1 3. tt.dot buffer+0 --------- Co-authored-by: Lei Zhang <antiagainst@gmail.com>
This PR first promotes common infrastructure in `lib/Dialect/TritonGPU/Transforms/Pipeliner` to enable inclusion by other target backends. No other changes have been made to the lib/include directories. Second, the `tritonamdgpu-stream-pipeline` pass has been completely revamped based on code from `lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp` using similar scheduling passes to compute multi-stage pipelines. Some of this code could be consolidated further in the CoarseSchedule class (or perhaps a derived LoopScheduler class). This modulo scheduler collects `tt.load` ops and generates local_storage and management ops for the ramp-up stage (stage-0), then collecting all uses of the loads for stage-1. Multi-buffering is introduced when num_stages exceeds the max distance between load and uses. Buffering may be in Shared memory for `tt.dot` uses or Registers for all other uses. This current implement does not support peeling the last iteration if the loop is dynamic. Lastly, the `tritonamdgpu-reorder-instructions` pass has been enhanced to move `tt.load` ops as early as possible in its region. This includes loop bodies as well as func entry blocks for the case of ramp-up. This pass will also move `triton_gpu.local_store` ops as early as possible if their source is not directly from a `tt.load`. In this way, a multi-buffered pipeline will overlap in this order: 1. tt.load buffer+2 2. tg.local_store buffer+1 3. tt.dot buffer+0 --------- Co-authored-by: Lei Zhang <antiagainst@gmail.com>
* Add blocked to dot shortcut * pack tensors in vectors instead of structures * fix * add moe bypass option * initial commit * fix * fix * add missing configurations and add more checks in passes * adjust global load layout for vllm swizzling format * Remove debug print * make load width dependable on data type * fix int 8 logic * generalize load analysis: return last load in dependant laod chain instead of 2 * Add message for assert failure So that people know what the problem is when this compiler error shows up * add k=512/1024 cases * [BACKEND] Add memory space to memdesc type. (triton-lang#4027) Currently only shared memory is supported but this will allow supporting different kinds of local memory (like private) or others. * [BACKEND] Fix memory side effects of `tt.dot` (triton-lang#4033) 1. Replaced `triton_nvidia_gpu.async_dot` with `triton_nvidia_gpu.group_dot` which has a `isAsync` attribute. Maybe `warp_group_dot` is a better name? 2. Removed `memdesc` from `tt.dot` because `tt.dot` should be pure, without any side effects 3. Removed hacks in Membar analysis. 4. Unified wgmma code generation in the backend. 5. Introduced the `DotLike` trait for `tt.dot` and `triton_nvidia_gpu.group_dot`. 6. Updated comments in matmul loop pipeline (maybe incomplete). 7. Removed the `ConvertDotConvert` pattern * remove streamPipelinev2 * [TEST] NFC: Drop irrelevant NVIDIA specific attributes (triton-lang#4384) Software pipeling should be not using them. This makes it cleaner and prepares reusing the same test inputs for AMD side. * [Pipeliner] NFC: Expose Pipeliner infrastructure for use by other target backends (triton-lang#4155) Non-functional changes to expose `lib/Dialect/TritonGPU/Transforms/Pipeliner` infrastructure for use by other target backends. * [BACKEND] Fix regression in pipeliner pre-checks. (triton-lang#4196) During some previous refactoring we changed the logic and started pipeling cases that had incompatible shared encoding. This was missed because one of the lit test had not been updated :( * [Backend][AMD] Introduce stream pipeliner v2 (triton-lang#4148) This PR first promotes common infrastructure in `lib/Dialect/TritonGPU/Transforms/Pipeliner` to enable inclusion by other target backends. No other changes have been made to the lib/include directories. Second, the `tritonamdgpu-stream-pipeline` pass has been completely revamped based on code from `lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp` using similar scheduling passes to compute multi-stage pipelines. Some of this code could be consolidated further in the CoarseSchedule class (or perhaps a derived LoopScheduler class). This modulo scheduler collects `tt.load` ops and generates local_storage and management ops for the ramp-up stage (stage-0), then collecting all uses of the loads for stage-1. Multi-buffering is introduced when num_stages exceeds the max distance between load and uses. Buffering may be in Shared memory for `tt.dot` uses or Registers for all other uses. This current implement does not support peeling the last iteration if the loop is dynamic. Lastly, the `tritonamdgpu-reorder-instructions` pass has been enhanced to move `tt.load` ops as early as possible in its region. This includes loop bodies as well as func entry blocks for the case of ramp-up. This pass will also move `triton_gpu.local_store` ops as early as possible if their source is not directly from a `tt.load`. In this way, a multi-buffered pipeline will overlap in this order: 1. tt.load buffer+2 2. tg.local_store buffer+1 3. tt.dot buffer+0 --------- Co-authored-by: Lei Zhang <antiagainst@gmail.com> * [AMD] Prefetch loads and independent local_stores (triton-lang#4429) This pass is enhanced to move tt.loads as early as possible. This enables buffering in registers for global loads while computing previous tiles (stream-pipelining), but may increase register pressure. If ttg.local_stores are independent of loads in the loop (i.e. double buffering in shared memory), then this pass will also move those early to overlap with global loads and compute. * [Pipeliner] Implement dynamic loop peeling - enabled for tritonamdgpu-stream-pipeline * * disabled for num_stages > 2 * updated tests * * guard each stage of ramp-down in epilogue * enable peeling for any num_stages * * pipeline reg buffers * [AMD] Fixed bug with tritonamdgpu-reorder-instructions - blindly moving local_loads can violate memory access order - also fixed case when moving instructions to top of loop * * only move ops early * Fix in streamPipelinerV2 * Fix lit tests * [Backend][AMD] Add temporary environment variable for pipeliner v2 (triton-lang#4430) This commit adds a new environment variable to enable pipeliner v2. It is expected to be temporary while we enable the new pipeliner and get all cases covered. Co-authored-by: SJW <swaters@amd.com> --------- Co-authored-by: Ognjen Plavsic <plognjen@amd.com> Co-authored-by: Alexander Efimov <efimov.alexander@gmail.com> Co-authored-by: Ognjen Plavsic <ognjen.plavsic@luxoft.com> Co-authored-by: Vinayak Gokhale <Vinayak.Gokhale@amd.com> Co-authored-by: Lixun Zhang <lixun.zhang@amd.com> Co-authored-by: Thomas Raoux <thomas.raoux@openai.com> Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Lei Zhang <antiagainst@gmail.com> Co-authored-by: SJW <48454132+sjw36@users.noreply.github.com> Co-authored-by: SJW <swaters@amd.com>
This PR first promotes common infrastructure in
lib/Dialect/TritonGPU/Transforms/Pipelinerto enable inclusion by other target backends. No other changes have been made to the lib/include directories.Second, the
tritonamdgpu-stream-pipelinepass has been completely revamped based on code fromlib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cppusing similar scheduling passes to compute multi-stage pipelines. Some of this code could be consolidated further in the CoarseSchedule class (or perhaps a derived LoopScheduler class). This modulo scheduler collectstt.loadops and generates local_storage and management ops for the ramp-up stage (stage-0), then collecting all uses of the loads for stage-1. Multi-buffering is introduced when num_stages exceeds the max distance between load and uses. Buffering may be in Shared memory fortt.dotuses or Registers for all other uses. This current implement does not support peeling the last iteration if the loop is dynamic.Lastly, the
tritonamdgpu-reorder-instructionspass has been enhanced to movett.loadops as early as possible in its region. This includes loop bodies as well as func entry blocks for the case of ramp-up. This pass will also movetriton_gpu.local_storeops as early as possible if their source is not directly from att.load. In this way, a multi-buffered pipeline will overlap in this order: