Improve Reduction kernel api#151
Closed
qianfengz wants to merge 64 commits into
Closed
Conversation
* add DeviceGemmXdl * update script * fix naming issue * fix comment * output HostTensorDescriptor * rename * padded GEMM for fwd v4r4r4 nhwc * refactor * refactor * refactor * adding ckProfiler * adding ckProfiler * refactor * fix tuning parameter bug * add more gemm instances * add more fp16 GEMM instances * fix profiler driver * fix bug in tuning parameter * add fp32 gemm instances * small fix * refactor * rename * refactor gemm profiler; adding DeviceConv and conv profiler * refactor * fix * add conv profiler * refactor * adding more GEMM and Conv instance * Create README.md Add build instruction for ckProfiler * Create README.md Add Readme for gemm_xdl example * Update README.md Remove build instruction from top most folder * Update README.md * clean up
* start fixing 16bit data packing * adding StaticTensor * adding StaticTensor * adding StaticTensor * add missing constexpr * adding static tensor * adding static tensor * adding transpose * add inline asm for transpose 2x2 of half_t * add general transpose_vectors(), but have unnecessary register initialization using v_mov * fix unnecessary register initialization in transpose_vector by using more pass-by-reference * add hardcoded logic for NHWC wrw * improve asm for v_pack * make ThreadwiseTensorSliceTransfer_v3r2 support any tensor * tweak * reorganize file
* init StaticBufferV2 * clean * adopt old output stage for staticBufferV2 * clean * remove hack * clean * clean * add parameters * clean code * move c_buffer alloc into blockwise gemm * add adaptors for m/n_thread_data_on_grid * tweak gemm * adjust blockwise_gemm_xdlops * tweak * update conv * update script * adding bwd 1x1 * update script * adding 1x1 bwd * debugging bwd 1x1 failure * update script * update script * test * test v100 * add bf16_1k * clang-format * clean * add bfp16 for gfx908 * add verification * clean up * clean code * restore bfl16 * clean * add bfp16 support into gemm_driver * apply new generator to other drivers * add int8 support * cleanb * clean * clean * clean Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: Chao Liu <lc.roy86@gmail.com> Co-authored-by: root <root@hayabusa6111.amd.com>
…n building ckProfiler (#51) * fixed bfloat16 issues * refactor type_convert Co-authored-by: Chao Liu <chao.liu2@amd.com>
* fixed bfloat16 issues * refactor type_convert * fixed host_convolution_forward for ushort Co-authored-by: Chao Liu <chao.liu2@amd.com>
* init
* refactor for 1x1
* rename e0_e1
* add e1 with bugs
* debug
* fixed
* fixed e1
* add timer
* imprve threadwise gemm with dot2
* add e2
* tuning
* seperate c2
* add nhwc
* restore nchwc
* clean
* opt
* fixed; tuning
* add BGlobalMoveSliceWindowStepHacks{}
* tuning
* repeat running
* adjust
* merge v5r1 nchwc
* add adaptors
* split k0 k1 in c_thread_grid
* split h and w
* remove v5r1 nhwc
* clean for pr
* remove host_conv_add
* clean code
* clean
* add dynamic support
* static mode
* test static
* add conv+add fusion
* fixed validation
* naming fix
* use activ_enum
* make static
* refactor conv_add for InMem::add
* add bias
* add conv_out
* add configurable makeddesc
* add maxpool fusion
* add maxpool host for validation
* enable static desc
* conv-only use v5r1_add
* test
* test
* for binary dumps
* fixed incorrect results due to typo
* clean
* debugging maxpool
* workaround with offset trick
* clean code
* modularize ops of fusion
* add gridwise_gemm_v3
* create seperate fusion fun
* enable dynamic mode of conv and conv+resize_add
* add dynamic mode of maxpool
* add pass by point
* add activ_type as arguments
* merge develop
* clean
* reset config to old default
Co-authored-by: Chao Liu <chao.liu2@amd.com>
…rom pointer of scalars (#53) * reworking vector_type * use __builtin_memcpy for bit_cast and vector access of scalar pointer * clean up
* gemm+activation * move C pointwise operation into threadwise copy * add pointwise operation to A/B matrix * update ckProfiler * adding bias add * adding bias add * adding bias add * added bias add; worked around compiler issues * clean up * clean up * Update README.md * Update README.md * Update README.md * clean up * add conv_xdl example * adding conv_xdl_bias_relu_add example * add conv+bias+relu+add, but has register spill issue * tweak * tweak * refactor * Update README.md update readme for example/2_gemm_xdl_bias_relu_add * clean up * Update README.md update readme for example/3_conv_xdl * Update README.md
* fix relu * clean up * clean up
* fix relu * clean up * clean up * adding 1x1 conv * adding 1x1 conv * added 1x1 conv * refactor * refactor * refactor * added profiler for conv+bias+relu+add * clean up * adding conv+bias+relu * adding conv+bias+relu * added conv+bias+relu * Update README.md * update cpu verification * adding c shuffle * update static_tensor for dealing with invalid element * adding c shuffle * debugging * fix bug * convert to fp16 before shuffle * shuffle more than one M/NRepeat * clean up * remove coordinate step hack from GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 * clean up * remove coordinate step hack from all gridwise gemm xdl * clean up coordinate step hack * clean up coordinate step hack * ThreadwiseTensorSliceTransfer_v3r2 support pointwise op on both src and dst * adding output shuffle in conv+bias+relu+add * update * added conv+bias+relu+add with c shuffle * added conv+bias+relu+add with c shuffle * fix forward_sweep bugs in threadwise copy * clean up * refactor * clean up * clean up * added conv_c_shuffle+bias_relu * clean up * added conv+bias+relu+atomic_add * clean up * clean up * clean up * clean up * clean up * clean up * misc fixes; add 1x1 specialization * clean up * delete unused device op * clean up * add support for odd C value
* fix build issue
* [What] 1. Add DeviceGemmXdl_C_Shuffle 2. Revise example of gemm_xdl [Why] Prepare to add shuffle version of D = alpha * (A * B) + beta * C [How] Imitate DeviceGemmXdl and device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
* Do not hardcode the function parameter, use template instead. * [What] Remove AThreadTransferSrcResetCoordinateAfterRun and BThreadTransferSrcResetCoordinateAfterRun in host API [Why] "C_Shuffle" version is supposed to be similar to the vanilla one * Fix typo Let DeviceGemmXdl_C_Shuffle use kernel_gemm_xdlops_v3r1
* add DeviceGemmSplitKXdl * add file device_gemm_splitk_xdl.hpp * set c matrix zero * using atomic * add all tuning parameter to f32 mkkn * grid size change to 720 * add tunning parameter for NT * add tunning parameter for TN * add tunning parameter for TT * add m=96tunning parameter * add lost config * add element wise operation * fixed MPerBlock=96 * remove marco for slpitk swtich * add test * add new line at the end of device_gemm_xdl_instance.hpp * remove step hack * seperate split-k instance files * add tunning parameters * change disired grid size to parameters * remove slice length * add desiredgridsize parameter to ckProfiler * add losting file device_gemm_xdl_splitk_instance.hpp * change desired gride size to kbatch * format * format * clean up * add selection of device_instances * clean code * fix build issue Co-authored-by: ltqin <letaoqin@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: Jing Zhang <jizhan@amd.com>
* test mfma builtins * add fp16 buildins * add int8 buildins * add bfl16 buildins * simplify host conv forward * clean * clean
* add reference * clean up * add reference for conv * rename Co-authored-by: ltqin <letaoqin@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
* tweak conv for odd C * update script * clean up elementwise op * fix build * clean up * added example for gemm+bias+relu+add * added example for gemm+bias+relu * add profiler for gemm_s_shuffle; re-org files * add profiler * fix build * clean up * clean up * clean up * fix build
- device_gemm_xdl_c_shuffle function signature matches split-k - retire host_driver since it is no longer maintained - linter error (unused variable) Co-authored-by: Chao Liu <chao.liu2@amd.com>
* [What] Add 2d version of bias, prepare to implement alpha / beta scaling * Add alpha / beta functor * Refine parameter of example * [What] Use real type instead of template [Why] Prevent implicit cast * Rename parameter for general operator * Remove redundant comment * Fix compile error Co-authored-by: rocking <chunylai@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
* prepare host for batched_gemm * init commit of batched kernels * fixed * refine transform with freeze * m/n padding * fixed a bug; clean * add small tiles * clean * clean code * clean code * add nt, tn, tt layout * add missing file * use StaticBufferTupleOfVector instead * add reference_batched_gemm * fixed a macro
* add DeviceGemmSplitKXdl * add file device_gemm_splitk_xdl.hpp * set c matrix zero * using atomic * add all tuning parameter to f32 mkkn * grid size change to 720 * add tunning parameter for NT * add tunning parameter for TN * add tunning parameter for TT * add m=96tunning parameter * add lost config * debug * fix sweep * add failed tuning params * fixed sweep logic * clean * add padding to M/N for irr tile size * clean code * add element wise operation * fixed MPerBlock=96 * remove marco for slpitk swtich * add test * add new line at the end of device_gemm_xdl_instance.hpp * remove step hack * seperate split-k instance files * add tunning parameters * change disired grid size to parameters * remove slice length * add desiredgridsize parameter to ckProfiler * add losting file device_gemm_xdl_splitk_instance.hpp * change desired gride size to kbatch * format * format * clean up * add selection of device_instances * clean code * clean code * add small tile size in fp16 nn * test for rocm 4.5 * merge develop * clean * clean * clean * remove no-use code * add padding switch to device_gemm_xdl * add padding switch for ksplit fp32 * clean * clean * add files * rename * Update profiler.cpp * format Co-authored-by: ltqin <letaoqin@amd.com> Co-authored-by: ltqin <letao.qin@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
* add fwd bf16 conv * change tunning parametor * add int8 for conv fwd * remove comments * change tunning parametor for int8 * change init int8 example * add test for conv2d fwd * change device operation file pos because merge develop * fwd int8 use reference * test_conv_fwd use reference * add braket for if statement * rename fwd example name * remove StaticBufferOfVectorTypeV2 * tweak example Co-authored-by: ltqin <letaoqin@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
* init for splitk f16 * a working prototype * debug * perf debug * update example * instances for mk kn * add instances for all layers * clean * clean * add tuning * format * add mn_padding into irregular tile * clean Co-authored-by: Chao Liu <chao.liu2@amd.com>
* add gitignore * host tensor: allow generating sequentially increasing value in a given dimension * gridwise gemm v3r1: allow distinct K0/K1 values for A/B block descriptor - remove dangling header include - modify example gemm_xdl accordingly - infer KPack value from M/NPerXdl - device conv2d fwd: update parameters accordingly for the underlying gridwise gemm v3r1 (API for conv2d fwd stays the same for now until we decide to expose individual K0s for activation and weight) * add LDS data dump utility * profiler: reflect API change for distinct K0/K1 for A/B matrices * profiler: add conflict-free LDS write FP16 kernel instances * fix accidental perf regression * address feedback; cosmetic changes * clang-format for new files * format Co-authored-by: Chao Liu <chao.liu2@amd.com>
… stage for tests (#88) * add docker file and make default target buildable * add Jenkinsfile * remove empty env block * fix package stage * remove render group from docker run * clean up Jenkins file * add cppcheck as dev dependency * update cmake file * Add profiler build stage * add hip_version config file for reduction operator * correct jenkins var name * Build release instead of debug * Update test CMakeLists.txt reorg test dir add test stage * reduce compile threads to prevent compiler crash * add optional debug stage, update second test * remove old test target * fix tests to return proper results and self review * Fix package name and make test run without args * change Dockerfile to ues rocm4.3.1 * remove parallelism from build * Lower paralellism Co-authored-by: Chao Liu <chao.liu2@amd.com>
…API (#92) * start conv2d bwd api * kernel running * add bwd reference * change to no shuffle * fix bwd reference * pass verification * add Filter1x1Stride1Pad0 and start testing * change some tuning parameter * fix test error * add fp16 tuning parameter * add bf16 tuning parameter * add int8 tuning parameters * change fp32 tuning parameter * add bwd to profiler * fix bug for bwd profiler * fix ckProfiler bug * change conv2d_bwd_xdl to fp16 * fix bug in comments * fix precompile id * fix enum conv name * chage _bwd_ to _bwd_data_ * change conv2d_bwd example id * bwd to bwd data * fix prehead * fix MakeDefaultBlock2CTileMap ,import form merge develop * format bwd instance * bwd to bwd data * change name bwd to bwd data * change name bwd to bwd data in example * formate code * change conv2d bwd data id in example * rewrite readme for example * fix CalculateMagicNumbers about div zero * add workaround CK_WORKAROUND_SWDEV_325164 * change test_conf2d_bwd_data show info * format * fix bug for workaround:CK_WORKAROUND_SWDEV_325164 * formate tuning parameters * formate tuning parameters again * formate tuning parameters 3 * formate tuning parameters 4 * remove add function template * format * update comment Co-authored-by: ltqin <letaoqin@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
* add space_filling_curve * cleanup and move space_filling_curve into test * WIP: start refactoring threadwise_transfer_v1r3 * threadwise_copy works but needs further refactoring * add some comments * add SpaceFillingCurve::GetIndices() * minor changes * removed GetIndices; refactored GetDstCoordinateResetStep * add DynamicBuffer::Transfer, but Add is not tested * rebased agaist develop * threadwise_copy_v6r1/v6r2/v6r3 using space-filling curve start to work * minor changes * refactored threadcopy v3r1, v2; removed old implementations * clang-format * cleanup * fix a typo in v6r3 * format Co-authored-by: Chao Liu <chao.liu2@amd.com>
* Add int8 of mk_nk_mn to the ckProfiler * Add example of int8 gemm * Fix typo, use ushort instead of half_t for bfloat16 * replace ushortXXX_t to bhalfXXX_t * rename ushort to bhalf_t * Add bf16 example * Add bf16 gemm to ckProfiler * Fix alignment * Fix typo * Add unit test for gemm_xdl int8 * Add gemm_xdl fp32 unit test * Add gemm_xdl bf16 unit test * fix build * fix build issue due to merge conflict * Fix build * Fix build error Co-authored-by: rocking <chunylai@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
* add wrw reference * start device * raw not split version * run simple example * start to use atomic add * simple transform result correct * first version that can run * fix atomic and set operator choice * add check split-k * format * change input parameter * add pad for t total * rename example index Co-authored-by: ltqin <letaoqin@amd.com>
* fix tests * remove useless file * fix test build * reduce parallelism when compiling * fix test
* Add int8 of mk_nk_mn to the ckProfiler * Add example of int8 gemm * Fix typo, use ushort instead of half_t for bfloat16 * replace ushortXXX_t to bhalfXXX_t * rename ushort to bhalf_t * Add bf16 example * Add bf16 gemm to ckProfiler * Fix alignment * Fix typo * Add unit test for gemm_xdl int8 * Add gemm_xdl fp32 unit test * Add gemm_xdl bf16 unit test * fix build * fix build issue due to merge conflict * Fix build * Fix build error * [What] gemm + relu inference [How] gemm + requant + relu + requant + clamp * clean Co-authored-by: rocking <chunylai@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
… used in threadwise copy) (#111)
* Initial adding of generic reduction * Initial adding of generic reduction ... * Updates to make compiling done * clang-format all files * clang-format some files again * Renaming in profiler/include/profile_reduce.hpp * Updates and make BlockWise cases passed * Updates and make ThreadWise and MultiBlockTwoCall cases passed * Remove the support for MUL and NORM1 reduceOp from the profiler and the device instances * Change to replace the dim0_max_vector_size/dim1_max_vector_size template argument in the device reduce classes * format * adding pooling * added max and average pooling * comment out cout and kernel timing * Tiny simplification in profiler/reduce_profiler.cpp * Add example for reduce_blockwise * Tiny updates * Change to pass the ElementWiseOp from device layer to kernel * Fix the vectorDim and vectorSize in Device layer * Enable vector load on both dim0 and dim1 for Threadwise method * Tiny updates * Change to let the user to pass the preUnaryOp and posUnaryOp * Make pooling example work * split device_reduce_instance into two libraries * Tiny update * Replace nanPropaOpt enum by boolean propagate_nan * Simplification in DeviceReduce layer codes * update build * Change to clarify the difference between ck::half_t and half_float::half * Renaming in all the reduction codes * Add VectorSize as template parameter for device layer * Add BetaIsZero as kernel template and as AccDataType for alpha * print * Small updates for pooling * Updates for host_generic_reduction for reference * Update to make AVG pooling pass * Update to make MAX pooling with indices output pass * fix * add OutDst vector store to threadwise reduction and pooling * tweak * turn off check_indices that caused build issue * refactor pooling * clean up * turn off check_indices for building issue for php-compiler * add more tile size for odd C * tweak conv for odd C * update script * clean up elementwise op * add hack in reduction_operator.hpp to avoid compile error. To fix it, need to use element_wise_op in reduction op * Add OutVectorSize as device and kernel tunable, also update to Elementwise Operations * Move reduce operator mapping to host layer file reduction_operator_mapping.hpp from reduction_operator.hpp * Change to the unary operators * Move the definitions of unary operations to element_wise_operation.hpp * re-org files * Refine in device interfaces and multiblock kernels * Split the reduction configurations into instances for specific methods * Update in getTypeString() of device pool2d * Renaming in host and kernel * Tiny update in profiler/src/profiler.cpp * Uncomment in device_operation/CMakeLists.txt to enable the building of all operations * Make check_indices a templated function to remove some linking issue * Renaming in the profiler reduce module * Add support for double Reduction (but disable MultiblockAtomicAdd for double) * Tiny correction of literal string * Rename DevicePoolFwd to DevicePool2dFwd * Split device_reduce_instance_xxx.cpp files according to the data types to speed up compiling * Add comments for lists of configurations, lists of instances and references of add_reduce_instances_xxx * Remove un-used header file gridwise_generic_reduction_wrapper_common.hpp * Renaming and refining in the Reduction codes * Tiny change in the unary operators * Renaming symbols and files * Renaming symbols in the kernels * Move kernel kernel_set_buffer_value to separate file * Add IndexDataType template parameter for kernels and use int32_t as index data type in device layer * Tiny update in the kernels * Remove definition of sqrtf()/isnan()/abs() for half_t due to some ADL issue * Simplify a helper function in device layer * Tiny adjustment in testing data initialization * Renaming in kernel/device/host * Add two testing scripts for reduction * Refine the Unary operators in element_wise_operation.hpp * Update in the reduce profiler module * Update to the reduction testing scripts * reduce compile parallelism * change CI docker to rocm5.0 * remove unused variables * fix build Co-authored-by: Chao Liu <chao.liu2@amd.com>
* delete obselete files * move files * build * update cmake * update cmake * fix build * reorg examples * update cmake for example and test
* Use thread cluster descriptor and explicit M_K 2d descriptor to simply Blockwise Reduction * Change by replacing ReduceDims by NumReduceDims as Device Reduce interface template parameter * Rename the folder name for the pool2d and reduce examples * Update to reduction test scripts * Add Readme for pool2d_fwd and reduce_blockwise examples * Tiny fix in reduce profiler and tiny update in reduce testing scripts * Tiny fix in testing script profile_reduce_no_index.sh * Tiny change in script/profile_reduce_with_index.sh * Renaming and refining in Reduction profiler/device layer/examples * Renaming and refining in Reduction profiler/device layer/examples * Renaming all NumReduceDims to NumReduceDim
* fixed a corner case in GetCoordinateResetStep * clean * rename num_accesses to num_access Co-authored-by: Chao Liu <chao.liu2@amd.com>
* [What] Separate fixpoint gemm from gemm example [Why] let example of gemm_int8 be pure gemm. [What] 1. Add gemm_requant_relu_requant, 2. Let CDataType be int32 in pure gemm, because no one use int8 CDataType. It is also part of gemm_requant_relu_requant * Fix path * Revise cmakelist due to merge develop Co-authored-by: rocking <chunylai@amd.com>
* fix bwd data filter1strid2 bug * fichangeshort to ck::bhalf_t * reset input to zero Co-authored-by: ltqin <letaoqin@amd.com>
* [What] Separate fixpoint gemm from gemm example [Why] let example of gemm_int8 be pure gemm. [What] 1. Add gemm_requant_relu_requant, 2. Let CDataType be int32 in pure gemm, because no one use int8 CDataType. It is also part of gemm_requant_relu_requant * Fix path * Revise cmakelist due to merge develop * Add gemm fp16 test * Extract PrepareGemmTensor * Extract TestGemm * Add test for different layout * Add 4 layouts of shuffle version of fp32 * Add 4 layouts of shuffle version of int8 * Add 4 layouts of shuffle version of bf16 * replace all DeviceGemmPtr_ with DeviceGemmNoOpPtr to fit naming convension * Add test for non-shuffle verstion of gemm * Fix typo * Print kernel information * Add rest of the fp32 kernel to the test * 1. Add rest of the fp16 device iop. 2. Mark the invalid device operation Co-authored-by: rocking <chunylai@amd.com>
… and int8 to profiler (#120) changed long_index_t to index_t when computing memory offset uncomment other ops in profiler added test for batched_gemm
* Use thread cluster descriptor and explicit M_K 2d descriptor to simply Blockwise Reduction * Change by replacing ReduceDims by NumReduceDims as Device Reduce interface template parameter * Rename the folder name for the pool2d and reduce examples * Update to reduction test scripts * Add Readme for pool2d_fwd and reduce_blockwise examples * Add support for int8_t reduction (ADD/AVG, MIN/MAX/AMAX) * Tiny fix in reduce profiler and tiny update in reduce testing scripts * Tiny fix in testing script profile_reduce_no_index.sh * Tiny fix in testing script profile_reduce_no_index.sh * Add support for bfp16 reduction (using bhalf_t = ushort) * Tiny fix in amd_buffer_addressing.hpp * Tiny change in script/profile_reduce_with_index.sh * Use AccDataType for Beta value and use element_wise::PassThrough * Use type_convert for type converting in host layer reduction * Renaming and refining in Reduction profiler/device layer/examples * Renaming and refining in Reduction profiler/device layer/examples * Renaming all NumReduceDims to NumReduceDim * Fix the leaked type_convert in ThreadwiseTensorSliceTransfer_v2 * Update to testing scripts to add bf16 support * added more static_assert * Remove buggy tunable configurations defined in device_reduce_instance_xxx.hpp * Add static_assert to give compile-time warning for incorrect thread slice-size/vector-size configurations * minor change * Refine and fix (in GetWorkspaceSizeInBytes of MultiBlockPartialReduce) to make int8 completely pass * Tiny renaming in gridwise_2d_reduction_multiblock_partial_reduce.hpp * Tiny fix in script/profile_reduce_no_index.sh * Refine in DeviceReduce layer with regard to using NumInvariantDim/NumReduceDim or InvariantDims/ReduceDims * Generic renaming in host reduction and DeviceReduce layer * Add support for 4-d all dimension reduction in the profiler and add_device_reduce_xxx instances * Use multi-thread and simplification for host Reduction implementation * Add ctest for reduction * Update to clarify the using of data init method in produce_reduce/example_reduce/test_reduce/ * Update to the reduce CTest executables to enable default testing behavior when no command argument * Renaming Co-authored-by: Jianfeng yan <jfyan008@gmail.com>
* init of grouped_gemm * 2 gemm test * perf test * clean * wrap desc into a struct * test cast static_arr to pointer * add ptr to GemmDesc * add grouped gemm profiler * fixed mem issue with unique_ptr * clean * clean * finished ckprofiler * Update README.md * readme * fixed readme * add example * improve code * fixed comments: reserve, seperate ptr and gemm_shapes * merge group and non-group * fixed comments: replace push_back with emplace_back to avoid copy constructor * fixed comments: unified blk2ctile; add test * ci fix * fixed ci * fixed ci * fixed ci
* add bf16 for batched gemm * batched_gemm_bf16 works * recover accidently changed files
…iseReduction api to simply the kernels
hyoon1
pushed a commit
to hyoon1/composable_kernel
that referenced
this pull request
Mar 19, 2026
* Fused Bwd (ROCm#137) * Fused with Good perf and stride fixed Fix fused bugs isolate failing case fix bug bring back test cases rm split impl in fused use exp2 is global variable now try oom fix save make fused the default limit to reproduce failure return default to split fix head size bug use exp2 back to true * new grid * BLK_SLICE_FACTOR = 1 * add tflops * new commit * test in parrallel * strides added by jusson * disable alibi * fix bugs again * default to fused * add bwd options for varlen * backend filter * default to jingning and batch 4 * best fwd config * fix TRITON_PRINT_AUTOTUNING flag bug * tune * Tuning fwd prefill * add if else * use flag * Minor mask fix * FLIP GRID * use best config for default * print when autotuning * test bfloat16 * fix k and v stride bugs * skip bfloat16 * test kvpacked * disable internal tests * pick default config based on arch * Add alibi in the new bwd kernel (ROCm#139) * enable alibi for jinging kernel enable alibi for jinging kernel match * save bad configs * fix alibi and causal bug * disable autotune by default * auto tune when benching is good * set best config * remove env var * Update amd_tests.yml * upgrad to triton==3.3.0 * increase shm * use 64 x 64 for now * save * handle 1d alibi * Add fp8 to fused kernel (ROCm#140) * fp8 stuff find test case compute delta fp8 basic fp8 config passing non causal path works * isolate bad case * fix fp8 bug * didnot fix fp8 bug * back to failing test * fp8 tests passing * skip * skip ref tests --------- Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com> * head, seq, batch (ROCm#141) * Fix keys (ROCm#144) * save * rm keys * fix keys * use GHA_RENDER_DEVICES * normal docker * Pad LSE (ROCm#148) * add round multiple * fix fwd * backward fix * use rounded lse flag * passing ROUNDED_LSE * default is new rounded mode * rename to fused_atmoics and fused_no_atomics * add test for torch_compile * add varlen torch compile test * add old one kernel for ref * fix varlen mismatch bug * fix shape issue in varlen but mismatch * sync torch compile kernel launch * simple varlen test * add debug code * rm old * ignore old impls * DEBUG flag works in interface only * ref uses the righ shape for lse * rm oldest bwd kernel * fix typo * fix varlen bug * fix bug. Get info from q for now * simple shape and stride checkout * add more tests * test kvcache * kvcache safe * match case * fix segfault due to bad return_softmax * run bench * run seperate for the main functions * just output benchmark * default csv format and time stamp files * non verbsoe bench * Sliding Window Forward (ROCm#151) * Compress SWA work test case set up debug inputs add fwd ref one mask ref fwd first pass save ref doesnot work for bigger seqlens save new version some causal cases failing found bad cases working new attn new atten works new attn_fwd works reorg n_extra_tokens use seqlen_delta_qk ref fwd works add sliding window to bwd ref test kvcache decode ref work with everything except sliding window add debug code for 12 failing sliding window cases for decode attention_decode_forward_ref_impl mostly works except for alibi fix alibi in attention_decode_forward_ref_impl ref works with normal, varlen & kvcache move stuff around figure out masking old attn inner two inner functions remove load_fn do Lk - Lq like ref unify IS_CAUSAL code in epilogue clean up add args rm inference stuff simplify compute_masking simpler compute mask stub out returning front masking variables remove pointer pass compute ptrs inloop compute block min and max window stub inside inner mask loop trying to use attn_fwd_mask causes issues fix compiler bug when front masking gen specifc types add sliding window and debug statements use identity for v add more taste cases add comments save use k_max_token for clarity disable debug configs basic NON-CAUSAL SLIDING WINDOW non causal sliding window works on the all the shapes non sliding window working in fwd clean up fused bwd seperate old fwd_prefill move configs to utils.py * fix bwd ref bug * skip local cases so that fa output * no sliding window causal green * add backward test skip for sliding window * clean reduce in fwd_kvcache. no is_CASUAL branching * add kvcache masking * kvcache working * fix some bugs in test.py * clean up * Fix Device Segfault (ROCm#152) * Compress segfault work fix backward segfault rework offset ignore .profile ignore .analysis save * assert the kernel launch device and tensor devices are the same * fix failing asserts * add asserts to fwd * Fix SDMASK bug * Log triton, torch and fa version * Fix fp8 import issues * fix docs (ROCm#154) * Sliding Window block classification logic (ROCm#155) * add aiter code * remove aiter stuff * sliding window non causal masking works * causal and sliding window block masking * extract common * clean up typo * helper for swa * ignore .amd * fix last block bug * Enable FA V3 (ROCm#157) * Compress PA work narrow pa test ref works on most cases inplace ref with new_kv inplace paged attention add pa ref save pa basic paged works save fix swa + causal in pa. Also new_kv only on pa path passing build fa v3 import interface from fa v3 copy fa tests use v3 api clean up rename to match old test support different head sizes remove fp8 basisc passing v3 cases test_flash_attn_varlen_output v3 working isolate bad case for kvcache case passing save use decode is seqused/ cacheseql is given use decode if not varlen basci kvcache v3 working kvcache enable more cases detect kvcache case if seqused_q is non and sequese_k is not None skip failing test find fp8 failing case mha fp8 works fix fp8 MQA/GQA bug clean up more clean up clean up more don't need fp8 dead code remove train code with fp8 stuff fp8 working in kvcache paged + fp8 seems to be working new_kv allowed * clean up * skip hopper race test * clean up more * fix paged + alibi * similar inner paged api * unify _attn_fwd_inner * AITER integration (ROCm#159) * clean up v2 interface * assert fp8 scale shapes * rotary working * move rotary to impl layers * remove einops * enable rotarry in v3 * create interface * fix descale assert * unify bwd * lint from aiter * clean fp8 api * add api change * assert shapes for v2 * remove ref and bench.py * remove metadata class and clean up * bwd_prefill * one bwd.py * rename * lint * add bwd_change (ROCm#156) * Tune FP8 Perf (ROCm#160) * check cu count for gfx942 * create get_cu_count * update repo root * update forward tune * clean up load * use float8_e4m3fnuz * save * show bwd mode * recommend fp8 * use torch.float32 for fp8 kernel * add both best fp16 and fp8 config * tune fp8 backward * descale factors should be b, hk * fp8 bwd working on all primus configs * tune bwd configs * fa v3 tests passing * better warning * clean up bwd launcher * v3 passing * tune more * improve perf * clean up * lint * clean * start tuning gfx950 * tune non causal path * fix bug * save * Skip configs where BLOCK_M2 % BLOCK_N2 != 0 * skip more * stop tuning * fix varlen bug * fix dropout & causal/swa segfault * update the to machine new changes * save * fix more bugs * remove random seed * clean up * update readme * print tensor stats for debug * disable sliding window tests * add rdna configs * fix k partial bug * fix block_size_n bug * fix type check bug --------- Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com> Co-authored-by: Tianxing Wu <tianxing.wu@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
[M, K]static buffer to[M]static buffer