diff --git a/.gitignore b/.gitignore index f75f6a5d..a9eec59c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,10 @@ __pycache__/ -build/ +*build*/ *.swp tritonsrc/tune-*.json +*.csv +*.png +1 +2 +1.* +2.* diff --git a/.gitmodules b/.gitmodules index 9a37e3cc..5ce5d5a1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,7 @@ [submodule "third_party/triton"] path = third_party/triton url = https://github.com/ROCmSoftwarePlatform/triton.git + branch = aotriton-hyperjump [submodule "third_party/incbin"] path = third_party/incbin url = https://github.com/graphitemaster/incbin.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 4d936106..4f227789 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,7 @@ set(AOTRITON_HIPCC_PATH "hipcc" CACHE STRING "Set HIPCC Path") option(AOTRITON_NO_SHARED "Disable shared object build. Incompatible with AOTRITON_COMPRESS_KERNEL." ON) option(AOTRITON_NO_PYTHON "Disable python binding build" OFF) option(AOTRITON_ENABLE_ASAN "Enable Address Sanitizer. Implies -g" OFF) +option(AOTRITON_BUILD_FOR_TUNING "Build all GPU kernels and set -DAOTRITON_BUILD_FOR_TUNING=1 (=0 otherwise)" OFF) set(TARGET_GPUS "MI200;MI300X" CACHE STRING "Target Architecture (Note here uses Trade names)") set(AMDHSA_LD_PRELOAD "/opt/rocm/lib/libhsa-runtime64.so" CACHE STRING "Workaround of libamdhip64.so.5: undefined symbol: hsa_amd_memory_async_copy_on_engine") diff --git a/bindings/CMakeLists.txt b/bindings/CMakeLists.txt index dbf69793..31ddf8ea 100644 --- a/bindings/CMakeLists.txt +++ b/bindings/CMakeLists.txt @@ -13,3 +13,9 @@ if(AOTRITON_OVERRIDE_ZSTD_LIB) else() target_link_libraries(pyaotriton PRIVATE ${ZSTD_TARGET}) endif() +# TODO: Unify build option marcos with "interface target+public compile definitions" +if(AOTRITON_BUILD_FOR_TUNING) + target_compile_definitions(pyaotriton PRIVATE -DAOTRITON_BUILD_FOR_TUNING=1) +else(AOTRITON_BUILD_FOR_TUNING) + target_compile_definitions(pyaotriton PRIVATE -DAOTRITON_BUILD_FOR_TUNING=0) +endif(AOTRITON_BUILD_FOR_TUNING) diff --git a/bindings/module.cc b/bindings/module.cc index a966a840..e42d91f2 100644 --- a/bindings/module.cc +++ b/bindings/module.cc @@ -14,8 +14,18 @@ namespace py = pybind11; namespace pyaotriton { namespace v2 { namespace flash { + using aotriton::v2::flash::ExtraArguments; void setup_module(py::module_& m) { m.def("check_gpu", &aotriton::v2::flash::check_gpu, py::arg("stream")); + py::class_(m, "ExtraArguments") + .def(py::init<>()) +#if AOTRITON_BUILD_FOR_TUNING + .def_readwrite("force_kernel_index", &ExtraArguments::force_kernel_index) + .def_readonly("total_number_of_kernels", &ExtraArguments::total_number_of_kernels) + .def_readonly("selected_kernel_psels", &ExtraArguments::selected_kernel_psels) + .def_readonly("selected_kernel_copts", &ExtraArguments::selected_kernel_copts) +#endif + ; m.def("attn_fwd", &aotriton::v2::flash::attn_fwd, "Flash Attention Forward Pass", @@ -31,7 +41,8 @@ namespace pyaotriton { py::arg("philox_offset"), py::arg("encoded_softmax"), py::arg("is_causal"), - py::arg("stream") = nullptr); + py::arg("stream") = nullptr, + py::arg("extargs") = ExtraArguments()); m.def("attn_fwd_compact_varlen", &aotriton::v2::flash::attn_fwd_compact_varlen, "Flash Attention Forward Pass, Compact Stored Varlen", @@ -51,7 +62,8 @@ namespace pyaotriton { py::arg("philox_offset"), py::arg("encoded_softmax"), py::arg("is_causal"), - py::arg("stream") = nullptr); + py::arg("stream") = nullptr, + py::arg("extargs") = ExtraArguments()); m.def("attn_bwd", &aotriton::v2::flash::attn_bwd, "Flash Attention Backward Pass", @@ -72,7 +84,8 @@ namespace pyaotriton { py::arg("philox_seed"), py::arg("philox_offset"), py::arg("is_causal"), - py::arg("stream") = nullptr); + py::arg("stream") = nullptr, + py::arg("extargs") = ExtraArguments()); m.def("attn_bwd_compact_varlen", &aotriton::v2::flash::attn_bwd_compact_varlen, "Flash Attention Backward Pass, Compact Stored Varlen", @@ -97,7 +110,8 @@ namespace pyaotriton { py::arg("philox_seed"), py::arg("philox_offset"), py::arg("is_causal"), - py::arg("stream") = nullptr); + py::arg("stream") = nullptr, + py::arg("extargs") = ExtraArguments()); m.def("debug_fill_dropout_rng", &aotriton::v2::flash::debug_fill_dropout_rng, "Flash Attention Debugging Function to get raw RNG numbers used in dropout", diff --git a/docs/How To Generate Tuning Database.md b/docs/How To Generate Tuning Database.md new file mode 100644 index 00000000..f16f4a99 --- /dev/null +++ b/docs/How To Generate Tuning Database.md @@ -0,0 +1,12 @@ +# TL;DR + +``` +mkdir cpptune_build +cd cpptune_build +cmake .. -DCMAKE_INSTALL_PREFIX=./install_dir -DCMAKE_BUILD_TYPE=Release -DAOTRITON_BUILD_FOR_TUNING=ON -G Ninja +# Optionally only build for one arch +# cmake .. -DCMAKE_INSTALL_PREFIX=./install_dir -DCMAKE_BUILD_TYPE=Release -DAOTRITON_BUILD_FOR_TUNING=ON -DTARGET_GPUS=Navi32 -G Ninja +ninja install +cd .. +PYTHONPATH=cpptune_build/bindings/ python test/tune_flash.py --bias_type 0 --db_file v2python/rules/tuning_database.sqlite3 +``` diff --git a/doc/How To Update Constraints of Tuning Database.md b/docs/How To Update Constraints of Tuning Database.md similarity index 100% rename from doc/How To Update Constraints of Tuning Database.md rename to docs/How To Update Constraints of Tuning Database.md diff --git a/include/aotriton/flash.h b/include/aotriton/flash.h index 2fb50244..6ae38214 100644 --- a/include/aotriton/flash.h +++ b/include/aotriton/flash.h @@ -16,6 +16,16 @@ using T4 = aotriton::TensorView<4>; using T2 = aotriton::TensorView<2>; using T1 = aotriton::TensorView<1>; +struct ExtraArguments { +#if AOTRITON_BUILD_FOR_TUNING + // TODO: Move them into a base class since they are common to all kernels + int force_kernel_index = -1; + int total_number_of_kernels = -1; + const char* selected_kernel_psels = nullptr; + const char* selected_kernel_copts = nullptr; +#endif +}; + hipError_t attn_fwd(T4 q, // batch_size x num_heads x seqlen_q x head_size T4 k, // batch_size x num_heads x seqlen_k x head_size @@ -29,7 +39,8 @@ attn_fwd(T4 q, // batch_size x num_heads x seqlen_q x head_size uint64_t philox_offset, T4 encoded_softmax, bool is_causal, - aotriton::Stream stream); + aotriton::Stream stream, + ExtraArguments* extargs = nullptr); hipError_t attn_fwd_compact_varlen(T4 q, // 1 x num_heads x total_q x head_size, total_q := \sum_{i=0}^{b} s_i @@ -48,7 +59,8 @@ attn_fwd_compact_varlen(T4 q, // 1 x num_heads x total_q x head_size, total_q := uint64_t philox_offset, T4 encoded_softmax, bool is_causal, - aotriton::Stream stream); + aotriton::Stream stream, + ExtraArguments* extargs = nullptr); hipError_t attn_bwd(T4 q, // batch_size x num_heads x seqlen_q x head_size @@ -68,7 +80,8 @@ attn_bwd(T4 q, // batch_size x num_heads x seqlen_q x head_size uint64_t philox_seed, uint64_t philox_offset, bool is_causal, - aotriton::Stream stream); + aotriton::Stream stream, + ExtraArguments* extargs = nullptr); hipError_t attn_bwd_compact_varlen(T4 q, // 1 x num_heads x total_q x head_size, total_q := \sum_{i=0}^{b} @@ -92,7 +105,8 @@ attn_bwd_compact_varlen(T4 q, // 1 x num_heads x total_q x head_size, total_q := uint64_t philox_seed, uint64_t philox_offset, bool is_causal, - aotriton::Stream stream); + aotriton::Stream stream, + ExtraArguments* extargs = nullptr); hipError_t debug_fill_dropout_rng(T4 r, diff --git a/include/aotriton/util.h b/include/aotriton/util.h index 91f61d98..efb93950 100644 --- a/include/aotriton/util.h +++ b/include/aotriton/util.h @@ -37,6 +37,8 @@ enum GpuArch : uint64_t { GPU_ARCH_UNKNOWN = 0, GPU_ARCH_AMD_GFX90A = CAT(GpuVendor::kAMD, 0x90a), GPU_ARCH_AMD_GFX942 = CAT(GpuVendor::kAMD, 0x942), + GPU_ARCH_AMD_GFX1100 = CAT(GpuVendor::kAMD, 0x1100), + GPU_ARCH_AMD_GFX1101 = CAT(GpuVendor::kAMD, 0x1101), }; template diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..f627c2c1 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,3 @@ +-r requirements.txt +tqdm +textual diff --git a/requirements.txt b/requirements.txt index edc348cb..0abbfe0c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ packaging pluggy numpy setuptools +wheel diff --git a/test/aotriton_flash.py b/test/aotriton_flash.py index 908035eb..0be463f7 100644 --- a/test/aotriton_flash.py +++ b/test/aotriton_flash.py @@ -7,8 +7,9 @@ attn_fwd_compact_varlen as fa_forward_compact_varlen, attn_bwd_compact_varlen as fa_backward_compact_varlen, debug_fill_dropout_rng as fa_debug_fill_dropout_rng, + ExtraArguments as ExtraArguments, ) -from pyaotriton import T1, T2, T4, DType, Stream +from pyaotriton import T1, T2, T4, DType, Stream, hipError_t def cast_dtype(dtype): assert not dtype.is_complex @@ -37,7 +38,9 @@ def mk_aotensor(q, if_empty_then_like=None): return klass(q.data_ptr(), tuple(q.size()), q.stride(), cast_dtype(q.dtype)) def attn_fwd(q, k, v, b, sm_scale, M, o, - dropout_p, philox_seed, philox_offset, encoded_softmax, is_causal): + dropout_p, philox_seed, philox_offset, encoded_softmax, is_causal, + extargs=None): + extargs = ExtraArguments() if extargs is None else extargs err = fa_forward(mk_aotensor(q), mk_aotensor(k), mk_aotensor(v), @@ -50,13 +53,31 @@ def attn_fwd(q, k, v, b, sm_scale, M, o, int(philox_offset), mk_aotensor(encoded_softmax, if_empty_then_like=q), is_causal, - Stream()) - print(f'{err=}') + Stream(), + extargs) + # print(f'{err=}') + return err + +def ipc_attn_fwd(ipc_to_read, ipc_to_write): + import torch + while True: + tup = ipc_to_read.get() + if tup is None: + break + q, k, v, b, sm_scale, M, o, dropout_p, philox_seed, philox_offset, encoded_softmax, is_causal, force_kernel_index, shard = tup + extargs = ExtraArguments() + extargs.force_kernel_index = force_kernel_index + with torch.cuda.device(shard): + ret = attn_fwd(q, k, v, b, sm_scale, M, o, + dropout_p, philox_seed, philox_offset, encoded_softmax, is_causal, + extargs) + torch.cuda.synchronize() + ipc_to_write.put(ret) def attn_bwd(q, k, v, b, sm_scale, o, dout, dq, dk, dv, db, L, delta, dropout_p, philox_seed, philox_offset, is_causal): b = mk_aotensor(b, if_empty_then_like=q) - print(f'{b=}') + # print(f'{b=}') err = fa_backward(mk_aotensor(q), mk_aotensor(k), mk_aotensor(v), @@ -75,14 +96,16 @@ def attn_bwd(q, k, v, b, sm_scale, o, dout, dq, dk, dv, db, L, delta, int(philox_offset), is_causal, Stream()) - print(f'{err=}') + # print(f'{err=}') + return err def debug_fill_dropout_rng(R, philox_seed, philox_offset): err = fa_debug_fill_dropout_rng(mk_aotensor(R), philox_seed, philox_offset, Stream()) - print(f'{err=}') + # print(f'debug_fill_dropout_rng {err=}') + return err def attn_fwd_compact_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, @@ -105,7 +128,8 @@ def attn_fwd_compact_varlen(q, k, v, mk_aotensor(encoded_softmax, if_empty_then_like=q), is_causal, Stream()) - print(f'{err=}') + # print(f'{err=}') + return err def attn_bwd_compact_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, @@ -135,4 +159,5 @@ def attn_bwd_compact_varlen(q, k, v, int(philox_offset), is_causal, Stream()) - print(f'{err=}') + # print(f'{err=}') + return err diff --git a/test/attn_torch_function.py b/test/attn_torch_function.py index 638175a2..2dd8ea6d 100644 --- a/test/attn_torch_function.py +++ b/test/attn_torch_function.py @@ -3,7 +3,23 @@ # SPDX-License-Identifier: MIT import torch -from aotriton_flash import attn_fwd, attn_bwd, debug_fill_dropout_rng +import queue +from torch.multiprocessing import Process +from aotriton_flash import attn_fwd, ipc_attn_fwd, attn_bwd, debug_fill_dropout_rng, ExtraArguments, hipError_t +from collections import namedtuple +from cpp_autotune import do_bench, cpp_autotune + +AttentionExtraArgs = namedtuple('AttentionExtraArgs', + ['return_encoded_softmax', + 'autotune', + 'return_autotune', + 'autotune_validator', + 'cpp_autotune_tqdm_position', + 'cpp_autotune_tqdm_prefix', + 'gpu_device', + 'tune_worker', + ], + defaults=[False, False, False, None, None, '', None, None]) VERBOSE=False DEFAULT_PHILOX_SEED = 0x1BF52 @@ -21,8 +37,9 @@ class _attention(torch.autograd.Function): # DEBUG_MASK_DTYPE = torch.float32 @staticmethod - def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax, - autotune=False, return_autotune=False): + def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, + attn_extra_args=AttentionExtraArgs()): + return_encoded_softmax, autotune, return_autotune = attn_extra_args[:3] # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv @@ -58,8 +75,73 @@ def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax philox_seed = DEFAULT_PHILOX_SEED philox_offset = DEFAULT_PHILOX_OFFSET - attn_fwd(q, k, v, b, sm_scale, M, o, - dropout_p, philox_seed, philox_offset, encoded_softmax, causal); + if autotune and return_autotune: + assert attn_extra_args.autotune_validator is not None + def sameprocess_func(extargs): + args = (q, k, v, b, sm_scale, M, o, + dropout_p, philox_seed, philox_offset, encoded_softmax, causal, + extargs) + try: + ret = attn_fwd(*args) + except Exception as e: + print(e) + return hipError_t.hipErrorLaunchFailure, None + return ret, (o,) + def ipc_func(force_kernel_index): + shard = attn_extra_args.gpu_device + tune_worker = attn_extra_args.tune_worker + def factory(): + ipc_to_worker = torch.multiprocessing.Queue() + ipc_from_worker = torch.multiprocessing.Queue() + ipc_to_worker.cancel_join_thread() + ipc_from_worker.cancel_join_thread() + p = Process(target=ipc_attn_fwd, + args=(ipc_to_worker, ipc_from_worker)) + p.start() + return (ipc_to_worker, ipc_from_worker, p) + ipc_to_worker, ipc_from_worker, p = tune_worker.request_cached_gpukernel_process(ipc_attn_fwd, factory) + # print(f'{q.data_ptr()=:x}') + # print(f'{k.data_ptr()=:x}') + # print(f'{v.data_ptr()=:x}') + # print(f'{b.data_ptr()=:x}') + # print(f'{M.data_ptr()=:x}') + # print(f'{o.data_ptr()=:x}') + ipc_to_worker.put((q, k, v, b, sm_scale, M, o, + dropout_p, philox_seed, philox_offset, encoded_softmax, causal, + force_kernel_index, shard)) + while p.is_alive(): + try: + iret = ipc_from_worker.get(timeout=1) + break + except queue.Empty: + # print(f'Process timeout {p.is_alive()=}') + pass + # print(f'Process attn_fwd starting') + if not p.is_alive(): + # print(f'Process exitcode {p.exitcode}') + tune_worker.invalid_gpukernel_process_cache(ipc_attn_fwd) + p.join() + ret = hipError_t.hipErrorLaunchFailure + else: + ret = hipError_t.hipSuccess if iret == 0 else hipError_t.hipErrorLaunchFailure + # print(f'Process attn_fwd joined') + # print(f'Process exitcode {p.exitcode}') + return ret, (o,) + def func(extargs, is_testing): + # print(f'{is_testing=}') + if not is_testing: + return sameprocess_func(extargs) + o.fill_(float('nan')) + return ipc_func(extargs.force_kernel_index) + # print(f'running attn_fwd with {extargs.force_kernel_index=}') + tuning_result = cpp_autotune(ExtraArguments, func, + attn_extra_args.autotune_validator, + tqdm_position=attn_extra_args.cpp_autotune_tqdm_position, + tqdm_prefix=attn_extra_args.cpp_autotune_tqdm_prefix) + else: + attn_fwd(q, k, v, b, sm_scale, M, o, + dropout_p, philox_seed, philox_offset, encoded_softmax, causal); + tuning_result = None ctx.save_for_backward(q, k, v, b, o, M) ctx.sm_scale = sm_scale @@ -69,12 +151,13 @@ def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax ctx.philox_seed = philox_seed ctx.philox_offset = philox_offset ctx.encoded_softmax = encoded_softmax # FIXME: for debugging only - return o, encoded_softmax, None + ctx.tuning_result = [('attn_fwd', tuning_result)] if tuning_result is not None else None + return o, encoded_softmax, ctx.tuning_result @staticmethod def backward(ctx, do, _, __): q, k, v, b, o, L = ctx.saved_tensors - print(f'{b=}') + # print(f'{b=}') sm_scale = ctx.sm_scale dropout_p = ctx.dropout_p philox_seed = ctx.philox_seed @@ -82,15 +165,16 @@ def backward(ctx, do, _, __): causal = ctx.causal # if q.shape[-1] <= 32: # do = do.contiguous() - dq = torch.zeros_like(q) + dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) db = torch.empty_like(b) if b is not None else None delta = torch.empty_like(L) seqlen_q = q.shape[2] seqlen_k = k.shape[2] - attn_bwd(q, k, v, b, sm_scale, o, do, dq, dk, dv, db, L, delta, - dropout_p, philox_seed, philox_offset, causal); - return dq, dk, dv, db, None, None, None, None, None, None, None + ret = attn_bwd(q, k, v, b, sm_scale, o, do, dq, dk, dv, db, L, delta, + dropout_p, philox_seed, philox_offset, causal); + assert ret == hipError_t.hipSuccess, ret + return dq, dk, dv, db, None, None, None, None, None attention = _attention.apply diff --git a/test/cpp_autotune.py b/test/cpp_autotune.py new file mode 100644 index 00000000..994ac3d6 --- /dev/null +++ b/test/cpp_autotune.py @@ -0,0 +1,160 @@ +from collections import namedtuple +from aotriton_flash import hipError_t +import json +import sys +import math +from tqdm import tqdm + +def do_bench(fn, *, warmup=25, rep=100, + grad_to_none=None, + quantiles=None, + fast_flush=True, + return_mode="mean", + validator=None): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float], optional + :param fast_flush: Use faster kernel to flush L2 cache between measurements + :type fast_flush: bool, default is True + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", or "median". Default is "mean". + :type return_mode: str + """ + assert return_mode in ["min", "max", "mean", "median"] + import torch + + torch.cuda.synchronize() + ret, outs = fn(is_testing=True) + if ret != hipError_t.hipSuccess: + # print(f'{ret=}', file=sys.stderr, flush=True) + return float('inf') + torch.cuda.synchronize() + valret = validator(*outs) + # print(f'{valret=}', flush=True) + if not valret: + # assert False + return float('inf') + torch.cuda.synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 cache + # doesn't contain any input data before the run + cache_size = 256 * 1024 * 1024 + if fast_flush: + cache = torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') + else: + cache = torch.empty(int(cache_size), dtype=torch.int8, device='cuda') + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + torch.cuda.synchronize() + times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + return getattr(torch, return_mode)(times).item() + +AutotuneResult = namedtuple('AutotuneResult', ['kernel_index', 'time', 'psels', 'copts']) + +# cannot be maxint in case extargs.total_number_of_kernels never returns a positive number +CPP_AUTOTUNE_MAX_KERNELS = 100 + +def cpp_autotune(extarg_klass, kernel_func, validator, *, tqdm_position=None, tqdm_prefix=''): + assert validator is not None + kernel_index = 0 + extargs = extarg_klass() + timings = [] + pbar = None + failed = 0 + success = 0 + total_number_of_kernels = CPP_AUTOTUNE_MAX_KERNELS + while True: + extargs.force_kernel_index = kernel_index + def func(is_testing=False): + return kernel_func(extargs, is_testing) + # t = do_bench(func, validator=validator, quantiles=(0.5, 0.2, 0.8)) + t = do_bench(func, validator=validator) + ''' + if kernel_index == 0: + print(f'Benchmarking with {kernel_index=}. Time {t}') + else: + print(f'Benchmarking with {kernel_index=} out of {extargs.total_number_of_kernels}. Time {t}') + ''' + # assert extargs.total_number_of_kernels > 0 + if math.isinf(t): + failed += 1 + else: + if extargs.total_number_of_kernels > 0: + assert extargs.total_number_of_kernels <= CPP_AUTOTUNE_MAX_KERNELS + total_number_of_kernels = extargs.total_number_of_kernels + success += 1 + r = AutotuneResult(kernel_index=kernel_index, + time=t, + psels=json.loads(extargs.selected_kernel_psels), + copts=json.loads(extargs.selected_kernel_copts)) + timings.append(r) + + if pbar is None and extargs.total_number_of_kernels > 0: + pbar = tqdm(total=extargs.total_number_of_kernels, unit="configs", position=tqdm_position) + pbar.set_description(f'{tqdm_prefix} Success {success}, Fail {failed}. Last time {t:.2g}') + if pbar is not None: + pbar.set_description(f'{tqdm_prefix} Success {success}, Fail {failed}. Last time {t:.2g}') + pbar.update(1) + + # print(f'{r.psels=}') + # print(f'{r.copts=}') + kernel_index += 1 + if kernel_index >= total_number_of_kernels: + break + # print(f'cpp_autotune {timings=}') + ret = min(timings, key=lambda atr:atr.time) + # print(f'{ret=}') + if math.isinf(ret.time): + # with open("/proc/self/maps") as f: + # print(f.read(), file=sys.stderr) + print("ERROR: No configuration works") + return ret diff --git a/test/performance_forward.py b/test/performance_forward.py index 32fbf7a2..8dbb856b 100644 --- a/test/performance_forward.py +++ b/test/performance_forward.py @@ -6,7 +6,7 @@ import torch import triton -from attn_torch_function import attention +from attn_torch_function import attention, AttentionExtraArgs try: from flash_attn.flash_attn_interface import \ @@ -62,10 +62,12 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) sm_scale = 1.3 - autotune = False - return_encoded_softmax = False + dropout_p = 0.0 b = None - fn = lambda: attention(q, k, v, b, causal, sm_scale, split_kernel, return_encoded_softmax, autotune) + ext = AttentionExtraArgs(return_encoded_softmax=causal, + autotune=False, + return_autotune=False) + fn = lambda: attention(q, k, v, b, causal, sm_scale, dropout_p, ext) if mode == 'bwd': o = fn() do = torch.randn_like(o) diff --git a/test/rocm_arch.py b/test/rocm_arch.py new file mode 120000 index 00000000..e7c4a439 --- /dev/null +++ b/test/rocm_arch.py @@ -0,0 +1 @@ +../tritonsrc/rocm_arch.py \ No newline at end of file diff --git a/test/test_backward.py b/test/test_backward.py index 2fd4aee7..5b9d6e03 100644 --- a/test/test_backward.py +++ b/test/test_backward.py @@ -4,8 +4,15 @@ import pytest import torch +import os -from attn_torch_function import attention +from attn_torch_function import ( + DEFAULT_PHILOX_SEED, + DEFAULT_PHILOX_OFFSET, + attention, + debug_fill_dropout_rng, + AttentionExtraArgs +) from _common_test import SdpaContext, SdpaParams def _make_block_eyes(q, base=1.0, inc=0.0): @@ -54,11 +61,13 @@ def _do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale bias_type=bias_type, storage_flip=transpose, device='cuda') ctx.create_ref_inputs() ctx.set_require_grads(skip_dq=SKIP_DQ, skip_dk_dv=SKIP_DK_DV, skip_db=SKIP_DB) - return_encoded_softmax = True q, k, v, b = ctx.dev_tensors # autotune = True # # triton implementation - tri_out, encoded_softmax, _ = attention(q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax, USE_AUTOTUNE) + ext = AttentionExtraArgs(return_encoded_softmax=True, + autotune=False, + return_autotune=False) + tri_out, encoded_softmax, _ = attention(q, k, v, b, causal, sm_scale, dropout_p, ext) dropout_mask = encoded_softmax >= 0 sdpa_params = SdpaParams(causal=causal, sm_scale=sm_scale, dropout_p=dropout_p, dropout_mask=dropout_mask) ref_out, _ = ctx.compute_ref_forward(sdpa_params) @@ -211,6 +220,24 @@ def test_op_bwd_with_matrix_bias(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, sm_ ''' _do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) +def main2(): + # Memo: False-0.0-dtype0-0.0-False-4-256-8-4-1 + # Memo: False-0.0-dtype0-0.0-False-4-256-8-1-4 + # False-1.2-dtype0-0.0-False-4-4-72-1-4 + BATCH = 8 + D_HEAD = 32 + N_HEADS = 8 + seqlen_q = 16 + seqlen_k = 16 + causal = False + + sm_scale = 1.2 + dropout_p = 0.0 + dtype = torch.float16 + storage_flip = False + bias_type = None + _do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) + def main(): BATCH = 1 D_HEAD = 80 @@ -231,4 +258,4 @@ def main(): _do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) if __name__ == '__main__': - main() + main2() diff --git a/test/test_forward.py b/test/test_forward.py index 451779ad..6a39121a 100644 --- a/test/test_forward.py +++ b/test/test_forward.py @@ -4,61 +4,16 @@ import pytest import torch +import os -from attn_torch_function import attention, debug_fill_dropout_rng, DEFAULT_PHILOX_SEED, DEFAULT_PHILOX_OFFSET - -def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: - # Efficient implementation equivalent to the following: - L, S = query.size(-2), key.size(-2) - scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale - attn_bias = torch.zeros(L, S, dtype=query.dtype) - if is_causal: - assert attn_mask is None - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) - - """ - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) - else: - attn_bias += attn_mask - """ - attn_weight = query @ key.transpose(-2, -1) * scale_factor - SPARSE_HEAD_SINCE = 5 - SPARSE_SEQ_SINCE = 5 - print(f'{query=}') - print(f'{key=}') - print(f'BEFORE softmax {attn_weight[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - # attn_weight += attn_bias - attn_weight = torch.softmax(attn_weight, dim=-1) - print(f'BEFORE DROPOUT_MASK {attn_weight[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - if dropout_p > 0.0: - if dropout_mask is not None: - print(f'BEFORE DROPOUT_MASK {attn_weight[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - attn_weight.masked_fill_(dropout_mask.logical_not(), float("0.0")) - print(f'AFTER DROPOUT_MASK {attn_weight[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - value = value / (1 - dropout_p) - else: - # assert False, "TESTING dropout_mask code path" - attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - else: - # assert False, "TESTING dropout_mask code path" - pass - print(f'{value[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - av = attn_weight @ value - print(f'{av[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - return av, attn_weight - -def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dtype: torch.dtype = None, device=None): - """ Clones the query, key, and value tensors and moves them to the specified dtype. """ - if dtype is None: - dtype = query.dtype - query_ref = query.clone().detach().to(dtype=dtype, device=device).requires_grad_(query.requires_grad) - key_ref = key.clone().detach().to(dtype=dtype, device=device).requires_grad_(key.requires_grad) - value_ref = value.clone().detach().to(dtype=dtype, device=device).requires_grad_(value.requires_grad) - return query_ref, key_ref, value_ref +from attn_torch_function import ( + DEFAULT_PHILOX_SEED, + DEFAULT_PHILOX_OFFSET, + attention, + debug_fill_dropout_rng, + AttentionExtraArgs +) +from _common_test import SdpaContext, SdpaParams ''' Flash Attention is batch operator that evaluates sm(QK')V @@ -74,243 +29,172 @@ def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch. but in PyTorch API it does not present at all ''' -class FwdTester(object): - - def __init__(self): - self.use_fill_rng_for_dropout = False - - def do_test_op_fwd(self, BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type,): - if causal and seqlen_q != seqlen_k: - pytest.skip("PyTorch's Flash V2 does not accept casual=True when seqlen_q != seqlen_k. Skipping") - if causal and bias_type is not None: - pytest.skip("_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True") - torch.manual_seed(20) - print(f"test_op_fwd {BATCH=}, {N_HEADS=}, {seqlen_q=}, {seqlen_k=}, {D_HEAD=}, {causal=}") - SPARSE_HEAD_SINCE = 3 - SPARSE_SEQ_SINCE = 3 - Z = BATCH - H = N_HEADS - if True: # Real UT - qdims = (BATCH, N_HEADS, seqlen_q, D_HEAD) - kdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) - vdims = (BATCH, N_HEADS, seqlen_k, D_HEAD) - bdims = (BATCH, N_HEADS, seqlen_q, seqlen_k) - if storage_flip: - qdims = (qdims[0], qdims[2], qdims[1], qdims[3]) - kdims = (kdims[0], kdims[2], kdims[1], kdims[3]) - vdims = (vdims[0], vdims[2], vdims[1], vdims[3]) - bdims = (bdims[0], bdims[2], bdims[1], bdims[3]) - q = ( - torch.empty(qdims, dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) - k = ( - torch.empty(kdims, dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) - v = ( - torch.empty(vdims, dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) - if bias_type is None: - b = None - elif bias_type == 'matrix': - b = torch.empty(bdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5) - else: - assert False, f'Unsupported bias_type {bias_type}' - if storage_flip: - q = torch.transpose(q, 1, 2) - k = torch.transpose(k, 1, 2) - v = torch.transpose(v, 1, 2) - if b is not None: - b = torch.transpose(b, 1, 2) - assert q.shape == (BATCH, N_HEADS, seqlen_q, D_HEAD) - assert k.shape == (BATCH, N_HEADS, seqlen_k, D_HEAD) - assert v.shape == (BATCH, N_HEADS, seqlen_k, D_HEAD) - if False: # Debugging - q = ( - torch.empty((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0., std=0.5) - .requires_grad_() - ) - k = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 1.0 - v = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 1.0 - if False: - q = torch.ones((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda") * 1.0 - k = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 2.0 - v = torch.ones((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda") * 3.0 - if False: - import numpy as np - q = torch.arange(np.prod([Z, H, seqlen_q, D_HEAD]), dtype=dtype, device="cuda").reshape((Z, H, seqlen_q, D_HEAD)) - k = torch.arange(np.prod([Z, H, seqlen_k, D_HEAD]), dtype=dtype, device="cuda").reshape((Z, H, seqlen_q, D_HEAD)) - v = torch.arange(np.prod([Z, H, seqlen_k, D_HEAD]), dtype=dtype, device="cuda").reshape((Z, H, seqlen_q, D_HEAD)) - q = (q - 128.0) * 0.01 - k = (k - 128.0) * 0.01 - v = (v - 128.0) * 0.01 - q[:, :, :, SPARSE_HEAD_SINCE: ] = 0.0 - k[:, :, :, SPARSE_HEAD_SINCE: ] = 0.0 - v[:, :, :, SPARSE_HEAD_SINCE: ] = 0.0 - q[:, :, SPARSE_SEQ_SINCE:, : ] = 0.0 - k[:, :, SPARSE_SEQ_SINCE:, : ] = 0.0 - v[:, :, SPARSE_SEQ_SINCE:, : ] = 0.0 +def _do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type): + if causal and seqlen_q != seqlen_k: + pytest.skip("PyTorch's Flash V2 does not accept casual=True when seqlen_q != seqlen_k. Skipping") + if causal and bias_type is not None: + pytest.skip("_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True") + # if BATCH > 1 and seqlen_q >= 1024 and seqlen_k >= 1024: + # torch.cuda.empty_cache() + SKIP_DK_DV = True + SKIP_DQ = True + SKIP_DB = True if bias_type is None else False + USE_AUTOTUNE = False + torch.manual_seed(20) + SPARSE_HEAD_SINCE = 1 + SPARSE_SEQ_SINCE = 1 + transpose = (1, 2) if storage_flip else None + ctx = SdpaContext(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, dtype, + bias_type=bias_type, storage_flip=transpose, device='cuda') + ctx.create_ref_inputs() + ctx.set_require_grads(skip_dq=SKIP_DQ, skip_dk_dv=SKIP_DK_DV, skip_db=SKIP_DB) + q, k, v, b = ctx.dev_tensors + # autotune = True + # # triton implementation + ext = AttentionExtraArgs(return_encoded_softmax=causal, + autotune=False, + return_autotune=False) + tri_out, encoded_softmax, _ = attention(q, k, v, b, causal, sm_scale, dropout_p, ext) + dropout_mask = encoded_softmax >= 0 if causal else None + sdpa_params = SdpaParams(causal=causal, sm_scale=sm_scale, dropout_p=dropout_p, dropout_mask=dropout_mask) + ref_out, _ = ctx.compute_ref_forward(sdpa_params) - ''' - dout = torch.randn_like(q) - # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - if causal: - p[:, :, M == 0] = float("-inf") - p = torch.softmax(p.float(), dim=-1).half() - ref_out = torch.matmul(p, v) - ''' - return_encoded_softmax = dropout_p > 0.0 and not self.use_fill_rng_for_dropout - # return_encoded_softmax = dropout_p > 0.0 # Reserved for debugging use_fill_rng_for_dropout - higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 - REF_DEVICE=None - q_ref, k_ref, v_ref = query_key_value_clones(q, k, v, dtype=higher_precision_dtype, device=REF_DEVICE) - def TO(ref_tensor): - return ref_tensor.to(device=q.device, dtype=dtype) - autotune = False - return_autotune = False - tri_out, encoded_softmax, _ = attention(q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax, - autotune, return_autotune) + is_allclose, adiff, _, _ = ctx.validate_with_reference(tri_out, None, no_backward=True) + if not is_allclose: + import numpy as np + err_idx = np.unravel_index(torch.argmax(torch.abs(ref_out.to(device=tri_out.device) - tri_out)).cpu().numpy(), ref_out.shape) + print(f'{err_idx=}') + print(f'{tri_out[err_idx]=}') + print(f'{ref_out[err_idx]=}') + assert is_allclose, 'Forward pass {is_allclose=}' + print(f'{adiff=}') - if self.use_fill_rng_for_dropout: - rdims = (BATCH, N_HEADS, seqlen_q, seqlen_k) - if storage_flip: - rdims = (rdims[0], rdims[2], rdims[1], rdims[3]) - r = torch.empty(rdims, device=q.device, dtype=torch.float32) - if storage_flip: - r = torch.transpose(r, 1, 2) - philox_seed = DEFAULT_PHILOX_SEED - philox_offset = DEFAULT_PHILOX_OFFSET - debug_fill_dropout_rng(r, philox_seed, philox_offset) - # Reserved for debugging use_fill_rng_for_dropout - # print(f'{r[0,0,:16, :16]}=') - # print(f'{r[0,0,:16, :16] > dropout_p}=') - # print(f'{encoded_softmax[0,0,:16, :16] > 0}=') - dropout_mask = r > dropout_p - else: - dropout_mask = encoded_softmax > 0 if encoded_softmax is not None else None - # assert torch.allclose(dropout_mask, dropout_mask_naive) - ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q_ref, k_ref, v_ref, - dropout_p=dropout_p, - is_causal=causal, - attn_mask=b, - scale=sm_scale, - dropout_mask=dropout_mask) - if False: - mref_out, mref_softmax = scaled_dot_product_attention(q, k, v, - dropout_p=dropout_p, - is_causal=causal, - scale=sm_scale, - dropout_mask=dropout_mask) - print(f'{tri_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - print(f'{ref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - print(f'{mref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - print(f'{tri_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]/ref_out[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - print(f'{q.shape=} {q.stride()=}') - print(f'{k.shape=} {k.stride()=}') - print(f'{v.shape=} {v.stride()=}') - print(f'{encoded_softmax=}') - if encoded_softmax is not None: - print(f'{encoded_softmax.shape=} {encoded_softmax.stride()=}') - print(f'{encoded_softmax[:,:, :SPARSE_SEQ_SINCE, :SPARSE_SEQ_SINCE]=}') - print(f'{dropout_mask.shape=} {dropout_mask.stride()=}') - print(f'{dropout_mask[:,:, :SPARSE_SEQ_SINCE, :SPARSE_HEAD_SINCE]=}') - if dtype==torch.bfloat16: - ATOL = 1e-1 * max(1.0, (seqlen_q + seqlen_k + D_HEAD) / 128.0) - else: - ATOL = 1e-2 * max(1.0, (seqlen_q + seqlen_k + D_HEAD) / 128.0) - RTOL = 0.0 - print(f'Using ATOL={ATOL} RTOL={RTOL}') - is_allclose = torch.allclose(TO(ref_out), tri_out, atol=ATOL, rtol=RTOL) - if not is_allclose: - import numpy as np - err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_out) - tri_out)).cpu().numpy(), ref_out.shape) - print(f'{err_idx=}') - print(f'{tri_out[err_idx]=} {ref_out[err_idx]=} error: {tri_out[err_idx] - ref_out[err_idx]}') - # if not is_allclose: - if False: - import numpy as np - err_idx = np.unravel_index(torch.argmax(torch.abs(TO(ref_out) - tri_out)).cpu().numpy(), ref_out.shape) - print(f'{tri_out[0][0][0][:]=}') - print(f'{ref_out[0][0][0][:]=}') - print(f'{mref_out[0][0][0][:]=}') - if encoded_softmax is not None: - print(f'{encoded_softmax[0][0][0][:]=}') - print(f'{ref_softmax[0][0][0][:]=}') - print(f'{tri_out[-1][0][0][:]=}') - print(f'{ref_out[-1][0][0][:]=}') - print(f'{err_idx=}') - print(f'{tri_out[err_idx]=}') - print(f'{ref_out[err_idx]=}') - if dropout_p > 0: - # print(f'{unmasked_ref_out[0][0][0][:]=}') - print(f'{dropout_mask[0][0][0][:]=}') - print(f'{dropout_mask[err_idx]=}') - # tri_cpu = tri_out[0, 0].cpu().detach().numpy() - # print(f'{tri_cpu.shape=}') - # compare - assert is_allclose - -# @pytest.mark.parametrize('BATCH', [1, 4]) -# @pytest.mark.parametrize('N_HEADS', [1, 4]) +# @pytest.mark.parametrize('BATCH', [1]) +# @pytest.mark.parametrize('N_HEADS', [1]) @pytest.mark.parametrize('BATCH', [1, 4]) @pytest.mark.parametrize('N_HEADS', [1, 4]) +# @pytest.mark.parametrize('D_HEAD', [16, 32, 64, 128, 256]) +# Irregular-only PyTorch set +# @pytest.mark.parametrize('D_HEAD', [8, 21, 72, 96, 160, 192, 203]) +# @pytest.mark.parametrize('seqlen_q', [1, 4, 32, 128, 256, 512, 1024, 7, 394, 250, 399, 511, 1019]) +# @pytest.mark.parametrize('seqlen_k', [1, 4, 32, 128, 256, 512, 1024, 3, 217, 339, 313, 491, 988]) +# PyTorch set @pytest.mark.parametrize('D_HEAD', [8, 16, 21, 32, 64, 72, 96, 128, 160, 192, 203, 256]) -# @pytest.mark.parametrize('seqlen_q', [16,32,64,128,256,512,1024]) -# @pytest.mark.parametrize('seqlen_k', [16,32,64,128,256,512,1024]) @pytest.mark.parametrize('seqlen_q', [4, 8, 64, 143, 256, 512, 1024, 2048]) @pytest.mark.parametrize('seqlen_k', [4, 8, 64, 128, 256, 587, 1024, 2048]) +# Minimal set # @pytest.mark.parametrize('seqlen_q', [32, 128]) # @pytest.mark.parametrize('seqlen_k', [32, 128]) @pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('dropout_p', [0.0, 0.5]) # @pytest.mark.parametrize('dropout_p', [0.0]) -@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16, torch.float32]) +# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('sm_scale', [0.0, 1.2]) @pytest.mark.parametrize('storage_flip', [False, True]) # @pytest.mark.parametrize('return_encoded_softmax', [False]) def test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip): bias_type = None - tester = FwdTester() - tester.do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) + _do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) +# @pytest.mark.parametrize('BATCH', [1, 4]) +# @pytest.mark.parametrize('N_HEADS', [1, 4]) @pytest.mark.parametrize('BATCH', [1, 4]) @pytest.mark.parametrize('N_HEADS', [1, 4]) @pytest.mark.parametrize('D_HEAD', [16,32,64,128,256]) +# @pytest.mark.parametrize('D_HEAD', [128]) +# Complete set +# @pytest.mark.parametrize('seqlen_q', [4,8,16,17,32,64,128,143,256,512,1024,2048]) +# @pytest.mark.parametrize('seqlen_k', [4,8,16,23,32,64,128,256,512,587,1024,2048]) +# PyTorch set @pytest.mark.parametrize('seqlen_q', [4, 8, 64, 143, 256, 512, 1024, 2048]) @pytest.mark.parametrize('seqlen_k', [4, 8, 64, 128, 256, 587, 1024, 2048]) +# @pytest.mark.parametrize('seqlen_q', [128,256,512,1024]) +# @pytest.mark.parametrize('seqlen_k', [128,256,512,1024]) +# @pytest.mark.parametrize('seqlen_q', [128, 113]) +# @pytest.mark.parametrize('seqlen_k', [128, 79]) @pytest.mark.parametrize('dropout_p', [0.0, 0.5]) -@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16, torch.float32]) +# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('sm_scale', [0.0, 1.2]) @pytest.mark.parametrize('storage_flip', [False, True]) +# @pytest.mark.parametrize('return_encoded_softmax', [False]) def test_op_fwd_with_matrix_bias(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, sm_scale, dropout_p, dtype, storage_flip): causal = False bias_type = 'matrix' - tester = FwdTester() ''' _scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True ''' - tester.do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) + _do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) -@pytest.mark.parametrize('BATCH', [1, 4]) -@pytest.mark.parametrize('N_HEADS', [1, 4]) -@pytest.mark.parametrize('seqlen_q', [4, 8, 64, 143, 256, 512, 1024, 2048]) -@pytest.mark.parametrize('seqlen_k', [4, 8, 64, 128, 256, 587, 1024, 2048]) -@pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('storage_flip', [False, True]) -def test_fill_dropout_rng(BATCH, N_HEADS, seqlen_q, seqlen_k, causal, storage_flip): - D_HEAD = 128 - dropout_p = 0.5 +dtype0 = torch.float16 +dtype1 = torch.bfloat16 +dtype2 = torch.float32 + +# Testing test_op_fwd_with_matrix_bias from string +def main4(): + # utshort = 'False-1.2-dtype0-0.0-587-2048-32-1-1' + utshort = 'False-1.2-dtype0-0.0-4-2048-32-1-1' + # utshort = 'False-1.2-dtype0-0.0-4-1024-32-1-1' + utlist_str = list(reversed(utshort.split('-'))) + BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, dropout_p, dtype, sm_scale, storage_flip = [eval(e) for e in utlist_str] + causal = False + bias_type = 'matrix' + _do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) + +def main3(): + # utshort = 'False-1.2-dtype0-0.0-4-2048-32-1-1' + # utshort = 'False-1.2-dtype0-0.0-4-1024-32-1-1' + utlist_str = list(reversed(utshort.split('-'))) + BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, dropout_p, dtype, sm_scale, storage_flip = [eval(e) for e in utlist_str] + bias_type = None + _do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) + +def main2(): + # False-1.2-dtype0-0.0-587-2048-32-1-1 + # Memo: False-0.0-dtype0-0.0-False-4-256-8-4-1 + # Memo: False-0.0-dtype0-0.0-False-4-256-8-1-4 + # False-1.2-dtype0-0.0-False-4-4-72-1-4 + # BATCH = 1 + # D_HEAD = 32 + # N_HEADS = 4 + # seqlen_q = 16 + # seqlen_k = 16 + # causal = False + + BATCH = 2 + D_HEAD = 4 + N_HEADS = 1 + seqlen_q = 8 + seqlen_k = 8 + causal = False + + sm_scale = 1.2 + dropout_p = 0.0 dtype = torch.float16 + storage_flip = False + bias_type = None + _do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) + +def main(): + BATCH = 1 + D_HEAD = 80 + N_HEADS = 2 + seqlen_q = 6432 + seqlen_k = 6432 + ''' + N_HEADS = 6432 + seqlen_q = 2 + seqlen_k = 2 + ''' + causal = False sm_scale = 1.2 + dropout_p = 0.0 + dtype = torch.bfloat16 + storage_flip = False bias_type = None - tester = FwdTester() - tester.use_fill_rng_for_dropout = True - tester.do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) + _do_test_op_fwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) + +if __name__ == '__main__': + main4() diff --git a/test/tune_flash.py b/test/tune_flash.py new file mode 100644 index 00000000..b193a8ed --- /dev/null +++ b/test/tune_flash.py @@ -0,0 +1,494 @@ +#!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +import os +# FIXME: Should set PYTORCH_NO_HIP_MEMORY_CACHING=1 as well but need to wait for +# https://github.com/pytorch/pytorch/issues/114534 +os.environ['HSA_SVM_GUARD_PAGES'] = '1' +os.environ['HSA_DISABLE_FRAGMENT_ALLOCATOR'] = '1' + +import pytest +import torch +import json +import sys +import subprocess +import queue +import multiprocessing +from multiprocessing import Process, Queue +import argparse +import itertools +import time +import math + +from rocm_arch import rocm_get_gpuarch +from attn_torch_function import ( + DEFAULT_PHILOX_SEED, + DEFAULT_PHILOX_OFFSET, + attention, + debug_fill_dropout_rng, + AttentionExtraArgs +) +from _common_test import SdpaContext, SdpaParams + +_DEBUG_SKIP_TUNE_BACKWARD = True + +class ArgArchVerbose: + def __init__(self, args): + self._args = args + self._arch = rocm_get_gpuarch() + + @property + def verbose(self): + return self._args.verbose + + def print(self, text): + if self.verbose: + print(text) + +class TunerWorker(ArgArchVerbose): + def __init__(self, args): + super().__init__(args) + self._tqdm_position = 0 + self._gpu_device = 'cuda' + self._cached_gpukernel_process = {} + + def profile_single_config(self, tup, *, prefix='', shard=None): + a = self._args + BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, return_encoded_softmax, dtype, bias_type = tup + head_dim_rounded = 2 ** (D_HEAD - 1).bit_length() + head_dim_rounded = max(16, head_dim_rounded) + inputs = { + 'Q_dtype': str(dtype), + 'N_HEADS': N_HEADS, + 'D_HEAD': D_HEAD, + 'max_seqlen_q': seqlen_q, + 'max_seqlen_k': seqlen_k, + 'CAUSAL': causal, + 'RETURN_ENCODED_SOFTMAX': return_encoded_softmax, + 'BLOCK_DMODEL': head_dim_rounded, + 'ENABLE_DROPOUT' : dropout_p > 0.0, + 'PADDED_HEAD' : head_dim_rounded != D_HEAD, + 'BIAS_TYPE' : bias_type, + } + if seqlen_q > 8192 and seqlen_k > 8192: + N_HEADS = 1 + if causal and seqlen_q != seqlen_k: + self.print('FA does not support accept casual=True when seqlen_q != seqlen_k. Skipping') + return 'Skip', inputs, None + if causal and bias_type != 0: + self.print('FA does not support accept casual=True when bias_type != 0. Skipping') + return 'Skip', inputs, None + if a.dry_run: + return 'Dryrun', None, None + torch.cuda.empty_cache() + ''' + Create reference dropout_mask + ''' + if dropout_p > 0.0: + rdims = (BATCH, N_HEADS, seqlen_q, seqlen_k) + r = torch.empty(rdims, device=self._gpu_device, dtype=torch.float32) + philox_seed = DEFAULT_PHILOX_SEED + philox_offset = DEFAULT_PHILOX_OFFSET + debug_fill_dropout_rng(r, philox_seed, philox_offset) + mask = r > dropout_p + torch.cuda.synchronize() + del r + else: + mask = None + torch.cuda.empty_cache() + ''' + Create SdpaContext for testing + ''' + ctx = SdpaContext(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, dtype, + bias_type=bias_type, storage_flip=None, device=self._gpu_device) + ctx.create_ref_inputs(target_gpu_device=self._gpu_device) + ctx.set_require_grads(skip_db=True) + q, k, v, b = ctx.dev_tensors + sdpa_params = SdpaParams(causal=causal, sm_scale=sm_scale, dropout_p=dropout_p, dropout_mask=mask) + ref_out, _ = ctx.compute_ref_forward(sdpa_params) + + ''' + Now, enable autotune (C++ form), enable output validation + ''' + def fwd_validator(tri_out): + is_allclose, adiff, _, _ = ctx.validate_with_reference(tri_out, None, no_backward=True) + ''' + if not is_allclose: + import numpy as np + err_idx = np.unravel_index(torch.argmax(torch.abs(ref_out - tri_out)).cpu().numpy(), ref_out.shape) + print(f'{err_idx=}') + print(f'{tri_out[err_idx]=}') + print(f'{ref_out[err_idx]=}') + ''' + return is_allclose + + ext = AttentionExtraArgs(return_encoded_softmax=return_encoded_softmax, + autotune=True, + return_autotune=True, + autotune_validator=fwd_validator, + cpp_autotune_tqdm_position=self._tqdm_position, + cpp_autotune_tqdm_prefix=f'{prefix}{tup}', + gpu_device=shard, + tune_worker=self, + ) + tri_out, encoded_softmax, best_configs = attention(q, k, v, b, causal, sm_scale, dropout_p, ext) + if self.verbose: + print('Returned best configs') + for kernel_name, best in best_configs: + # print(f'{kernel_name=} {best.kwargs=} {best.num_warps=} {best.num_stages=}') + print(f'{kernel_name=}') + if not _DEBUG_SKIP_TUNE_BACKWARD: + dout = torch.randn_like(q) + tri_out.backward(dout) + if self.verbose: + print('Returned best configs after backward') + for kernel_name, best in best_configs: + print(f'{kernel_name=}') + return 'Success', inputs, best_configs + + def do_profile(self, dba, gen): + dry_run_counter = 0 + skip_counter = 0 + for i, tup in gen(): + action, inputs, best_configs = self.profile_single_config(tup) + if action == 'Success': + dba.pipe_configs(inputs, best_configs, _debug_task_id=i) + if action == 'Dryrun': + dry_run_counter += 1 + if action == 'Skip': + dba.pipe_skipped_configs(inputs, _debug_task_id=i) + skip_counter += 1 + dba.stop() + self.clean_cached_gpukernel_process() + print(f"Valid sample points {dry_run_counter=}. Skipped invalid/unsupported points {skip_counter}") + + def request_cached_gpukernel_process(self, target, factory): + if target not in self._cached_gpukernel_process: + self._cached_gpukernel_process[target] = factory() + return self._cached_gpukernel_process[target] + + def invalid_gpukernel_process_cache(self, target): + del self._cached_gpukernel_process[target] + + def clean_cached_gpukernel_process(self): + for k, tup in self._cached_gpukernel_process.items(): + ipc_to, ipc_from, p = tup + ipc_to.put(None) + p.join() + ipc_to.close() + ipc_from.close() + +class IPCTunerWorker(TunerWorker): + END_OF_QUEUE_OBJECT = (-1, None) + + ''' + Initialize multiprocessing related variables + ''' + def init_mp(self, shard): + self._shard = shard + self._tqdm_position = shard + self._gpu_device = f'cuda:{shard}' + self._cached_gpukernel_process = {} + + def clean_mp(self, shard): + self.clean_cached_gpukernel_process() + + def do_profile(self, ipc_read, ipc_write): + a = self._args + shard, total_shards = ipc_read.get() + print(f'{shard=} {total_shards=}') + shard_prefix= '' if shard is None else f'[Shard {shard:02d}/{total_shards:02d}]' + self.init_mp(shard) + with torch.cuda.device(shard): + while True: + try: + i, tup = ipc_read.get() + # print(f'ipc_read {i} {tup}') + if i == -1 and tup is None: + break + prefix = shard_prefix + f'[{i:06d}]' + action, inputs, best_configs = self.profile_single_config(tup, + prefix=prefix, + shard=self._shard) + ipc_write.put((i, action, inputs, best_configs)) + except ValueError: # mp.Queue closed + break + self.clean_mp(shard) + ''' + with torch.cuda.device(shard): + for i, tup in enumerate(self.gen()): + if i % total_shards != shard: + continue + print(f"{shard_prefix}[{i:06d}] Handling {tup}") + if a.continue_from is not None and i < a.continue_from: + continue + if a.stop_at is not None and i > a.stop_at: + break + if a.dry_run: + continue + action, inputs, best_configs = self.profile_single_config(tup) + if action == 'Success': + ipc_write.put((inputs, best_configs)) + ipc_write.put((None, shard)) + ''' + +class DbAccessor(ArgArchVerbose): + KERNEL_FAMILY = 'FLASH' + END_OF_QUEUE_OBJECT = (-1, None, None) + + def create_dbp(self): + a = self._args + if a.json_file is not None and not a.dry_run: + assert a.json_file != a.db_file + self._jsonfile = open(a.json_file, 'a' if a.continue_from_json_file else 'w') + else: + self._jsonfile = None + dbargs = ['python3', '-m', 'v2python.table_tool'] + if self.verbose: + dbargs += ['-v'] + dbargs += ['-f', self._args.db_file, '-k', self.KERNEL_FAMILY] + if a.create_table_only: + dbargs += ['--action', 'createtableonly'] + self._dbp = subprocess.Popen(dbargs, + stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, + text=True) + os.set_blocking(self._dbp.stdout.fileno(), False) + os.set_blocking(self._dbp.stderr.fileno(), False) + + def pipe_from_ipc(self, ipc_read): + self.create_dbp() + while True: + try: + i, action, inputs, best_configs = ipc_read.get() + if i == -1 and inputs is None: + print('[DbAccessor] No more tasks. Exiting') + break + if action == 'Success': + self.pipe_configs(inputs, best_configs, prefix=f'[{i:06d}]', _debug_task_id=i) + if action == 'Skip': + self.pipe_skipped_configs(inputs, prefix=f'[{i:06d} (skip)]', _debug_task_id=i) + except ValueError: # mp.Queue closed + break + self.stop() + return + + def pipe_configs(self, inputs, best_configs, *, prefix='', _debug_task_id=None): + for kernel_name, best in best_configs: + j = self.translate_config(inputs, kernel_name, best) + if _debug_task_id is not None: + j['_debug_task_id'] = _debug_task_id + js = json.dumps(j, separators=(',', ':')) + if self._jsonfile is None: + print(f'{prefix}Piping to db process {js}') + else: + print(js, file=self._jsonfile, flush=True) + print(js, file=self._dbp.stdin, flush=True) + self.splice_pipes() + + def pipe_skipped_configs(self, inputs, _debug_task_id, *, prefix=''): + skipped_result = { + 'arch' : self._arch, + 'inputs' : inputs, + '_debug_task_id' : _debug_task_id, + 'result' : 'skipped', + } + js = json.dumps(skipped_result, separators=(',', ':')) + if self._jsonfile is not None: + print(js, file=self._jsonfile, flush=True) + + def splice_pipes(self): + nattempts = 10 if self.verbose else 1 + for i in range(nattempts): + while True: + line = self._dbp.stdout.readline() + if line: + print(line, end='') + else: + if self.verbose: + time.sleep(0.1) + break + + for i in range(nattempts): + while True: + line = self._dbp.stderr.readline() + if line: + print(line, end='', file=sys.stderr) + else: + if self.verbose: + time.sleep(0.1) + break + sys.stdout.flush() + sys.stderr.flush() + + def translate_config(self, inputs, kernel_name, best): + tuning_result = { + 'arch' : self._arch, + 'kernel_name' : kernel_name, + 'inputs' : inputs, + 'result' : 'tuned', + 'tuned_kernel' : best.psels, + 'compiler_options' : best.copts, + } + return tuning_result + + def stop(self): + if self._jsonfile is not None: + self._jsonfile.close() + self._dbp.stdin.close() + print("Waiting for database process to terminate") + self._dbp.wait() + self.splice_pipes() + +class TunerManager(ArgArchVerbose): + + def gen(self): + a = self._args + yield from itertools.product(a.batch, a.n_heads, a.d_head, a.seqlen_q, a.seqlen_k, a.causal, a.sm_scale, a.dropout_p, a.return_encoded_softmax, a.dtype, a.bias_type) + + def gen_itup(self): + a = self._args + skip_set = set() + if a.continue_from_json_file and a.json_file is not None: + with open(a.json_file, 'r') as f: + for line in f.readlines(): + j = json.loads(line) + skip_set.add(j['_debug_task_id']) + for i, tup in enumerate(self.gen()): + # print(f"[{i:06d}] gen_itup {tup}") + if a.continue_from is not None and i < a.continue_from: + continue + if i in skip_set: + continue + if a.stop_at is not None and i > a.stop_at: + break + yield i, tup + + def profile_all(self): + a = self._args + dba = DbAccessor(a) + if a.use_multigpu is None: + dba.create_dbp() + worker = TunerWorker(a) + worker.do_profile(dba, self.gen_itup) + return + shards = list([i for i in range(torch.cuda.device_count())]) if -1 in a.use_multigpu else a.use_multigpu + ipc_write = Queue() + ipc_worker_out = Queue() + ipc_tuners = [IPCTunerWorker(self._args) for i in shards] + workers = [Process(target=worker.do_profile, args=(ipc_write, ipc_worker_out)) for worker in ipc_tuners] + db_accessor = Process(target=dba.pipe_from_ipc, args=(ipc_worker_out,)) + + ''' + Start processes + ''' + nlive_processes = len(workers) + for i, p in enumerate(workers): + ipc_write.put((i, nlive_processes)) + for p in workers: + p.start() + db_accessor.start() + ''' + Dispatching tasks to ipc_write + ''' + for i, tup in self.gen_itup(): + obj = (i, tup) + # print(f"write_to_ipc {obj}") + any_process_alive = self.write_to_ipc(ipc_write, obj, workers) + if not any_process_alive: + break + nlive_processes = self.scan_live_processes(workers) + for i in range(nlive_processes): + self.write_to_ipc(ipc_write, IPCTunerWorker.END_OF_QUEUE_OBJECT, workers) + ipc_write.close() + """ + while nlive_processes > 0: + try: + inputs, best_configs = ipc_worker_out.get(timeout=30) + # print(f'{inputs=}') + # print(f'{best_configs=}') + if inputs is None: + shard = best_configs + nlive_processes -= 1 + print(f'Shard {shard} has completed all tasks. Updated {nlive_processes=}') + continue + self.pipe_configs(inputs, best_configs) + except queue.Empty: + print("Timed out. Re-scan live processes") + # "watchdog" + """ + for p in workers: + p.join() + ipc_write.close() + print('All workers joined') + ipc_worker_out.put(DbAccessor.END_OF_QUEUE_OBJECT) + ipc_worker_out.close() + db_accessor.join() + print('Db accessor joined') + # Otherwise current process may block if any child died + ipc_write.cancel_join_thread() + ipc_worker_out.cancel_join_thread() + + def write_to_ipc(self, ipc_write, obj, workers): + while True: + try: + ipc_write.put(obj, timeout=60) + return True + except queue.Full: + print("Task Queue Full. Re-scan live processes") + nlive_processes = self.scan_live_processes(workers) + print(f"{nlive_processes=}") + if nlive_processes == 0: + print("PANIC: All Processes Died") + return False + + def scan_live_processes(self, workers): + nlive_processes = 0 + for i, p in enumerate(workers): + nlive_processes += 1 if p.is_alive() else 0 + return nlive_processes + +def parse(): + p = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + p.add_argument('--batch', type=int, nargs=1, default=[1], help='(Not a functional) Batch size.') + p.add_argument('--n_heads', type=int, nargs=1, default=[12], help='(Not a functional) Number of heads') + p.add_argument('--sm_scale', type=float, nargs=1, default=[1.2], help='(Not a functional) Softmax Scale') + p.add_argument('--return_encoded_softmax', type=bool, default=[False], + help="(A functional for debugging) kernel that returns softmax(dropout(QK')) to validate the correctness of dropout") + p.add_argument('--d_head', type=int, nargs='+', default=[16,32,64,128,256], help='Head dimensions.') + p.add_argument('--seqlen_q', type=int, nargs='+', default=[4,8,16,32,64,128,256,1024,2048,4096,8192], help='Sequence length of Q.') + p.add_argument('--seqlen_k', type=int, nargs='+', default=[4,8,16,32,64,128,256,1024,2048,4096,8192], help='Sequence length of K/V.') + p.add_argument('--causal', type=int, nargs='+', default=[True,False], choices=[0, 1], help='Causal mask. (Use 0/1 for False/True') + p.add_argument('--dropout_p', type=float, nargs='+', default=[0.5, 0.0], help='Probablity to dropout (0 to disable).') + p.add_argument('--dtype', type=str, nargs='+', + default=['float16', 'bfloat16', 'float32'], + choices=['float16', 'bfloat16', 'float32'], + help='Datatype to profile.') + p.add_argument('--bias_type', type=int, nargs='+', default=[0, 1], choices=[0, 1], help='Bias types to profile, 0: None, 1: Matrix.') + p.add_argument('--verbose', action='store_true', help='Verbose') + p.add_argument('--validate', + action='store_true', help='Validate the correctness of the output to avoid faulty autotune configs') + p.add_argument('--dry_run', action='store_true', help="Print parameter combinations without running tests") + p.add_argument('--continue_from', type=int, default=None, help="Continue from n-th functional set") + p.add_argument('--stop_at', type=int, default=None, help="Stop at n-th functional set") + p.add_argument('--db_file', type=str, required=True, help="Sqlite Database file") + p.add_argument('--json_file', type=str, default=None, help="Json file for record. Disables printing json to stdout") + p.add_argument('--continue_from_json_file', action='store_true', help="Append to Json file instead of overwrite, and skip already tested entries.") + p.add_argument('--create_table_only', action='store_true', help="Do not insert data, only create tables. Used for schema updates.") + p.add_argument('--use_multigpu', type=int, nargs='+', default=None, help='Profiling on multiple GPUs. Passing -1 for all GPUs available to pytorch.') + args = p.parse_args() + args.dtype = [ getattr(torch, t) for t in args.dtype ] + args.causal = [ bool(c) for c in args.causal ] + # assert args.causal == [False], f'{args.causal=} {args.return_encoded_softmax=}' + return args + +def main(): + assert os.getenv('PYTORCH_NO_CUDA_MEMORY_CACHING', default=0) == 0, 'PYTORCH_NO_HIP_MEMORY_CACHING does not play nicely with torch.multiprocessing. See https://github.com/pytorch/pytorch/issues/114534' + torch.multiprocessing.set_start_method('spawn', force=True) # Otherwise torch complains + # multiprocessing.set_start_method('spawn') # "context has already been set" + args = parse() + tuner = TunerManager(args) + tuner.profile_all() + +if __name__ == '__main__': + main() diff --git a/third_party/triton b/third_party/triton index b5f3d15d..f471ba9e 160000 --- a/third_party/triton +++ b/third_party/triton @@ -1 +1 @@ -Subproject commit b5f3d15da4a191fc81ade329e5f60c7a1118f921 +Subproject commit f471ba9e3dd672b948dda9dc0b8a7658e6f39e95 diff --git a/tritonsrc/_common_test.py b/tritonsrc/_common_test.py index 69fadc55..1a762d70 100644 --- a/tritonsrc/_common_test.py +++ b/tritonsrc/_common_test.py @@ -101,9 +101,9 @@ def __init__(self, BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, dtype, q = torch.rand(*qdims, dtype=dtype, device=device) k = torch.rand(*kdims, dtype=dtype, device=device) v = torch.rand(*vdims, dtype=dtype, device=device) - if bias_type is None: + if bias_type is None or bias_type == 0: b = None - elif bias_type == 'matrix': + elif bias_type == 'matrix' or bias_type == 1: # b = torch.empty(bdims, dtype=dtype, device="cuda").normal_(mean=0., std=0.5) b = torch.rand(*bdims, dtype=dtype, device=device) b = b.expand(BATCH, N_HEADS, b.shape[0], b.shape[1]) @@ -149,7 +149,7 @@ def clone_tensor(t, dtype, device=None): def clone_tensor_tuple(in_tensors, dtype, device=None): return tuple([SdpaContext.clone_tensor(t, dtype=dtype, device=device) for t in in_tensors]) - def create_ref_inputs(self): + def create_ref_inputs(self, target_gpu_device='cuda'): ref_device_option = os.getenv('AOTRITON_REF_DEVICE_OPTION', default='default') if ref_device_option == 'default': seqlen_k = self.seqlen_k @@ -161,9 +161,9 @@ def create_ref_inputs(self): if seqlen_k == 587: ref_device = 'cpu' else: - ref_device = 'cuda' + ref_device = target_gpu_device elif ref_device_option == 'cuda': - ref_device = 'cuda' + ref_device = target_gpu_device elif ref_device_option == 'cpu': ref_device = 'cpu' else: @@ -257,8 +257,10 @@ def _validate(out, ref, lp_ref, fudge_factor, tname): max_adiff = float(torch.max(torch.abs(x - y))) return torch.allclose(x, y, atol=atol, rtol=rtol), max_adiff - def validate_with_reference(self, out, grads): + def validate_with_reference(self, out, grads, *, no_backward=False): out_allclose, out_adiff = self._validate(out, self.refout_tensors[0], self.lp_refout_tensors[0], self.OUT_FUDGE_FACTOR, 'out') + if no_backward: + return out_allclose, out_adiff, [], [] grads_allclose = [] grads_adiff = [] for grad, ref, lp_ref, fudge_factor, tname in zip(grads, self.dref_tensors, self.lp_dref_tensors, self.fudge_factors, self.TENSOR_NAMES): diff --git a/tritonsrc/attn_torch_function.py b/tritonsrc/attn_torch_function.py index d74ec176..d5936b78 100644 --- a/tritonsrc/attn_torch_function.py +++ b/tritonsrc/attn_torch_function.py @@ -56,6 +56,8 @@ def tuned_attn_fwd( stride_vz, stride_vh, stride_vk, stride_vn, stride_bz, stride_bh, stride_bm, stride_bn, stride_oz, stride_oh, stride_om, stride_on, + num_head_q, + num_head_k, cu_seqlens_q, cu_seqlens_k, num_seqlens, @@ -83,6 +85,8 @@ def tuned_attn_fwd( stride_vz, stride_vh, stride_vk, stride_vn, stride_bz, stride_bh, stride_bm, stride_bn, stride_oz, stride_oh, stride_om, stride_on, + num_head_q, + num_head_k, cu_seqlens_q, cu_seqlens_k, num_seqlens, @@ -252,6 +256,8 @@ def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax head_dim_rounded = 2 ** (Lk - 1).bit_length() head_dim_rounded = max(16, head_dim_rounded) padded_head = head_dim_rounded != Lk + num_head_q = q.shape[1] + num_head_k = q.shape[2] max_seqlen_q = q.shape[2] max_seqlen_k = k.shape[2] o = torch.zeros_like(q) @@ -269,7 +275,7 @@ def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax null_tensor = torch.empty((0), device=q.device, dtype=torch.int32) M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) if return_encoded_softmax: - encoded_softmax = torch.ones((q.shape[0], q.shape[1], q.shape[2], k.shape[2]), device=q.device, dtype=_attention.DEBUG_MASK_DTYPE) * 114.514 + encoded_softmax = torch.ones((q.shape[0], q.shape[1], q.shape[2], k.shape[2]), device=q.device, dtype=q.dtype) else: encoded_softmax = None if False or VERBOSE: @@ -321,6 +327,8 @@ def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax v.stride(0), v.stride(1), v.stride(2), v.stride(3), b.stride(0), b.stride(1), b.stride(2), b.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), + num_head_q=num_head_q, + num_head_k=num_head_k, cu_seqlens_q=null_tensor, cu_seqlens_k=null_tensor, num_seqlens=0, @@ -342,6 +350,10 @@ def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax RETURN_ENCODED_SOFTMAX=encoded_softmax is not None print(f'{BLOCK_M=} {BLOCK_N=} {RETURN_ENCODED_SOFTMAX=} seqlen_q={q.shape[2]} seqlen_k={k.shape[2]}', flush=True) + print(f'{q.data_ptr()=:x} {k.data_ptr()=:x} {v.data_ptr()=:x} {b.data_ptr()=:x} {M.data_ptr()=:x} {o.data_ptr()=:x}', flush=True) + print(f'{encoded_softmax.data_ptr()=:x}', flush=True) + print(f'{q.shape=} {k.shape=} {v.shape=} {b.shape=} {M.shape=} {o.shape=}', flush=True) + print(f'{q.stride()=} {k.stride()=} {v.stride()=} {b.stride()=} {M.stride()=} {o.stride()=}', flush=True) bare_attn_fwd[grid]( q, k, v, b, sm_scale, M, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), @@ -349,6 +361,8 @@ def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax v.stride(0), v.stride(1), v.stride(2), v.stride(3), b.stride(0), b.stride(1), b.stride(2), b.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), + num_head_q=num_head_q, + num_head_k=num_head_k, cu_seqlens_q=null_tensor, cu_seqlens_k=null_tensor, num_seqlens=0, @@ -420,6 +434,8 @@ def forward(ctx, q, k, v, b, causal, sm_scale, dropout_p, return_encoded_softmax tuning_result = None block_m = min(128, q.shape[2], k.shape[2]) grid = (triton.cdiv(q.shape[2], block_m), q.shape[1], q.shape[0]) + # print(f'{M=}') + # print(f'{M.shape=}') ctx.save_for_backward(q, k, v, b, o, M) ctx.grid = grid ctx.sm_scale = sm_scale @@ -452,6 +468,8 @@ def backward(ctx, do, _, fwd_tuning_result): db = torch.empty_like(b) delta = torch.empty_like(L) null_tensor = torch.empty((0), device=q.device, dtype=torch.int32) + num_head_q = q.shape[1] + num_head_k = q.shape[2] max_seqlen_q = q.shape[2] max_seqlen_k = k.shape[2] MAX_BLOCK = 64 if ctx.dropout_p == 0 else 16 @@ -608,6 +626,10 @@ def backward(ctx, do, _, fwd_tuning_result): PADDED_HEAD=padded_head, BIAS_TYPE=ctx.bias_type, ) + print(f"{dq.stride()=}", flush=True) + print(f"{dq.data_ptr()=:x}", flush=True) + print(f"{dk.stride()=}", flush=True) + print(f"{dk.data_ptr()=:x}", flush=True) # mask_allclose = torch.allclose(debug_mask < 0, ctx.encoded_softmax < 0) if False: mask_allclose = torch.allclose(torch.abs(debug_mask), torch.abs(ctx.encoded_softmax)) # Stores QK diff --git a/tritonsrc/bwd_kernel_common.py b/tritonsrc/bwd_kernel_common.py index 1921c49a..effca286 100644 --- a/tritonsrc/bwd_kernel_common.py +++ b/tritonsrc/bwd_kernel_common.py @@ -151,8 +151,7 @@ def bwd_kernel_dk_dv_common( if BLOCK_M == 1: dv += p.to(Q_block_ptr.dtype.element_ty) * do else: - # dv += tl.dot(tl.trans(p.to(do.dtype)), do) - dv += tl.dot(tl.trans(p).to(do.dtype), do) + dv += tl.dot(tl.trans(p.to(do.dtype)), do) dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # compute dp = dot(do, vt) # dp += dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, do, vt) @@ -173,9 +172,7 @@ def bwd_kernel_dk_dv_common( DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0)) # Debug DO accessing problems if BIAS_TYPE == 1: B_block_ptr = tl.advance(B_block_ptr, (BLOCK_M, 0)) - # initialize pointers to output - tl.store(DK_block_ptr, (dk * sm_scale).to(DK_block_ptr.type.element_ty), boundary_check=(0,1)) - tl.store(DV_block_ptr, dv.to(DV_block_ptr.type.element_ty), boundary_check=(0,1)) + return (dk * sm_scale).to(DK_block_ptr.type.element_ty), dv.to(DV_block_ptr.type.element_ty) @triton.jit def bwd_kernel_dq_db_common( @@ -308,4 +305,5 @@ def bwd_kernel_dq_db_common( if BIAS_TYPE == 1: B_block_ptr = tl.advance(B_block_ptr, (0, BLOCK_N)) DB_block_ptr = tl.advance(DB_block_ptr, (0, BLOCK_N)) - tl.store(DQ_block_ptr, (dq * sm_scale).to(DQ_block_ptr.type.element_ty), boundary_check=(0,1)) + return (dq * sm_scale).to(DQ_block_ptr.type.element_ty) + # tl.store(DQ_block_ptr, (dq * sm_scale).to(DQ_block_ptr.type.element_ty), boundary_check=(0,1)) diff --git a/tritonsrc/bwd_split_kernel.py b/tritonsrc/bwd_split_kernel.py index fce7e3cf..0bf89384 100644 --- a/tritonsrc/bwd_split_kernel.py +++ b/tritonsrc/bwd_split_kernel.py @@ -17,6 +17,7 @@ import triton import triton.language as tl from bwd_kernel_common import bwd_kernel_dk_dv_common, bwd_kernel_dq_db_common +from masked_load_store import mstore2d # TODO: Remove Unused 'Out' Argument from kernels below @triton.jit @@ -34,7 +35,7 @@ def bwd_kernel_dk_dv( stride_dvz, stride_dvh, stride_dvk, stride_dvn, cu_seqlens_q, cu_seqlens_k, - num_seqlens, # set num_seqlens to zero to ignore cu_seqlens_q/k + num_seqlens : 'i32', # set num_seqlens to zero to ignore cu_seqlens_q/k max_seqlen_q, # and use max_seqlen_q/k for all seqlen_q/k max_seqlen_k, head_dim, @@ -55,6 +56,13 @@ def bwd_kernel_dk_dv( num_h = tl.num_programs(1) num_z = tl.num_programs(2) off_zh = off_z * num_h + off_h * 1 + + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + batch_index = off_z + if num_seqlens > 0: cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) @@ -65,13 +73,8 @@ def bwd_kernel_dk_dv( cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start batch_index = 0 - elif num_seqlens == 0: - cu_seqlens_q_start = 0 - cu_seqlens_k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - batch_index = off_z - else: # < 0 for padded seqlen + + if num_seqlens < 0: # for padded seqlen cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start @@ -171,7 +174,7 @@ def bwd_kernel_dk_dv( order=(1, 0) ) - bwd_kernel_dk_dv_common( + dk, dv = bwd_kernel_dk_dv_common( Q_block_ptr, KT_block_ptr, VT_block_ptr, B_block_ptr, sm_scale, DO_block_ptr, DK_block_ptr, DV_block_ptr, @@ -191,6 +194,26 @@ def bwd_kernel_dk_dv( ENABLE_DROPOUT, PADDED_HEAD, BIAS_TYPE) + mstore2d(dk, + BLOCK_N, + BLOCK_DMODEL, + o_base=DK + dk_offset, + o_start_row=start_m, + o_start_col=0, + o_rows=seqlen_k, + o_cols=head_dim, + stride_row=stride_dkn, + stride_col=stride_dkk) + mstore2d(dv, + BLOCK_N, + BLOCK_DMODEL, + o_base=DV + dv_offset, + o_start_row=start_m, + o_start_col=0, + o_rows=seqlen_k, + o_cols=head_dim, + stride_row=stride_dvk, + stride_col=stride_dvn) @triton.jit def bwd_kernel_dq( @@ -227,6 +250,13 @@ def bwd_kernel_dq( num_h = tl.num_programs(1) num_z = tl.num_programs(2) off_zh = off_z * num_h + off_h * 1 + + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + batch_index = off_z + if num_seqlens > 0: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) @@ -237,13 +267,8 @@ def bwd_kernel_dq( cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start batch_index = 0 - elif num_seqlens == 0: - cu_seqlens_q_start = 0 - cu_seqlens_k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - batch_index = off_z - else: # < 0 for padded seqlen + + if num_seqlens < 0: # for padded seqlen cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start @@ -302,7 +327,10 @@ def bwd_kernel_dq( batch_philox_offset = 0 # initialize pointers to output - dq_offset = off_h * stride_dqh + batch_index * stride_dqz + cu_seqlens_q_start * stride_dqm + dq_offset = batch_index * stride_dqz + off_h * stride_dqh + cu_seqlens_q_start * stride_dqm + # tl.device_print('batch_index ', batch_index) + # tl.device_print('off_h ', off_h) + # tl.device_print('cu_seqlens_q_start ', cu_seqlens_q_start) DQ_block_ptr = tl.make_block_ptr( base=DQ + dq_offset, shape=(seqlen_q, head_dim), @@ -339,7 +367,7 @@ def bwd_kernel_dq( else: tl.static_assert(False, f'Unsupported BIAS_TYPE {BIAS_TYPE}') - bwd_kernel_dq_db_common( + dq = bwd_kernel_dq_db_common( Q_block_ptr, K_block_ptr, V_block_ptr, B_block_ptr, sm_scale, DO_block_ptr, DQ_block_ptr, DB_block_ptr, store_db, @@ -359,3 +387,19 @@ def bwd_kernel_dq( ENABLE_DROPOUT, PADDED_HEAD, BIAS_TYPE) + dq_ptrs, dq_masks = mstore2d(dq, + BLOCK_M, + BLOCK_DMODEL, + o_base=DQ + dq_offset, + o_start_row=start_m, + o_start_col=0, + o_rows=seqlen_q, + o_cols=head_dim, + stride_row=stride_dqm, + stride_col=stride_dqk) + # tl.device_print('dq_offset ', dq_offset) + # tl.device_print('stride_dqm ', stride_dqm) + # tl.device_print('stride_dqk ', stride_dqk) + # tl.device_print('head_dim ', head_dim) + # tl.device_print('dq_ptrs ', dq_ptrs) + # tl.device_print('dq_masks ', dq_masks) diff --git a/tritonsrc/fwd_kernel.py b/tritonsrc/fwd_kernel.py index 058cf08f..70612d98 100644 --- a/tritonsrc/fwd_kernel.py +++ b/tritonsrc/fwd_kernel.py @@ -6,54 +6,65 @@ Fused Attention =============== -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) - -Extra Credits: -- Original flash attention paper (https://arxiv.org/abs/2205.14135) -- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) -- Adam P. Goucher for simplified vector math +This is a Triton implementation of the Flash Attention v2 algorithm +See https://tridao.me/publications/flash2/flash2.pdf +Credits: +AMD Triton kernels team +OpenAI kernel team """ import triton import triton.language as tl -from fwd_kernel_common import attn_fwd_common +from fwd_kernel_inner import attn_fwd_inner +from masked_load_store import mstore2d + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y @triton.jit def attn_fwd( - Q, K, V, B, sm_scale, M, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - stride_bz, stride_bh, stride_bm, stride_bn, - stride_oz, stride_oh, stride_om, stride_on, - cu_seqlens_q, - cu_seqlens_k, - num_seqlens, # set num_seqlens to zero to ignore cu_seqlens_q/k - max_seqlen_q, # and use max_seqlen_q/k for all seqlen_q/k - max_seqlen_k, - head_dim, - dropout_p, - philox_seed, - philox_offset_base, - encoded_softmax, - CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - pre_load_v: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - PADDED_HEAD: tl.constexpr, - BIAS_TYPE: tl.constexpr, + Q, K, V, B, sm_scale, L, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_bz, stride_bh, stride_bm, stride_bn, + stride_oz, stride_oh, stride_om, stride_on, + num_head_q : 'i32', + num_head_k : 'i32', + cu_seqlens_q, + cu_seqlens_k, + num_seqlens : 'i32', + max_seqlen_q : 'i32', + max_seqlen_k : 'i32', + head_dim : 'i32', + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, + BIAS_TYPE: tl.constexpr, ): + # lower case pre_load_v for backward compatibility, minimize changes to + # other files. Will be fixed in a separate PR + PRE_LOAD_V : tl.constexpr = pre_load_v + # No ALIBI interface for now + USE_ALIBI : tl.constexpr = False + # alibi_slopes = None start_m = tl.program_id(0) - off_h = tl.program_id(1) # head index - off_z = tl.program_id(2) # batch index - num_h = tl.num_programs(1) - num_z = tl.num_programs(2) - off_zh = off_z * num_h + off_h * 1 - # FIXME: Better pattern for this branch? It's copied into three kernels + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) if num_seqlens > 0: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) @@ -84,109 +95,234 @@ def attn_fwd( cu_seqlens_k_start = 0 batch_index = off_z - if start_m * BLOCK_M + BLOCK_M > seqlen_q: - q_padded = True - else: - q_padded = False + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if CAUSAL: + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = Out + batch_index * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + o_ptrs_mask = offs_m[:, None] < seqlen_q + # We still need to write 0s to the result + tl.store(o_ptrs, acc, mask=o_ptrs_mask) + # The tensor allocated for L is based on max_seqlen_q as that is + # statically known. + L_ptr_base = L + (off_z * num_head_q + off_h_q) * max_seqlen_q + l_ptrs = L_ptr_base + offs_m + # We store inf to LSE, not -inf because in the bwd pass, we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. + l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + l_ptrs_mask = offs_m < max_seqlen_q + tl.store(l_ptrs, l, mask=l_ptrs_mask) + # TODO: Should dropout and return encoded softmax be handled here too? + return + + # If MQA / GQA, set the K and V head offsets appropriately. + # group_size = num_head_q // num_head_k + # if group_size != 1: + # off_h_k = off_h_q // group_size + # else: + # off_h_k = off_h_q + off_h_k = off_h_q if num_head_q != num_head_k else off_h_q // (num_head_q // num_head_k) + + n_extra_tokens = 0 if seqlen_k < BLOCK_N: - seqlen_k_faligned = 0 # floor aligned + n_extra_tokens = BLOCK_N - seqlen_k elif seqlen_k % BLOCK_N: - extra_tokens_n = seqlen_k % BLOCK_N - seqlen_k_faligned = seqlen_k - extra_tokens_n - else: - seqlen_k_faligned = seqlen_k - - q_offset = off_h * stride_qh + batch_index * stride_qz + cu_seqlens_q_start * stride_qm - Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(seqlen_q, head_dim), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - k_offset = off_h * stride_kh + batch_index * stride_kz + cu_seqlens_k_start * stride_kn - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(head_dim, seqlen_k), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) - ) - v_offset = off_h * stride_vh + batch_index * stride_vz + cu_seqlens_k_start * stride_vk - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(seqlen_k, head_dim), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0) - ) + n_extra_tokens = seqlen_k % BLOCK_N + + # Compute pointers for all the tensors used in this kernel. + q_offset = Q + batch_index * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + k_offset = batch_index * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + k_ptrs = K + k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + # tl.device_print('batch_index ', batch_index) + # tl.device_print('off_h_k ', off_h_k) + # tl.device_print('cu_seqlens_k_start ', cu_seqlens_k_start) + # tl.device_print('k_offset ', k_offset) + v_offset = V + batch_index * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn if BIAS_TYPE == 0: - B_block_ptr = 0 + bias_ptrs = None elif BIAS_TYPE == 1: - B_block_ptr = tl.make_block_ptr( - base=B + off_h * stride_bh + batch_index * stride_bz, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0) - ) + # Note: this might get large enough to overflow on some configs + bias_offset = off_h_q * stride_bh + bias_ptrs = B + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn else: tl.static_assert(False, f'Unsupported BIAS_TYPE {BIAS_TYPE}') - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.make_block_ptr( - base=encoded_softmax + off_zh * max_seqlen_q * max_seqlen_k, - shape=(seqlen_q, seqlen_k), - strides=(max_seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0) - ) + + if USE_ALIBI: + a_offset = batch_index * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(alibi_slopes + a_offset) else: - encoded_softmax_block_ptr = 0 - # write back O - o_offset = off_h * stride_oh + batch_index * stride_oz + cu_seqlens_q_start * stride_om - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, head_dim), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - - M_ptr_base = M + off_zh * max_seqlen_q + alibi_slope = None + + off_zh = batch_index * num_head_q + off_h_q if ENABLE_DROPOUT: batch_philox_offset = philox_offset_base + off_zh * max_seqlen_q * max_seqlen_k else: batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. In + # this case, we return an invalid pointer so indicate the mask is not valid. + if RETURN_ENCODED_SOFTMAX: + encoded_sm_base = encoded_softmax + off_zh * max_seqlen_q * max_seqlen_k + # encoded_sm_ptrs = encoded_sm_base + offs_m[:, None] * max_seqlen_k + offs_n[None, :] + else: + encoded_sm_base = None + # encoded_sm_ptrs = None + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q_ptrs_mask = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < head_dim) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + q = (q * qk_scale).to(q.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its actual + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = attn_fwd_inner( + acc, l_i, m_i, + q, k_ptrs, v_ptrs, bias_ptrs, + stride_kn, stride_vk, stride_bn, + seqlen_q, seqlen_k, head_dim, + start_m, block_min, block_max, + dropout_p, philox_seed, batch_philox_offset, max_seqlen_k, + encoded_sm_base, + # offs_n_causal, masked_blocks, n_extra_tokens + 0, 0, 0, + alibi_slope, + # CAUSAL, .... + False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, + # _, MASK_STEPS, ... + PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if masked_blocks > 0: + if CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vk + if BIAS_TYPE == 0: + pass + elif BIAS_TYPE == 1: + bias_ptrs += n_full_blocks * BLOCK_N * stride_bn + else: + tl.static_assert(False, f'Unsupported BIAS_TYPE {BIAS_TYPE}') + # if RETURN_ENCODED_SOFTMAX: + # encoded_sm_base += n_full_blocks * BLOCK_N + # encoded_sm_ptrs += n_full_blocks * BLOCK_N + acc, l_i, m_i = attn_fwd_inner( + acc, l_i, m_i, + q, k_ptrs, v_ptrs, bias_ptrs, + stride_kn, stride_vk, stride_bn, + seqlen_q, seqlen_k, head_dim, + start_m, block_min, block_max, + dropout_p, philox_seed, batch_philox_offset, max_seqlen_k, + encoded_sm_base, + offs_n_causal, masked_blocks, n_extra_tokens, + alibi_slope, + # CAUSAL, .... + CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, + # _, MASK_STEPS, ... + PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if CAUSAL: + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # FIXME: MQA/GQA L tensor + # TODO: make writing of L optional + # write back LSE + + # L's shape: (batch, head, seqlen_q) + L_ptr_base = L + (off_z * num_head_q + off_h_q) * max_seqlen_q + l_ptrs = L_ptr_base + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary + tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + else: + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) - attn_fwd_common(Q_block_ptr, - K_block_ptr, - V_block_ptr, - B_block_ptr, - O_block_ptr, - M_ptr_base, - sm_scale, - start_m, - seqlen_q, - seqlen_k, - seqlen_k_faligned, - q_padded, - dropout_p, - philox_seed, - batch_philox_offset, - max_seqlen_k, - encoded_softmax_block_ptr, - CAUSAL=CAUSAL, - BLOCK_M=BLOCK_M, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK_N, - pre_load_v=pre_load_v, - ENABLE_DROPOUT=ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX=RETURN_ENCODED_SOFTMAX, - PADDED_HEAD=PADDED_HEAD, - BIAS_TYPE=BIAS_TYPE) + o_base = Out + batch_index * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + mstore2d(acc.to(Out.dtype.element_ty), + BLOCK_M, + BLOCK_DMODEL, + o_base=o_base, + o_start_row=start_m * BLOCK_M, + o_start_col=0, + o_rows=seqlen_q, + o_cols=head_dim, + stride_row=stride_om, + stride_col=stride_on) + # # write back O + # o_offset = Out + batch_index * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + # o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + # o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + # if overflow_size > 0: + # o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + # if PADDED_HEAD: + # o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < head_dim) + # tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) diff --git a/tritonsrc/fwd_kernel_common.py b/tritonsrc/fwd_kernel_common.py index 20d3a02d..0746dab5 100644 --- a/tritonsrc/fwd_kernel_common.py +++ b/tritonsrc/fwd_kernel_common.py @@ -6,6 +6,38 @@ from fwd_kernel_inner import attn_fwd_inner +''' +@triton.jit +def store0(O_block_ptr, acc): + tl.store(O_block_ptr, acc.to(O_block_ptr.type.element_ty)) + +@triton.jit +def store1(O_block_ptr, acc): + tl.store(O_block_ptr, acc.to(O_block_ptr.type.element_ty), boundary_check=(0,)) + +@triton.jit +def store2(O_block_ptr, acc): + tl.store(O_block_ptr, acc.to(O_block_ptr.type.element_ty), boundary_check=(1,)) + +@triton.jit +def store3(O_block_ptr, acc): + tl.store(O_block_ptr, acc.to(O_block_ptr.type.element_ty), boundary_check=(1,0)) +''' + +@triton.jit +def store_a(O_block_ptr, acc, q_padded): + if not q_padded: + tl.store(O_block_ptr, acc) + else: + tl.store(O_block_ptr, acc, boundary_check=(0,)) + +@triton.jit +def store_b(O_block_ptr, acc, q_padded): + if not q_padded: + tl.store(O_block_ptr, acc, boundary_check=(1,)) + else: + tl.store(O_block_ptr, acc, boundary_check=(1,0,)) + @triton.jit def attn_fwd_common( Q_block_ptr, @@ -48,6 +80,8 @@ def attn_fwd_common( # don't work as expected with `exp` in the loop qk_scale = sm_scale * 1.44269504089 # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs + q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero") + ''' if q_padded: if PADDED_HEAD: q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero") @@ -58,6 +92,7 @@ def attn_fwd_common( q = tl.load(Q_block_ptr, boundary_check=(1,), padding_option="zero") else: q = tl.load(Q_block_ptr) + ''' q = (q * qk_scale).to(Q_block_ptr.type.element_ty) # stage 1: off-band # For causal = True, STAGE = 3 and attn_fwd_inner gets 1 as its STAGE @@ -125,6 +160,10 @@ def attn_fwd_common( else: tl.store(m_ptrs, m_i + tl.math.log2(l_i)) + acc = acc.to(O_block_ptr.type.element_ty) + tl.store(O_block_ptr, acc, boundary_check=(1,0,)) + + ''' if q_padded: if PADDED_HEAD: tl.store(O_block_ptr, acc.to(O_block_ptr.type.element_ty), boundary_check=(0,1)) @@ -135,3 +174,4 @@ def attn_fwd_common( tl.store(O_block_ptr, acc.to(O_block_ptr.type.element_ty), boundary_check=(1,)) else: tl.store(O_block_ptr, acc.to(O_block_ptr.type.element_ty)) + ''' diff --git a/tritonsrc/fwd_kernel_inner.py b/tritonsrc/fwd_kernel_inner.py index 4df80b83..145a5cbd 100644 --- a/tritonsrc/fwd_kernel_inner.py +++ b/tritonsrc/fwd_kernel_inner.py @@ -4,138 +4,171 @@ import triton import triton.language as tl from dropout import dropout_mask, dropout_rng, dropout_offsets +from masked_load_store import mstore2d +# Convenience function to load with optional boundary checks. +# "First" is the major dim, "second" is the minor dim. @triton.jit -def max_fn(x, y): - return tl.math.max(x, y) +def load_fn(ptrs, offset_first, offset_second, _in_boundary_first, _in_boundary_second): + boundary_first = _in_boundary_first + boundary_second = _in_boundary_second + """ + # Debugging GPU segfault + boundary_first = 0 + boundary_second = 0 + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + return tl.load(ptrs, mask=mask, other=0.0) + """ + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) + else: + tensor = tl.load(ptrs) + return tensor @triton.jit def attn_fwd_inner( - acc, l_i, m_i, q, - K_block_ptr, V_block_ptr, B_block_ptr, - start_m, - seqlen_q, - q_padded, - seqlen_k_low, - seqlen_k_high, - k_padded, - dropout_p, - dropout_seqlen_k, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - CAUSAL: tl.constexpr, - offs_m: tl.constexpr, - offs_n: tl.constexpr, - pre_load_v: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - MARGINAL_BLOCK: tl.constexpr, # MARGINAL_BLOCK = CAUSAL or k_padded - PADDED_HEAD: tl.constexpr, - BIAS_TYPE: tl.constexpr, + # Problem Description + acc, l_i, m_i, + q, k_ptrs, v_ptrs, bias_ptrs, + stride_kn, stride_vk, stride_bn, + seqlen_q, seqlen_k, head_dim, + # Sub-problem range + start_m, block_min, block_max, + # Auxiliary options + ## Dropout + dropout_p, philox_seed, batch_philox_offset, max_seqlen_k, + ## Debug Return + encoded_sm_base, + ## Irregular support + offs_n_causal, masked_blocks, n_extra_tokens, + ## Alibi + alibi_slope, + # constexpr starts here + CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, ): - lo, hi = seqlen_k_low, seqlen_k_high - if MARGINAL_BLOCK: - K_block_ptr = tl.advance(K_block_ptr, (0, lo)) - V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, lo)) - if BIAS_TYPE == 1: - B_block_ptr = tl.advance(B_block_ptr, (0, lo)) - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - # -- compute qk ---- - # MARGINAL_BLOCK serves as a compile-time switch for first attn_fwd_inner calls to "solid" blocks - if MARGINAL_BLOCK and k_padded: - if PADDED_HEAD: - k = tl.load(K_block_ptr, boundary_check=(1,0), padding_option="zero") - else: - k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero") + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + if MASK_STEPS: + k_offs_n = start_n + tl.arange(0, BLOCK_N) else: - if PADDED_HEAD: - k = tl.load(K_block_ptr, boundary_check=(0,), padding_option="zero") - else: - k = tl.load(K_block_ptr) - if pre_load_v: - if MARGINAL_BLOCK and k_padded: - if PADDED_HEAD: - v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") - else: - v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") - else: - if PADDED_HEAD: - v = tl.load(V_block_ptr, boundary_check=(1,), padding_option="zero") - else: - v = tl.load(V_block_ptr) + k_offs_n = None + k_offs_d = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + ''' + k_offs_n = start_n + tl.arange(0, BLOCK_N) + k_offs_d = tl.arange(0, BLOCK_DMODEL) + # k_offs_d = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + ''' + k = load_fn(k_ptrs, k_offs_d, k_offs_n, head_dim, seqlen_k) + if PRE_LOAD_V: + # We can use the same offsets as k, just with dims transposed. + v = load_fn(v_ptrs, k_offs_n, k_offs_d, seqlen_k, head_dim) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - if MARGINAL_BLOCK: - if CAUSAL: - mask = offs_m[:, None] >= (start_n + offs_n[None, :]) - qk = tl.where(mask, qk, float("-inf")) - if k_padded: - boundary_m = tl.full([BLOCK_M], seqlen_k_high, dtype=tl.int32) - size_n = start_n + offs_n[None,:] - mask = size_n < boundary_m[:,None] + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], seqlen_k, dtype=tl.int32) + size_n = start_n + offs_n[None, :] + mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) - if BIAS_TYPE == 0: - pass - elif BIAS_TYPE == 1: - if q_padded and k_padded: # CAVEAT: using "or" disables the partial boundary_check branches - bias = tl.load(B_block_ptr, boundary_check=(0,1), padding_option="zero") - elif q_padded: - bias = tl.load(B_block_ptr, boundary_check=(0,), padding_option="zero") - elif k_padded: - bias = tl.load(B_block_ptr, boundary_check=(1,), padding_option="zero") - else: - bias = tl.load(B_block_ptr) - qk += bias * 1.44269504089 - else: - tl.static_assert(False, f'Unsupported BIAS_TYPE {BIAS_TYPE}') + if CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = offs_m[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- qk += tl.dot(q, k) + if bias_ptrs is not None: + bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None + bias = load_fn(bias_ptrs, offs_m, bias_offs_n, seqlen_q, seqlen_k) + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += (bias * 1.44269504089) + + if alibi_slope is not None: + # Compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + alibi_block = compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, global_m_positions, + global_n_positions) + qk += (alibi_block * 1.44269504089) # scale factor of log2(e) + + # softmax m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk = qk - m_ij[:, None] p = tl.math.exp2(qk) + # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) - # Note about the conflicts of Flash attention algorithm and PyTorch's CUDA implementation - # PyTorch needs to return softmax(qk) (dropout mask encoded in sign bits) - # While Flash attention paper computer the dropout AFTER exp2(qk- m_ij) if ENABLE_DROPOUT: - philox_offset = batch_philox_offset + start_m * BLOCK_M * dropout_seqlen_k + start_n - keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, dropout_seqlen_k) + philox_offset = batch_philox_offset + start_m * BLOCK_M * max_seqlen_k + start_n + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, max_seqlen_k) if RETURN_ENCODED_SOFTMAX: - tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty), boundary_check=(0,1)) + mstore2d(tl.where(keep, p, -p).to(q.type.element_ty), + BLOCK_M, + BLOCK_N, + o_base=encoded_sm_base, + o_start_row=start_m * BLOCK_M, + o_start_col=start_n, + o_rows=seqlen_q, + o_cols=seqlen_k, + stride_row=max_seqlen_k, + stride_col=1) + # tl.store(encoded_sm_ptrs, tl.where(keep, p, -p).to(q.type.element_ty)) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: - tl.store(encoded_softmax_block_ptr, - p.to(encoded_softmax_block_ptr.type.element_ty), - boundary_check=(0,1)) + mstore2d(p.to(q.type.element_ty), + BLOCK_M, + BLOCK_N, + o_base=encoded_sm_base, + o_start_row=start_m * BLOCK_M, + o_start_col=start_n, + o_rows=seqlen_q, + o_cols=seqlen_k, + stride_row=max_seqlen_k, + stride_col=1) + # tl.store(encoded_sm_ptrs, p.to(q.type.element_ty)) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] - if not pre_load_v: - if MARGINAL_BLOCK and k_padded: - if PADDED_HEAD: - v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") - else: - v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") - else: - if PADDED_HEAD: - v = tl.load(V_block_ptr, boundary_check=(1,), padding_option="zero") - else: - v = tl.load(V_block_ptr) + if not PRE_LOAD_V: + v = load_fn(v_ptrs, k_offs_n, k_offs_d, seqlen_k, head_dim) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) - if BIAS_TYPE == 1: - B_block_ptr = tl.advance(B_block_ptr, (0, BLOCK_N)) + acc += tl.dot(p.to(v.type.element_ty), v) + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if bias_ptrs is not None: + bias_ptrs += BLOCK_N * stride_bn + # if RETURN_ENCODED_SOFTMAX: + # encoded_sm_ptrs += BLOCK_N return acc, l_i, m_i diff --git a/tritonsrc/masked_load_store.py b/tritonsrc/masked_load_store.py new file mode 100644 index 00000000..68804f88 --- /dev/null +++ b/tritonsrc/masked_load_store.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# Copyright © 2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +import triton +import triton.language as tl + +@triton.jit +def mstore2d( + registers, + REG_ROWS : tl.constexpr, + REG_COLS : tl.constexpr, + o_base, + o_start_row, + o_start_col, + o_rows, + o_cols, + stride_row, + stride_col, +): + off_rows = tl.arange(0, REG_ROWS) + o_start_row + off_cols = tl.arange(0, REG_COLS) + o_start_col + o_ptrs = o_base + off_rows[:, None] * stride_row + off_cols[None, :] * stride_col + o_ptrs_mask = tl.full([REG_ROWS, REG_COLS], 1, dtype=tl.int1) + row_overflow = o_start_row + REG_ROWS - o_rows + if row_overflow > 0: + o_ptrs_mask = o_ptrs_mask & (off_rows[:, None] < o_rows) + col_overflow = o_start_col + REG_COLS - o_cols + if col_overflow > 0: + o_ptrs_mask = o_ptrs_mask & (off_cols[None, :] < o_cols) + tl.store(o_ptrs, registers, mask=o_ptrs_mask) + return o_ptrs, o_ptrs_mask diff --git a/tritonsrc/old_fwd_kernel.py b/tritonsrc/old_fwd_kernel.py new file mode 100644 index 00000000..26f1cf4d --- /dev/null +++ b/tritonsrc/old_fwd_kernel.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Extra Credits: +- Original flash attention paper (https://arxiv.org/abs/2205.14135) +- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) +- Adam P. Goucher for simplified vector math + +""" + +import triton +import triton.language as tl +from fwd_kernel_inner import attn_fwd_inner + +@triton.jit +def attn_fwd( + Q, K, V, B, sm_scale, M, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_bz, stride_bh, stride_bm, stride_bn, + stride_oz, stride_oh, stride_om, stride_on, + cu_seqlens_q, + cu_seqlens_k, + num_seqlens, # set num_seqlens to zero to ignore cu_seqlens_q/k + max_seqlen_q, # and use max_seqlen_q/k for all seqlen_q/k + max_seqlen_k, + head_dim, + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, + BIAS_TYPE: tl.constexpr, +): + start_m = tl.program_id(0) + off_h = tl.program_id(1) # head index + off_z = tl.program_id(2) # batch index + num_h = tl.num_programs(1) + num_z = tl.num_programs(2) + off_zh = off_z * num_h + off_h * 1 + # FIXME: Better pattern for this branch? It's copied into three kernels + if num_seqlens > 0: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + if start_m * BLOCK_M >= seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + batch_index = 0 + elif num_seqlens == 0: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + batch_index = off_z + else: # < 0 for padded seqlen + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + if start_m * BLOCK_M >= seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + # Varlen, but padded to Rank 4 tensor + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + batch_index = off_z + + if start_m * BLOCK_M + BLOCK_M > seqlen_q: + q_padded = True + else: + q_padded = False + if seqlen_k < BLOCK_N: + seqlen_k_faligned = 0 # floor aligned + elif seqlen_k % BLOCK_N: + extra_tokens_n = seqlen_k % BLOCK_N + seqlen_k_faligned = seqlen_k - extra_tokens_n + else: + seqlen_k_faligned = seqlen_k + + q_offset = off_h * stride_qh + batch_index * stride_qz + cu_seqlens_q_start * stride_qm + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, head_dim), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + k_offset = off_h * stride_kh + batch_index * stride_kz + cu_seqlens_k_start * stride_kn + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(head_dim, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + v_offset = off_h * stride_vh + batch_index * stride_vz + cu_seqlens_k_start * stride_vk + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, head_dim), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0) + ) + if BIAS_TYPE == 0: + B_block_ptr = 0 + elif BIAS_TYPE == 1: + B_block_ptr = tl.make_block_ptr( + base=B + off_h * stride_bh + batch_index * stride_bz, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0) + ) + else: + tl.static_assert(False, f'Unsupported BIAS_TYPE {BIAS_TYPE}') + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_zh * max_seqlen_q * max_seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(max_seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0) + ) + else: + encoded_softmax_block_ptr = 0 + # write back O + o_offset = off_h * stride_oh + batch_index * stride_oz + cu_seqlens_q_start * stride_om + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, head_dim), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + + M_ptr_base = M + off_zh * max_seqlen_q + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + off_zh * max_seqlen_q * max_seqlen_k + else: + batch_philox_offset = 0 + + attn_fwd_common(Q_block_ptr, + K_block_ptr, + V_block_ptr, + B_block_ptr, + O_block_ptr, + M_ptr_base, + sm_scale, + start_m, + seqlen_q, + seqlen_k, + seqlen_k_faligned, + q_padded, + dropout_p, + philox_seed, + batch_philox_offset, + max_seqlen_k, + encoded_softmax_block_ptr, + CAUSAL=CAUSAL, + BLOCK_M=BLOCK_M, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_N=BLOCK_N, + pre_load_v=pre_load_v, + ENABLE_DROPOUT=ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX=RETURN_ENCODED_SOFTMAX, + PADDED_HEAD=PADDED_HEAD, + BIAS_TYPE=BIAS_TYPE) + + diff --git a/tritonsrc/old_fwd_kernel_inner.py b/tritonsrc/old_fwd_kernel_inner.py new file mode 100644 index 00000000..db5e0c0f --- /dev/null +++ b/tritonsrc/old_fwd_kernel_inner.py @@ -0,0 +1,149 @@ +# Copyright © 2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +import triton +import triton.language as tl +from dropout import dropout_mask, dropout_rng, dropout_offsets + +@triton.jit +def attn_fwd_inner( + acc, l_i, m_i, q, + K_block_ptr, V_block_ptr, B_block_ptr, + start_m, + seqlen_q, + q_padded, + seqlen_k_low, + seqlen_k_high, + k_padded, + dropout_p, + dropout_seqlen_k, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + CAUSAL: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, + pre_load_v: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + MARGINAL_BLOCK: tl.constexpr, # MARGINAL_BLOCK = CAUSAL or k_padded + PADDED_HEAD: tl.constexpr, + BIAS_TYPE: tl.constexpr, +): + lo, hi = seqlen_k_low, seqlen_k_high + if MARGINAL_BLOCK: + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, lo)) + if BIAS_TYPE == 1: + B_block_ptr = tl.advance(B_block_ptr, (0, lo)) + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + # -- compute qk ---- + # MARGINAL_BLOCK serves as a compile-time switch for first attn_fwd_inner calls to "solid" blocks + k = tl.load(K_block_ptr, boundary_check=(1,0), padding_option="zero") + ''' + if MARGINAL_BLOCK and k_padded: + if PADDED_HEAD: + k = tl.load(K_block_ptr, boundary_check=(1,0), padding_option="zero") + else: + k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero") + else: + if PADDED_HEAD: + k = tl.load(K_block_ptr, boundary_check=(0,), padding_option="zero") + else: + k = tl.load(K_block_ptr) + if pre_load_v: + if MARGINAL_BLOCK and k_padded: + if PADDED_HEAD: + v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") + else: + v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") + else: + if PADDED_HEAD: + v = tl.load(V_block_ptr, boundary_check=(1,), padding_option="zero") + else: + v = tl.load(V_block_ptr) + ''' + if pre_load_v: + v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if MARGINAL_BLOCK: + if CAUSAL: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = tl.where(mask, qk, float("-inf")) + if k_padded: + boundary_m = tl.full([BLOCK_M], seqlen_k_high, dtype=tl.int32) + size_n = start_n + offs_n[None,:] + mask = size_n < boundary_m[:,None] + qk = tl.where(mask, qk, float("-inf")) + if BIAS_TYPE == 0: + pass + elif BIAS_TYPE == 1: + bias = tl.load(B_block_ptr, boundary_check=(0,1), padding_option="zero") + ''' + if q_padded and k_padded: # CAVEAT: using "or" disables the partial boundary_check branches + bias = tl.load(B_block_ptr, boundary_check=(0,1), padding_option="zero") + elif q_padded: + bias = tl.load(B_block_ptr, boundary_check=(0,), padding_option="zero") + elif k_padded: + bias = tl.load(B_block_ptr, boundary_check=(1,), padding_option="zero") + else: + bias = tl.load(B_block_ptr) + ''' + qk += bias * 1.44269504089 + else: + tl.static_assert(False, f'Unsupported BIAS_TYPE {BIAS_TYPE}') + qk += tl.dot(q, k) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + # Note about the conflicts of Flash attention algorithm and PyTorch's CUDA implementation + # PyTorch needs to return softmax(qk) (dropout mask encoded in sign bits) + # While Flash attention paper computer the dropout AFTER exp2(qk- m_ij) + if ENABLE_DROPOUT: + philox_offset = batch_philox_offset + start_m * BLOCK_M * dropout_seqlen_k + start_n + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, dropout_seqlen_k) + if RETURN_ENCODED_SOFTMAX: + tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty), boundary_check=(0,1)) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store(encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + boundary_check=(0,1)) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") + ''' + if MARGINAL_BLOCK and k_padded: + if PADDED_HEAD: + v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") + else: + v = tl.load(V_block_ptr, boundary_check=(0,1), padding_option="zero") + else: + if PADDED_HEAD: + v = tl.load(V_block_ptr, boundary_check=(1,), padding_option="zero") + else: + v = tl.load(V_block_ptr) + ''' + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) + if BIAS_TYPE == 1: + B_block_ptr = tl.advance(B_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + diff --git a/tritonsrc/performance_forward.py b/tritonsrc/performance_forward.py index 2c9445b0..cec50c08 100644 --- a/tritonsrc/performance_forward.py +++ b/tritonsrc/performance_forward.py @@ -5,6 +5,7 @@ import pytest import torch +import os import triton from attn_torch_function import attention @@ -20,6 +21,12 @@ FLASH_VER = None HAS_FLASH = FLASH_VER is not None +n_ctx = os.getenv('N_CTX', default=list(range(10, 14))) +if isinstance(n_ctx, str): + n_ctx = map(lambda x: int(x), n_ctx.split(',')) +X_VALS = list(map(lambda x: 2 ** x, n_ctx)) +print(f'{X_VALS=}') + BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 configs = [] @@ -27,7 +34,9 @@ for causal in [False, True]: configs.append(triton.testing.Benchmark( x_names=['N_CTX'], - x_vals=[2**i for i in range(10, 15)], + x_vals=list(X_VALS), + # x_vals=[2**i for i in range(10, 14)], # 2 ** 15 not working for now + # x_vals=[2**12], line_arg='provider', line_vals=['triton'] + (['flash'] if HAS_FLASH else []), line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), diff --git a/tritonsrc/test_backward.py b/tritonsrc/test_backward.py index 9e48ed9f..7ee8de32 100644 --- a/tritonsrc/test_backward.py +++ b/tritonsrc/test_backward.py @@ -71,6 +71,8 @@ def _do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale print(f'{err_idx=}') print(f'{tri_out[err_idx]=}') print(f'{ref_out[err_idx]=}') + print(f'{tri_out[0, 0, :4, :4]=}') + print(f'{ref_out[0, 0, :4, :4]=}') assert is_allclose, 'Forward pass {is_allclose=}' dq_allclose, dk_allclose, dv_allclose, db_allclose = grads_allclose @@ -207,6 +209,28 @@ def test_op_bwd_with_matrix_bias(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, sm_ ''' _do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) +def main2(): + # Memo: False-0.0-dtype0-0.0-False-4-256-8-4-1 + # Memo: False-0.0-dtype0-0.0-False-4-256-8-1-4 + # False-1.2-dtype0-0.0-False-4-4-72-1-4 + # BATCH = 8 + # D_HEAD = 32 + # N_HEADS = 8 + # seqlen_q = 16 + # seqlen_k = 16 + BATCH = 4 + D_HEAD = 1 + N_HEADS = 8 + seqlen_q = 256 + seqlen_k = 4 + causal = False + sm_scale = 1.2 + dropout_p = 0.0 + dtype = torch.float16 + storage_flip = False + bias_type = None + _do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) + def main(): BATCH = 1 D_HEAD = 80 @@ -225,4 +249,4 @@ def main(): _do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type) if __name__ == '__main__': - main() + main2() diff --git a/tritonsrc/varlen_attn_torch_function.py b/tritonsrc/varlen_attn_torch_function.py index d73b17c0..afcc5e3d 100644 --- a/tritonsrc/varlen_attn_torch_function.py +++ b/tritonsrc/varlen_attn_torch_function.py @@ -37,6 +37,8 @@ def forward(ctx, q, k, v, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, retur head_dim_rounded = 2 ** (Lk - 1).bit_length() head_dim_rounded = max(16, head_dim_rounded) padded_head = head_dim_rounded != Lk + num_head_q = q.shape[1] + num_head_k = k.shape[1] # Varlen packed all batches of seqlens into dim[0] batch = len(seqlen_q) num_heads = q.shape[1] @@ -131,6 +133,8 @@ def forward(ctx, q, k, v, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, retur v.stride(0), v.stride(1), v.stride(2), v.stride(3), b.stride(0), b.stride(1), b.stride(2), b.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), + num_head_q=num_head_q, + num_head_k=num_head_k, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, num_seqlens=len(cu_seqlens_q), diff --git a/v2python/autotune_config.py b/v2python/autotune_config.py new file mode 100644 index 00000000..c35908b3 --- /dev/null +++ b/v2python/autotune_config.py @@ -0,0 +1,33 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# Copyright © 2020-2022 OpenAI +# SPDX-License-Identifier: MIT + +from .kernel_argument import TunedArgument + +class Config: + ''' + A compatibile class to store triton.Config + ''' + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, maxnreg=None, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.maxnreg = maxnreg + self.pre_hook = pre_hook + + def translate_to_psel_and_co(self, perf_metas : 'list[ArgumentMetadata]'): + psels = [] + for k, v in self.kwargs.items(): + for meta in perf_metas: + if meta.has_argument(k): + psels.append(TunedArgument(meta, v)) + break + if 'waves_per_eu' in self.kwargs: + co = {'waves_per_eu' : self.kwargs['waves_per_eu'] } + else: + co = {} + co['num_warps'] = self.num_warps + co['num_stages'] = self.num_stages + # print(f'translate_to_psel_and_co {psels=} {co=}') + return psels, co diff --git a/v2python/common_tuning_database.py b/v2python/common_tuning_database.py index fe1ff7a4..e947e0e9 100644 --- a/v2python/common_tuning_database.py +++ b/v2python/common_tuning_database.py @@ -40,6 +40,10 @@ def gpu(self): def arch_number(self): return self._arch_number + @classmethod + def is_passthrough_tuning(klass): + return False + ''' Create db index, and also initialize _fsel_positions so that _extract_keys_from_fsels can use it ''' @@ -47,6 +51,13 @@ def arch_number(self): def _build_db_index(self, fsels): pass + ''' + Callgraph: select -> _select_from_db -> _lookup_tuning_info + <- + _select_from_db -> craft_perf_selection + <--------------------------- craft_perf_selection + Called by KernelDescription.gen_all_object_files to narrow down kernels to build for fsels + ''' def select(self, fsels : 'list[ArgumentSelection]', perf_meta : 'list[ArgumentMetadata]') -> 'list[ArgumentSelection], dict[str,str]': if self.empty: yield [], None @@ -61,12 +72,31 @@ def _select_from_db(self, pass ''' - tinfo: one piece of tuning information, can be a json object, or a row in SQLite database + Translate row into dict that only contains maps from input keys to values + ''' + @abstractmethod + def extract_inputs(self, columns, row): + pass + + ''' + columns, row: one piece of tuning information, can be a json object, or a + single row in SQLite database. + For json database, columns is None (schemaless, metadata included in rows) + Called by select -> _select_from_db + or + KernelTuningEntryForFunctionalOnGPU ''' @abstractmethod - def craft_perf_selection(self, tinfo, perf_meta: 'list[ArgumentSelection]'): + def craft_perf_selection(self, + columns, + row, + perf_meta: 'list[ArgumentSelection]') -> 'list[TunedArgument], compiler_options': pass + ''' + Callgraph: get_lut -> (Extract tuning info for kdesc+fsels) + <--------- Construct KernelTuningEntryForFunctionalOnGPU fro tuning info + ''' @abstractmethod def get_lut(self, kdesc : 'KernelDescription', diff --git a/v2python/compile.py b/v2python/compile.py old mode 100755 new mode 100644 index 8d5be9fc..adccbb1e --- a/v2python/compile.py +++ b/v2python/compile.py @@ -1,7 +1,4 @@ -#!/usr/bin/env python -# Copyright © 2023-2024 Advanced Micro Devices, Inc. -# SPDX-License-Identifier: MIT - +import binascii import hashlib import importlib.util import sys @@ -10,6 +7,8 @@ from typing import List import triton +from triton.compiler.code_generator import kernel_suffix +from triton.backends.amd.driver import ty_to_cpp import shutil import subprocess @@ -19,24 +18,40 @@ Triton ahead-of-time compiler: """ +from triton.backends.compiler import GPUTarget + +KNOWN_TARGETS = { + None : None, + 'MI200' : GPUTarget('hip', 'gfx90a', 64), + 'MI300X' : GPUTarget('hip', 'gfx942', 64), + 'Navi31' : GPUTarget('hip', 'gfx1100', 32), + 'Navi32' : GPUTarget('hip', 'gfx1101', 32), +} + def main(): # command-line arguments parser = ArgumentParser(description=desc) - parser.add_argument("path", help="Path to Python source containing desired kernel in its scope. File will be executed.") - parser.add_argument("--target", type=str, default=None, help="Ahead of Time (AOT) Compile Architecture. PyTorch is required for autodetection if --target is missing.") - parser.add_argument("--kernel_name", "-n", type=str, default="", help="Name of the kernel to compile", required=True) + parser.add_argument("path", + help="Path to Python source containing desired kernel in its scope. File will be executed.") + parser.add_argument("--target", type=str, default=None, + choices=list(KNOWN_TARGETS.keys()), + help="Ahead of Time (AOT) Compile Architecture. PyTorch is required for autodetection if --target is missing.") + parser.add_argument("--kernel_name", "-n", type=str, default="", help="Name of the kernel to compile", + required=True) parser.add_argument("--num_warps", "-w", type=int, default=1, help="Number of warps to launch the kernel") - parser.add_argument("--num_stages", "-ns", type=int, default=3, help="Number of stages (meta-parameter of the kernel)") + parser.add_argument("--num_stages", "-ns", type=int, default=3, + help="Number of stages (meta-parameter of the kernel)") parser.add_argument("--waves_per_eu", type=int, default=0) - parser.add_argument("--out_path", "-o", type=Path, default=None, help="Out filename", required=True) + parser.add_argument("--out_name", "-on", type=str, default=None, help="Out name for the compiled kernel") + parser.add_argument("--out_path", "-o", type=Path, default=None, help="Out filename") parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True) parser.add_argument("--verbose", "-v", help="Enable vebose output", action='store_true') parser.add_argument("--nostrip", help="Keep debugging symbols", action='store_true') args = parser.parse_args() - out_path = args.out_path - out_path = out_path.with_suffix('') + out_name = args.out_name if args.out_name else args.kernel_name + out_path = args.out_path if args.out_path else Path(out_name) # execute python sources and extract functions wrapped in JITFunction arg_path = Path(args.path) @@ -55,13 +70,6 @@ def main(): # kernel = globals()[f"{arg_path.stem}.{args.kernel_name}"] mod = globals()[arg_path.stem] kernel = getattr(mod, args.kernel_name) - # print(fused_attention_trimmed.attn_fwd) - if False: - mod = importlib.import_module(arg_path.stem) - print(mod.attn_fwd) - # print(fused_attention_trimmed.attn_fwd) - kernel = globals()[f"{arg_path.stem}.{args.kernel_name}"] - print(f"{kernel=}") grid = args.grid.split(",") assert len(grid) == 3 @@ -96,12 +104,12 @@ def constexpr(s): hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} hints = {k: v for k, v in hints.items() if v is not None} - constexprs = {i: constexpr(s) for i, s in enumerate(signature)} - constexprs = {k: v for k, v in constexprs.items() if v is not None} + constants = {i: constexpr(s) for i, s in enumerate(signature)} + constants = {k: v for k, v in constants.items() if v is not None} # print(f"{constexprs=}") - signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constexprs} - const_sig = 'x'.join([str(v) for v in constexprs.values()]) - doc_string = [f"{kernel.arg_names[i]}={constexprs[i]}" for i in constexprs.keys()] + signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constants} + const_sig = 'x'.join([str(v) for v in constants.values()]) + doc_string = [f"{kernel.arg_names[i]}={constants[i]}" for i in constants.keys()] doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] # compile ast into cubin @@ -109,26 +117,55 @@ def constexpr(s): assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" divisible_by_16 = [i for i, h in hints.items() if h == 16] equal_to_1 = [i for i, h in hints.items() if h == 1] - config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) + attrs = triton.compiler.AttrsDescriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) for i in equal_to_1: - constexprs.update({i: 1}) - # print(f'{kernel=}') - ccinfo = triton.compile(kernel, single_cpu=True, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps, num_stages=args.num_stages, waves_per_eu=args.waves_per_eu, aot_arch=args.target) - hsaco_path = ccinfo.asm.get('hsaco_path', None) - if args.verbose: - print(dir(ccinfo)) - print(f'{ccinfo.asm.keys()=}') - print(f'{ccinfo.fn=}') - print(f'{hsaco_path=}') - - if hsaco_path is not None: - if args.nostrip: - shutil.copy(hsaco_path, out_path.with_suffix('.hsaco')) - else: - subprocess.run(['/opt/rocm/llvm/bin/llvm-objcopy', '--remove-section', '.debug_*', str(hsaco_path), str(out_path.with_suffix('.hsaco'))]) - - with out_path.with_suffix('.json').open("w") as fp: - json.dump(ccinfo.metadata, fp, indent=2) + constants.update({i: 1}) + src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) + opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} + ccinfo = triton.compile(src, target=KNOWN_TARGETS[args.target], options=opts) + # import pdb; pdb.set_trace() + with open(out_path.with_suffix('.hsaco'), 'bw') as f: + f.write(ccinfo.kernel) + with open(out_path.with_suffix('.json'), 'w') as f: + di = ccinfo.metadata._asdict() + del di['target'] + json.dump(di, f, indent=2) + + ''' + arg_names = [] + arg_types = [] + for i in signature.keys(): + if i not in equal_to_1: + arg_names += [kernel.arg_names[i]] + arg_types += [signature[i]] + + # dump C stub code + suffix = kernel_suffix(signature.values(), attrs) + func_name = '_'.join([out_name, sig_hash, suffix]) + hex_ = str(binascii.hexlify(ccinfo.asm["cubin"]))[2:-1] + params = { + "kernel_name": func_name, + "triton_kernel_name": args.kernel_name, + "bin_size": len(hex_), + "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), + "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), + "full_signature": ", ".join([f"{ty_to_cpp(signature[i])} {kernel.arg_names[i]}" for i in signature.keys()]), + "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names]), + "num_args": len(arg_names), + "kernel_docstring": doc_string, + "shared": ccinfo.metadata.shared, + "num_warps": args.num_warps, + "algo_info": '_'.join([const_sig, meta_sig]), + "gridX": grid[0], + "gridY": grid[1], + "gridZ": grid[2], + "_placeholder": "", + } + for ext in ['h', 'c']: + template_path = Path(__file__).parent / f"compile.{ext}" + with out_path.with_suffix(f".{sig_hash}_{suffix}.{ext}").open("w") as fp: + fp.write(Path(template_path).read_text().format(**params)) + ''' if __name__ == "__main__": main() diff --git a/v2python/generate_compile.py b/v2python/generate_compile.py index 66a78d69..a20ec3b3 100644 --- a/v2python/generate_compile.py +++ b/v2python/generate_compile.py @@ -20,6 +20,7 @@ def parse(): p.add_argument("--python", type=str, default=None, help="python binary to run compile.py") p.add_argument("--enable_zstd", type=str, default=None, help="Use zstd to compress the compiled kernel") p.add_argument("--bare_mode", action='store_true', help="Instead of generating a proper Makefile, only generate compiler options and leave the remaining tasks to cmake.") + p.add_argument("--build_for_tuning", action='store_true', help="Build all possible GPU kernels for performance tuning.") # p.add_argument("--autotune_data", type=str, default=None, help="Autotune results generated by tune_flash.py") args = p.parse_args() # print(args) @@ -53,8 +54,9 @@ def gen_from_kernel(args, k, build_dir, makefile): all_targets = [] object_rules = io.StringIO() arches = [None] if args.target_gpus is None else args.target_gpus - ktd = KernelTuningDatabase(SOURCE_PATH.parent / 'rules', k) - if True: # Debugging + # ktd = None if args.build_for_tuning else KernelTuningDatabase(SOURCE_PATH.parent / 'rules', k) + ktd = KernelTuningDatabase(SOURCE_PATH.parent / 'rules', k, build_for_tuning=args.build_for_tuning) + if False: # Debugging if k.SHIM_KERNEL_NAME == 'attn_fwd': assert not ktd.empty k.set_target_gpus(arches) diff --git a/v2python/generate_shim.py b/v2python/generate_shim.py index 6140c21f..45c0dd93 100755 --- a/v2python/generate_shim.py +++ b/v2python/generate_shim.py @@ -51,6 +51,7 @@ def parse(): p.add_argument("--archive_only", action='store_true', help='Only generate archive library instead of shared library. No linking with dependencies.') p.add_argument("--enable_zstd", type=str, default=None, help="Use zstd to compress the compiled kernel") p.add_argument("--bare_mode", action='store_true', help="Instead of generating a proper Makefile, only generate a list of source files and leave the remaining tasks to cmake.") + p.add_argument("--build_for_tuning", action='store_true', help="Include all GPU kernels in the dispatcher for performance tuning.") p.add_argument("--verbose", action='store_true', help="Print debugging messages") args = p.parse_args() args._build_root = Path(args.build_dir) @@ -281,7 +282,9 @@ def __init__(self, args, out, k : 'KernelDescription'): # Autotune dispatcher self._autotune_path = Path(args.build_dir) / k.KERNEL_FAMILY / f'autotune.{k.SHIM_KERNEL_NAME}' self._autotune_path.mkdir(parents=True, exist_ok=True) - self._ktd = KernelTuningDatabase(SOURCE_PATH.parent / 'rules', k) + self._ktd = KernelTuningDatabase(SOURCE_PATH.parent / 'rules', + k, + build_for_tuning=self._args.build_for_tuning) self._objpaths = [] @property @@ -307,7 +310,6 @@ def gen_children(self, out): k = self._kdesc p = self._shim_path / f'gpu_kernel_image.{k.SHIM_KERNEL_NAME}' args = self._args - ktd = KernelTuningDatabase(SOURCE_PATH.parent / 'rules', k) debug_counter = 0 for gpu, fsels, lut in k.gen_tuned_kernel_lut(self._ktd): # print(f'KernelShimGenerator.gen_children {fsels=}') @@ -320,7 +322,7 @@ def gen_children(self, out): if self.is_bare: return - for o in k.gen_all_object_files(p, tuned_db=ktd, sancheck_fileexists=True): + for o in k.gen_all_object_files(p, tuned_db=self._ktd, sancheck_fileexists=True): yield ObjectShimCodeGenerator(self._args, k, o) def write_conclude(self): diff --git a/v2python/gpu_targets.py b/v2python/gpu_targets.py index 86ef5dc0..fb51fef5 100644 --- a/v2python/gpu_targets.py +++ b/v2python/gpu_targets.py @@ -1,13 +1,23 @@ # Copyright © 2023-2024 Advanced Micro Devices, Inc. # SPDX-License-Identifier: MIT - AOTRITON_SUPPORTED_GPUS = { 'MI200' : 'GPU_ARCH_AMD_GFX90A', 'MI300X' : 'GPU_ARCH_AMD_GFX942', + 'Navi31' : 'GPU_ARCH_AMD_GFX1100', + 'Navi32' : 'GPU_ARCH_AMD_GFX1101', } AOTRITON_GPU_ARCH_TUNING_STRING = { 'MI200' : 'gfx90a', 'MI300X' : 'gfx942', + 'Navi31' : 'gfx1100', + 'Navi32' : 'gfx1101', +} + +AOTRITON_GPU_WARPSIZE = { + 'MI200' : 64, + 'MI300X' : 64, + 'Navi31' : 32, + 'Navi32' : 32, } diff --git a/v2python/json_tuning_database.py b/v2python/json_tuning_database.py index c5bc322c..6cd46276 100644 --- a/v2python/json_tuning_database.py +++ b/v2python/json_tuning_database.py @@ -139,7 +139,11 @@ def _lookup_tuning_info(self, fsels, with_duplicates=True): tuning_info = self._index[fallback_tup] if with_duplicates else self._index_dedup[fallback_tup] return self._downgrade(fallback_applied_fsels, tuning_info) - def craft_perf_selection(self, tinfo, perf_meta: 'list[ArgumentSelection]'): + def craft_perf_selection(self, + columns, + row, + perf_meta: 'list[ArgumentSelection]') -> 'list[TunedArgument], compiler_options': + tinfo = row if tinfo is None: # default value when tuning db does not contain the kernel return [TunedArgument(meta, meta.default_value) for meta in perf_meta], None ps = dict(tinfo['tuned_kernel']) @@ -157,7 +161,7 @@ def _select_from_db(self, indexed = self._lookup_tuning_info(fsels, with_duplicates=not no_duplicate) assert indexed for tinfo in indexed: - yield self.craft_perf_selection(tinfo, perf_meta) + yield self.craft_perf_selection(None, tinfo, perf_meta) def get_lut(self, kdesc : 'KernelDescription', diff --git a/v2python/kernel_argument.py b/v2python/kernel_argument.py index f25047b8..a24a0398 100644 --- a/v2python/kernel_argument.py +++ b/v2python/kernel_argument.py @@ -75,6 +75,10 @@ def nchoices(self): def argument_names(self): return [a[0] for a in self._ordered_arguments] + @property + def repr_name(self): + return self._ordered_arguments[0][0] + def has_argument(self, aname): return aname in self.argument_names @@ -214,6 +218,18 @@ def update_triton_api_signature(self, sig: dict): for place in self._meta.ordered_argument_places: sig[place] = self.triton_signature + # FIXME: XXX_CHOICES's key is unordered + ''' + Build a dict that maps "representative name" to selected value + Consider changing it from frozenset to tuple + ''' + @staticmethod + def build_fsel_dict(fsels : 'list[ArgumentSelection]'): + d = {} + for fsel in fsels: + d[fsel.meta.repr_name] = fsel.argument_value + return d + class TunedArgument(ArgumentSelection): def __init__(self, meta : ArgumentMetadata, value): self._meta = meta diff --git a/v2python/kernel_desc.py b/v2python/kernel_desc.py index 81d379d5..7fadc181 100644 --- a/v2python/kernel_desc.py +++ b/v2python/kernel_desc.py @@ -12,7 +12,7 @@ ) from .kernel_signature import KernelSignature from .object_desc import ObjectFileDescription -from .gpu_targets import AOTRITON_SUPPORTED_GPUS +from .gpu_targets import AOTRITON_SUPPORTED_GPUS, AOTRITON_GPU_WARPSIZE SOURCE_PATH = Path(__file__).resolve() @@ -89,6 +89,16 @@ def insert_tensor_strides_to_choices(self, last_is_continuous=False): print(f"{self.TYPE_CHOICES=}") print(f"{self.FEAT_CHOICES=}") + def gen_patched_autotune_configs(self, gpu, fsel_dict): + if AOTRITON_GPU_WARPSIZE[gpu] == 64: + yield from self.gen_autotune_configs(fsel_dict) + return + for cfg in self.gen_autotune_configs(fsel_dict): + cfg.num_warps *= 2 + if cfg.num_warps > 8: # ignore super large block + continue + yield cfg + def __init__(self, triton_kernel_name, triton_file_path): self.insert_tensor_strides_to_choices(last_is_continuous=True) self._DATA_ARGUMENTS = None @@ -127,6 +137,11 @@ def __init__(self, triton_kernel_name, triton_file_path): break if is_type: self.AUTOTUNE_KEYS_VALIDATED.append((key, self.AUTOTUNE_KEYS[key])) + ''' + AUTOTUNE_KEYS sanity check, otherwise autotune code may be broken (already happened twice). + ''' + for key in self.AUTOTUNE_KEYS: + assert key in self.ARGUMENTS, f'AUTOTUNE_KEYS "{key}" cannot be found in {self.__class__.__name__}.ARGUMENTS' @property def name(self): @@ -155,6 +170,14 @@ def gen_tuned_perf_selections(self, def set_target_gpus(self, gpus): self._target_gpus = ['native'] if gpus is None else list(gpus) + def gen_perf_selections_from_kdesc(self, + gpu : str, + fsels : 'list[ArgumentSelection]'): + fsel_dict = ArgumentSelection.build_fsel_dict(fsels) + for cfg in self.gen_patched_autotune_configs(gpu, fsel_dict): + psels, compiler_options = cfg.translate_to_psel_and_co(self._perf_meta) + yield gpu, fsels, psels, compiler_options + def gen_all_object_files(self, outpath : Path, # kernel_name : str = None, @@ -163,17 +186,27 @@ def gen_all_object_files(self, sancheck_fileexists = False) -> 'Iterator[ObjectFileDescription]': def gen(): if tuned_db is None or tuned_db.empty: - yield from itertools.product(self._target_gpus, - self.gen_func_selections(), - self.gen_perf_selections(), - [None]) + if not hasattr(self, 'gen_autotune_configs'): + yield from itertools.product(self._target_gpus, + self.gen_func_selections(), + self.gen_perf_selections(), + [None]) + return + for gpu, fsels in itertools.product(self._target_gpus, + self.gen_func_selections()): + yield from self.gen_perf_selections_from_kdesc(gpu, fsels) else: for gpu, fsels in itertools.product(self._target_gpus, self.gen_func_selections()): yield from self.gen_tuned_perf_selections(tuned_db, gpu, fsels) debug_counter = 0 for gpu, fsels, psels, compiler_options in gen(): - sig = KernelSignature(self, fsels, psels, compiler_options, gpu) + try: + sig = KernelSignature(self, fsels, psels, compiler_options, gpu) + except: + print(f"{fsels=}") + print(f"{psels=}") + exit() yield self.build_object_file_description(outpath, sig, sancheck_fileexists=sancheck_fileexists) if False: # Debugging debug_counter += 1 diff --git a/v2python/kernel_signature.py b/v2python/kernel_signature.py index 10dae4d4..fbd43d1d 100644 --- a/v2python/kernel_signature.py +++ b/v2python/kernel_signature.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: MIT from .gpu_targets import AOTRITON_GPU_ARCH_TUNING_STRING +import json class KernelSignature(object): def __init__(self, @@ -83,3 +84,14 @@ def codegen_perf_object(self) -> str: for aname in ps.argument_names: perf_key_value.append(f'.{aname} = {value}') return ', '.join(perf_key_value) + + def jsongen_psels(self) -> str: + d = {} + for ps in self._perf_selections: + value = ps.argument_value + for aname in ps.argument_names: + d[aname] = value + return json.dumps(d) + + def jsongen_copts(self) -> str: + return json.dumps(self._compiler_options) diff --git a/v2python/rules/flash/_common.py b/v2python/rules/flash/_common.py index b024dc66..b9bec0a3 100644 --- a/v2python/rules/flash/_common.py +++ b/v2python/rules/flash/_common.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: MIT from ...kernel_desc import KernelDescription, get_possible_types, select_pattern +from ...autotune_config import Config from ...autotune_binning import BinningLessOrEqual, BinningExact class FlashKernel(KernelDescription): diff --git a/v2python/rules/flash/attn_fwd.py b/v2python/rules/flash/attn_fwd.py index 615366ba..39a08854 100644 --- a/v2python/rules/flash/attn_fwd.py +++ b/v2python/rules/flash/attn_fwd.py @@ -1,7 +1,8 @@ # Copyright © 2023-2024 Advanced Micro Devices, Inc. # SPDX-License-Identifier: MIT -from ._common import FlashKernel, select_pattern, BinningLessOrEqual, BinningExact +import itertools +from ._common import FlashKernel, select_pattern, BinningLessOrEqual, BinningExact, Config class attn_fwd(FlashKernel): ARGUMENTS = [ @@ -11,6 +12,8 @@ class attn_fwd(FlashKernel): 'stride_vz', 'stride_vh', 'stride_vk', 'stride_vn', 'stride_bz', 'stride_bh', 'stride_bm', 'stride_bn', 'stride_oz', 'stride_oh', 'stride_om', 'stride_on', + 'num_head_q', + 'num_head_k', 'cu_seqlens_q', 'cu_seqlens_k', 'num_seqlens', @@ -43,7 +46,7 @@ class attn_fwd(FlashKernel): frozenset(['sm_scale']) : ['fp32'], frozenset(['M']) : ['*fp32:16'], frozenset(['cu_seqlens_q', 'cu_seqlens_k']) : ['*i32:16'], - frozenset(['num_seqlens', 'max_seqlen_q', 'max_seqlen_k']) : ['i32'], + frozenset(['num_head_q', 'num_head_k', 'num_seqlens', 'max_seqlen_q', 'max_seqlen_k']) : ['i32'], frozenset(['head_dim']) : ['i32'], frozenset(['dropout_p']) : ['fp32'], frozenset(['philox_seed']) : ['u64'], @@ -83,6 +86,7 @@ class attn_fwd(FlashKernel): 'max_seqlen_q' : BinningLessOrEqual, 'max_seqlen_k' : BinningLessOrEqual, 'CAUSAL' : BinningExact, + 'ENABLE_DROPOUT' : BinningExact, } # List of functionals that are not fully tuned in the tuning database # First element of the tuple is name. Second is the value to use instead @@ -109,3 +113,28 @@ def DOWNGRADE_RETURN_ENCODED_SOFTMAX(tuned_kernel, compiler_options): DOWNGRADER = [(('RETURN_ENCODED_SOFTMAX', True), DOWNGRADE_RETURN_ENCODED_SOFTMAX)] + @staticmethod + def gen_autotune_configs(fsel_dict : 'dict[str, Any]'): + dtype = fsel_dict['Q'] + ret = [] + BLOCK_SIZES = [(128, 64), (64, 64), (64, 32)] + WAVES_PER_EU = [0, 1, 2, 3, 4] + PRE_LOAD_V = [True, False] + for (M, N), waves, pre in itertools.product(BLOCK_SIZES, + WAVES_PER_EU, + PRE_LOAD_V): + if dtype == '*fp32:16': + M //= 2 + kw = {'BLOCK_M': M, 'BLOCK_N': N, 'waves_per_eu': waves, 'pre_load_v': pre} + yield Config(kw, num_stages=1, num_warps=4) + yield from [ + Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), + Config({'BLOCK_M': 128, 'BLOCK_N':128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4), + Config({'BLOCK_M': 256, 'BLOCK_N':128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), + Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=1, num_warps=8), + Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=1, num_warps=4), + Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=1, num_warps=8), + Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=1, num_warps=4), + ] diff --git a/v2python/rules/tuning_database.sqlite3 b/v2python/rules/tuning_database.sqlite3 index 19d86038..29aeb0e2 100644 Binary files a/v2python/rules/tuning_database.sqlite3 and b/v2python/rules/tuning_database.sqlite3 differ diff --git a/v2python/sqlite_tuning_database.py b/v2python/sqlite_tuning_database.py index eb4486f4..a1198ee9 100644 --- a/v2python/sqlite_tuning_database.py +++ b/v2python/sqlite_tuning_database.py @@ -99,7 +99,10 @@ def _select_from_db(self, for row in selected_rows: yield self.craft_perf_selection(selected_columns, row, perf_meta) - def craft_perf_selection(self, columns, row, perf_meta: 'list[ArgumentSelection]'): + def craft_perf_selection(self, + columns, + row, + perf_meta: 'list[ArgumentSelection]') -> 'list[TunedArgument], compiler_options': if row is None: # default value when tuning db does not contain the kernel return [TunedArgument(meta, meta.default_value) for meta in perf_meta], None ps = self._row_to_dict(columns, row, prefix='tuned_kernel') diff --git a/v2python/table_tool.py b/v2python/table_tool.py index 7a84027c..3e89a10a 100644 --- a/v2python/table_tool.py +++ b/v2python/table_tool.py @@ -102,6 +102,11 @@ def upsert(self, line_text, *, create_table_only): if self.verbose: print(f'{line_text=}') print(f'{tune_info=}') + tune_result = tune_info.get('result', 'result-not-reported-in-older-version') + if not tune_result == 'tuned': + if self.verbose: + print(f'{tune_result=}') + return sql_table = self.ensure_table(tune_info) if create_table_only: return diff --git a/v2python/tuning_database.py b/v2python/tuning_database.py index 5bfa48ce..f7ca313c 100644 --- a/v2python/tuning_database.py +++ b/v2python/tuning_database.py @@ -3,7 +3,7 @@ import pathlib import sqlite3 -from .kernel_argument import TunedArgument +from .kernel_argument import ArgumentSelection, TunedArgument from .gpu_targets import AOTRITON_GPU_ARCH_TUNING_STRING from .common_tuning_database import CommonKernelTuningDatabaseForArch from .sqlite_tuning_database import SQLiteKernelTuningDatabaseForArch @@ -27,7 +27,13 @@ def _select_from_db(self, no_duplicate=True): assert False - def craft_perf_selection(self, tinfo, perf_meta: 'list[ArgumentSelection]'): + def extract_inputs(self, columns, row): + assert False + + def craft_perf_selection(self, + columns, + row, + perf_meta: 'list[ArgumentSelection]') -> 'list[TunedArgument], compiler_options': return [TunedArgument(meta, meta.default_value) for meta in perf_meta], None def get_lut(self, @@ -40,12 +46,48 @@ def get_lut(self, autotune_keys=None, perf_meta=perf_meta) +class BootstrapTuningDatabaseForArch(EmptyKernelTuningDatabaseForArch): + + @classmethod + def is_passthrough_tuning(klass): + return True + + def extract_inputs(self, columns, row): + assert False + + def craft_perf_selection(self, + columns, + row, + perf_meta: 'list[ArgumentSelection]') -> 'list[TunedArgument], compiler_options': + if row is None: + return [TunedArgument(meta, meta.default_value) for meta in perf_meta], None + return row + + def get_lut(self, + kdesc : 'KernelDescription', + autotune_keys : 'list[tuple[str, Binning]]', + fsels : 'list[ArgumentSelection]', + perf_meta : 'list[ArgumentMetadata]'): + fsel_dict = ArgumentSelection.build_fsel_dict(fsels) + rows = [] + for cfg in kdesc.gen_autotune_configs(fsel_dict): + psels, compiler_options = cfg.translate_to_psel_and_co(perf_meta) + rows.append((psels, compiler_options)) + # print(f'get_lut {len(rows)=}') + return KernelTuningEntryForFunctionalOnGPU(kdesc, self, fsels, + columns=None, rows=rows, + autotune_keys=None, + perf_meta=perf_meta) + class KernelTuningDatabase(object): MONOLITHIC_TUNING_DATABASE_FILE = 'tuning_database.sqlite3' - def __init__(self, tune_info_dir : pathlib.Path, k : 'KernelDescription'): + def __init__(self, tune_info_dir : pathlib.Path, k : 'KernelDescription', build_for_tuning=False): self._kdesc = k self.arch_dict = {} + self._build_for_tuning = build_for_tuning and hasattr(k, 'gen_autotune_configs') + if self._build_for_tuning: + return td = pathlib.Path(tune_info_dir) / self.MONOLITHIC_TUNING_DATABASE_FILE # in case tune_info_dir is str # print(f"Tryint to probe KernelTuningDatabase inside {td}") downgrader = TuningDowngrader.create_from_kdesc(k) @@ -62,10 +104,13 @@ def __init__(self, tune_info_dir : pathlib.Path, k : 'KernelDescription'): def select_gpu(self, gpu, index): arch = AOTRITON_GPU_ARCH_TUNING_STRING[gpu] if arch not in self.arch_dict: - print(f'For kernel {self._kdesc.KERNEL_FAMILY}.{self._kdesc.name}, Architecture {arch} was not found in tuning database, using dummy one instead') - self.arch_dict[arch] = EmptyKernelTuningDatabaseForArch(self._kdesc, arch) + if not self._build_for_tuning: + print(f'For kernel {self._kdesc.KERNEL_FAMILY}.{self._kdesc.name}, Architecture {arch} was not found in tuning database, using dummy one instead') + self.arch_dict[arch] = EmptyKernelTuningDatabaseForArch(self._kdesc, arch) + else: + self.arch_dict[arch] = BootstrapTuningDatabaseForArch(self._kdesc, arch) return self.arch_dict[arch].set_gpu(gpu, index) @property def empty(self): - return not self.arch_dict + return not self.arch_dict or self._build_for_tuning diff --git a/v2python/tuning_lut.py b/v2python/tuning_lut.py index 14cbfad4..1bd454b8 100644 --- a/v2python/tuning_lut.py +++ b/v2python/tuning_lut.py @@ -26,19 +26,27 @@ def __init__(self, self._fsels = fsels # print(f'{self._fsels=}') self._lut_dic = {} - self._autotune_keys = autotune_keys if autotune_keys is not None else None + self._autotune_keys = autotune_keys self._autotune_key_values = { key : set() for key, _ in autotune_keys } if autotune_keys is not None else None self._autotune_key_class = { key : klass for key, klass in autotune_keys } if autotune_keys is not None else None self._sigs = [] self._sig_dict = {} - if rows is None and autotune_keys is None: - self._lut_dtype = np.uint8 - self._lut_cdtype = f'uint8_t' - self._lut_tensor = np.array([0], dtype=np.uint8) + if autotune_keys is None: + self._lut_dtype = np.int8 + self._lut_cdtype = f'int8_t' + self._lut_tensor = np.array([0], dtype=np.int8) self._lut_cshape = ''.join([f'[{s}]' for s in self._lut_tensor.shape]) self._untuned = True - default_psels, default_co = dba.craft_perf_selection(None, perf_meta) - self._lut_dic[0] = self._allocate_sig(default_psels, default_co)[0] + # print(f'{dba.is_passthrough_tuning()=}') + if dba.is_passthrough_tuning(): + # print(f'KernelTuningEntryForFunctionalOnGPU.__init__ {len(rows)=}') + for row in rows: + psels, compiler_options = dba.craft_perf_selection(columns, row, perf_meta) + self._allocate_sig(psels, compiler_options) + self._lut_dic[0] = 0 + else: + default_psels, default_co = dba.craft_perf_selection(None, None, perf_meta) + self._lut_dic[0] = self._allocate_sig(default_psels, default_co)[0] return self._untuned = False # print(f'KernelTuningEntryForFunctionalOnGPU {fsels=}') @@ -65,6 +73,7 @@ def track_autotune_key_values(self, columns, row, tup): def _allocate_sig(self, psels, compiler_options): sig = KernelSignature(self._kdesc, self._fsels, psels, compiler_options, self._dba.gpu) + # print(f'_allocate_sig {sig.compact_signature}') compact = sig.compact_signature if compact not in self._sig_dict: self._sig_dict[compact] = (len(self._sigs), sig) @@ -79,11 +88,11 @@ def get_lut(self) -> 'tuple[np.ndarray, list[KernelSignature]': def _build_lut_tensor(self): self._autotune_key_buckets = [ klass(self._autotune_key_values[key]) for key, klass in self._autotune_keys ] - for dtype in [np.uint8, np.uint16, np.uint32, np.uint64]: + for dtype in [np.int8, np.int16, np.int32, np.int64]: if len(self._sigs) < np.iinfo(dtype).max: break self._lut_dtype = dtype - self._lut_cdtype = f'uint{np.iinfo(dtype).bits}_t' + self._lut_cdtype = f'int{np.iinfo(dtype).bits}_t' self._lut_tensor = np.empty([bucket.nvalues for bucket in self._autotune_key_buckets], dtype=dtype) assert self._lut_tensor.size > 0, 'LUT tensor must be non-empty. Empty LUT is not constructed by _build_lut_tensor' self._lut_cshape = ''.join([f'[{s}]' for s in self._lut_tensor.shape]) @@ -92,7 +101,7 @@ def _build_lut_tensor(self): for indices, atk_values in zip(itertools.product(*list_of_atk_indices), itertools.product(*self._list_of_atk_representatives)): fs_atk_values = tuple(atk_values) - self._lut_tensor[indices] = self._lut_dic[fs_atk_values] + self._lut_tensor[indices] = self._lut_dic.get(fs_atk_values, -1) # FIXME: Debugging if False and self._kdesc.SHIM_KERNEL_NAME == 'attn_fwd': print(f'_build_lut_tensor {self._autotune_key_values=}') @@ -101,6 +110,7 @@ def _build_lut_tensor(self): def gen_kernel_symbols(self, kernel_image_dir): for sig in self._sigs: + # print(f"gen_kernel_symbols {sig.compact_signature=}") o = self._kdesc.build_object_file_description(kernel_image_dir, sig) yield o.c_identifier_signature, o._hsaco_kernel_path, o @@ -121,6 +131,18 @@ def codegen_incbin_names(self, kernel_image_dir, compressed=False): incbin_lines.append(f'"{incbin_symbol_name}"') return 'static const char* incbin_kernel_names[] = {\n ' + ",\n ".join(incbin_lines) + "\n};" + def codegen_kernel_psels(self, kernel_image_dir, compressed=False): + lines = [] + for sig in self._sigs: + lines.append(f'R"xyzw({sig.jsongen_psels()})xyzw"') + return 'static const char* kernel_psels[] = {\n ' + ",\n ".join(lines) + "\n};" + + def codegen_kernel_copts(self, kernel_image_dir, compressed=False): + lines = [] + for sig in self._sigs: + lines.append(f'R"xyzw({sig.jsongen_copts()})xyzw"') + return 'static const char* kernel_copts[] = {\n ' + ",\n ".join(lines) + "\n};" + def codegen_kernel_image_objects(self, kernel_image_dir): kernel_image_symbols = [] for incbin_symbol_name, _, o in self.gen_kernel_symbols(kernel_image_dir): @@ -158,6 +180,8 @@ def write_lut_source(self, outdir : 'pathlib.Path', compressed, bare_mode): d = { 'incbin_kernel_images' : self.codegen_incbin_code(gpu_kernel_image_dir, compressed=compressed), 'incbin_kernel_names' : self.codegen_incbin_names(gpu_kernel_image_dir, compressed=compressed), + 'kernel_psels' : self.codegen_kernel_psels(gpu_kernel_image_dir, compressed=compressed), + 'kernel_copts' : self.codegen_kernel_copts(gpu_kernel_image_dir, compressed=compressed), 'kernel_family_name' : self._kdesc.KERNEL_FAMILY, 'shim_kernel_name' : self._kdesc.SHIM_KERNEL_NAME, 'godel_number' : godel_number, diff --git a/v2src/CMakeLists.txt b/v2src/CMakeLists.txt index 02dd38e2..f7399fdc 100644 --- a/v2src/CMakeLists.txt +++ b/v2src/CMakeLists.txt @@ -25,8 +25,13 @@ message("AOTRITON_COMPILER ${AOTRITON_COMPILER}") # ) # add_dependencies(aotriton_v2_gen_compile aotriton_venv_triton) +if(AOTRITON_BUILD_FOR_TUNING) + set(GENERATE_OPTION "--build_for_tuning") +else(AOTRITON_BUILD_FOR_TUNING) + set(GENERATE_OPTION "") +endif() execute_process( - COMMAND ${CMAKE_COMMAND} -E env VIRTUAL_ENV=${VENV_DIR} "${VENV_DIR}/bin/python" -m v2python.generate_compile --target_gpus ${TARGET_GPUS} --build_dir "${AOTRITON_V2_BUILD_DIR}" --bare_mode + COMMAND ${CMAKE_COMMAND} -E env VIRTUAL_ENV=${VENV_DIR} "${VENV_DIR}/bin/python" -m v2python.generate_compile --target_gpus ${TARGET_GPUS} --build_dir "${AOTRITON_V2_BUILD_DIR}" --bare_mode ${GENERATE_OPTION} COMMAND_ECHO STDOUT WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_PARENT_DIR}" ) @@ -56,7 +61,7 @@ foreach(RULE IN LISTS HSACO_RULES) "--waves_per_eu" "${WAVESPEREU}" "--target" "${TGTGPU}" "--signature" "${SIG}" - COMMAND ${ZSTD_EXEC} "-f" ${HSACO} + COMMAND ${ZSTD_EXEC} "-q" "-f" ${HSACO} DEPENDS aotriton_venv_triton ) list(APPEND ALL_HSACOS "${HSACO}.zst") @@ -113,7 +118,7 @@ message(STATUS "AOTRITON_ZSTD_INCLUDE ${AOTRITON_ZSTD_INCLUDE}") message(STATUS "AOTRITON_SHIM_FLAGS ${AOTRITON_SHIM_FLAGS}") execute_process( - COMMAND ${CMAKE_COMMAND} -E env VIRTUAL_ENV=${VENV_DIR} "${VENV_DIR}/bin/python" -m v2python.generate_shim --target_gpus ${TARGET_GPUS} --build_dir ${AOTRITON_V2_BUILD_DIR} --bare_mode + COMMAND ${CMAKE_COMMAND} -E env VIRTUAL_ENV=${VENV_DIR} "${VENV_DIR}/bin/python" -m v2python.generate_shim --target_gpus ${TARGET_GPUS} --build_dir ${AOTRITON_V2_BUILD_DIR} --bare_mode ${GENERATE_OPTION} COMMAND_ECHO STDOUT WORKING_DIRECTORY "${CMAKE_CURRENT_LIST_DIR}/.." ) @@ -122,7 +127,7 @@ file(STRINGS "${AOTRITON_V2_BUILD_DIR}/Bare.shim" SHIM_CC_FILES) # CAVEAT: Actual shim code can only be generated during build phase because it # requires some kernel information. (Notably shared memory requirement) add_custom_target(aotriton_v2_gen_shim - COMMAND ${CMAKE_COMMAND} -E env VIRTUAL_ENV=${VENV_DIR} "${VENV_DIR}/bin/python" -m v2python.generate_shim --target_gpus ${TARGET_GPUS} --build_dir ${AOTRITON_V2_BUILD_DIR} ${AOTRITON_SHIM_FLAGS} + COMMAND ${CMAKE_COMMAND} -E env VIRTUAL_ENV=${VENV_DIR} "${VENV_DIR}/bin/python" -m v2python.generate_shim --target_gpus ${TARGET_GPUS} --build_dir ${AOTRITON_V2_BUILD_DIR} ${AOTRITON_SHIM_FLAGS} ${GENERATE_OPTION} BYPRODUCTS ${SHIM_CC_FILES} # Essential, otherwise add_library complains COMMAND_EXPAND_LISTS WORKING_DIRECTORY "${CMAKE_CURRENT_LIST_DIR}/.." @@ -148,6 +153,11 @@ target_include_directories(aotriton_v2 PUBLIC ${CMAKE_CURRENT_LIST_DIR}/../inclu target_include_directories(aotriton_v2 PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) target_include_directories(aotriton_v2 PRIVATE ${CMAKE_CURRENT_LIST_DIR}/../third_party/incbin) target_compile_options(aotriton_v2 PRIVATE -fPIC --no-offload-arch=all) +if(AOTRITON_BUILD_FOR_TUNING) + target_compile_definitions(aotriton_v2 PRIVATE -DAOTRITON_BUILD_FOR_TUNING=1) +else(AOTRITON_BUILD_FOR_TUNING) + target_compile_definitions(aotriton_v2 PRIVATE -DAOTRITON_BUILD_FOR_TUNING=0) +endif(AOTRITON_BUILD_FOR_TUNING) # message(STATUS "AOTRITON_EXTRA_COMPILER_OPTIONS ${AOTRITON_EXTRA_COMPILER_OPTIONS}") # add_custom_target(aotriton_v2 diff --git a/v2src/flash/attn_bwd.cc b/v2src/flash/attn_bwd.cc index 245c6f25..fefe8d33 100644 --- a/v2src/flash/attn_bwd.cc +++ b/v2src/flash/attn_bwd.cc @@ -117,7 +117,8 @@ bwd_kernel_dk_dv(T4 q, uint64_t philox_seed, uint64_t philox_offset, bool is_causal, - aotriton::Stream stream_wrap) { + aotriton::Stream stream_wrap, + ExtraArguments* extargs) { hipError_t err; auto stream = stream_wrap.native(); auto arch = getArchFromStream(stream); @@ -127,6 +128,7 @@ bwd_kernel_dk_dv(T4 q, uint32_t(params.Q->size(1)), params.num_seqlens == 0 ? uint32_t(params.Q->size(0)) : params.num_seqlens, }; + // std::cerr << "bwd_kernel_dk_dv grid conf " << grid.x << " " << grid.y << " " << grid.z << std::endl; return grid; }; constexpr int kMinHeadDimCompiled = 16; @@ -163,6 +165,10 @@ bwd_kernel_dk_dv(T4 q, .PADDED_HEAD = head_size_rounded != head_size, .BIAS_TYPE = bias_type, }; +#if AOTRITON_BUILD_FOR_TUNING + if (extargs) + params._has_preferred_kernel = extargs->force_kernel_index; +#endif BwdKernelDkDvContext context; context.grid_calculator = grid_calculator; err = context.lookup_optimal(params, arch); @@ -194,7 +200,8 @@ bwd_kernel_dq(T4 q, uint64_t philox_seed, uint64_t philox_offset, bool is_causal, - aotriton::Stream stream_wrap) { + aotriton::Stream stream_wrap, + ExtraArguments* extargs) { hipError_t err; auto stream = stream_wrap.native(); auto arch = getArchFromStream(stream); @@ -204,6 +211,7 @@ bwd_kernel_dq(T4 q, uint32_t(params.Q->size(1)), params.num_seqlens == 0 ? uint32_t(params.Q->size(0)) : params.num_seqlens, }; + // std::cerr << "bwd_kernel_dq grid conf " << grid.x << " " << grid.y << " " << grid.z << std::endl; return grid; }; constexpr int kMinHeadDimCompiled = 16; @@ -240,6 +248,10 @@ bwd_kernel_dq(T4 q, .PADDED_HEAD = head_size_rounded != head_size, .BIAS_TYPE = bias_type, }; +#if AOTRITON_BUILD_FOR_TUNING + if (extargs) + params._has_preferred_kernel = extargs->force_kernel_index; +#endif BwdKernelDqContext context; context.grid_calculator = grid_calculator; err = context.lookup_optimal(params, arch); @@ -273,7 +285,8 @@ _attn_bwd_common(T4 q, uint64_t philox_seed, uint64_t philox_offset, bool is_causal, - aotriton::Stream stream) { + aotriton::Stream stream, + ExtraArguments* extargs) { hipError_t ret; if (num_seqlens == 0) ret = bwd_preprocess(out, dout, delta, stream); @@ -301,7 +314,8 @@ _attn_bwd_common(T4 q, philox_seed, philox_offset, is_causal, - stream); + stream, + extargs); if (ret != hipSuccess) return ret; @@ -325,7 +339,8 @@ _attn_bwd_common(T4 q, philox_seed, philox_offset, is_causal, - stream); + stream, + extargs); return ret; } @@ -347,7 +362,8 @@ attn_bwd(T4 q, uint64_t philox_seed, uint64_t philox_offset, bool is_causal, - aotriton::Stream stream) { + aotriton::Stream stream, + ExtraArguments* extargs) { auto null_t1 = T1::get_null_tensor(DType::kInt32); return _attn_bwd_common(q, k, @@ -371,7 +387,8 @@ attn_bwd(T4 q, philox_seed, philox_offset, is_causal, - stream); + stream, + extargs); } hipError_t @@ -396,13 +413,14 @@ attn_bwd_compact_varlen(T4 q, // 1 x num_heads x total_q x head_size, uint64_t philox_seed, uint64_t philox_offset, bool is_causal, - aotriton::Stream stream) { + aotriton::Stream stream, + ExtraArguments* extargs) { return _attn_bwd_common(q, k, v, cu_seqlens_q, cu_seqlens_k, - cu_seqlens_q.size(0) - 1, + cu_seqlens_q.size(0), max_seqlen_q, max_seqlen_k, b, @@ -419,7 +437,8 @@ attn_bwd_compact_varlen(T4 q, // 1 x num_heads x total_q x head_size, philox_seed, philox_offset, is_causal, - stream); + stream, + extargs); } } diff --git a/v2src/flash/attn_fwd.cc b/v2src/flash/attn_fwd.cc index 07428dc1..a1acf291 100644 --- a/v2src/flash/attn_fwd.cc +++ b/v2src/flash/attn_fwd.cc @@ -33,7 +33,8 @@ _attn_fwd_common(T4 q, uint64_t philox_offset, T4 encoded_softmax, bool is_causal, - aotriton::Stream stream_wrap) { + aotriton::Stream stream_wrap, + ExtraArguments* extargs) { hipError_t err; auto stream = stream_wrap.native(); auto arch = getArchFromStream(stream); @@ -54,6 +55,8 @@ _attn_fwd_common(T4 q, return grid; }; int head_size = q.size(3); + int num_head_q = q.size(1); + int num_head_k = k.size(1); int head_dim_rounded = std::max(16, aotriton::bit_ceil(head_size)); int bias_type = 0; if (b) { @@ -69,11 +72,13 @@ _attn_fwd_common(T4 q, .encoded_softmax = &encoded_softmax, .sm_scale = sm_scale, .M = &softmax_lse, - .cu_seqlens_q = &cu_seqlens_q, - .cu_seqlens_k = &cu_seqlens_k, + .num_head_q = num_head_q, + .num_head_k = num_head_k, .num_seqlens = num_seqlens, .max_seqlen_q = max_seqlen_q, .max_seqlen_k = max_seqlen_k, + .cu_seqlens_q = &cu_seqlens_q, + .cu_seqlens_k = &cu_seqlens_k, .head_dim = static_cast(head_size), .dropout_p = dropout_p, .philox_seed = philox_seed, @@ -85,10 +90,21 @@ _attn_fwd_common(T4 q, .PADDED_HEAD = head_dim_rounded != head_size, .BIAS_TYPE = bias_type, }; +#if AOTRITON_BUILD_FOR_TUNING + if (extargs) + params._has_preferred_kernel = extargs->force_kernel_index; +#endif AttnFwdContext context; context.grid_calculator = grid_calculator; // .grid_calculator = grid_calculator err = context.lookup_optimal(params, arch); +#if AOTRITON_BUILD_FOR_TUNING + if (extargs) { + extargs->total_number_of_kernels = params._total_number_of_kernels; + extargs->selected_kernel_psels = params._preferred_kernel_psels; + extargs->selected_kernel_copts = params._preferred_kernel_copts; + } +#endif if (err != hipSuccess) { return err; } @@ -109,7 +125,8 @@ attn_fwd(T4 q, uint64_t philox_offset, T4 encoded_softmax, bool is_causal, - aotriton::Stream stream_wrap) { + aotriton::Stream stream_wrap, + ExtraArguments* extargs) { auto null_t1 = T1::get_null_tensor(DType::kInt32); return _attn_fwd_common(q, k, @@ -128,7 +145,8 @@ attn_fwd(T4 q, philox_offset, encoded_softmax, is_causal, - stream_wrap); + stream_wrap, + extargs); } hipError_t @@ -148,7 +166,8 @@ attn_fwd_compact_varlen(T4 q, // 1 x num_heads x total_q x head_size, uint64_t philox_offset, T4 encoded_softmax, bool is_causal, - aotriton::Stream stream_wrap) { + aotriton::Stream stream_wrap, + ExtraArguments* extargs) { return _attn_fwd_common(q, k, v, @@ -166,7 +185,8 @@ attn_fwd_compact_varlen(T4 q, // 1 x num_heads x total_q x head_size, philox_offset, encoded_softmax, is_causal, - stream_wrap); + stream_wrap, + extargs); } } diff --git a/v2src/template/autotune_table_entry.cc b/v2src/template/autotune_table_entry.cc index 1a13565b..f119c741 100644 --- a/v2src/template/autotune_table_entry.cc +++ b/v2src/template/autotune_table_entry.cc @@ -18,12 +18,24 @@ [[incbin_kernel_images]]; -#ifndef NDEBUG +#if defined(NDEBUG) || AOTRITON_BUILD_FOR_TUNING [[incbin_kernel_names]]; #endif +#define ARRAY_SIZE(array) (sizeof(array) / sizeof(array[0])) + namespace { // Anonymous namespace +#if AOTRITON_BUILD_FOR_TUNING +static constexpr int incbin_num_kernels = ARRAY_SIZE(incbin_kernel_names); +#endif + +#if AOTRITON_BUILD_FOR_TUNING +// PSels and Copts in JSON String +[[kernel_psels]]; +[[kernel_copts]]; +#endif + struct PerfFields { [[perf_fields]]; }; @@ -45,8 +57,26 @@ namespace aotriton::v2::[[kernel_family_name]]::autotune { // using aotriton::v2::[[kernel_family_name]]::[[param_class_name]]; void CURRENT_ENTRY_PUBLIC::operator()([[param_class_name]]& params) { +#if AOTRITON_BUILD_FOR_TUNING + int preferred_index = params._has_preferred_kernel; + params._total_number_of_kernels = incbin_num_kernels; + if (preferred_index >= 0) { + if (preferred_index >= incbin_num_kernels) + return ; + params.selected_kernel = &image_list[preferred_index]; + params._debug_kernel_name = incbin_kernel_names[preferred_index]; + params._preferred_kernel_psels = kernel_psels[preferred_index]; + params._preferred_kernel_copts = kernel_copts[preferred_index]; + const auto& perf = image_perf_list[preferred_index]; + [[perf_field_assignment]]; + return ; + } +#endif [[binning_autotune_keys]] auto kernel_index = lut[[binned_indices]]; + if (kernel_index < 0) { + return ; + } params.selected_kernel = &image_list[kernel_index]; #ifndef NDEBUG std::cerr << __FILE__ << " kernel_index = " << int(kernel_index) << std::endl; diff --git a/v2src/template/shim.h b/v2src/template/shim.h index 635174be..99150741 100644 --- a/v2src/template/shim.h +++ b/v2src/template/shim.h @@ -20,6 +20,12 @@ struct [[param_class_name]] { TritonKernel* selected_kernel = nullptr; const char* _debug_kernel_name = nullptr; +#if AOTRITON_BUILD_FOR_TUNING + int _has_preferred_kernel = -1; // For C++ based autotune database generation + int _total_number_of_kernels = -1; + const char* _preferred_kernel_psels = nullptr; + const char* _preferred_kernel_copts = nullptr; +#endif int64_t godel_number() const; }; diff --git a/v2src/util.cc b/v2src/util.cc index 28fca9de..2c47e1dc 100644 --- a/v2src/util.cc +++ b/v2src/util.cc @@ -30,6 +30,8 @@ struct LazyArch { std::unordered_map LazyArch::string_to_arch = { {"gfx90a:sramecc+:xnack-", GPU_ARCH_AMD_GFX90A}, {"gfx942:sramecc+:xnack-", GPU_ARCH_AMD_GFX942}, + {"gfx1100", GPU_ARCH_AMD_GFX1100}, + {"gfx1101", GPU_ARCH_AMD_GFX1101}, }; GpuArch