Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
8d5ba3f
A new Triton compiler sans CUDA support.
xinyazhang Jun 18, 2024
003b06e
Fix the compiler for new Triton
xinyazhang Jun 18, 2024
cd9615c
Mitigate compiler bug (https://github.com/ROCm/triton/issues/596)
xinyazhang Jun 18, 2024
77fa1e1
add wheel as another required package.
xinyazhang Jun 18, 2024
579bf82
Port to performant kernel and moving away from block pointers for tl.…
xinyazhang Jun 19, 2024
2a6872a
Fix the off_h_k computation.
xinyazhang Jun 25, 2024
c62c904
Fix writing to encoded_softmax
xinyazhang Jun 25, 2024
8be265d
Submit the Triton kernel as we are testing. All UTs passed
xinyazhang Jun 25, 2024
e150f54
remove debugging output
xinyazhang Jun 25, 2024
713e87a
cpp tuning: Add basic C++ tuning support
xinyazhang Jul 8, 2024
36ca9fe
v2src/flash/attn_fwd: add missing num_head_q and num_head_k
xinyazhang Jul 8, 2024
63cc8fb
Flash API now returns selected psels and copts to extra arguments, if…
xinyazhang Jul 8, 2024
fdd14e1
Implement tune_flash with AOT kernels
xinyazhang Jul 9, 2024
0e988f9
Fix the dropout_mask and add a progressbar to test/tune_flash.py
xinyazhang Jul 9, 2024
c4c201f
Save memory for long seq length
xinyazhang Jul 9, 2024
8528c7a
Update the tuning database for MI200 only GPUs
xinyazhang Jul 10, 2024
1d4fbd0
Remove seqlen_q/k >= 32k rows from the database
xinyazhang Jul 10, 2024
dd6a26b
Fix CMakeLists. Do not pass empty string as cmd argument if GENERATE_…
xinyazhang Jul 10, 2024
98c404c
Return hipErrorSharedObjectSymbolNotFound for untuned cases.
xinyazhang Jul 10, 2024
ec12934
Fix test/test_backward.py
xinyazhang Jul 10, 2024
4513dfe
Fix AUTOTUNE_KEYS for backward kernels.
xinyazhang Jul 10, 2024
350b6bb
tritonsrc: add type annotation 'i32' to num_seqlens, and fix varlen h…
xinyazhang Jul 10, 2024
ece99b8
fix the assignment of .num_head_q/k
xinyazhang Jul 11, 2024
b3f9dab
Add Navi 31/32 compiler options.
xinyazhang Jul 11, 2024
d33cf43
Fix various problems and now most fwd kernel tests passed.
xinyazhang Jul 11, 2024
ad33017
Various fixes to tune_flash
xinyazhang Jul 11, 2024
09583e2
Make zstd quite
xinyazhang Jul 12, 2024
87262c1
Add draft document 'How To Generate Tuning Database.md'
xinyazhang Jul 12, 2024
f95d878
doc -> docs
xinyazhang Jul 12, 2024
a5a3189
Debugging output in bwd kernel
xinyazhang Jul 12, 2024
c091ef3
add num_head_q/k argument to varlen's attention module.
xinyazhang Jul 12, 2024
34ac678
tritonsrc/performance_forward: read env var N_CTX to determine testin…
xinyazhang Jul 12, 2024
22c3197
Reduce the tuning time since there are too many cases to test...
xinyazhang Jul 12, 2024
0b40af3
cpp autotune: x2 num_warps if warp_size == 32
xinyazhang Jul 11, 2024
7457a5a
Navi32: skip autotune configs that takes too long to build
xinyazhang Jul 12, 2024
b7d647c
Add --use_multigpu to test/tune_flash.py for multi-GPU tuning
xinyazhang Jul 12, 2024
e819160
test/tune_flash.py: actually distribute tensor/computing to different…
xinyazhang Jul 12, 2024
b8c702d
Move dev-only packages from requirements.txt into requirements-dev.txt
xinyazhang Jul 12, 2024
b5869a1
tune_flash.py: Fix the slow splice_pipes
xinyazhang Jul 12, 2024
c4f0de5
Fix single GPU script.
xinyazhang Jul 13, 2024
8558d73
Move database accessing to a separate process, and unify the task gen…
xinyazhang Jul 13, 2024
18b56c0
tune_flash: add --json_file, improve --dry_run to report total numbers,
xinyazhang Jul 13, 2024
fcfa3e8
tune_flash: Move the testing to a separate process to avoid segfault.
xinyazhang Jul 14, 2024
f7b1f28
Cache the minesweeping process to avoid creating processes repeatedly
xinyazhang Jul 14, 2024
a150716
Remove 16k from seqlen_q/k, record task id and skipped tests in json
xinyazhang Jul 15, 2024
d93ceef
tune_flash: add --continue_from_json_file
xinyazhang Jul 15, 2024
7d224cd
table_tool: skip result=skipped json objects
xinyazhang Jul 15, 2024
6ef9a40
tuning_database: Update FLASH$attn_fwd for gfx90a and gfx942
xinyazhang Jul 15, 2024
e92f7d1
track aotriton-hyperjump branch in third_party/triton
xinyazhang Jul 15, 2024
7ed9ac6
Fix test/performance_forward.py
xinyazhang Jul 15, 2024
bb1a5e8
Remove old_compile.py
xinyazhang Jul 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
__pycache__/
build/
*build*/
*.swp
tritonsrc/tune-*.json
*.csv
*.png
1
2
1.*
2.*
1 change: 1 addition & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[submodule "third_party/triton"]
path = third_party/triton
url = https://github.com/ROCmSoftwarePlatform/triton.git
branch = aotriton-hyperjump
Comment thread
groenenboomj marked this conversation as resolved.
[submodule "third_party/incbin"]
path = third_party/incbin
url = https://github.com/graphitemaster/incbin.git
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ set(AOTRITON_HIPCC_PATH "hipcc" CACHE STRING "Set HIPCC Path")
option(AOTRITON_NO_SHARED "Disable shared object build. Incompatible with AOTRITON_COMPRESS_KERNEL." ON)
option(AOTRITON_NO_PYTHON "Disable python binding build" OFF)
option(AOTRITON_ENABLE_ASAN "Enable Address Sanitizer. Implies -g" OFF)
option(AOTRITON_BUILD_FOR_TUNING "Build all GPU kernels and set -DAOTRITON_BUILD_FOR_TUNING=1 (=0 otherwise)" OFF)
set(TARGET_GPUS "MI200;MI300X" CACHE STRING "Target Architecture (Note here uses Trade names)")
set(AMDHSA_LD_PRELOAD "/opt/rocm/lib/libhsa-runtime64.so" CACHE STRING "Workaround of libamdhip64.so.5: undefined symbol: hsa_amd_memory_async_copy_on_engine")

Expand Down
6 changes: 6 additions & 0 deletions bindings/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@ if(AOTRITON_OVERRIDE_ZSTD_LIB)
else()
target_link_libraries(pyaotriton PRIVATE ${ZSTD_TARGET})
endif()
# TODO: Unify build option marcos with "interface target+public compile definitions"
if(AOTRITON_BUILD_FOR_TUNING)
target_compile_definitions(pyaotriton PRIVATE -DAOTRITON_BUILD_FOR_TUNING=1)
else(AOTRITON_BUILD_FOR_TUNING)
target_compile_definitions(pyaotriton PRIVATE -DAOTRITON_BUILD_FOR_TUNING=0)
endif(AOTRITON_BUILD_FOR_TUNING)
22 changes: 18 additions & 4 deletions bindings/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,18 @@ namespace py = pybind11;
namespace pyaotriton {
namespace v2 {
namespace flash {
using aotriton::v2::flash::ExtraArguments;
void setup_module(py::module_& m) {
m.def("check_gpu", &aotriton::v2::flash::check_gpu, py::arg("stream"));
py::class_<ExtraArguments>(m, "ExtraArguments")
.def(py::init<>())
#if AOTRITON_BUILD_FOR_TUNING
.def_readwrite("force_kernel_index", &ExtraArguments::force_kernel_index)
.def_readonly("total_number_of_kernels", &ExtraArguments::total_number_of_kernels)
.def_readonly("selected_kernel_psels", &ExtraArguments::selected_kernel_psels)
.def_readonly("selected_kernel_copts", &ExtraArguments::selected_kernel_copts)
#endif
;
m.def("attn_fwd",
&aotriton::v2::flash::attn_fwd,
"Flash Attention Forward Pass",
Expand All @@ -31,7 +41,8 @@ namespace pyaotriton {
py::arg("philox_offset"),
py::arg("encoded_softmax"),
py::arg("is_causal"),
py::arg("stream") = nullptr);
py::arg("stream") = nullptr,
py::arg("extargs") = ExtraArguments());
m.def("attn_fwd_compact_varlen",
&aotriton::v2::flash::attn_fwd_compact_varlen,
"Flash Attention Forward Pass, Compact Stored Varlen",
Expand All @@ -51,7 +62,8 @@ namespace pyaotriton {
py::arg("philox_offset"),
py::arg("encoded_softmax"),
py::arg("is_causal"),
py::arg("stream") = nullptr);
py::arg("stream") = nullptr,
py::arg("extargs") = ExtraArguments());
m.def("attn_bwd",
&aotriton::v2::flash::attn_bwd,
"Flash Attention Backward Pass",
Expand All @@ -72,7 +84,8 @@ namespace pyaotriton {
py::arg("philox_seed"),
py::arg("philox_offset"),
py::arg("is_causal"),
py::arg("stream") = nullptr);
py::arg("stream") = nullptr,
py::arg("extargs") = ExtraArguments());
m.def("attn_bwd_compact_varlen",
&aotriton::v2::flash::attn_bwd_compact_varlen,
"Flash Attention Backward Pass, Compact Stored Varlen",
Expand All @@ -97,7 +110,8 @@ namespace pyaotriton {
py::arg("philox_seed"),
py::arg("philox_offset"),
py::arg("is_causal"),
py::arg("stream") = nullptr);
py::arg("stream") = nullptr,
py::arg("extargs") = ExtraArguments());
m.def("debug_fill_dropout_rng",
&aotriton::v2::flash::debug_fill_dropout_rng,
"Flash Attention Debugging Function to get raw RNG numbers used in dropout",
Expand Down
12 changes: 12 additions & 0 deletions docs/How To Generate Tuning Database.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# TL;DR

```
mkdir cpptune_build
cd cpptune_build
cmake .. -DCMAKE_INSTALL_PREFIX=./install_dir -DCMAKE_BUILD_TYPE=Release -DAOTRITON_BUILD_FOR_TUNING=ON -G Ninja
# Optionally only build for one arch
# cmake .. -DCMAKE_INSTALL_PREFIX=./install_dir -DCMAKE_BUILD_TYPE=Release -DAOTRITON_BUILD_FOR_TUNING=ON -DTARGET_GPUS=Navi32 -G Ninja
ninja install
cd ..
PYTHONPATH=cpptune_build/bindings/ python test/tune_flash.py --bias_type 0 --db_file v2python/rules/tuning_database.sqlite3
```
22 changes: 18 additions & 4 deletions include/aotriton/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ using T4 = aotriton::TensorView<4>;
using T2 = aotriton::TensorView<2>;
using T1 = aotriton::TensorView<1>;

struct ExtraArguments {
#if AOTRITON_BUILD_FOR_TUNING
// TODO: Move them into a base class since they are common to all kernels
int force_kernel_index = -1;
int total_number_of_kernels = -1;
const char* selected_kernel_psels = nullptr;
const char* selected_kernel_copts = nullptr;
#endif
};

hipError_t
attn_fwd(T4 q, // batch_size x num_heads x seqlen_q x head_size
T4 k, // batch_size x num_heads x seqlen_k x head_size
Expand All @@ -29,7 +39,8 @@ attn_fwd(T4 q, // batch_size x num_heads x seqlen_q x head_size
uint64_t philox_offset,
T4 encoded_softmax,
bool is_causal,
aotriton::Stream stream);
aotriton::Stream stream,
ExtraArguments* extargs = nullptr);

hipError_t
attn_fwd_compact_varlen(T4 q, // 1 x num_heads x total_q x head_size, total_q := \sum_{i=0}^{b} s_i
Expand All @@ -48,7 +59,8 @@ attn_fwd_compact_varlen(T4 q, // 1 x num_heads x total_q x head_size, total_q :=
uint64_t philox_offset,
T4 encoded_softmax,
bool is_causal,
aotriton::Stream stream);
aotriton::Stream stream,
ExtraArguments* extargs = nullptr);

hipError_t
attn_bwd(T4 q, // batch_size x num_heads x seqlen_q x head_size
Expand All @@ -68,7 +80,8 @@ attn_bwd(T4 q, // batch_size x num_heads x seqlen_q x head_size
uint64_t philox_seed,
uint64_t philox_offset,
bool is_causal,
aotriton::Stream stream);
aotriton::Stream stream,
ExtraArguments* extargs = nullptr);

hipError_t
attn_bwd_compact_varlen(T4 q, // 1 x num_heads x total_q x head_size, total_q := \sum_{i=0}^{b}
Expand All @@ -92,7 +105,8 @@ attn_bwd_compact_varlen(T4 q, // 1 x num_heads x total_q x head_size, total_q :=
uint64_t philox_seed,
uint64_t philox_offset,
bool is_causal,
aotriton::Stream stream);
aotriton::Stream stream,
ExtraArguments* extargs = nullptr);

hipError_t
debug_fill_dropout_rng(T4 r,
Expand Down
2 changes: 2 additions & 0 deletions include/aotriton/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ enum GpuArch : uint64_t {
GPU_ARCH_UNKNOWN = 0,
GPU_ARCH_AMD_GFX90A = CAT(GpuVendor::kAMD, 0x90a),
GPU_ARCH_AMD_GFX942 = CAT(GpuVendor::kAMD, 0x942),
GPU_ARCH_AMD_GFX1100 = CAT(GpuVendor::kAMD, 0x1100),
GPU_ARCH_AMD_GFX1101 = CAT(GpuVendor::kAMD, 0x1101),
};

template<int Rank>
Expand Down
3 changes: 3 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-r requirements.txt
tqdm
textual
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ packaging
pluggy
numpy
setuptools
wheel
43 changes: 34 additions & 9 deletions test/aotriton_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
attn_fwd_compact_varlen as fa_forward_compact_varlen,
attn_bwd_compact_varlen as fa_backward_compact_varlen,
debug_fill_dropout_rng as fa_debug_fill_dropout_rng,
ExtraArguments as ExtraArguments,
)
from pyaotriton import T1, T2, T4, DType, Stream
from pyaotriton import T1, T2, T4, DType, Stream, hipError_t

def cast_dtype(dtype):
assert not dtype.is_complex
Expand Down Expand Up @@ -37,7 +38,9 @@ def mk_aotensor(q, if_empty_then_like=None):
return klass(q.data_ptr(), tuple(q.size()), q.stride(), cast_dtype(q.dtype))

def attn_fwd(q, k, v, b, sm_scale, M, o,
dropout_p, philox_seed, philox_offset, encoded_softmax, is_causal):
dropout_p, philox_seed, philox_offset, encoded_softmax, is_causal,
extargs=None):
extargs = ExtraArguments() if extargs is None else extargs
err = fa_forward(mk_aotensor(q),
mk_aotensor(k),
mk_aotensor(v),
Expand All @@ -50,13 +53,31 @@ def attn_fwd(q, k, v, b, sm_scale, M, o,
int(philox_offset),
mk_aotensor(encoded_softmax, if_empty_then_like=q),
is_causal,
Stream())
print(f'{err=}')
Stream(),
extargs)
# print(f'{err=}')
return err

def ipc_attn_fwd(ipc_to_read, ipc_to_write):
import torch
while True:
tup = ipc_to_read.get()
if tup is None:
break
q, k, v, b, sm_scale, M, o, dropout_p, philox_seed, philox_offset, encoded_softmax, is_causal, force_kernel_index, shard = tup
extargs = ExtraArguments()
extargs.force_kernel_index = force_kernel_index
with torch.cuda.device(shard):
ret = attn_fwd(q, k, v, b, sm_scale, M, o,
dropout_p, philox_seed, philox_offset, encoded_softmax, is_causal,
extargs)
torch.cuda.synchronize()
ipc_to_write.put(ret)

def attn_bwd(q, k, v, b, sm_scale, o, dout, dq, dk, dv, db, L, delta,
dropout_p, philox_seed, philox_offset, is_causal):
b = mk_aotensor(b, if_empty_then_like=q)
print(f'{b=}')
# print(f'{b=}')
err = fa_backward(mk_aotensor(q),
mk_aotensor(k),
mk_aotensor(v),
Expand All @@ -75,14 +96,16 @@ def attn_bwd(q, k, v, b, sm_scale, o, dout, dq, dk, dv, db, L, delta,
int(philox_offset),
is_causal,
Stream())
print(f'{err=}')
# print(f'{err=}')
return err

def debug_fill_dropout_rng(R, philox_seed, philox_offset):
err = fa_debug_fill_dropout_rng(mk_aotensor(R),
philox_seed,
philox_offset,
Stream())
print(f'{err=}')
# print(f'debug_fill_dropout_rng {err=}')
return err

def attn_fwd_compact_varlen(q, k, v,
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
Expand All @@ -105,7 +128,8 @@ def attn_fwd_compact_varlen(q, k, v,
mk_aotensor(encoded_softmax, if_empty_then_like=q),
is_causal,
Stream())
print(f'{err=}')
# print(f'{err=}')
return err

def attn_bwd_compact_varlen(q, k, v,
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
Expand Down Expand Up @@ -135,4 +159,5 @@ def attn_bwd_compact_varlen(q, k, v,
int(philox_offset),
is_causal,
Stream())
print(f'{err=}')
# print(f'{err=}')
return err
Loading