From 8ec76ac23028ab12b6f9f8c727eadfbefc9b63af Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Tue, 11 Feb 2025 17:03:50 -0700 Subject: [PATCH 01/69] - Add SBCC partial pass kernel generator. --- library/src/CMakeLists.txt | 2 + library/src/device/generator/generator.h | 4 + library/src/device/generator/stockham_gen.cpp | 6 +- library/src/device/generator/stockham_gen.h | 3 + .../src/device/generator/stockham_gen_rr.h | 1 + .../src/device/generator/stockham_pp_gen_cc.h | 1109 +++++++++++++++++ .../src/device/generator/stockham_pp_gen_rr.h | 35 + library/src/device/kernel-generator.py | 2 +- library/src/device/kernels/common.h | 7 + library/src/include/rtc_stockham_gen.h | 2 + library/src/rocfft_aot_helper.cpp | 13 +- library/src/rocfft_kernel_config_search.cpp | 2 + library/src/rtc_stockham_gen.cpp | 83 +- library/src/rtc_stockham_kernel.cpp | 24 +- 14 files changed, 1278 insertions(+), 15 deletions(-) create mode 100644 library/src/device/generator/stockham_pp_gen_cc.h create mode 100644 library/src/device/generator/stockham_pp_gen_rr.h diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index b625209dd32..0be67d2f95e 100644 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -158,9 +158,11 @@ set( kgen_logic_files ${CMAKE_SOURCE_DIR}/library/src/device/generator/stockham_gen_2d.h ${CMAKE_SOURCE_DIR}/library/src/device/generator/stockham_gen_base.h ${CMAKE_SOURCE_DIR}/library/src/device/generator/stockham_gen_cc.h + ${CMAKE_SOURCE_DIR}/library/src/device/generator/stockham_pp_gen_cc.h ${CMAKE_SOURCE_DIR}/library/src/device/generator/stockham_gen_cr.h ${CMAKE_SOURCE_DIR}/library/src/device/generator/stockham_gen_rc.h ${CMAKE_SOURCE_DIR}/library/src/device/generator/stockham_gen_rr.h + ${CMAKE_SOURCE_DIR}/library/src/device/generator/stockham_pp_gen_rr.h ${CMAKE_SOURCE_DIR}/library/src/device/generator/bluestein_generator.h ${CMAKE_SOURCE_DIR}/library/src/rtc_compile.cpp ${CMAKE_SOURCE_DIR}/library/src/include/rtc_stockham_gen.h diff --git a/library/src/device/generator/generator.h b/library/src/device/generator/generator.h index 7dbcc40c11f..a4d80ee83fd 100644 --- a/library/src/device/generator/generator.h +++ b/library/src/device/generator/generator.h @@ -729,6 +729,10 @@ static Assign AddAssign(const Variable& lhs, const Expression& rhs) { return Assign(lhs, rhs, "+="); } +static Assign DivideAssign(const Variable& lhs, const Expression& rhs) +{ + return Assign(lhs, rhs, "/="); +} static Assign MultiplyAssign(const Variable& lhs, const Expression& rhs) { return Assign(lhs, rhs, "*="); diff --git a/library/src/device/generator/stockham_gen.cpp b/library/src/device/generator/stockham_gen.cpp index b76aca04621..1ea6da29f09 100644 --- a/library/src/device/generator/stockham_gen.cpp +++ b/library/src/device/generator/stockham_gen.cpp @@ -33,6 +33,8 @@ using namespace std::placeholders; #include "stockham_gen_cr.h" #include "stockham_gen_rc.h" #include "stockham_gen_rr.h" +#include "stockham_pp_gen_cc.h" +#include "stockham_pp_gen_rr.h" #include "stockham_gen_2d.h" @@ -323,14 +325,14 @@ int main() ++arg; factors = parse_uints_csv(*arg); - StockhamGeneratorSpecs specs(factors, factors2d, precisions, workgroup_size, scheme); + StockhamGeneratorSpecs specs(factors, {}, factors2d, precisions, workgroup_size, scheme); specs.half_lds = half_lds; specs.direct_to_from_reg = direct_to_from_reg; specs.threads_per_transform = threads_per_transform.front(); // second dimension for 2D_SINGLE - StockhamGeneratorSpecs specs2d(factors2d, factors, precisions, workgroup_size, scheme); + StockhamGeneratorSpecs specs2d(factors2d, {}, factors, precisions, workgroup_size, scheme); if(!threads_per_transform.empty()) specs2d.threads_per_transform = threads_per_transform.back(); diff --git a/library/src/device/generator/stockham_gen.h b/library/src/device/generator/stockham_gen.h index a0fb18204b1..62497ec002d 100644 --- a/library/src/device/generator/stockham_gen.h +++ b/library/src/device/generator/stockham_gen.h @@ -30,11 +30,13 @@ struct StockhamGeneratorSpecs { StockhamGeneratorSpecs(const std::vector& factors, + const std::vector& factors_pp, const std::vector& factors2d, const std::vector& precisions, unsigned int workgroup_size, const std::string& scheme) : factors(factors) + , factors_pp(factors_pp) , factors2d(factors2d) , precisions(precisions) , length(product(factors.begin(), factors.end())) @@ -45,6 +47,7 @@ struct StockhamGeneratorSpecs } std::vector factors; + std::vector factors_pp; std::vector factors2d; std::vector precisions; // mapped from rocfft_precision unsigned int length; diff --git a/library/src/device/generator/stockham_gen_rr.h b/library/src/device/generator/stockham_gen_rr.h index 2f8614436b7..a508fa1a947 100644 --- a/library/src/device/generator/stockham_gen_rr.h +++ b/library/src/device/generator/stockham_gen_rr.h @@ -18,6 +18,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +#pragma once #include "../../../../shared/arithmetic.h" #include "stockham_gen_base.h" diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h new file mode 100644 index 00000000000..ed0962e7826 --- /dev/null +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -0,0 +1,1109 @@ +// Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once +#include "stockham_gen_cc.h" + +struct StockhamPartialPassKernelCC : public StockhamKernelCC +{ + explicit StockhamPartialPassKernelCC(const StockhamGeneratorSpecs& specs, + bool largeTwdBatchIsTransformCount) + : StockhamKernelCC(specs, largeTwdBatchIsTransformCount, false) + { + large_twiddle_steps.decl_default = 3; + large_twiddle_base.decl_default = 8; + + // TODO: Address and test all "lds_linear=false" cases + + // TODO: revisit this. Test with factors_pp.size() > 1 + max_factor_pp = *std::max_element(specs.factors_pp.begin(), specs.factors_pp.end()); + + // TODO: transforms_per_block_pp or threads_per_transform? Revisit all usages + transforms_per_block_pp = transforms_per_block / max_factor_pp; + } + + unsigned int transforms_per_block_pp; + unsigned int max_factor_pp; + + Variable thread_lds{"thread_lds", "unsigned int"}; + Variable idx_lds{"idx_lds", "unsigned int"}; + Variable stride_lds_pp{"stride_lds_pp", "unsigned int"}; + Variable offset_lds_pp{"offset_lds_pp", "unsigned int"}; + + Variable tid_hor_lds{"tid_hor_lds", "unsigned int"}; + Variable offfset_unbatched{"offfset_unbatched", "unsigned int"}; + Variable tid_hor_pp{"tid_hor_pp", "unsigned int"}; + Variable offset_tid_hor{"offset_tid_hor", "unsigned int"}; + Variable offset_pp{"offset_pp", "unsigned int"}; + Variable thread_new{"thread_new", "unsigned int"}; + Variable batch_new{"batch_new", "unsigned int"}; + + Variable thread_idx{"thread_idx", "unsigned int"}; + Variable block_idx{"block_idx", "unsigned int"}; + + Variable thread_in_device_twd{"thread_in_device_twd", "unsigned int"}; + + Variable global_idx{"global_idx", "unsigned int"}; + Variable transpose_idx{"transpose_idx", "unsigned int"}; + + StatementList load_global_generator(unsigned int h, + unsigned int hr, + unsigned int width, + unsigned int dt, + Expression guard, + bool intrinsic, + Expression pred) const + { + if(hr == 0) + hr = h; + StatementList load; + + for(unsigned int w = 0; w < width; ++w) + { + auto tid = Parens{thread + dt + h * threads_per_transform}; + auto idx = Parens{tid + w * length / width}; + + if(intrinsic) + { + // no need to and with trivial "true" + load += Assign{ + R[hr * width + w], + IntrinsicLoad{ + {buf, + tid_hor * stride[1] + Parens{Expression{idx}} * stride0, + offset, + std::holds_alternative(guard) ? pred : (guard && pred)}}}; + } + else + { + load += Assign{ + R[hr * width + w], + LoadGlobal{buf, + offset + tid_hor * stride[1] + Parens{Expression{idx}} * stride0}}; + } + } + return load; + } + + StatementList store_pp_step_1_2_lds_generator( + unsigned int h, unsigned int hr, unsigned int width, unsigned int dt, Expression guard) + { + if(hr == 0) + hr = h; + StatementList work; + + for(unsigned int w = 0; w < width; ++w) + //TODO: lstride not used here, address to have input/output strides working + work += Assign(lds_complex[offset_lds + (w * stride_lds)], R[w]); + + return work; + } + + Function generate_lds_from_reg_output_pp_step_1_2_function() + { + std::string function_name + = "lds_from_reg_output_pp_step_1_2_length" + std::to_string(length) + "_device"; + + Function f{function_name}; + f.templates = device_lds_reg_inout_templates(); + f.arguments = device_lds_reg_inout_pp_arguments(); + f.qualifier = "__device__"; + + StatementList& body = f.body; + body += Declaration{ + lstride, Ternary{Parens{stride_type == "SB_UNIT"}, Parens{1}, Parens{stride_lds}}}; + + auto store_lds = std::mem_fn(&StockhamPartialPassKernelCC::store_pp_step_1_2_lds_generator); + // last pass of store (full) + unsigned int width = factors.back(); + float height = static_cast(length) / width / threads_per_transform; + body += SyncThreads(); + body += add_work(std::bind(store_lds, this, _1, _2, _3, _4, _5), + width, + height, + ThreadGuardMode::NO_GUARD); + return f; + } + + StatementList load_lds_step_1_2_generator( + unsigned int h, unsigned int hr, unsigned int width, unsigned int dt, Expression guard) + { + if(hr == 0) + hr = h; + StatementList work; + + for(unsigned int w = 0; w < width; ++w) + //TODO: lstride not used here, address to have input/output strides working + work += Assign(R[w], lds_complex[offset_lds + (w * stride_lds)]); + + return work; + } + + ArgumentList device_lds_reg_inout_pp_arguments() + { + ArgumentList args{R, lds_complex, stride_lds, offset_lds}; + return args; + } + + std::vector device_lds_reg_inout_pp_device_call_arguments() + { + return {R, lds_complex, stride_lds_pp, offset_lds_pp}; + } + + Function generate_lds_to_reg_input_step_1_2_function() + { + std::string function_name + = "lds_to_reg_input_pp_step_1_2_length" + std::to_string(length) + "_device"; + + Function f{function_name}; + f.templates = device_lds_reg_inout_templates(); + f.arguments = device_lds_reg_inout_pp_arguments(); + f.qualifier = "__device__"; + + StatementList& body = f.body; + + auto load_lds = std::mem_fn(&StockhamPartialPassKernelCC::load_lds_step_1_2_generator); + // first pass of load (full) + unsigned int width = factors[0]; + float height = static_cast(length) / width / threads_per_transform; + body += SyncThreads(); + body += add_work(std::bind(load_lds, this, _1, _2, _3, _4, _5), + width, + height, + ThreadGuardMode::NO_GUARD); + + return f; + } + + StatementList calculate_offsets() override + { + Variable d{"d", "int"}; + Variable index_along_d{"index_along_d", "size_t"}; + Variable remaining{"remaining", "size_t"}; + Variable plength{"plength", "size_t"}; + Variable global_stride_in{"global_stride_in", "const size_t"}; + Variable global_stride_out{"global_stride_out", "const size_t"}; + + StatementList stmts; + stmts += Declaration{tile_index}; + stmts += Declaration{num_of_tiles}; + + stmts += LineBreak{}; + stmts += CommentLines{"calculate offset for each tile:", + " tile_index now means index of the tile along dim1", + " num_of_tiles now means number of tiles along dim1"}; + stmts += Declaration{plength, 1}; + stmts += Declaration{remaining}; + stmts += Declaration{index_along_d}; + stmts += Assign{num_of_tiles, (lengths[1] - 1) / transforms_per_block_pp + 1}; + stmts += Assign{plength, num_of_tiles}; + stmts += Assign{tile_index, block_id % num_of_tiles}; + //TODO figure out mod 128 for other lengths + // mod 128 required to work with nbatch > 1 + stmts += Assign{remaining, (block_id % 128) / num_of_tiles}; + stmts += Assign{offset, tile_index * transforms_per_block_pp * stride[1]}; + stmts += For{d, + 2, + d < dim, + 1, + {Assign{plength, plength * lengths[d]}, + Assign{index_along_d, remaining % lengths[d]}, + Assign{remaining, remaining / lengths[d]}, + Assign{offset, offset + index_along_d * stride[d]}}}; + + stmts += LineBreak{}; + + stmts += Assign{batch, block_id / plength}; + //stmts += Assign{offset, offset + batch * stride[dim]}; + if(!direct_to_from_reg) + { + // TODO: figure out this branch + stmts + += Assign{transform, + tile_index * transforms_per_block_pp + thread_id / threads_per_transform}; + stmts += Assign{stride_lds, (length + get_lds_padding())}; + + // TODO: figure out factor 4 for other lengths + stmts += MultiplyAssign(stride_lds, Literal{max_factor_pp}); + + //stmts += Assign{offset_lds, stride_lds * (transform % transforms_per_block_pp)}; + } + else + { + stmts += Assign{ + transform, + Ternary{lds_linear, + tile_index * transforms_per_block_pp + thread_id / threads_per_transform, + tile_index * transforms_per_block_pp + + thread_id % transforms_per_block_pp}}; + stmts += Assign{stride_lds, + Ternary{lds_linear, + length + get_lds_padding(), + transforms_per_block_pp + get_lds_padding()}}; + // TODO: figure out factor 4 for other lengths + stmts += MultiplyAssign(stride_lds, Literal{max_factor_pp}); + + // stmts += Assign{offset_lds, + // Ternary{lds_linear, + // stride_lds * (transform % transforms_per_block_pp), + // thread_id % transforms_per_block_pp}}; + } + + stmts += Declaration{ + in_bound, + Ternary{ + Parens((tile_index + 1) * transforms_per_block_pp > lengths[1]), "false", "true"}}; + + // [dim0, dim1] = [tid_ver, tid_hor] : + // each thread reads position [tid_ver, tid_hor], [tid_ver+step_height*1, tid_hor] , [tid_ver+step_height*2, tid_hor]... + // tid_ver walks the columns; tid_hor walks the rows + stmts += Declaration{thread, thread_id / transforms_per_block_pp}; + stmts += Declaration{tid_hor, thread_id % transforms_per_block_pp}; + + stmts += Declaration{thread_lds, thread_id / transforms_per_block_pp}; + stmts += Declaration{tid_hor_lds, thread_id % transforms_per_block_pp}; + + // TODO: figure out factor 4 here for other lengths + stmts += Declaration( + tid_hor_pp, thread_id % transforms_per_block_pp + length * (thread % max_factor_pp)); + stmts += Declaration(thread_new, thread_id / (transforms_per_block_pp * max_factor_pp)); + stmts += Declaration(batch_new, block_id / (plength / max_factor_pp)); + + stmts += Declaration(thread_idx, thread_id); + stmts += Declaration(block_idx, block_id); + + // TODO: figure out factor 192 here for other lengths + stmts += Declaration( + offset_pp, offset + Parens(offset / length) * Literal{192} + batch_new * stride[dim]); + stmts += Declaration(offset_tid_hor, offset_pp + tid_hor_pp * stride[1]); + + // TODO: figure out factor 4 here for other lengths + if(!direct_to_from_reg) + stmts += Assign{transform, + tile_index * transforms_per_block_pp + + thread_id / (threads_per_transform * max_factor_pp)}; + else + stmts += Assign{transform, + Ternary{lds_linear, + tile_index * transforms_per_block_pp + + thread_id / (threads_per_transform * max_factor_pp), + tile_index * transforms_per_block_pp + + thread_id % transforms_per_block_pp}}; + + stmts += Assign{offset_lds, + Ternary{lds_linear, + stride_lds * (transform % transforms_per_block_pp), + thread_id % transforms_per_block_pp}}; + + return stmts; + } + + StatementList load_from_global(bool load_registers) override + { + StatementList stmts; + StatementList tmp_stmts; + Expression pred{tile_index * transforms_per_block_pp + tid_hor < lengths[1]}; + + if(!load_registers) + { + auto stripmine_w = transforms_per_block; + auto stripmine_h = workgroup_size / stripmine_w; + + auto offset_tile_rbuf + = [&](unsigned int i) { return (thread_new + i * stripmine_h) * stride0; }; + auto offset_tile_wlds = [&](unsigned int i) { + return tid_hor_lds * stride_lds + + (thread_lds + i * stripmine_h * max_factor_pp) * 1; + }; + + for(unsigned int i = 0; i < length / stripmine_h; ++i) + tmp_stmts += Assign{lds_complex[offset_tile_wlds(i)], + LoadGlobal{buf, offset_tid_hor + offset_tile_rbuf(i)}}; + + stmts += CommentLines{ + "no intrinsic when load to lds. FIXME- check why use nested branch is better"}; + stmts += If{in_bound, tmp_stmts}; + stmts += If{Not{in_bound}, {If{pred, tmp_stmts}}}; + // stmts += Else{{If{pred, tmp_stmts}}}; // FIXME: Need to check with compiler team. + } + else + { + // TODO: Figure out this branch + StatementList intrinsic_stmts; + StatementList non_intrinsic_stmts; + + unsigned int width = factors[0]; + auto height = static_cast(length) / width / threads_per_transform; + + auto load_global = std::mem_fn(&StockhamPartialPassKernelCC::load_global_generator); + intrinsic_stmts += CommentLines{"use intrinsic load"}; + intrinsic_stmts += CommentLines{"evaluate all flags as one rw argument"}; + intrinsic_stmts += add_work(std::bind(load_global, + this, + _1, + _2, + _3, + _4, + _5, + true, + Expression{Parens(in_bound || pred)}), + width, + height, + ThreadGuardMode::GURAD_BY_FUNC_ARG, + true); + + tmp_stmts += add_work( + std::bind(load_global, this, _1, _2, _3, _4, _5, false, Expression{in_bound}), + width, + height, + ThreadGuardMode::GUARD_BY_IF, + true); + non_intrinsic_stmts += CommentLines{"can't use intrinsic load"}; + non_intrinsic_stmts += If{in_bound, tmp_stmts}; + non_intrinsic_stmts += If{!in_bound, {If{pred, tmp_stmts}}}; + + stmts += If{intrinsic_mode != "IntrinsicAccessType::DISABLE_BOTH", intrinsic_stmts}; + stmts += Else{non_intrinsic_stmts}; + // stmts += Else{{If{in_bound, tmp_stmts}}}; + } + + return stmts; + } + + StatementList store_global_generator(unsigned int h, + unsigned int hr, + unsigned int width, + unsigned int dt, + Expression guard, + unsigned int cumheight, + bool intrinsic, + Expression pred) + { + if(hr == 0) + hr = h; + StatementList work; + for(unsigned int w = 0; w < width; ++w) + { + auto tid = Parens{thread + dt + h * threads_per_transform}; + auto idx + = Parens{tid / cumheight} * (width * cumheight) + tid % cumheight + w * cumheight; + + if(intrinsic) + { + // no need to and with trivial "true" + work += IntrinsicStore{buf, + tid_hor * stride[1] + Parens{Expression{idx}} * stride0, + offset, + R[hr * width + w], + std::holds_alternative(guard) ? pred + : (guard && pred)}; + } + else + { + work + += StoreGlobal{buf, + offset + tid_hor * stride[1] + Parens{Expression{idx}} * stride0, + R[hr * width + w]}; + } + } + return work; + } + + StatementList store_to_global(bool store_registers) override + { + StatementList stmts; + StatementList tmp_stmts; + Expression pred{tile_index * transforms_per_block_pp + tid_hor < lengths[1]}; + + if(!store_registers) + { + auto stripmine_w = transforms_per_block; + auto stripmine_h = workgroup_size / stripmine_w; + + // TODO: stride[1] not being handle here, address this to have output strides working + auto offset_tile_wbuf = [&](unsigned int i) { + return offset_tid_hor + (thread_new + i * stripmine_h) * stride0; + }; + auto offset_tile_rlds = [&](unsigned int i) { + return tid_hor_lds * stride_lds + + (thread_lds + i * stripmine_h * max_factor_pp) * 1; + }; + + for(unsigned int i = 0; i < length / stripmine_h; ++i) + tmp_stmts += StoreGlobal{ + buf, + CallExpr{"local_transpose_pp_length" + std::to_string(length) + "_device", + {offset_tile_wbuf(i)}}, + lds_complex[offset_tile_rlds(i)]}; + + stmts += CommentLines{ + "no intrinsic when store from lds. FIXME- check why use nested branch is better"}; + stmts += If{in_bound, tmp_stmts}; + stmts += If{Not{in_bound}, {If{pred, tmp_stmts}}}; + // stmts += Else{{If{pred, tmp_stmts}}}; // FIXME: Need to check with compiler team. + } + else + { + // TODO: figure out this branch + StatementList intrinsic_stmts; + StatementList non_intrinsic_stmts; + + auto width = factors.back(); + auto cumheight = product(factors.begin(), factors.begin() + (factors.size() - 1)); + auto height = static_cast(length) / width / threads_per_transform; + + auto store_global = std::mem_fn(&StockhamKernelCC::store_global_generator); + intrinsic_stmts += CommentLines{"use intrinsic store"}; + intrinsic_stmts += add_work(std::bind(store_global, + this, + _1, + _2, + _3, + _4, + _5, + cumheight, + true, + Expression{Parens(in_bound || pred)}), + width, + height, + ThreadGuardMode::GURAD_BY_FUNC_ARG); + + tmp_stmts += add_work( + std::bind( + store_global, this, _1, _2, _3, _4, _5, cumheight, false, Expression{in_bound}), + width, + height, + ThreadGuardMode::GUARD_BY_IF); + non_intrinsic_stmts += CommentLines{"can't use intrinsic store"}; + non_intrinsic_stmts += If{in_bound, tmp_stmts}; + non_intrinsic_stmts += If{!in_bound, {If{pred, tmp_stmts}}}; + + stmts += If{intrinsic_mode == "IntrinsicAccessType::ENABLE_BOTH", intrinsic_stmts}; + stmts += Else{non_intrinsic_stmts}; + // stmts += Else{{If{in_bound, {If{pred, tmp_stmts}}}}}; + } + + return stmts; + } + + StatementList load_lds_pp_generator(unsigned int h, + unsigned int hr, + unsigned int width, + unsigned int dt, + Expression guard, + Component component) + { + if(hr == 0) + hr = h; + StatementList work; + + for(unsigned int w = 0; w < width; ++w) + { + const auto tid = Parens{thread + dt + h * threads_per_transform}; + const auto idx = offset_lds + (tid + w * (length / width) * max_factor_pp) * lstride; + work += Assign(l_offset, idx); + + switch(component) + { + case Component::REAL: + work += Assign(R[hr * width + w].x(), lds_real[l_offset]); + break; + case Component::IMAG: + work += Assign(R[hr * width + w].y(), lds_real[l_offset]); + break; + case Component::BOTH: + work += Assign(R[hr * width + w], lds_complex[l_offset]); + break; + } + } + + return work; + } + + StatementList load_lds_generator(unsigned int h, + unsigned int hr, + unsigned int width, + unsigned int dt, + Expression guard, + Component component) + { + if(hr == 0) + hr = h; + StatementList work; + + for(unsigned int w = 0; w < width; ++w) + { + const auto tid = Parens{thread + dt + h * threads_per_transform}; + const auto idx = offset_lds + (tid + w * (length / width) * max_factor_pp) * lstride; + work += Assign(l_offset, idx); + + switch(component) + { + case Component::REAL: + work += Assign(R[hr * width + w].x(), lds_real[l_offset]); + break; + case Component::IMAG: + work += Assign(R[hr * width + w].y(), lds_real[l_offset]); + break; + case Component::BOTH: + work += Assign(R[hr * width + w], lds_complex[l_offset]); + break; + } + } + + return work; + } + + Function generate_lds_to_reg_input_pp_function() + { + std::string function_name + = "lds_to_reg_input_pp_length" + std::to_string(length) + "_device"; + + Function f{function_name}; + f.templates = device_lds_reg_inout_templates(); + f.arguments = device_lds_reg_inout_arguments(); + f.qualifier = "__device__"; + + StatementList& body = f.body; + body += Declaration{ + lstride, Ternary{Parens{stride_type == "SB_UNIT"}, Parens{1}, Parens{stride_lds}}}; + + body += Declaration{l_offset}; + + auto load_lds = std::mem_fn(&StockhamPartialPassKernelCC::load_lds_pp_generator); + // first pass of load (full) + unsigned int width = factors[0]; + float height = static_cast(length) / width / threads_per_transform; + body += SyncThreads(); + body += add_work(std::bind(load_lds, this, _1, _2, _3, _4, _5, Component::BOTH), + width, + height, + ThreadGuardMode::NO_GUARD); + + return f; + } + + StatementList store_lds_pp_generator(unsigned int h, + unsigned int hr, + unsigned int width, + unsigned int dt, + Expression guard, + Component component, + unsigned int cumheight) + { + if(hr == 0) + hr = h; + StatementList work; + + for(unsigned int w = 0; w < width; ++w) + { + const auto tid = thread + dt + h * threads_per_transform; + const auto idx = offset_lds + + (Parens{tid / (max_factor_pp * cumheight)} * (width * cumheight) + + tid % (max_factor_pp * cumheight) + w * max_factor_pp * cumheight) + * lstride; + work += Assign(l_offset, idx); + + switch(component) + { + case Component::REAL: + work += Assign(lds_real[l_offset], R[hr * width + w].x()); + break; + case Component::IMAG: + work += Assign(lds_real[l_offset], R[hr * width + w].y()); + break; + case Component::BOTH: + work += Assign(lds_complex[l_offset], R[hr * width + w]); + break; + } + } + + return work; + } + + StatementList store_lds_generator(unsigned int h, + unsigned int hr, + unsigned int width, + unsigned int dt, + Expression guard, + Component component, + unsigned int cumheight) + { + if(hr == 0) + hr = h; + StatementList work; + + for(unsigned int w = 0; w < width; ++w) + { + const auto tid = thread + dt + h * threads_per_transform; + const auto idx + = offset_lds + + (Parens{tid / (cumheight * max_factor_pp)} * (width * cumheight * max_factor_pp) + + tid % (cumheight * max_factor_pp) + w * cumheight * max_factor_pp) + * lstride; + work += Assign(l_offset, idx); + + switch(component) + { + case Component::REAL: + work += Assign(lds_real[l_offset], R[hr * width + w].x()); + break; + case Component::IMAG: + work += Assign(lds_real[l_offset], R[hr * width + w].y()); + break; + case Component::BOTH: + work += Assign(lds_complex[l_offset], R[hr * width + w]); + break; + } + } + + return work; + } + + Function generate_lds_from_reg_output_pp_function() + { + std::string function_name + = "lds_from_reg_output_pp_length" + std::to_string(length) + "_device"; + + Function f{function_name}; + f.templates = device_lds_reg_inout_templates(); + f.arguments = device_lds_reg_inout_arguments(); + f.qualifier = "__device__"; + + StatementList& body = f.body; + body += Declaration{ + lstride, Ternary{Parens{stride_type == "SB_UNIT"}, Parens{1}, Parens{stride_lds}}}; + + body += Declaration{l_offset}; + + auto store_lds = std::mem_fn(&StockhamPartialPassKernelCC::store_lds_pp_generator); + // last pass of store (full) + unsigned int width = factors.back(); + float height = static_cast(length) / width / threads_per_transform; + unsigned int cumheight = product(factors.begin(), factors.end() - 1); + body += SyncThreads(); + body += add_work(std::bind(store_lds, this, _1, _2, _3, _4, _5, Component::BOTH, cumheight), + width, + height, + ThreadGuardMode::GUARD_BY_IF); + return f; + } + + Function generate_local_transpose_pp_function() + { + std::string function_name + = "local_transpose_pp_length" + std::to_string(length) + "_device"; + + Function f{function_name}; + f.arguments = ArgumentList{global_idx}; + f.return_type = "unsigned int"; + f.qualifier = "__device__"; + + StatementList& body = f.body; + + // TODO: figure out these factors for other lengths + auto factor_transpose_1 = (length * length) / max_factor_pp; + auto factor_transpose_2 = length * max_factor_pp; + auto factor_transpose_3 = length * length; + auto factor_transpose_4 = length * length - length; + auto factor_transpose_5 = length * length * length; + + body += Declaration{transpose_idx, global_idx % factor_transpose_5}; + body += Assign{transpose_idx, + Parens((transpose_idx % length) + + Parens(Parens(transpose_idx % factor_transpose_1) / length) + * factor_transpose_1) + % factor_transpose_4 + + (Parens(transpose_idx / factor_transpose_1) * factor_transpose_2) + + Parens(transpose_idx / factor_transpose_3) + * (factor_transpose_3 - factor_transpose_1)}; + + // TODO: clean-up this expression: global_idx / factor_transpose_5 * factor_transpose_5 + body += Assign{transpose_idx, + transpose_idx + global_idx / factor_transpose_5 * factor_transpose_5}; + + body += ReturnExpr(transpose_idx); + + return f; + } + + // TODO: Move this to a device function + StatementList perform_partial_pass_step_1_2() + { + StatementList stmts; + + // TODO: figure out factor 1 here (what happens with different in/out strides and lengths) + stmts += Declaration{stride_lds_pp, Literal{1}}; + stmts += Declaration{offset_lds_pp, thread_id * transforms_per_block_pp}; + + auto pre_post_lds_tmpl = device_lds_reg_inout_device_call_templates(); + auto pre_post_lds_args = device_lds_reg_inout_pp_device_call_arguments(); + pre_post_lds_tmpl.set_value(stride_type.name, "lds_linear ? SB_UNIT : SB_NONUNIT"); + + // TODO: handle direct_to_from_reg + StatementList preLoad; + preLoad += Call{"lds_to_reg_input_pp_step_1_2_length" + std::to_string(length) + "_device", + pre_post_lds_tmpl, + pre_post_lds_args}; + stmts += preLoad; + + for(unsigned int npass = 0; npass < factors_pp.size(); ++npass) + { + unsigned int width = factors_pp[npass]; + unsigned int height = threads_per_transform / max_factor_pp; + + auto butterfly = std::mem_fn(&StockhamKernel::butterfly_generator); + stmts += add_work(std::bind(butterfly, this, _1, _2, _3, _4, _5), + width, + height, + ThreadGuardMode::NO_GUARD); + } + + StatementList postStore; + postStore + += Call{"lds_from_reg_output_pp_step_1_2_length" + std::to_string(length) + "_device", + pre_post_lds_tmpl, + pre_post_lds_args}; + stmts += postStore; + + return stmts; + } + + // The "stacked" twiddle table starts at the second factor, since + // the first factor's values are not actually needed for + // anything. It still counts towards cumulative height, but we + // subtract it from the twiddle table offset when computing an + // index. + StatementList apply_twiddle_generator(unsigned int h, + unsigned int hr, + unsigned int width, + unsigned int dt, + Expression guard, + unsigned int cumheight, + unsigned int firstFactor) + { + if(hr == 0) + hr = h; + StatementList work; + Expression loadFlag{thread < length / width}; + for(unsigned int w = 1; w < width; ++w) + { + auto tid = thread_in_device_twd + dt + h * threads_per_transform; + auto tidx = cumheight - firstFactor + w - 1 + (width - 1) * (tid % cumheight); + auto ridx = hr * width + w; + + // TODO- Can try IntrinsicLoadToDest, but should not be a bottleneck + work += Assign(W, twiddles[tidx]); + work += Assign(t, TwiddleMultiply(R[ridx], W)); + work += Assign(R[ridx], t); + } + return work; + } + + ArgumentList device_arguments() override + { + ArgumentList args = StockhamKernel::device_arguments(); + args.append(large_twiddles); + args.append(trans_local); + args.append(thread_in_device_twd); + return args; + } + + std::vector device_call_arguments(unsigned int call_iter) override + { + std::vector args = StockhamKernel::device_call_arguments(call_iter); + auto which = Ternary{Parens{And{apply_large_twiddle, large_twiddle_base < 8}}, + Parens{large_twd_lds}, + Parens{large_twiddles}}; + args.push_back(which); + args.push_back(largeTwdBatchIsTransformCount ? batch : transform); + args.push_back(thread_in_device_twd); + return args; + } + + // TODO: Stopped here: implement device function. Kernel is not getting launched, found out why + Function generate_device_function() + { + std::string function_name + = "forward_pp_length" + std::to_string(length) + "_" + tiling_name() + "_device"; + + Function f{function_name}; + f.arguments = device_arguments(); + f.templates = device_templates(); + f.qualifier = "__device__"; + if(length == 1) + { + return f; + } + + StatementList& body = f.body; + body += Declaration{W}; + body += Declaration{t}; + body += Declaration{ + lstride, Ternary{Parens{stride_type == "SB_UNIT"}, Parens{1}, Parens{stride_lds}}}; + body += Declaration{l_offset}; + + for(unsigned int npass = 0; npass < factors.size(); ++npass) + { + // width is the butterfly width, Radix-n. + unsigned int width = factors[npass]; + // height is how many butterflies per thread will do on average + float height = static_cast(length) / width / threads_per_transform; + + unsigned int cumheight + = product(factors.begin(), + factors.begin() + npass); // cumheight is irrelevant to the above height, + // is used for twiddle multiplication and lds writing. + + body += LineBreak{}; + body += CommentLines{ + "pass " + std::to_string(npass) + ", width " + std::to_string(width), + "using " + std::to_string(threads_per_transform) + " threads we need to do " + + std::to_string(length / width) + " radix-" + std::to_string(width) + + " butterflies", + "therefore each thread will do " + std::to_string(height) + " butterflies"}; + + auto load_lds = std::mem_fn(&StockhamPartialPassKernelCC::load_lds_generator); + auto store_lds = std::mem_fn(&StockhamPartialPassKernelCC::store_lds_generator); + + if(npass > 0) + { + // internal full lds2reg (both linear/nonlinear variants) + StatementList lds2reg_full; + lds2reg_full += SyncThreads(); + lds2reg_full + += add_work(std::bind(load_lds, this, _1, _2, _3, _4, _5, Component::BOTH), + width, + height, + ThreadGuardMode::GUARD_BY_IF); + body += If{Not{lds_is_real}, lds2reg_full}; + + auto apply_twiddle + = std::mem_fn(&StockhamPartialPassKernelCC::apply_twiddle_generator); + body += add_work( + std::bind(apply_twiddle, this, _1, _2, _3, _4, _5, cumheight, factors.front()), + width, + height, + ThreadGuardMode::NO_GUARD); + } + + auto butterfly = std::mem_fn(&StockhamKernel::butterfly_generator); + body += add_work(std::bind(butterfly, this, _1, _2, _3, _4, _5), + width, + height, + ThreadGuardMode::NO_GUARD); + + if(npass == factors.size() - 1) + body += large_twiddles_multiply(width, height, cumheight); + + // internal lds store (half-with-linear and full-with-linear/nonlinear) + StatementList reg2lds_full; + StatementList reg2lds_half; + if(npass < factors.size() - 1) + { + // linear variant store (half) and load (half) + for(auto component : {Component::REAL, Component::IMAG}) + { + bool isFirstStore = (npass == 0) && (component == Component::REAL); + auto half_width = factors[npass]; + auto half_height + = static_cast(length) / half_width / threads_per_transform; + // minimize sync as possible + if(!isFirstStore) + reg2lds_half += SyncThreads(); + reg2lds_half += add_work( + std::bind(store_lds, this, _1, _2, _3, _4, _5, component, cumheight), + half_width, + half_height, + ThreadGuardMode::GUARD_BY_IF); + + half_width = factors[npass + 1]; + half_height = static_cast(length) / half_width / threads_per_transform; + reg2lds_half += SyncThreads(); + reg2lds_half + += add_work(std::bind(load_lds, this, _1, _2, _3, _4, _5, component), + half_width, + half_height, + ThreadGuardMode::GUARD_BY_IF); + } + + // internal full lds store (both linear/nonlinear variants) + if(npass == 0) + reg2lds_full += If{!direct_load_to_reg, {SyncThreads()}}; + else + reg2lds_full += SyncThreads(); + reg2lds_full += add_work( + std::bind(store_lds, this, _1, _2, _3, _4, _5, Component::BOTH, cumheight), + width, + height, + ThreadGuardMode::GUARD_BY_IF); + + body += If{Not{lds_is_real}, reg2lds_full}; + body += Else{reg2lds_half}; + } + } + return f; + } + + Function generate_global_function() override + { + Function f("forward_pp_length" + std::to_string(length) + "_" + tiling_name()); + f.qualifier = "__global__"; + f.launch_bounds = workgroup_size; + + StatementList& body = f.body; + body += CommentLines{ + "this kernel:", + " uses " + std::to_string(threads_per_transform) + " threads per transform", + " does " + std::to_string(transforms_per_block) + " transforms per thread block", + "therefore it should be called with " + std::to_string(workgroup_size) + + " threads per thread block"}; + body += Declaration{R}; + body += LDSDeclaration{scalar_type.name}; + body += Declaration{offset, 0}; + body += Declaration{offset_lds}; + body += Declaration{stride_lds}; + body += Declaration{batch}; + body += Declaration{transform}; + + // TODO- don't override, unify them + body += set_direct_to_from_registers(); + + body += Declaration{lds_is_real, Literal{"false"}}; + + body += CallbackLoadDeclaration{scalar_type.name, callback_type.name}; + body += CallbackStoreDeclaration{scalar_type.name, callback_type.name}; + + body += LineBreak{}; + body += CommentLines{"large twiddles"}; + body += large_twiddles_load(); + + body += LineBreak{}; + body += CommentLines{"offsets"}; + collect_length_stride(body); + body += calculate_offsets(); + body += LineBreak{}; + + StatementList loadlds; + loadlds += CommentLines{"load global into lds"}; + loadlds += load_from_global(false); + loadlds += LineBreak{}; + + if(!direct_to_from_reg) + { + body += loadlds; + } + else + { + StatementList loadr; + loadr += CommentLines{"load global into registers"}; + loadr += load_from_global(true); + + body += If{direct_load_to_reg, loadr}; + body += Else{loadlds}; + } + + // partial pass here + body += perform_partial_pass_step_1_2(); + + body += LineBreak{}; + body += CommentLines{"calc the thread_in_device value once and for all device funcs"}; + body += Declaration{thread_in_device, + Ternary{lds_linear, + thread_id % (threads_per_transform * max_factor_pp), + thread_id / transforms_per_block}}; + body += Declaration{thread_in_device_twd, + Parens(thread_id / max_factor_pp) % threads_per_transform}; + + // before starting the transform job (core device function) + // we call a re-load lds-to-reg function here, but it's not always doing things. + // If we're doing direct-to-reg, this function simply returns. + body += LineBreak{}; + body += CommentLines{"call a pre-load from lds to registers (if necessary)"}; + auto pre_post_lds_tmpl = device_lds_reg_inout_device_call_templates(); + auto pre_post_lds_args = device_lds_reg_inout_device_call_arguments(); + pre_post_lds_tmpl.set_value(stride_type.name, "lds_linear ? SB_UNIT : SB_NONUNIT"); + StatementList preLoad; + preLoad += Call{"lds_to_reg_input_pp_length" + std::to_string(length) + "_device", + pre_post_lds_tmpl, + pre_post_lds_args}; + if(!direct_to_from_reg) + body += preLoad; + else + body += If{!direct_load_to_reg, preLoad}; + + body += LineBreak{}; + body += CommentLines{"transform"}; + for(unsigned int c = 0; c < n_device_calls; ++c) + { + auto templates = device_call_templates(); + auto arguments = device_call_arguments(c); + + templates.set_value(stride_type.name, "lds_linear ? SB_UNIT : SB_NONUNIT"); + + body += Call{"forward_pp_length" + std::to_string(length) + "_" + tiling_name() + + "_device", + templates, + arguments}; + body += LineBreak{}; + } + + // after finishing the transform job (core device function) + // we call a post-store reg-to-lds function here, but it's not always doing things. + // If we're doing direct-from-reg, this function simply returns. + body += LineBreak{}; + body += CommentLines{"call a post-store from registers to lds (if necessary)"}; + StatementList postStore; + postStore += Call{"lds_from_reg_output_pp_length" + std::to_string(length) + "_device", + pre_post_lds_tmpl, + pre_post_lds_args}; + if(!direct_to_from_reg) + body += postStore; + else + body += If{!direct_store_from_reg, postStore}; + + body += LineBreak{}; + StatementList storelds; + storelds += LineBreak{}; + + storelds += LineBreak{}; + storelds += CommentLines{"store global"}; + storelds += SyncThreads{}; + storelds += store_to_global(false); + + if(!direct_to_from_reg) + { + body += storelds; + } + else + { + StatementList storer; + storer += CommentLines{"store registers into global"}; + storer += store_to_global(true); + + body += If{direct_store_from_reg, storer}; + body += Else{storelds}; + } + + f.templates = global_templates(); + f.arguments = global_arguments(); + return f; + } +}; diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h new file mode 100644 index 00000000000..46f839708da --- /dev/null +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -0,0 +1,35 @@ +// Copyright (C) 2021 - 2022 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once +#include "stockham_gen_rr.h" + +struct StockhamPartialPassKernelRR : public StockhamKernelRR +{ + explicit StockhamPartialPassKernelRR(const StockhamGeneratorSpecs& specs) + : StockhamKernelRR(specs) + { + } + + std::string tiling_name() override + { + return "SBRR_PP"; + } +}; diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index 60baddae7aa..9633f0bc924 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -887,7 +887,7 @@ def list_large_kernels(): NS(length=60, factors=[6, 10], use_3steps_large_twd={ 'sp': 'false', 'dp': 'false'}), NS(length=64, factors=[8, 8], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'false'}, workgroup_size=256), + 'sp': 'true', 'dp': 'false'}, workgroup_size=256, direct_to_from_reg=False), NS(length=72, factors=[8, 3, 3], use_3steps_large_twd={ 'sp': 'true', 'dp': 'false'}), NS(length=80, factors=[10, 8], use_3steps_large_twd={ diff --git a/library/src/device/kernels/common.h b/library/src/device/kernels/common.h index 832c0e9e5f2..919aa82217e 100644 --- a/library/src/device/kernels/common.h +++ b/library/src/device/kernels/common.h @@ -178,6 +178,13 @@ enum BluesteinFuseType BFT_INV_CHIRP_MUL, // fused convolution Hadamard product + inverse fft + chirp Hadamard product }; +enum PartialPassType +{ + PPT_NONE, + PPT_SBCC, + PPT_SBRR, +}; + template struct real_type; diff --git a/library/src/include/rtc_stockham_gen.h b/library/src/include/rtc_stockham_gen.h index 78fdcf4c219..5b82ca17f4b 100644 --- a/library/src/include/rtc_stockham_gen.h +++ b/library/src/include/rtc_stockham_gen.h @@ -50,6 +50,7 @@ std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, SBRC_TRANSPOSE_TYPE transpose_type, CallbackType cbtype, BluesteinFuseType fuseBlue, + PartialPassType ppType, const LoadOps& loadOps, const StoreOps& storeOps); @@ -76,6 +77,7 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, SBRC_TRANSPOSE_TYPE transpose_type, CallbackType cbtype, const BluesteinFuseType& fuseBlue, + const PartialPassType& ppType, const LoadOps& loadOps, const StoreOps& storeOps); diff --git a/library/src/rocfft_aot_helper.cpp b/library/src/rocfft_aot_helper.cpp index f0a889f4e82..0824685495e 100644 --- a/library/src/rocfft_aot_helper.cpp +++ b/library/src/rocfft_aot_helper.cpp @@ -225,8 +225,9 @@ void build_stockham_function_pool(CompileQueue& queue) // build everything in the function pool function_pool& fp = function_pool::get_function_pool(); - // fused Bluestein kernels are always built at runtime + // fused Bluestein and partial-pass kernels are always built at runtime auto fuseBlue = BluesteinFuseType::BFT_NONE; + auto ppType = PartialPassType::PPT_NONE; for(const auto& i : fp.get_map()) { @@ -242,6 +243,7 @@ void build_stockham_function_pool(CompileQueue& queue) std::copy(i.second.factors.begin(), i.second.factors.end(), std::back_inserter(factors)); StockhamGeneratorSpecs specs{factors, + {}, {}, {static_cast(precision)}, static_cast(i.second.workgroup_size), @@ -298,11 +300,13 @@ void build_stockham_function_pool(CompileQueue& queue) sbrc_trans_type, cbtype, fuseBlue, + ppType, {}, {}); std::function generate_src = [=](const std::string& kernel_name) -> std::string { StockhamGeneratorSpecs specs{factors, + {}, {}, {static_cast(precision)}, static_cast(i.second.workgroup_size), @@ -330,6 +334,7 @@ void build_stockham_function_pool(CompileQueue& queue) sbrc_trans_type, cbtype, fuseBlue, + ppType, {}, {}); }; @@ -616,8 +621,9 @@ void build_solution_kernels(CompileQueue& queue) std::vector kernel_nodes; solmap.get_all_kernels(kernel_nodes, true); - // fused Bluestein kernels are always built at runtime + // fused Bluestein and partial-pass kernels are always built at runtime auto fuseBlue = BluesteinFuseType::BFT_NONE; + auto ppType = PartialPassType::PPT_NONE; for(const SolutionNode& kernel_sol : kernel_nodes) { @@ -663,6 +669,7 @@ void build_solution_kernels(CompileQueue& queue) } StockhamGeneratorSpecs specs{factors, + {}, {}, {static_cast(precision)}, static_cast(config.workgroup_size), @@ -693,6 +700,7 @@ void build_solution_kernels(CompileQueue& queue) sbrc_trans_type, cbtype, fuseBlue, + ppType, {}, {}); @@ -718,6 +726,7 @@ void build_solution_kernels(CompileQueue& queue) sbrc_trans_type, cbtype, fuseBlue, + ppType, {}, {}); }; diff --git a/library/src/rocfft_kernel_config_search.cpp b/library/src/rocfft_kernel_config_search.cpp index 52b64ee0c5e..5fb5935db09 100644 --- a/library/src/rocfft_kernel_config_search.cpp +++ b/library/src/rocfft_kernel_config_search.cpp @@ -157,6 +157,7 @@ std::string test_kernel_src(const std::string& kernel_name, bool direct_to_from_reg) { StockhamGeneratorSpecs specs{factorization, + {}, {}, {static_cast(rocfft_precision_single)}, wgs, @@ -186,6 +187,7 @@ std::string test_kernel_src(const std::string& kernel_name, SBRC_TRANSPOSE_TYPE::NONE, CallbackType::NONE, BluesteinFuseType::BFT_NONE, + PartialPassType::PPT_NONE, {}, {}); } diff --git a/library/src/rtc_stockham_gen.cpp b/library/src/rtc_stockham_gen.cpp index 264d6e20bb8..e3310a518e3 100644 --- a/library/src/rtc_stockham_gen.cpp +++ b/library/src/rtc_stockham_gen.cpp @@ -35,6 +35,8 @@ using namespace std::placeholders; #include "device/generator/stockham_gen_cr.h" #include "device/generator/stockham_gen_rc.h" #include "device/generator/stockham_gen_rr.h" +#include "device/generator/stockham_pp_gen_cc.h" +#include "device/generator/stockham_pp_gen_rr.h" #include "device/generator/stockham_gen_2d.h" @@ -59,6 +61,7 @@ std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, SBRC_TRANSPOSE_TYPE transpose_type, CallbackType cbtype, BluesteinFuseType fuseBlue, + PartialPassType ppType, const LoadOps& loadOps, const StoreOps& storeOps) { @@ -69,6 +72,15 @@ std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, else kernel_name += "_back"; + switch(ppType) + { + case PPT_NONE: + break; + case PPT_SBCC: + case PPT_SBRR: + kernel_name += "_pp"; + } + kernel_name += "_len"; kernel_name += std::to_string(specs.length); if(scheme == CS_KERNEL_2D_SINGLE) @@ -256,10 +268,13 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, SBRC_TRANSPOSE_TYPE transpose_type, CallbackType cbtype, const BluesteinFuseType& fuseBlue, + const PartialPassType& ppType, const LoadOps& loadOps, const StoreOps& storeOps) { std::unique_ptr lds2reg, reg2lds, device; + std::unique_ptr lds2reg_pp_steps, reg2lds_pp_steps; + std::unique_ptr local_transpose_pp; std::unique_ptr lds2reg1, reg2lds1, device1; std::unique_ptr bluestein_load, bluestein_intrinsic_load; std::unique_ptr bluestein_store, bluestein_intrinsic_store; @@ -297,10 +312,21 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, { std::unique_ptr kernel; if(scheme == CS_KERNEL_STOCKHAM) - kernel = std::make_unique(specs); + { + if(ppType == PartialPassType::PPT_SBRR) + kernel = std::make_unique(specs); + else + kernel = std::make_unique(specs); + } else if(scheme == CS_KERNEL_STOCKHAM_BLOCK_CC) - kernel = std::make_unique( - specs, largeTwdBatchIsTransformCount, fuseBluestein); + { + if(ppType == PartialPassType::PPT_SBCC) + kernel = std::make_unique( + specs, largeTwdBatchIsTransformCount); + else + kernel = std::make_unique( + specs, largeTwdBatchIsTransformCount, fuseBluestein); + } else if(scheme == CS_KERNEL_STOCKHAM_BLOCK_CR) kernel = std::make_unique(specs); else if(scheme == CS_KERNEL_STOCKHAM_BLOCK_RC) @@ -315,9 +341,41 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, throw std::runtime_error("unhandled scheme"); if(transforms_per_block) *transforms_per_block = kernel->transforms_per_block; - lds2reg = std::make_unique(kernel->generate_lds_to_reg_input_function()); - reg2lds = std::make_unique(kernel->generate_lds_from_reg_output_function()); - device = std::make_unique(kernel->generate_device_function()); + + switch(ppType) + { + case PPT_NONE: + { + lds2reg = std::make_unique(kernel->generate_lds_to_reg_input_function()); + reg2lds = std::make_unique(kernel->generate_lds_from_reg_output_function()); + device = std::make_unique(kernel->generate_device_function()); + break; + } + case PPT_SBCC: + { + auto kernel_pp = static_cast(kernel.get()); + + lds2reg + = std::make_unique(kernel_pp->generate_lds_to_reg_input_pp_function()); + reg2lds + = std::make_unique(kernel_pp->generate_lds_from_reg_output_pp_function()); + lds2reg_pp_steps = std::make_unique( + kernel_pp->generate_lds_to_reg_input_step_1_2_function()); + reg2lds_pp_steps = std::make_unique( + kernel_pp->generate_lds_from_reg_output_pp_step_1_2_function()); + local_transpose_pp + = std::make_unique(kernel_pp->generate_local_transpose_pp_function()); + device = std::make_unique(kernel_pp->generate_device_function()); + break; + } + case PPT_SBRR: + { + lds2reg = std::make_unique(kernel->generate_lds_to_reg_input_function()); + reg2lds = std::make_unique(kernel->generate_lds_from_reg_output_function()); + device = std::make_unique(kernel->generate_device_function()); + break; + } + } if(fuseBluestein) { @@ -384,6 +442,10 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, src += large_twiddles_h; // append the neccessary functions only append_radix_h(src, all_factors); + + if(ppType != PPT_NONE) + append_radix_h(src, specs.factors_pp); + // SBCCs don't need this if(scheme != CS_KERNEL_STOCKHAM_BLOCK_CC) src += real2complex_device_h; @@ -391,6 +453,15 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, src += lds2reg->render(); src += reg2lds->render(); src += device->render(); + + // TODO: remove null pointer check + if(ppType != PPT_NONE && lds2reg_pp_steps && reg2lds_pp_steps && local_transpose_pp) + { + src += lds2reg_pp_steps->render(); + src += reg2lds_pp_steps->render(); + src += local_transpose_pp->render(); + } + if(lds2reg1) src += lds2reg1->render(); if(reg2lds1) diff --git a/library/src/rtc_stockham_kernel.cpp b/library/src/rtc_stockham_kernel.cpp index f85d0646c20..39dbe8e97ab 100644 --- a/library/src/rtc_stockham_kernel.cpp +++ b/library/src/rtc_stockham_kernel.cpp @@ -38,6 +38,10 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& std::optional specs; std::optional specs2d; + std::vector factors_pp; + std::copy( + node.kernelFactorsPP.begin(), node.kernelFactorsPP.end(), std::back_inserter(factors_pp)); + // SBRC variants look in the function pool for plain BLOCK_RC to // learn the block width, then decide on the transpose type once // that's known. @@ -65,10 +69,6 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& case CS_KERNEL_STOCKHAM_BLOCK_CR: case CS_KERNEL_STOCKHAM_BLOCK_RC: { - // Partial-pass nodes have their own generators - if(node.applyPartialPass) - return generator; - // for sbrc variant, the sbrcTranstype should be assigned when we are here // since the value is assigned in KernelCheck() if((pool_scheme == CS_KERNEL_STOCKHAM_BLOCK_RC) && (node.sbrcTranstype == NONE)) @@ -91,6 +91,7 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& std::vector precisions = {static_cast(node.precision)}; specs.emplace(factors, + factors_pp, std::vector(), precisions, static_cast(kernel->workgroup_size), @@ -129,6 +130,7 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& } specs.emplace(factors1d, + factors_pp, factors2d, precisions, static_cast(kernel->workgroup_size), @@ -137,6 +139,7 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& specs->half_lds = kernel->half_lds; specs2d.emplace(factors2d, + factors_pp, factors1d, precisions, static_cast(kernel->workgroup_size), @@ -168,6 +171,17 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& bool unit_stride = node.inStride.front() == 1 && node.outStride.front() == 1; + auto ppType = PartialPassType::PPT_NONE; + if(node.applyPartialPass) + { + if(node.scheme == CS_KERNEL_STOCKHAM_BLOCK_CC) + ppType = PartialPassType::PPT_SBCC; + else if(node.scheme == CS_KERNEL_STOCKHAM) + ppType = PartialPassType::PPT_SBRR; + else + throw std::runtime_error("Invalid scheme for partial pass"); + } + generator.generate_name = [=, &node]() { return stockham_rtc_kernel_name(*specs, specs2d ? *specs2d : *specs, @@ -187,6 +201,7 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& node.sbrcTranstype, node.GetCallbackType(enable_callbacks), node.fuseBlue, + ppType, node.loadOps, node.storeOps); }; @@ -216,6 +231,7 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& node.sbrcTranstype, node.GetCallbackType(enable_callbacks), node.fuseBlue, + ppType, node.loadOps, node.storeOps); }; From 1bc8bee0512a27bf86d24e4c7d3bcaf422da9a79 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 13 Feb 2025 14:07:52 -0700 Subject: [PATCH 02/69] Add SBRR partial pass kernel generator. --- .../src/device/generator/stockham_pp_gen_cc.h | 35 +- .../src/device/generator/stockham_pp_gen_rr.h | 466 +++++++++++++++++- library/src/device/kernel-generator.py | 2 +- library/src/rtc_stockham_gen.cpp | 39 +- 4 files changed, 512 insertions(+), 30 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index ed0962e7826..947cdf77a84 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -21,6 +21,13 @@ #pragma once #include "stockham_gen_cc.h" +// TODO: - Kernel is not getting launched, found out why +// - Check launch bounds. +// - Implementation here used a kernel with work_group_size = 256, however, the prototype was using 64. +// Change kernel_generator.py to use 64 and fix all the issues, comparing again with the prototype. +// - Start testing with different threads_per_transform once the original configuration works. +// - Then test with other lengths and direct_from_reg=true, half_lds=true, etc. + struct StockhamPartialPassKernelCC : public StockhamKernelCC { explicit StockhamPartialPassKernelCC(const StockhamGeneratorSpecs& specs, @@ -102,7 +109,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC return load; } - StatementList store_pp_step_1_2_lds_generator( + StatementList store_pp_step_3_4_lds_generator( unsigned int h, unsigned int hr, unsigned int width, unsigned int dt, Expression guard) { if(hr == 0) @@ -116,10 +123,10 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC return work; } - Function generate_lds_from_reg_output_pp_step_1_2_function() + Function generate_lds_from_reg_output_pp_step_3_4_function() { std::string function_name - = "lds_from_reg_output_pp_step_1_2_length" + std::to_string(length) + "_device"; + = "lds_from_reg_output_pp_step_3_4_length" + std::to_string(length) + "_device"; Function f{function_name}; f.templates = device_lds_reg_inout_templates(); @@ -130,7 +137,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC body += Declaration{ lstride, Ternary{Parens{stride_type == "SB_UNIT"}, Parens{1}, Parens{stride_lds}}}; - auto store_lds = std::mem_fn(&StockhamPartialPassKernelCC::store_pp_step_1_2_lds_generator); + auto store_lds = std::mem_fn(&StockhamPartialPassKernelCC::store_pp_step_3_4_lds_generator); // last pass of store (full) unsigned int width = factors.back(); float height = static_cast(length) / width / threads_per_transform; @@ -142,7 +149,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC return f; } - StatementList load_lds_step_1_2_generator( + StatementList load_lds_step_3_4_generator( unsigned int h, unsigned int hr, unsigned int width, unsigned int dt, Expression guard) { if(hr == 0) @@ -167,10 +174,10 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC return {R, lds_complex, stride_lds_pp, offset_lds_pp}; } - Function generate_lds_to_reg_input_step_1_2_function() + Function generate_lds_to_reg_input_step_3_4_function() { std::string function_name - = "lds_to_reg_input_pp_step_1_2_length" + std::to_string(length) + "_device"; + = "lds_to_reg_input_pp_step_3_4_length" + std::to_string(length) + "_device"; Function f{function_name}; f.templates = device_lds_reg_inout_templates(); @@ -179,7 +186,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC StatementList& body = f.body; - auto load_lds = std::mem_fn(&StockhamPartialPassKernelCC::load_lds_step_1_2_generator); + auto load_lds = std::mem_fn(&StockhamPartialPassKernelCC::load_lds_step_3_4_generator); // first pass of load (full) unsigned int width = factors[0]; float height = static_cast(length) / width / threads_per_transform; @@ -745,7 +752,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC } // TODO: Move this to a device function - StatementList perform_partial_pass_step_1_2() + StatementList perform_partial_pass_step_3_4() { StatementList stmts; @@ -759,14 +766,15 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC // TODO: handle direct_to_from_reg StatementList preLoad; - preLoad += Call{"lds_to_reg_input_pp_step_1_2_length" + std::to_string(length) + "_device", + preLoad += Call{"lds_to_reg_input_pp_step_3_4_length" + std::to_string(length) + "_device", pre_post_lds_tmpl, pre_post_lds_args}; stmts += preLoad; for(unsigned int npass = 0; npass < factors_pp.size(); ++npass) { - unsigned int width = factors_pp[npass]; + unsigned int width = factors_pp[npass]; + // TODO: revisit this. Different from same function in stockham_pp_gen_rr.h unsigned int height = threads_per_transform / max_factor_pp; auto butterfly = std::mem_fn(&StockhamKernel::butterfly_generator); @@ -778,7 +786,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC StatementList postStore; postStore - += Call{"lds_from_reg_output_pp_step_1_2_length" + std::to_string(length) + "_device", + += Call{"lds_from_reg_output_pp_step_3_4_length" + std::to_string(length) + "_device", pre_post_lds_tmpl, pre_post_lds_args}; stmts += postStore; @@ -838,7 +846,6 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC return args; } - // TODO: Stopped here: implement device function. Kernel is not getting launched, found out why Function generate_device_function() { std::string function_name @@ -1021,7 +1028,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC } // partial pass here - body += perform_partial_pass_step_1_2(); + body += perform_partial_pass_step_3_4(); body += LineBreak{}; body += CommentLines{"calc the thread_in_device value once and for all device funcs"}; diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index 46f839708da..340975e141a 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -21,15 +21,477 @@ #pragma once #include "stockham_gen_rr.h" +// TODO: transform_per_block or max_factor_pp? Revisit all usages + struct StockhamPartialPassKernelRR : public StockhamKernelRR { explicit StockhamPartialPassKernelRR(const StockhamGeneratorSpecs& specs) : StockhamKernelRR(specs) { + // TODO: revisit this. Test with factors_pp.size() > 1 + max_factor_pp = *std::max_element(specs.factors_pp.begin(), specs.factors_pp.end()); + } + + unsigned int max_factor_pp; + Variable offset_pp{"offset_pp", "size_t"}; + Variable stride_lds_pp{"stride_lds_pp", "size_t"}; + Variable offset_lds_pp{"offset_lds_pp", "size_t"}; + + // TODO: this should be __restrict__ + Variable twiddles_pp{"twiddles_pp", "const scalar_type", true, true}; + + StatementList calculate_offsets() override + { + Variable d{"d", "int"}; + Variable index_along_d{"index_along_d", "size_t"}; + Variable remaining{"remaining", "size_t"}; + Variable remaining_pp{"remaining_pp", "size_t"}; + + StatementList stmts; + stmts += Declaration{thread}; + stmts += Declaration(remaining); + stmts += Declaration(index_along_d); + stmts += Declaration(remaining_pp); + stmts += Declaration(offset_pp); + stmts += Assign{transform, + block_id * transforms_per_block + thread_id / threads_per_transform}; + stmts += Assign{remaining, transform}; + stmts += Assign{remaining_pp, + length * Parens(transform / length) + + Parens(transform % length) / max_factor_pp + + Parens(transform * (length / max_factor_pp)) % length}; + + stmts += For{d, + 1, + d < dim, + 1, + { + Assign{remaining, remaining / lengths[d]}, + Assign{index_along_d, remaining_pp % lengths[d]}, + Assign{remaining_pp, remaining_pp / lengths[d]}, + Assign{offset_pp, offset_pp + index_along_d * stride[d]}, + }}; + + stmts += Assign{batch, remaining}; + stmts += Assign{offset_pp, offset_pp + batch * stride[dim]}; + stmts += Assign{stride_lds, (length + get_lds_padding())}; + stmts += Assign{offset_lds, stride_lds * Parens{transform % transforms_per_block}}; + + stmts += Declaration{inbound, batch < nbatch}; + + return stmts; + } + + StatementList load_from_global(bool load_registers) override + { + StatementList stmts; + stmts += Assign{thread, thread_id % threads_per_transform}; + + if(!load_registers) + { + unsigned int width = threads_per_transform; + unsigned int height = length / width; + + for(unsigned int h = 0; h < height; ++h) + { + auto idx = thread + h * width; + stmts += Assign{lds_complex[offset_lds + idx], + LoadGlobal{buf, offset_pp + idx * stride0}}; + } + stmts += LineBreak(); + stmts += CommentLines{"append extra global loading for C2Real pre-process only"}; + + StatementList stmts_c2real_pre; + stmts_c2real_pre += CommentLines{ + "use the last thread of each transform to load one more element per row"}; + stmts_c2real_pre += If{ + thread == threads_per_transform - 1, + {Assign{lds_complex[offset_lds + thread + (height - 1) * width + 1], + LoadGlobal{buf, offset + (thread + (height - 1) * width + 1) * stride0}}}}; + stmts += If{embedded_type == Literal{"EmbeddedType::C2Real_PRE"}, stmts_c2real_pre}; + } + else + { + unsigned int width = factors[0]; + auto height = static_cast(length) / width / threads_per_transform; + + auto load_global = std::mem_fn(&StockhamKernel::load_global_generator); + stmts += add_work(std::bind(load_global, this, _1, _2, _3, _4, _5), + width, + height, + ThreadGuardMode::GUARD_BY_IF); + } + + return {If{inbound, stmts}}; + } + + StatementList store_to_global(bool store_registers) override + { + StatementList stmts; + + if(!store_registers) + { + auto width = threads_per_transform; + auto height = length / width; + for(unsigned int h = 0; h < height; ++h) + { + auto idx = thread + h * width; + stmts += StoreGlobal{buf, offset_pp + idx * stride0, lds_complex[offset_lds + idx]}; + } + + stmts += LineBreak{}; + stmts += CommentLines{"append extra global write for Real2C post-process only"}; + StatementList stmts_real2c_post; + stmts_real2c_post += CommentLines{ + "use the last thread of each transform to write one more element per row"}; + stmts_real2c_post + += If{Equal{thread, threads_per_transform - 1}, + {StoreGlobal{buf, + offset + (thread + (height - 1) * width + 1) * stride0, + lds_complex[offset_lds + thread + (height - 1) * width + 1]}}}; + stmts += If{Equal{embedded_type, "EmbeddedType::Real2C_POST"}, stmts_real2c_post}; + } + else + { + auto width = factors.back(); + auto cumheight = product(factors.begin(), factors.begin() + (factors.size() - 1)); + auto height = static_cast(length) / width / threads_per_transform; + + auto store_global = std::mem_fn(&StockhamKernel::store_global_generator); + stmts += add_work(std::bind(store_global, this, _1, _2, _3, _4, _5, cumheight), + width, + height, + ThreadGuardMode::GUARD_BY_IF); + } + + return {If{inbound, stmts}}; } - std::string tiling_name() override + StatementList load_lds_step_1_2_generator( + unsigned int h, unsigned int hr, unsigned int width, unsigned int dt, Expression guard) { - return "SBRR_PP"; + if(hr == 0) + hr = h; + StatementList work; + + for(unsigned int w = 0; w < width; ++w) + //TODO: lstride not used here, address to have input/output strides working + work += Assign(R[w], lds_complex[offset_lds + (w * stride_lds)]); + + return work; + } + + ArgumentList device_lds_reg_inout_pp_arguments() + { + ArgumentList args{R, lds_complex, stride_lds, offset_lds}; + return args; + } + + std::vector device_lds_reg_inout_pp_device_call_arguments() + { + return {R, lds_complex, stride_lds_pp, offset_lds_pp}; + } + + Function generate_lds_to_reg_input_step_1_2_function() + { + std::string function_name + = "lds_to_reg_input_pp_step_1_2_length" + std::to_string(length) + "_device"; + + Function f{function_name}; + f.templates = device_lds_reg_inout_templates(); + f.arguments = device_lds_reg_inout_pp_arguments(); + f.qualifier = "__device__"; + + StatementList& body = f.body; + + auto load_lds = std::mem_fn(&StockhamPartialPassKernelRR::load_lds_step_1_2_generator); + // first pass of load (full) + // TODO: revisit width. it used to be factors[0] + unsigned int width = max_factor_pp; + float height = static_cast(length) / width / threads_per_transform; + body += SyncThreads(); + body += add_work(std::bind(load_lds, this, _1, _2, _3, _4, _5), + width, + height, + ThreadGuardMode::NO_GUARD); + + return f; + } + + StatementList store_pp_step_1_2_lds_generator( + unsigned int h, unsigned int hr, unsigned int width, unsigned int dt, Expression guard) + { + if(hr == 0) + hr = h; + StatementList work; + + for(unsigned int w = 0; w < width; ++w) + //TODO: lstride not used here, address to have input/output strides working + work += Assign(lds_complex[offset_lds + (w * stride_lds)], R[w]); + + return work; + } + + Function generate_lds_from_reg_output_pp_step_1_2_function() + { + std::string function_name + = "lds_from_reg_output_pp_step_1_2_length" + std::to_string(length) + "_device"; + + Function f{function_name}; + f.templates = device_lds_reg_inout_templates(); + f.arguments = device_lds_reg_inout_pp_arguments(); + f.qualifier = "__device__"; + + StatementList& body = f.body; + + auto store_lds = std::mem_fn(&StockhamPartialPassKernelRR::store_pp_step_1_2_lds_generator); + // last pass of store (full) + // TODO: revisit width. it used to be factors.back() + unsigned int width = max_factor_pp; + float height = static_cast(length) / width / threads_per_transform; + body += SyncThreads(); + body += add_work(std::bind(store_lds, this, _1, _2, _3, _4, _5), + width, + height, + ThreadGuardMode::NO_GUARD); + return f; + } + + Function generate_twiddle_multiply_pp_function() + { + std::string function_name + = "twiddle_multiply_pp_length" + std::to_string(length) + "_device"; + + Function f{function_name}; + + TemplateList tpls = {scalar_type}; + f.templates = tpls; + + f.arguments = ArgumentList{R, thread, twiddles_pp}; + + f.return_type = "void"; + f.qualifier = "__device__"; + + StatementList& body = f.body; + + for(unsigned int w = 0; w < max_factor_pp; ++w) + body += Assign(R[w], twiddles_pp[thread * length + w] * R[w]); + + return f; + } + + // TODO: Move this to a device function + StatementList perform_partial_pass_step_1_2() + { + StatementList stmts; + + // TODO: figure out factor 1 here (what happens with different in/out strides and lengths) + stmts += Declaration{stride_lds_pp, length}; + stmts += Declaration{offset_lds_pp, + Parens(block_id * transforms_per_block + thread_id) % length}; + + auto pre_post_lds_tmpl = device_lds_reg_inout_device_call_templates(); + auto pre_post_lds_args = device_lds_reg_inout_pp_device_call_arguments(); + pre_post_lds_tmpl.set_value(stride_type.name, "lds_linear ? SB_UNIT : SB_NONUNIT"); + + // TODO: handle direct_to_from_reg + StatementList preLoad; + preLoad += Call{"lds_to_reg_input_pp_step_1_2_length" + std::to_string(length) + "_device", + pre_post_lds_tmpl, + pre_post_lds_args}; + stmts += preLoad; + + for(unsigned int npass = 0; npass < factors_pp.size(); ++npass) + { + unsigned int width = factors_pp[npass]; + // TODO: revisit this. Different from same function in stockham_pp_gen_cc.h + unsigned int height = transforms_per_block / max_factor_pp; + + auto butterfly = std::mem_fn(&StockhamKernel::butterfly_generator); + stmts += add_work(std::bind(butterfly, this, _1, _2, _3, _4, _5), + width, + height, + ThreadGuardMode::NO_GUARD); + } + + TemplateList pre_twd_mul_tmpl = TemplateList{scalar_type}; + std::vector pre_twd_mul_args + = {R, block_id % (length / max_factor_pp), twiddles_pp}; + StatementList twdMul; + twdMul += Call{"twiddle_multiply_pp_length" + std::to_string(length) + "_device", + pre_twd_mul_tmpl, + pre_twd_mul_args}; + + stmts += twdMul; + + StatementList postStore; + postStore + += Call{"lds_from_reg_output_pp_step_1_2_length" + std::to_string(length) + "_device", + pre_post_lds_tmpl, + pre_post_lds_args}; + stmts += postStore; + + return stmts; + + return stmts; + } + + ArgumentList global_arguments() override + { + auto arguments + = static_dim + ? ArgumentList{twiddles_pp, twiddles, lengths, stride, nbatch, lds_padding} + : ArgumentList{twiddles_pp, twiddles, dim, lengths, stride, nbatch, lds_padding}; + for(const auto& arg : get_callback_args().arguments) + arguments.append(arg); + arguments.append(buf); + return arguments; + } + + Function generate_global_function() override + { + Function f("forward_length" + std::to_string(length) + "_" + tiling_name()); + f.qualifier = "__global__"; + f.launch_bounds = workgroup_size; + + StatementList& body = f.body; + body += CommentLines{ + "this kernel:", + " uses " + std::to_string(threads_per_transform) + " threads per transform", + " does " + std::to_string(transforms_per_block) + " transforms per thread block", + "therefore it should be called with " + std::to_string(workgroup_size) + + " threads per thread block"}; + body += Declaration{R}; + body += LDSDeclaration{scalar_type.name}; + body += Declaration{offset, 0}; + body += Declaration{offset_lds}; + body += Declaration{stride_lds}; + body += Declaration{batch}; + body += Declaration{transform}; + + // TODO- don't override, unify them + body += set_direct_to_from_registers(); + + // half-lds + body += set_lds_is_real(); + + body += CallbackLoadDeclaration{scalar_type.name, callback_type.name}; + body += CallbackStoreDeclaration{scalar_type.name, callback_type.name}; + + body += LineBreak{}; + body += CommentLines{"large twiddles"}; + body += large_twiddles_load(); + + body += LineBreak{}; + body += CommentLines{"offsets"}; + collect_length_stride(body); + body += calculate_offsets(); + body += LineBreak{}; + + StatementList loadlds; + loadlds += CommentLines{"load global into lds"}; + loadlds += load_from_global(false); + loadlds += LineBreak{}; + // handle even-length real to complex pre-process in lds before transform + loadlds += real_trans_pre_post(ProcessingType::PRE); + + if(!direct_to_from_reg) + { + body += loadlds; + } + else + { + StatementList loadr; + loadr += CommentLines{"load global into registers"}; + loadr += load_from_global(true); + + body += If{direct_load_to_reg, loadr}; + body += Else{loadlds}; + } + + body += LineBreak{}; + body += CommentLines{"calc the thread_in_device value once and for all device funcs"}; + body += Declaration{thread_in_device, + Ternary{lds_linear, + thread_id % threads_per_transform, + thread_id / transforms_per_block}}; + + // before starting the transform job (core device function) + // we call a re-load lds-to-reg function here, but it's not always doing things. + // If we're doing direct-to-reg, this function simply returns. + body += LineBreak{}; + body += CommentLines{"call a pre-load from lds to registers (if necessary)"}; + auto pre_post_lds_tmpl = device_lds_reg_inout_device_call_templates(); + auto pre_post_lds_args = device_lds_reg_inout_device_call_arguments(); + pre_post_lds_tmpl.set_value(stride_type.name, "lds_linear ? SB_UNIT : SB_NONUNIT"); + StatementList preLoad; + preLoad += Call{"lds_to_reg_input_length" + std::to_string(length) + "_device", + pre_post_lds_tmpl, + pre_post_lds_args}; + if(!direct_to_from_reg) + body += preLoad; + else + body += If{!direct_load_to_reg, preLoad}; + + body += LineBreak{}; + body += CommentLines{"transform"}; + for(unsigned int c = 0; c < n_device_calls; ++c) + { + auto templates = device_call_templates(); + auto arguments = device_call_arguments(c); + + templates.set_value(stride_type.name, "lds_linear ? SB_UNIT : SB_NONUNIT"); + + body + += Call{"forward_length" + std::to_string(length) + "_" + tiling_name() + "_device", + templates, + arguments}; + body += LineBreak{}; + } + + // after finishing the transform job (core device function) + // we call a post-store reg-to-lds function here, but it's not always doing things. + // If we're doing direct-from-reg, this function simply returns. + body += LineBreak{}; + body += CommentLines{"call a post-store from registers to lds (if necessary)"}; + StatementList postStore; + postStore += Call{"lds_from_reg_output_length" + std::to_string(length) + "_device", + pre_post_lds_tmpl, + pre_post_lds_args}; + if(!direct_to_from_reg) + body += postStore; + else + body += If{!direct_store_from_reg, postStore}; + + // partial pass here + body += perform_partial_pass_step_1_2(); + + body += LineBreak{}; + StatementList storelds; + storelds += LineBreak{}; + // handle even-length complex to real post-process in lds after transform + storelds += real_trans_pre_post(ProcessingType::POST); + storelds += LineBreak{}; + storelds += CommentLines{"store global"}; + storelds += SyncThreads{}; + storelds += store_to_global(false); + + if(!direct_to_from_reg) + { + body += storelds; + } + else + { + StatementList storer; + storer += CommentLines{"store registers into global"}; + storer += store_to_global(true); + + body += If{direct_store_from_reg, storer}; + body += Else{storelds}; + } + + f.templates = global_templates(); + f.arguments = global_arguments(); + return f; } }; diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index 9633f0bc924..729e6aced98 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -309,7 +309,7 @@ def list_small_kernels(): NS(length= 56, workgroup_size=128, threads_per_transform= 8, factors=(7, 8)), NS(length= 60, workgroup_size= 64, threads_per_transform= 10, factors=(6, 10)), NS(length= 63, workgroup_size=256, threads_per_transform= 21, factors=(3, 3, 7), half_lds=False, runtime_compile=True), - NS(length= 64, workgroup_size= 64, threads_per_transform= 16, factors=(4, 4, 4), half_lds=False, direct_to_from_reg=True), + NS(length= 64, workgroup_size=128, threads_per_transform= 8, factors=(8, 8), half_lds=False, direct_to_from_reg=False), NS(length= 65, workgroup_size=256, threads_per_transform= 13, factors=(13, 5), runtime_compile=True), NS(length= 66, workgroup_size=256, threads_per_transform= 11, factors=(6, 11), half_lds=False, runtime_compile=True), NS(length= 68, workgroup_size=256, threads_per_transform= 17, factors=(17, 4), runtime_compile=True), diff --git a/library/src/rtc_stockham_gen.cpp b/library/src/rtc_stockham_gen.cpp index e3310a518e3..b6c75ac3147 100644 --- a/library/src/rtc_stockham_gen.cpp +++ b/library/src/rtc_stockham_gen.cpp @@ -274,7 +274,7 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, { std::unique_ptr lds2reg, reg2lds, device; std::unique_ptr lds2reg_pp_steps, reg2lds_pp_steps; - std::unique_ptr local_transpose_pp; + std::unique_ptr twiddle_multiply_pp, local_transpose_pp; std::unique_ptr lds2reg1, reg2lds1, device1; std::unique_ptr bluestein_load, bluestein_intrinsic_load; std::unique_ptr bluestein_store, bluestein_intrinsic_store; @@ -351,6 +351,22 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, device = std::make_unique(kernel->generate_device_function()); break; } + case PPT_SBRR: + { + auto kernel_pp = static_cast(kernel.get()); + + lds2reg = std::make_unique(kernel_pp->generate_lds_to_reg_input_function()); + reg2lds + = std::make_unique(kernel_pp->generate_lds_from_reg_output_function()); + lds2reg_pp_steps = std::make_unique( + kernel_pp->generate_lds_to_reg_input_step_1_2_function()); + reg2lds_pp_steps = std::make_unique( + kernel_pp->generate_lds_from_reg_output_pp_step_1_2_function()); + twiddle_multiply_pp + = std::make_unique(kernel_pp->generate_twiddle_multiply_pp_function()); + device = std::make_unique(kernel_pp->generate_device_function()); + break; + } case PPT_SBCC: { auto kernel_pp = static_cast(kernel.get()); @@ -360,21 +376,14 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, reg2lds = std::make_unique(kernel_pp->generate_lds_from_reg_output_pp_function()); lds2reg_pp_steps = std::make_unique( - kernel_pp->generate_lds_to_reg_input_step_1_2_function()); + kernel_pp->generate_lds_to_reg_input_step_3_4_function()); reg2lds_pp_steps = std::make_unique( - kernel_pp->generate_lds_from_reg_output_pp_step_1_2_function()); + kernel_pp->generate_lds_from_reg_output_pp_step_3_4_function()); local_transpose_pp = std::make_unique(kernel_pp->generate_local_transpose_pp_function()); device = std::make_unique(kernel_pp->generate_device_function()); break; } - case PPT_SBRR: - { - lds2reg = std::make_unique(kernel->generate_lds_to_reg_input_function()); - reg2lds = std::make_unique(kernel->generate_lds_from_reg_output_function()); - device = std::make_unique(kernel->generate_device_function()); - break; - } } if(fuseBluestein) @@ -454,12 +463,16 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, src += reg2lds->render(); src += device->render(); - // TODO: remove null pointer check - if(ppType != PPT_NONE && lds2reg_pp_steps && reg2lds_pp_steps && local_transpose_pp) + if(ppType != PPT_NONE) { src += lds2reg_pp_steps->render(); src += reg2lds_pp_steps->render(); - src += local_transpose_pp->render(); + + if(ppType == PPT_SBRR) + src += twiddle_multiply_pp->render(); + + if(ppType == PPT_SBCC) + src += local_transpose_pp->render(); } if(lds2reg1) From 2b0c5d35091659c26347f6f35a45955e86d50ec8 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 14 Feb 2025 14:24:10 -0700 Subject: [PATCH 03/69] - Remove fixed length-64 partial pass kernels. --- library/src/CMakeLists.txt | 2 - .../include/rtc_partial_pass_sbcc_64_64_64.h | 56 - .../include/rtc_partial_pass_sbrr_64_64_64.h | 55 - library/src/rtc_kernel.cpp | 8 - .../src/rtc_partial_pass_sbcc_64_64_64.cpp | 1149 ----------------- .../src/rtc_partial_pass_sbrr_64_64_64.cpp | 1019 --------------- library/src/rtc_stockham_kernel.cpp | 4 +- library/src/tree_node_1D.cpp | 43 +- 8 files changed, 10 insertions(+), 2326 deletions(-) delete mode 100644 library/src/include/rtc_partial_pass_sbcc_64_64_64.h delete mode 100644 library/src/include/rtc_partial_pass_sbrr_64_64_64.h delete mode 100644 library/src/rtc_partial_pass_sbcc_64_64_64.cpp delete mode 100644 library/src/rtc_partial_pass_sbrr_64_64_64.cpp diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index 0be67d2f95e..a1d528b3eaa 100644 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -296,8 +296,6 @@ add_library( rocfft-rtc-launch OBJECT rtc_transpose_kernel.cpp rtc_twiddle_kernel.cpp rtc_chirp_kernel.cpp - rtc_partial_pass_sbcc_64_64_64.cpp - rtc_partial_pass_sbrr_64_64_64.cpp load_store_ops_kernel.cpp tree_node_callback.cpp ) diff --git a/library/src/include/rtc_partial_pass_sbcc_64_64_64.h b/library/src/include/rtc_partial_pass_sbcc_64_64_64.h deleted file mode 100644 index 829692cd6a3..00000000000 --- a/library/src/include/rtc_partial_pass_sbcc_64_64_64.h +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#ifndef RTC_PARTIAL_PASS_SBCC_64_64_64_GEN -#define RTC_PARTIAL_PASS_SBCC_64_64_64_GEN - -#include "rocfft/rocfft.h" -#include "rtc_kernel.h" -#include "rtc_partial_pass_sbcc_64_64_64.h" -#include -#include - -static const unsigned int PARTIAL_PASS_SBCC_64_64_64_THREADS = 64; - -// generate name for partial-pass sbcc 64x64x64 compute kernel -std::string partial_pass_64_64_64_sbcc_rtc_kernel_name(rocfft_precision precision); -// generate source for partial-pass sbcc 64x64x64 compute kernel -std::string partial_pass_64_64_64_sbcc_rtc(const std::string& kernel_name, - rocfft_precision precision); - -struct RTCKernelPartialPassSBCC64Cubed : public RTCKernel -{ - static RTCKernel::RTCGenerator generate_from_node(const LeafNode& node, - const std::string& gpu_arch, - bool enable_callbacks); - - virtual RTCKernelArgs get_launch_args(DeviceCallIn& data) override; - -protected: - RTCKernelPartialPassSBCC64Cubed(const std::string& kernel_name, - const std::vector& code, - dim3 gridDim, - dim3 blockDim) - : RTCKernel(kernel_name, code, gridDim, blockDim) - { - } -}; - -#endif // RTC_PARTIAL_PASS_SBCC_64_64_64_GEN diff --git a/library/src/include/rtc_partial_pass_sbrr_64_64_64.h b/library/src/include/rtc_partial_pass_sbrr_64_64_64.h deleted file mode 100644 index c321edce7e8..00000000000 --- a/library/src/include/rtc_partial_pass_sbrr_64_64_64.h +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#ifndef ROCFFT_RTC_PARTIAL_PASS_SBRR_64_64_64_KERNEL_H -#define ROCFFT_RTC_PARTIAL_PASS_SBRR_64_64_64_KERNEL_H - -#include "rocfft/rocfft.h" -#include "rtc_kernel.h" -#include -#include - -static const unsigned int PARTIAL_PASS_SBRR_64_64_64_THREADS = 128; - -// generate name for partial-pass sbrr 64x64x64 compute kernel -std::string partial_pass_64_64_64_sbrr_rtc_kernel_name(rocfft_precision precision); -// generate source for partial-pass sbrr 64x64x64 compute kernel -std::string partial_pass_64_64_64_sbrr_rtc(const std::string& kernel_name, - rocfft_precision precision); - -struct RTCKernelPartialPassSBRR64Cubed : public RTCKernel -{ - static RTCKernel::RTCGenerator generate_from_node(const LeafNode& node, - const std::string& gpu_arch, - bool enable_callbacks); - - virtual RTCKernelArgs get_launch_args(DeviceCallIn& data) override; - -protected: - RTCKernelPartialPassSBRR64Cubed(const std::string& kernel_name, - const std::vector& code, - dim3 gridDim, - dim3 blockDim) - : RTCKernel(kernel_name, code, gridDim, blockDim) - { - } -}; - -#endif // ROCFFT_RTC_PARTIAL_PASS_SBRR_64_64_64_KERNEL_H diff --git a/library/src/rtc_kernel.cpp b/library/src/rtc_kernel.cpp index 42a14eb31fe..8f72345d027 100644 --- a/library/src/rtc_kernel.cpp +++ b/library/src/rtc_kernel.cpp @@ -29,8 +29,6 @@ #include "logging.h" #include "rtc_bluestein_kernel.h" #include "rtc_cache.h" -#include "rtc_partial_pass_sbcc_64_64_64.h" -#include "rtc_partial_pass_sbrr_64_64_64.h" #include "rtc_realcomplex_kernel.h" #include "rtc_stockham_kernel.h" #include "rtc_transpose_kernel.h" @@ -153,12 +151,6 @@ std::shared_future> generator = RTCKernelBluesteinSingle::generate_from_node(node, gpu_arch, enable_callbacks); if(!generator.valid()) generator = RTCKernelBluesteinMulti::generate_from_node(node, gpu_arch, enable_callbacks); - if(!generator.valid()) - generator - = RTCKernelPartialPassSBRR64Cubed::generate_from_node(node, gpu_arch, enable_callbacks); - if(!generator.valid()) - generator - = RTCKernelPartialPassSBCC64Cubed::generate_from_node(node, gpu_arch, enable_callbacks); if(generator.valid()) { diff --git a/library/src/rtc_partial_pass_sbcc_64_64_64.cpp b/library/src/rtc_partial_pass_sbcc_64_64_64.cpp deleted file mode 100644 index 3f13e3a2431..00000000000 --- a/library/src/rtc_partial_pass_sbcc_64_64_64.cpp +++ /dev/null @@ -1,1149 +0,0 @@ -#include "rtc_partial_pass_sbcc_64_64_64.h" -#include "../../shared/arithmetic.h" -#include "device/kernel-generator-embed.h" -#include "include/kernel_launch.h" -#include "rtc_kernel.h" -#include "tree_node.h" -#include - -std::string partial_pass_64_64_64_sbcc_rtc_kernel_name(rocfft_precision precision, - int direction, - rocfft_result_placement placement, - rocfft_array_type inArrayType, - rocfft_array_type outArrayType, - CallbackType cbtype) -{ - std::string kernel_name = "sbcc_64_64_64_partial_pass"; - - if(direction == -1) - kernel_name += "_fwd"; - else - kernel_name += "_bck"; - - if(placement == rocfft_placement_inplace) - { - kernel_name += "_ip"; - kernel_name += rtc_array_type_name(inArrayType); - } - else - { - kernel_name += "_op"; - kernel_name += rtc_array_type_name(inArrayType); - kernel_name += rtc_array_type_name(outArrayType); - } - - kernel_name += rtc_precision_name(precision); - - kernel_name += rtc_cbtype_name(cbtype); - - return kernel_name; -} - -static std::string apply_local_transpose() -{ - std::string body = R"_SRC( - __device__ size_t apply_local_transpose(size_t index) -{ - // wrap around index around first batch - auto index_transpose = index % (64 * 64 * 64); - - // apply local transpose transformation on first batch - index_transpose = ((index_transpose % 64) + ((index_transpose % 1024) / 64) * 1024) % (4096 - 64) + - ((index_transpose / 1024) * 256) + (index_transpose / 4096) * (4096 - 1024); - - // move transformed index to correct batch - index_transpose = index_transpose + (index / (64 * 64 * 64) * (64 * 64 * 64)); - - return index_transpose; -} -)_SRC"; - - return body; -} - -static std::string lds_to_reg() -{ - std::string body = R"_SRC( - template -__device__ void lds_to_reg_input_length64_device_sbcc(scalar_type *R, - scalar_type *__restrict__ lds_complex, - unsigned int stride_lds, - unsigned int offset_lds, - unsigned int thread, - bool write) -{ - const unsigned int lstride = (sb == SB_UNIT) ? (1) : (stride_lds); - unsigned int l_offset; - __syncthreads(); - l_offset = offset_lds + ((thread + 0 + 0) + 0) * lstride; - R[0] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 8 * 4) * lstride; - R[1] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 16 * 4) * lstride; - R[2] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 24 * 4) * lstride; - R[3] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 32 * 4) * lstride; - R[4] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 40 * 4) * lstride; - R[5] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 48 * 4) * lstride; - R[6] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 56 * 4) * lstride; - R[7] = lds_complex[l_offset]; -} -)_SRC"; - - return body; -} - -static std::string reg_to_lds() -{ - std::string body = R"_SRC( - template -__device__ void lds_from_reg_output_length64_device_sbcc(scalar_type *R, - scalar_type *__restrict__ lds_complex, - unsigned int stride_lds, - unsigned int offset_lds, - unsigned int thread, - bool write) -{ - const unsigned int lstride = (sb == SB_UNIT) ? (1) : (stride_lds); - unsigned int l_offset; - __syncthreads(); - l_offset = offset_lds + (((thread + 0 + 0) / (8 * 4)) * 64 + (thread + 0 + 0) % (8 * 4) + 0 * 4) * lstride; - lds_complex[l_offset] = R[0]; - l_offset = offset_lds + (((thread + 0 + 0) / (8 * 4)) * 64 + (thread + 0 + 0) % (8 * 4) + 8 * 4) * lstride; - lds_complex[l_offset] = R[1]; - l_offset = offset_lds + (((thread + 0 + 0) / (8 * 4)) * 64 + (thread + 0 + 0) % (8 * 4) + 16 * 4) * lstride; - lds_complex[l_offset] = R[2]; - l_offset = offset_lds + (((thread + 0 + 0) / (8 * 4)) * 64 + (thread + 0 + 0) % (8 * 4) + 24 * 4) * lstride; - lds_complex[l_offset] = R[3]; - l_offset = offset_lds + (((thread + 0 + 0) / (8 * 4)) * 64 + (thread + 0 + 0) % (8 * 4) + 32 * 4) * lstride; - lds_complex[l_offset] = R[4]; - l_offset = offset_lds + (((thread + 0 + 0) / (8 * 4)) * 64 + (thread + 0 + 0) % (8 * 4) + 40 * 4) * lstride; - lds_complex[l_offset] = R[5]; - l_offset = offset_lds + (((thread + 0 + 0) / (8 * 4)) * 64 + (thread + 0 + 0) % (8 * 4) + 48 * 4) * lstride; - lds_complex[l_offset] = R[6]; - l_offset = offset_lds + (((thread + 0 + 0) / (8 * 4)) * 64 + (thread + 0 + 0) % (8 * 4) + 56 * 4) * lstride; - lds_complex[l_offset] = R[7]; -} -)_SRC"; - - return body; -} - -static std::string lds_to_reg_pp() -{ - std::string body = R"_SRC( - template -__device__ void lds_to_reg_4_input_length64_device_pp(scalar_type *R, - scalar_type *__restrict__ lds_complex, - unsigned int stride, - unsigned int offset) -{ - unsigned int idx, thread; - __syncthreads(); - - thread = 0; - idx = offset + thread * stride; - R[0] = lds_complex[idx]; - - thread = 1; - idx = offset + thread * stride; - R[1] = lds_complex[idx]; - - thread = 2; - idx = offset + thread * stride; - R[2] = lds_complex[idx]; - - thread = 3; - idx = offset + thread * stride; - R[3] = lds_complex[idx]; - - thread = 4; - idx = offset + thread * stride; - R[4] = lds_complex[idx]; - - thread = 5; - idx = offset + thread * stride; - R[5] = lds_complex[idx]; - - thread = 6; - idx = offset + thread * stride; - R[6] = lds_complex[idx]; - - thread = 7; - idx = offset + thread * stride; - R[7] = lds_complex[idx]; -} -)_SRC"; - - return body; -} - -static std::string reg_to_lds_pp() -{ - std::string body = R"_SRC( - template -__device__ void lds_from_reg_4_output_length64_device_pp(scalar_type *R, - scalar_type *__restrict__ lds_complex, - unsigned int stride, - unsigned int offset) -{ - unsigned int idx, thread; - __syncthreads(); - - thread = 0; - idx = offset + thread * stride; - lds_complex[idx] = R[0]; - - thread = 1; - idx = offset + thread * stride; - lds_complex[idx] = R[1]; - - thread = 2; - idx = offset + thread * stride; - lds_complex[idx] = R[2]; - - thread = 3; - idx = offset + thread * stride; - lds_complex[idx] = R[3]; - - thread = 4; - idx = offset + thread * stride; - lds_complex[idx] = R[4]; - - thread = 5; - idx = offset + thread * stride; - lds_complex[idx] = R[5]; - - thread = 6; - idx = offset + thread * stride; - lds_complex[idx] = R[6]; - - thread = 7; - idx = offset + thread * stride; - lds_complex[idx] = R[7]; -} -)_SRC"; - - return body; -} - -static std::string fwd_len64() -{ - std::string body = R"_SRC( - template -__device__ void forward_length64_SBCC_device_pp(scalar_type *R, - real_type_t *__restrict__ lds_real, - scalar_type *__restrict__ lds_complex, - const scalar_type *__restrict__ twiddles, - unsigned int stride_lds, - unsigned int offset_lds, - unsigned int thread_twd, - unsigned int thread, - bool write, - const scalar_type *large_twiddles, - size_t trans_local) -{ - scalar_type W; - scalar_type t; - const unsigned int lstride = (sb == SB_UNIT) ? (1) : (stride_lds); - unsigned int l_offset; - - // pass 0, width 8 - // using 8 threads we need to do 8 radix-8 butterflies - // therefore each thread will do 1.000000 butterflies - FwdRad8B1(R + 0, R + 1, R + 2, R + 3, R + 4, R + 5, R + 6, R + 7); - if (!lds_is_real) - { - if (!direct_load_to_reg) - { - __syncthreads(); - } - - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (0 * 4)) * lstride; - lds_complex[l_offset] = R[0]; - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (1 * 4)) * lstride; - lds_complex[l_offset] = R[1]; - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (2 * 4)) * lstride; - lds_complex[l_offset] = R[2]; - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (3 * 4)) * lstride; - lds_complex[l_offset] = R[3]; - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (4 * 4)) * lstride; - lds_complex[l_offset] = R[4]; - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (5 * 4)) * lstride; - lds_complex[l_offset] = R[5]; - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (6 * 4)) * lstride; - lds_complex[l_offset] = R[6]; - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (7 * 4)) * lstride; - lds_complex[l_offset] = R[7]; - } - - else - { - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 0) * lstride; - lds_real[l_offset] = R[0].x; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 1) * lstride; - lds_real[l_offset] = R[1].x; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 2) * lstride; - lds_real[l_offset] = R[2].x; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 3) * lstride; - lds_real[l_offset] = R[3].x; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 4) * lstride; - lds_real[l_offset] = R[4].x; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 5) * lstride; - lds_real[l_offset] = R[5].x; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 6) * lstride; - lds_real[l_offset] = R[6].x; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 7) * lstride; - lds_real[l_offset] = R[7].x; - __syncthreads(); - l_offset = offset_lds + ((thread + 0 + 0) + 0) * lstride; - R[0].x = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 8) * lstride; - R[1].x = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 16) * lstride; - R[2].x = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 24) * lstride; - R[3].x = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 32) * lstride; - R[4].x = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 40) * lstride; - R[5].x = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 48) * lstride; - R[6].x = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 56) * lstride; - R[7].x = lds_real[l_offset]; - __syncthreads(); - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 0) * lstride; - lds_real[l_offset] = R[0].y; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 1) * lstride; - lds_real[l_offset] = R[1].y; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 2) * lstride; - lds_real[l_offset] = R[2].y; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 3) * lstride; - lds_real[l_offset] = R[3].y; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 4) * lstride; - lds_real[l_offset] = R[4].y; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 5) * lstride; - lds_real[l_offset] = R[5].y; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 6) * lstride; - lds_real[l_offset] = R[6].y; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 7) * lstride; - lds_real[l_offset] = R[7].y; - __syncthreads(); - l_offset = offset_lds + ((thread + 0 + 0) + 0) * lstride; - R[0].y = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 8) * lstride; - R[1].y = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 16) * lstride; - R[2].y = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 24) * lstride; - R[3].y = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 32) * lstride; - R[4].y = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 40) * lstride; - R[5].y = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 48) * lstride; - R[6].y = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 56) * lstride; - R[7].y = lds_real[l_offset]; - } - - // pass 1, width 8 - // using 8 threads we need to do 8 radix-8 butterflies - // therefore each thread will do 1.000000 butterflies - if (!lds_is_real) - { - __syncthreads(); - l_offset = offset_lds + ((thread + 0 + 0) + (0 * 4)) * lstride; - R[0] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + (8 * 4)) * lstride; - R[1] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + (16 * 4)) * lstride; - R[2] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + (24 * 4)) * lstride; - R[3] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + (32 * 4)) * lstride; - R[4] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + (40 * 4)) * lstride; - R[5] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + (48 * 4)) * lstride; - R[6] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + (56 * 4)) * lstride; - R[7] = lds_complex[l_offset]; - } - - W = twiddles[0 + 7 * ((thread_twd + 0 + 0) % 8)]; - t = {R[1].x * W.x - R[1].y * W.y, R[1].y * W.x + R[1].x * W.y}; - R[1] = t; - W = twiddles[1 + 7 * ((thread_twd + 0 + 0) % 8)]; - t = {R[2].x * W.x - R[2].y * W.y, R[2].y * W.x + R[2].x * W.y}; - R[2] = t; - W = twiddles[2 + 7 * ((thread_twd + 0 + 0) % 8)]; - t = {R[3].x * W.x - R[3].y * W.y, R[3].y * W.x + R[3].x * W.y}; - R[3] = t; - W = twiddles[3 + 7 * ((thread_twd + 0 + 0) % 8)]; - t = {R[4].x * W.x - R[4].y * W.y, R[4].y * W.x + R[4].x * W.y}; - R[4] = t; - W = twiddles[4 + 7 * ((thread_twd + 0 + 0) % 8)]; - t = {R[5].x * W.x - R[5].y * W.y, R[5].y * W.x + R[5].x * W.y}; - R[5] = t; - W = twiddles[5 + 7 * ((thread_twd + 0 + 0) % 8)]; - t = {R[6].x * W.x - R[6].y * W.y, R[6].y * W.x + R[6].x * W.y}; - R[6] = t; - W = twiddles[6 + 7 * ((thread_twd + 0 + 0) % 8)]; - t = {R[7].x * W.x - R[7].y * W.y, R[7].y * W.x + R[7].x * W.y}; - R[7] = t; - FwdRad8B1(R + 0, R + 1, R + 2, R + 3, R + 4, R + 5, R + 6, R + 7); -} -)_SRC"; - - return body; -} - -static std::string inv_len64() -{ - std::string body = R"_SRC( - template -__device__ void inverse_length64_SBCC_device_pp(scalar_type *R, - real_type_t *__restrict__ lds_real, - scalar_type *__restrict__ lds_complex, - const scalar_type *__restrict__ twiddles, - unsigned int stride_lds, - unsigned int offset_lds, - unsigned int thread_twd, - unsigned int thread, - bool write, - const scalar_type *large_twiddles, - size_t trans_local) -{ - scalar_type W; - scalar_type t; - const unsigned int lstride = (sb == SB_UNIT) ? (1) : (stride_lds); - unsigned int l_offset; - - // pass 0, width 8 - // using 8 threads we need to do 8 radix-8 butterflies - // therefore each thread will do 1.000000 butterflies - InvRad8B1(R + 0, R + 1, R + 2, R + 3, R + 4, R + 5, R + 6, R + 7); - if (!lds_is_real) - { - if (!direct_load_to_reg) - { - __syncthreads(); - } - - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (0 * 4)) * lstride; - lds_complex[l_offset] = R[0]; - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (1 * 4)) * lstride; - lds_complex[l_offset] = R[1]; - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (2 * 4)) * lstride; - lds_complex[l_offset] = R[2]; - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (3 * 4)) * lstride; - lds_complex[l_offset] = R[3]; - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (4 * 4)) * lstride; - lds_complex[l_offset] = R[4]; - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (5 * 4)) * lstride; - lds_complex[l_offset] = R[5]; - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (6 * 4)) * lstride; - lds_complex[l_offset] = R[6]; - l_offset = offset_lds + (((thread + 0 + 0) / (1 * 4)) * (8 * 4) + (thread + 0 + 0) % (1 * 4) + (7 * 4)) * lstride; - lds_complex[l_offset] = R[7]; - } - - else - { - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 0) * lstride; - lds_real[l_offset] = R[0].x; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 1) * lstride; - lds_real[l_offset] = R[1].x; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 2) * lstride; - lds_real[l_offset] = R[2].x; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 3) * lstride; - lds_real[l_offset] = R[3].x; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 4) * lstride; - lds_real[l_offset] = R[4].x; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 5) * lstride; - lds_real[l_offset] = R[5].x; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 6) * lstride; - lds_real[l_offset] = R[6].x; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 7) * lstride; - lds_real[l_offset] = R[7].x; - __syncthreads(); - l_offset = offset_lds + ((thread + 0 + 0) + 0) * lstride; - R[0].x = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 8) * lstride; - R[1].x = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 16) * lstride; - R[2].x = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 24) * lstride; - R[3].x = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 32) * lstride; - R[4].x = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 40) * lstride; - R[5].x = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 48) * lstride; - R[6].x = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 56) * lstride; - R[7].x = lds_real[l_offset]; - __syncthreads(); - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 0) * lstride; - lds_real[l_offset] = R[0].y; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 1) * lstride; - lds_real[l_offset] = R[1].y; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 2) * lstride; - lds_real[l_offset] = R[2].y; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 3) * lstride; - lds_real[l_offset] = R[3].y; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 4) * lstride; - lds_real[l_offset] = R[4].y; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 5) * lstride; - lds_real[l_offset] = R[5].y; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 6) * lstride; - lds_real[l_offset] = R[6].y; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 7) * lstride; - lds_real[l_offset] = R[7].y; - __syncthreads(); - l_offset = offset_lds + ((thread + 0 + 0) + 0) * lstride; - R[0].y = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 8) * lstride; - R[1].y = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 16) * lstride; - R[2].y = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 24) * lstride; - R[3].y = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 32) * lstride; - R[4].y = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 40) * lstride; - R[5].y = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 48) * lstride; - R[6].y = lds_real[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 56) * lstride; - R[7].y = lds_real[l_offset]; - } - - // pass 1, width 8 - // using 8 threads we need to do 8 radix-8 butterflies - // therefore each thread will do 1.000000 butterflies - if (!lds_is_real) - { - __syncthreads(); - l_offset = offset_lds + ((thread + 0 + 0) + (0 * 4)) * lstride; - R[0] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + (8 * 4)) * lstride; - R[1] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + (16 * 4)) * lstride; - R[2] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + (24 * 4)) * lstride; - R[3] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + (32 * 4)) * lstride; - R[4] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + (40 * 4)) * lstride; - R[5] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + (48 * 4)) * lstride; - R[6] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + (56 * 4)) * lstride; - R[7] = lds_complex[l_offset]; - } - - W = twiddles[0 + 7 * ((thread_twd + 0 + 0) % 8)]; - t = {R[1].x * W.x + R[1].y * W.y, R[1].y * W.x - R[1].x * W.y}; - R[1] = t; - W = twiddles[1 + 7 * ((thread_twd + 0 + 0) % 8)]; - t = {R[2].x * W.x + R[2].y * W.y, R[2].y * W.x - R[2].x * W.y}; - R[2] = t; - W = twiddles[2 + 7 * ((thread_twd + 0 + 0) % 8)]; - t = {R[3].x * W.x + R[3].y * W.y, R[3].y * W.x - R[3].x * W.y}; - R[3] = t; - W = twiddles[3 + 7 * ((thread_twd + 0 + 0) % 8)]; - t = {R[4].x * W.x + R[4].y * W.y, R[4].y * W.x - R[4].x * W.y}; - R[4] = t; - W = twiddles[4 + 7 * ((thread_twd + 0 + 0) % 8)]; - t = {R[5].x * W.x + R[5].y * W.y, R[5].y * W.x - R[5].x * W.y}; - R[5] = t; - W = twiddles[5 + 7 * ((thread_twd + 0 + 0) % 8)]; - t = {R[6].x * W.x + R[6].y * W.y, R[6].y * W.x - R[6].x * W.y}; - R[6] = t; - W = twiddles[6 + 7 * ((thread_twd + 0 + 0) % 8)]; - t = {R[7].x * W.x + R[7].y * W.y, R[7].y * W.x - R[7].x * W.y}; - R[7] = t; - InvRad8B1(R + 0, R + 1, R + 2, R + 3, R + 4, R + 5, R + 6, R + 7); -} -)_SRC"; - - return body; -} - -static std::string pp_fwd_step_3_4_radix4() -{ - std::string body = R"_SRC( - // Partial pass step 3: local transposition performed on store to global memory below - - // Partial pass step 4: length-4 DFT on off-dimension - - // Radix-4 pass - FwdRad4B1(R + 0, R + 1, R + 2, R + 3); - - // Radix-4 pass - FwdRad4B1(R + 4, R + 5, R + 6, R + 7); -)_SRC"; - - return body; -} - -static std::string pp_inv_step_3_4_radix4() -{ - std::string body = R"_SRC( - // Partial pass step 3: local transposition performed on store to global memory below - - // Partial pass step 4: length-4 DFT on off-dimension - - // Radix-4 pass - InvRad4B1(R + 0, R + 1, R + 2, R + 3); - - // Radix-4 pass - InvRad4B1(R + 4, R + 5, R + 6, R + 7); -)_SRC"; - - return body; -} - -static std::string partial_pass_sbcc_64_64_64_rtc_body(const std::string& kernel_name, - int direction) -{ - std::string body; - - body += R"_SRC( - extern "C" __global__ - __launch_bounds__(256) void - )_SRC"; - - body += kernel_name; - - body += R"_SRC( - ( - const scalar_type *__restrict__ twiddles, - const scalar_type *large_twiddles, - const size_t *__restrict__ lengths, - const size_t *__restrict__ stride, - const size_t nbatch, - const unsigned int lds_padding, - void *__restrict__ load_cb_fn, - void *__restrict__ load_cb_data, - unsigned int load_cb_lds_bytes, - void *__restrict__ store_cb_fn, - void *__restrict__ store_cb_data, - scalar_type *__restrict__ ibuf, - scalar_type *__restrict__ obuf) -{ - auto const sb = SB_NONUNIT; - auto const ebtype = EmbeddedType::NONE; - auto const sbrc_type = SBRC_2D; - auto const transpose_type = NONE; - auto const drtype = DirectRegType::FORCE_OFF_OR_NOT_SUPPORT; - auto const apply_large_twiddle = false; - auto const intrinsic_mode = IntrinsicAccessType::DISABLE_BOTH; - const size_t large_twiddle_base = 8; - const size_t large_twiddle_steps = 0; - - scalar_type R[8]; - extern __shared__ unsigned char __attribute__((aligned(sizeof(scalar_type)))) lds_uchar[]; - real_type_t *__restrict__ lds_real = reinterpret_cast *>(lds_uchar); - scalar_type *__restrict__ lds_complex = reinterpret_cast(lds_uchar); - size_t offset = 0; - unsigned int offset_lds; - unsigned int stride_lds; - size_t batch; - size_t transform; - const bool direct_load_to_reg = drtype == DirectRegType::TRY_ENABLE_IF_SUPPORT; - const bool direct_store_from_reg = direct_load_to_reg; - const bool lds_linear = !direct_load_to_reg; - const bool lds_is_real = false; - auto load_cb = get_load_cb(load_cb_fn); - auto store_cb = get_store_cb(store_cb_fn); - - // offsets - const size_t dim = 3; - const size_t stride0 = (sb == SB_UNIT) ? (1) : (stride[0]); - size_t tile_index; - size_t num_of_tiles; - - // calculate offset for each tile: - // tile_index now means index of the tile along dim1 - // num_of_tiles now means number of tiles along dim1 - size_t plength = 1; - size_t remaining; - size_t index_along_d; - num_of_tiles = (lengths[1] - 1) / 8 + 1; - plength = num_of_tiles; - tile_index = blockIdx.x % num_of_tiles; - - // mod 128 required to work with nbatch > 1 - remaining = (blockIdx.x % 128) / num_of_tiles; - offset = tile_index * 8 * stride[1]; - for (int d = 2; d < dim; ++d) - { - plength = plength * lengths[d]; - - index_along_d = remaining % lengths[d]; - remaining = remaining / lengths[d]; - offset = offset + index_along_d * stride[d]; - } - - batch = blockIdx.x / plength; - // offset = offset + batch * stride[dim]; // don't add batch here - transform = lds_linear ? tile_index * 8 + threadIdx.x / 8 : tile_index * 8 + threadIdx.x % 8; - stride_lds = lds_linear ? 64 + (ebtype == EmbeddedType::NONE ? 0 : lds_padding) - : 8 + (ebtype == EmbeddedType::NONE ? 0 : lds_padding); - stride_lds *= 4; - - offset_lds = lds_linear ? stride_lds * (transform % 8) : threadIdx.x % 8; - bool in_bound = ((tile_index + 1) * 8 > lengths[1]) ? false : true; - unsigned int thread = threadIdx.x / 8; - unsigned int tid_hor = threadIdx.x % 8; - - unsigned int thread_lds = threadIdx.x / 8; - unsigned int tid_hor_lds = threadIdx.x % 8; - - auto tid_hor_pp = threadIdx.x % 8 + 64 * (thread % 4); - - auto thread_new = threadIdx.x / (8 * 4); - auto batch_new = blockIdx.x / (plength / 4); - - auto thread_idx = threadIdx.x; - auto block_idx = blockIdx.x; - - auto offset_pp = offset + (offset / 64) * 192 + batch_new * stride[dim]; - - auto offset_tid_hor = (offset_pp + tid_hor_pp * stride[1]); - - transform = lds_linear ? tile_index * 8 + threadIdx.x / (8 * 4) : tile_index * 8 + threadIdx.x % 8; - offset_lds = lds_linear ? stride_lds * (transform % 8) : threadIdx.x % 8; - - size_t global_mem_idx = 0; - - // load global into lds - // no intrinsic when load to lds. FIXME- check why use nested branch is better - if (in_bound) - { - global_mem_idx = offset_tid_hor + (thread_new + 0) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 0) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - nullptr); - global_mem_idx = offset_tid_hor + (thread_new + 8) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 8 * 4) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - nullptr); - global_mem_idx = offset_tid_hor + (thread_new + 16) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 16 * 4) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - nullptr); - global_mem_idx = offset_tid_hor + (thread_new + 24) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 24 * 4) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - - nullptr); - global_mem_idx = offset_tid_hor + (thread_new + 32) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 32 * 4) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - nullptr); - global_mem_idx = offset_tid_hor + (thread_new + 40) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 40 * 4) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - nullptr); - global_mem_idx = offset_tid_hor + (thread_new + 48) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 48 * 4) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - nullptr); - global_mem_idx = offset_tid_hor + (thread_new + 56) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 56 * 4) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - nullptr); - } - - if (!in_bound) - { - if (tile_index * 8 + tid_hor < lengths[1]) - { - global_mem_idx = offset_tid_hor + (thread_new + 0) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 0) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - nullptr); - global_mem_idx = offset_tid_hor + (thread_new + 8) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 8 * 4) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - nullptr); - global_mem_idx = offset_tid_hor + (thread_new + 16) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 16 * 4) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - nullptr); - global_mem_idx = offset_tid_hor + (thread_new + 24) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 24 * 4) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - nullptr); - global_mem_idx = offset_tid_hor + (thread_new + 32) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 32 * 4) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - nullptr); - global_mem_idx = offset_tid_hor + (thread_new + 40) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 40 * 4) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - nullptr); - global_mem_idx = offset_tid_hor + (thread_new + 48) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 48 * 4) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - nullptr); - global_mem_idx = offset_tid_hor + (thread_new + 56) * stride0; - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 56 * 4) * 1] = load_cb(ibuf, - global_mem_idx, - load_cb_data, - nullptr); - } - } - - auto stride_lds_pp = 1; - auto offset_lds_pp = threadIdx.x * 8; - - // call a pre-load from lds to registers (if necessary) - lds_to_reg_4_input_length64_device_pp(R, lds_complex, stride_lds_pp, offset_lds_pp); - )_SRC"; - - if(direction == -1) - body += pp_fwd_step_3_4_radix4(); - else if(direction == 1) - body += pp_inv_step_3_4_radix4(); - - body += R"_SRC( - // call a post-store from registers to lds (if necessary) - lds_from_reg_4_output_length64_device_pp(R, lds_complex, stride_lds_pp, offset_lds_pp); - - // calc the thread_in_device value once and for all device funcs - auto thread_in_device = lds_linear ? threadIdx.x % (8 * 4) : threadIdx.x / 8; - auto thread_in_device_twd = (threadIdx.x / 4) % (8); - - // call a pre-load from lds to registers (if necessary) - lds_to_reg_input_length64_device_sbcc( - R, lds_complex, stride_lds, offset_lds, thread_in_device, true); - )_SRC"; - - if(direction == -1) - { - body += R"_SRC( - // transform - forward_length64_SBCC_device_pp( - R, - lds_real, - lds_complex, - twiddles, - stride_lds, - offset_lds, - thread_in_device_twd, - thread_in_device, - true, - large_twiddles, - transform); - )_SRC"; - } - else if(direction == 1) - { - body += R"_SRC( - // transform - inverse_length64_SBCC_device_pp( - R, - lds_real, - lds_complex, - twiddles, - stride_lds, - offset_lds, - thread_in_device_twd, - thread_in_device, - true, - large_twiddles, - transform); - )_SRC"; - } - - body += R"_SRC( - // call a post-store from registers to lds (if necessary) - lds_from_reg_output_length64_device_sbcc( - R, lds_complex, stride_lds, offset_lds, thread_in_device, true); - - // store global - __syncthreads(); - // no intrinsic when store from lds. FIXME- check why use nested branch is better - if (in_bound) - { - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 0) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 0) * 1], - store_cb_data, - nullptr); - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 8) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 8 * 4) * 1], - store_cb_data, - nullptr); - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 16) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 16 * 4) * 1], - store_cb_data, - nullptr); - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 24) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 24 * 4) * 1], - store_cb_data, - nullptr); - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 32) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 32 * 4) * 1], - store_cb_data, - nullptr); - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 40) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 40 * 4) * 1], - store_cb_data, - nullptr); - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 48) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 48 * 4) * 1], - store_cb_data, - nullptr); - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 56) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 56 * 4) * 1], - store_cb_data, - nullptr); - } - - if (!in_bound) - { - if (tile_index * 8 + tid_hor < lengths[1]) - { - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 0) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 0 * 4) * 1], - store_cb_data, - nullptr); - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 8) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 8 * 4) * 1], - store_cb_data, - nullptr); - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 16) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 16 * 4) * 1], - store_cb_data, - nullptr); - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 24) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 24 * 4) * 1], - store_cb_data, - nullptr); - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 32) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 32 * 4) * 1], - store_cb_data, - nullptr); - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 40) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 40 * 4) * 1], - store_cb_data, - nullptr); - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 48) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 48 * 4) * 1], - store_cb_data, - nullptr); - global_mem_idx = apply_local_transpose(offset_tid_hor + (thread_new + 56) * stride0); - store_cb(obuf, - global_mem_idx, - lds_complex[tid_hor_lds * stride_lds + (thread_lds + 56 * 4) * 1], - store_cb_data, - nullptr); - } - } -} - )_SRC"; - - return body; -} - -std::string partial_pass_sbcc_64_64_64_rtc(const std::string& kernel_name, - const std::vector all_factors, - int direction, - rocfft_precision precision, - CallbackType cbtype) -{ - std::string src; - - // start off with includes - src += rocfft_complex_h; - src += common_h; - src += memory_gfx_h; - src += callback_h; - src += butterfly_constant_h; - src += large_twiddles_h; - - append_radix_h(src, all_factors); - - src += rtc_precision_type_decl(precision); - src += apply_local_transpose(); - src += lds_to_reg(); - src += reg_to_lds(); - src += lds_to_reg_pp(); - src += reg_to_lds_pp(); - if(direction == -1) - src += fwd_len64(); - else if(direction == 1) - src += inv_len64(); - src += rtc_const_cbtype_decl(cbtype); - src += partial_pass_sbcc_64_64_64_rtc_body(kernel_name, direction); - - return src; -} - -RTCKernel::RTCGenerator RTCKernelPartialPassSBCC64Cubed::generate_from_node( - const LeafNode& node, const std::string& gpu_arch, bool enable_callbacks) -{ - RTCGenerator generator; - - auto scheme = node.scheme; - if(!(scheme == CS_KERNEL_STOCKHAM_BLOCK_CC && node.applyPartialPass)) - return generator; - - auto workgroup_size = PARTIAL_PASS_SBCC_64_64_64_THREADS; - auto threads_per_transform = 8; - auto transforms_per_block = workgroup_size / threads_per_transform; - - auto bwd = transforms_per_block; - - auto b_x = ((node.length[1]) - 1) / bwd + 1; - b_x *= product(node.length.begin() + 2, node.length.end()) * node.batch; - auto wgs_x = workgroup_size; - - b_x /= 4; - wgs_x *= 4; - - generator.gridDim = {static_cast(b_x), 1, 1}; - generator.blockDim = {wgs_x, 1, 1}; - - auto precision = node.precision; - auto direction = node.direction; - auto placement = node.placement; - auto inArrayType = node.inArrayType; - auto outArrayType = node.outArrayType; - auto cbType = node.GetCallbackType(enable_callbacks); - - auto kernelFactors = node.kernelFactors; - auto kernelFactorsPP = node.kernelFactorsPP; - - std::vector all_factors(kernelFactors.begin(), kernelFactors.end()); - all_factors.insert(all_factors.end(), kernelFactorsPP.begin(), kernelFactorsPP.end()); - - generator.generate_name = [=]() { - return partial_pass_64_64_64_sbcc_rtc_kernel_name( - precision, direction, placement, inArrayType, outArrayType, cbType); - }; - - generator.generate_src = [=](const std::string& kernel_name) { - return partial_pass_sbcc_64_64_64_rtc( - kernel_name, all_factors, direction, precision, cbType); - }; - - generator.construct_rtckernel = [=](const std::string& kernel_name, - const std::vector& code, - dim3 gridDim, - dim3 blockDim) { - return std::unique_ptr( - new RTCKernelPartialPassSBCC64Cubed(kernel_name, code, gridDim, blockDim)); - }; - - return generator; -} - -RTCKernelArgs RTCKernelPartialPassSBCC64Cubed::get_launch_args(DeviceCallIn& data) -{ - RTCKernelArgs kargs; - - kargs.append_ptr(data.node->twiddles); - kargs.append_ptr(data.node->twiddles_large); - - kargs.append_ptr(kargs_lengths(data.node->devKernArg)); - kargs.append_ptr(kargs_stride_in(data.node->devKernArg)); - kargs.append_size_t(data.node->batch); - kargs.append_size_t(data.node->lds_padding); - - // callback params - kargs.append_ptr(data.callbacks.load_cb_fn); - kargs.append_ptr(data.callbacks.load_cb_data); - kargs.append_unsigned_int(data.callbacks.load_cb_lds_bytes); - kargs.append_ptr(data.callbacks.store_cb_fn); - kargs.append_ptr(data.callbacks.store_cb_data); - append_load_store_args(kargs, *data.node); - - kargs.append_ptr(data.bufIn[0]); - kargs.append_ptr(data.bufOut[0]); - - return kargs; -} diff --git a/library/src/rtc_partial_pass_sbrr_64_64_64.cpp b/library/src/rtc_partial_pass_sbrr_64_64_64.cpp deleted file mode 100644 index d9d0ebaedd0..00000000000 --- a/library/src/rtc_partial_pass_sbrr_64_64_64.cpp +++ /dev/null @@ -1,1019 +0,0 @@ -#include "rtc_partial_pass_sbrr_64_64_64.h" -#include "device/kernel-generator-embed.h" -#include "include/kernel_launch.h" -#include "rtc_chirp_kernel.h" -#include "rtc_kernel.h" -#include "tree_node.h" - -std::string partial_pass_64_64_64_sbrr_rtc_kernel_name(rocfft_precision precision, - int direction, - rocfft_result_placement placement, - rocfft_array_type inArrayType, - rocfft_array_type outArrayType, - CallbackType cbtype) -{ - std::string kernel_name = "sbrr_64_64_64_partial_pass"; - - if(direction == -1) - kernel_name += "_fwd"; - else - kernel_name += "_bck"; - - if(placement == rocfft_placement_inplace) - { - kernel_name += "_ip"; - kernel_name += rtc_array_type_name(inArrayType); - } - else - { - kernel_name += "_op"; - kernel_name += rtc_array_type_name(inArrayType); - kernel_name += rtc_array_type_name(outArrayType); - } - - kernel_name += rtc_precision_name(precision); - - kernel_name += rtc_cbtype_name(cbtype); - - return kernel_name; -} - -static std::string lds_to_reg() -{ - std::string body = R"_SRC( - template -__device__ void lds_to_reg_input_length64_device_sbrr(scalar_type *R, - scalar_type *__restrict__ lds_complex, - unsigned int stride_lds, - unsigned int offset_lds, - unsigned int thread, - bool write) -{ - const unsigned int lstride = (sb == SB_UNIT) ? (1) : (stride_lds); - unsigned int l_offset; - __syncthreads(); - l_offset = offset_lds + ((thread + 0 + 0) + 0) * lstride; - R[0] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 8) * lstride; - R[1] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 16) * lstride; - R[2] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 24) * lstride; - R[3] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 32) * lstride; - R[4] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 40) * lstride; - R[5] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 48) * lstride; - R[6] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 56) * lstride; - R[7] = lds_complex[l_offset]; -} -)_SRC"; - - return body; -} - -static std::string reg_to_lds() -{ - std::string body = R"_SRC( -template -__device__ void lds_from_reg_output_length64_device_sbrr(scalar_type *R, - scalar_type *__restrict__ lds_complex, - unsigned int stride_lds, - unsigned int offset_lds, - unsigned int thread, - bool write) -{ - const unsigned int lstride = (sb == SB_UNIT) ? (1) : (stride_lds); - unsigned int l_offset; - __syncthreads(); - l_offset = offset_lds + (((thread + 0 + 0) / 8) * 64 + (thread + 0 + 0) % 8 + 0) * lstride; - lds_complex[l_offset] = R[0]; - l_offset = offset_lds + (((thread + 0 + 0) / 8) * 64 + (thread + 0 + 0) % 8 + 8) * lstride; - lds_complex[l_offset] = R[1]; - l_offset = offset_lds + (((thread + 0 + 0) / 8) * 64 + (thread + 0 + 0) % 8 + 16) * lstride; - lds_complex[l_offset] = R[2]; - l_offset = offset_lds + (((thread + 0 + 0) / 8) * 64 + (thread + 0 + 0) % 8 + 24) * lstride; - lds_complex[l_offset] = R[3]; - l_offset = offset_lds + (((thread + 0 + 0) / 8) * 64 + (thread + 0 + 0) % 8 + 32) * lstride; - lds_complex[l_offset] = R[4]; - l_offset = offset_lds + (((thread + 0 + 0) / 8) * 64 + (thread + 0 + 0) % 8 + 40) * lstride; - lds_complex[l_offset] = R[5]; - l_offset = offset_lds + (((thread + 0 + 0) / 8) * 64 + (thread + 0 + 0) % 8 + 48) * lstride; - lds_complex[l_offset] = R[6]; - l_offset = offset_lds + (((thread + 0 + 0) / 8) * 64 + (thread + 0 + 0) % 8 + 56) * lstride; - lds_complex[l_offset] = R[7]; -} -)_SRC"; - - return body; -} - -static std::string lds_to_reg_pp() -{ - std::string body = R"_SRC( - template -__device__ void lds_to_reg_16_input_length64_device_pp(scalar_type *R, - scalar_type *__restrict__ lds_complex, - unsigned int stride, - unsigned int offset) -{ - unsigned int idx, thread; - __syncthreads(); - - thread = 0; - idx = offset + thread * stride; - R[0] = lds_complex[idx]; - - thread = 1; - idx = offset + thread * stride; - R[1] = lds_complex[idx]; - - thread = 2; - idx = offset + thread * stride; - R[2] = lds_complex[idx]; - - thread = 3; - idx = offset + thread * stride; - R[3] = lds_complex[idx]; - - thread = 4; - idx = offset + thread * stride; - R[4] = lds_complex[idx]; - - thread = 5; - idx = offset + thread * stride; - R[5] = lds_complex[idx]; - - thread = 6; - idx = offset + thread * stride; - R[6] = lds_complex[idx]; - - thread = 7; - idx = offset + thread * stride; - R[7] = lds_complex[idx]; - - thread = 8; - idx = offset + thread * stride; - R[8] = lds_complex[idx]; - - thread = 9; - idx = offset + thread * stride; - R[9] = lds_complex[idx]; - - thread = 10; - idx = offset + thread * stride; - R[10] = lds_complex[idx]; - - thread = 11; - idx = offset + thread * stride; - R[11] = lds_complex[idx]; - - thread = 12; - idx = offset + thread * stride; - R[12] = lds_complex[idx]; - - thread = 13; - idx = offset + thread * stride; - R[13] = lds_complex[idx]; - - thread = 14; - idx = offset + thread * stride; - R[14] = lds_complex[idx]; - - thread = 15; - idx = offset + thread * stride; - R[15] = lds_complex[idx]; -} -)_SRC"; - - return body; -} - -static std::string reg_to_lds_pp() -{ - std::string body = R"_SRC( - template -__device__ void lds_from_reg_16_output_length64_device_pp(scalar_type *R, - scalar_type *__restrict__ lds_complex, - unsigned int stride, - unsigned int offset) -{ - unsigned int idx, thread; - __syncthreads(); - - thread = 0; - idx = offset + thread * stride; - lds_complex[idx] = R[0]; - - thread = 1; - idx = offset + thread * stride; - lds_complex[idx] = R[1]; - - thread = 2; - idx = offset + thread * stride; - lds_complex[idx] = R[2]; - - thread = 3; - idx = offset + thread * stride; - lds_complex[idx] = R[3]; - - thread = 4; - idx = offset + thread * stride; - lds_complex[idx] = R[4]; - - thread = 5; - idx = offset + thread * stride; - lds_complex[idx] = R[5]; - - thread = 6; - idx = offset + thread * stride; - lds_complex[idx] = R[6]; - - thread = 7; - idx = offset + thread * stride; - lds_complex[idx] = R[7]; - - thread = 8; - idx = offset + thread * stride; - lds_complex[idx] = R[8]; - - thread = 9; - idx = offset + thread * stride; - lds_complex[idx] = R[9]; - - thread = 10; - idx = offset + thread * stride; - lds_complex[idx] = R[10]; - - thread = 11; - idx = offset + thread * stride; - lds_complex[idx] = R[11]; - - thread = 12; - idx = offset + thread * stride; - lds_complex[idx] = R[12]; - - thread = 13; - idx = offset + thread * stride; - lds_complex[idx] = R[13]; - - thread = 14; - idx = offset + thread * stride; - lds_complex[idx] = R[14]; - - thread = 15; - idx = offset + thread * stride; - lds_complex[idx] = R[15]; -} -)_SRC"; - - return body; -} - -static std::string twiddle_multiply_pp_fwd() -{ - std::string body = R"_SRC( - template -__device__ void twiddle_multiple_pp_fwd(scalar_type *R, - unsigned int thread, - const scalar_type *__restrict__ twiddles_pp) -{ - scalar_type t; - scalar_type W; - - W = twiddles_pp[thread * 64 + 0]; - t = {R[0].x * W.x - R[0].y * W.y, R[0].y * W.x + R[0].x * W.y}; - R[0] = t; - - W = twiddles_pp[thread * 64 + 1]; - t = {R[1].x * W.x - R[1].y * W.y, R[1].y * W.x + R[1].x * W.y}; - R[1] = t; - - W = twiddles_pp[thread * 64 + 2]; - t = {R[2].x * W.x - R[2].y * W.y, R[2].y * W.x + R[2].x * W.y}; - R[2] = t; - - W = twiddles_pp[thread * 64 + 3]; - t = {R[3].x * W.x - R[3].y * W.y, R[3].y * W.x + R[3].x * W.y}; - R[3] = t; - - W = twiddles_pp[thread * 64 + 4]; - t = {R[4].x * W.x - R[4].y * W.y, R[4].y * W.x + R[4].x * W.y}; - R[4] = t; - - W = twiddles_pp[thread * 64 + 5]; - t = {R[5].x * W.x - R[5].y * W.y, R[5].y * W.x + R[5].x * W.y}; - R[5] = t; - - W = twiddles_pp[thread * 64 + 6]; - t = {R[6].x * W.x - R[6].y * W.y, R[6].y * W.x + R[6].x * W.y}; - R[6] = t; - - W = twiddles_pp[thread * 64 + 7]; - t = {R[7].x * W.x - R[7].y * W.y, R[7].y * W.x + R[7].x * W.y}; - R[7] = t; - - W = twiddles_pp[thread * 64 + 8]; - t = {R[8].x * W.x - R[8].y * W.y, R[8].y * W.x + R[8].x * W.y}; - R[8] = t; - - W = twiddles_pp[thread * 64 + 9]; - t = {R[9].x * W.x - R[9].y * W.y, R[9].y * W.x + R[9].x * W.y}; - R[9] = t; - - W = twiddles_pp[thread * 64 + 10]; - t = {R[10].x * W.x - R[10].y * W.y, R[10].y * W.x + R[10].x * W.y}; - R[10] = t; - - W = twiddles_pp[thread * 64 + 11]; - t = {R[11].x * W.x - R[11].y * W.y, R[11].y * W.x + R[11].x * W.y}; - R[11] = t; - - W = twiddles_pp[thread * 64 + 12]; - t = {R[12].x * W.x - R[12].y * W.y, R[12].y * W.x + R[12].x * W.y}; - R[12] = t; - - W = twiddles_pp[thread * 64 + 13]; - t = {R[13].x * W.x - R[13].y * W.y, R[13].y * W.x + R[13].x * W.y}; - R[13] = t; - - W = twiddles_pp[thread * 64 + 14]; - t = {R[14].x * W.x - R[14].y * W.y, R[14].y * W.x + R[14].x * W.y}; - R[14] = t; - - W = twiddles_pp[thread * 64 + 15]; - t = {R[15].x * W.x - R[15].y * W.y, R[15].y * W.x + R[15].x * W.y}; - R[15] = t; -} -)_SRC"; - - return body; -} - -static std::string twiddle_multiply_pp_inv() -{ - std::string body = R"_SRC( - template -__device__ void twiddle_multiple_pp_inv(scalar_type *R, - unsigned int thread, - const scalar_type *__restrict__ twiddles_pp) -{ - scalar_type t; - scalar_type W; - - W = twiddles_pp[thread * 64 + 0]; - t = {R[0].x * W.x + R[0].y * W.y, R[0].y * W.x - R[0].x * W.y}; - R[0] = t; - - W = twiddles_pp[thread * 64 + 1]; - t = {R[1].x * W.x + R[1].y * W.y, R[1].y * W.x - R[1].x * W.y}; - R[1] = t; - - W = twiddles_pp[thread * 64 + 2]; - t = {R[2].x * W.x + R[2].y * W.y, R[2].y * W.x - R[2].x * W.y}; - R[2] = t; - - W = twiddles_pp[thread * 64 + 3]; - t = {R[3].x * W.x + R[3].y * W.y, R[3].y * W.x - R[3].x * W.y}; - R[3] = t; - - W = twiddles_pp[thread * 64 + 4]; - t = {R[4].x * W.x + R[4].y * W.y, R[4].y * W.x - R[4].x * W.y}; - R[4] = t; - - W = twiddles_pp[thread * 64 + 5]; - t = {R[5].x * W.x + R[5].y * W.y, R[5].y * W.x - R[5].x * W.y}; - R[5] = t; - - W = twiddles_pp[thread * 64 + 6]; - t = {R[6].x * W.x + R[6].y * W.y, R[6].y * W.x - R[6].x * W.y}; - R[6] = t; - - W = twiddles_pp[thread * 64 + 7]; - t = {R[7].x * W.x + R[7].y * W.y, R[7].y * W.x - R[7].x * W.y}; - R[7] = t; - - W = twiddles_pp[thread * 64 + 8]; - t = {R[8].x * W.x + R[8].y * W.y, R[8].y * W.x - R[8].x * W.y}; - R[8] = t; - - W = twiddles_pp[thread * 64 + 9]; - t = {R[9].x * W.x + R[9].y * W.y, R[9].y * W.x - R[9].x * W.y}; - R[9] = t; - - W = twiddles_pp[thread * 64 + 10]; - t = {R[10].x * W.x + R[10].y * W.y, R[10].y * W.x - R[10].x * W.y}; - R[10] = t; - - W = twiddles_pp[thread * 64 + 11]; - t = {R[11].x * W.x + R[11].y * W.y, R[11].y * W.x - R[11].x * W.y}; - R[11] = t; - - W = twiddles_pp[thread * 64 + 12]; - t = {R[12].x * W.x + R[12].y * W.y, R[12].y * W.x - R[12].x * W.y}; - R[12] = t; - - W = twiddles_pp[thread * 64 + 13]; - t = {R[13].x * W.x + R[13].y * W.y, R[13].y * W.x - R[13].x * W.y}; - R[13] = t; - - W = twiddles_pp[thread * 64 + 14]; - t = {R[14].x * W.x + R[14].y * W.y, R[14].y * W.x - R[14].x * W.y}; - R[14] = t; - - W = twiddles_pp[thread * 64 + 15]; - t = {R[15].x * W.x + R[15].y * W.y, R[15].y * W.x - R[15].x * W.y}; - R[15] = t; -} -)_SRC"; - - return body; -} - -static std::string fwd_len64() -{ - std::string body = R"_SRC( - template -__device__ void forward_length64_SBRR_device_pp(scalar_type *R, - real_type_t *__restrict__ lds_real, - scalar_type *__restrict__ lds_complex, - const scalar_type *__restrict__ twiddles, - unsigned int stride_lds, - unsigned int offset_lds, - unsigned int thread, - bool write) -{ - scalar_type W; - scalar_type t; - const unsigned int lstride = (sb == SB_UNIT) ? (1) : (stride_lds); - unsigned int l_offset; - - // pass 0, width 8 - // using 8 threads we need to do 8 radix-8 butterflies - // therefore each thread will do 1.000000 butterflies - FwdRad8B1(R + 0, R + 1, R + 2, R + 3, R + 4, R + 5, R + 6, R + 7); - - __syncthreads(); - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 0) * lstride; - lds_complex[l_offset] = R[0]; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 1) * lstride; - lds_complex[l_offset] = R[1]; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 2) * lstride; - lds_complex[l_offset] = R[2]; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 3) * lstride; - lds_complex[l_offset] = R[3]; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 4) * lstride; - lds_complex[l_offset] = R[4]; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 5) * lstride; - lds_complex[l_offset] = R[5]; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 6) * lstride; - lds_complex[l_offset] = R[6]; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 7) * lstride; - lds_complex[l_offset] = R[7]; - - // pass 1, width 8 - // using 8 threads we need to do 8 radix-8 butterflies - // therefore each thread will do 1.000000 butterflies - __syncthreads(); - l_offset = offset_lds + ((thread + 0 + 0) + 0) * lstride; - R[0] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 8) * lstride; - R[1] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 16) * lstride; - R[2] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 24) * lstride; - R[3] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 32) * lstride; - R[4] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 40) * lstride; - R[5] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 48) * lstride; - R[6] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 56) * lstride; - R[7] = lds_complex[l_offset]; - - W = twiddles[0 + 7 * ((thread + 0 + 0) % 8)]; - t = {R[1].x * W.x - R[1].y * W.y, R[1].y * W.x + R[1].x * W.y}; - R[1] = t; - W = twiddles[1 + 7 * ((thread + 0 + 0) % 8)]; - t = {R[2].x * W.x - R[2].y * W.y, R[2].y * W.x + R[2].x * W.y}; - R[2] = t; - W = twiddles[2 + 7 * ((thread + 0 + 0) % 8)]; - t = {R[3].x * W.x - R[3].y * W.y, R[3].y * W.x + R[3].x * W.y}; - R[3] = t; - W = twiddles[3 + 7 * ((thread + 0 + 0) % 8)]; - t = {R[4].x * W.x - R[4].y * W.y, R[4].y * W.x + R[4].x * W.y}; - R[4] = t; - W = twiddles[4 + 7 * ((thread + 0 + 0) % 8)]; - t = {R[5].x * W.x - R[5].y * W.y, R[5].y * W.x + R[5].x * W.y}; - R[5] = t; - W = twiddles[5 + 7 * ((thread + 0 + 0) % 8)]; - t = {R[6].x * W.x - R[6].y * W.y, R[6].y * W.x + R[6].x * W.y}; - R[6] = t; - W = twiddles[6 + 7 * ((thread + 0 + 0) % 8)]; - t = {R[7].x * W.x - R[7].y * W.y, R[7].y * W.x + R[7].x * W.y}; - R[7] = t; - FwdRad8B1(R + 0, R + 1, R + 2, R + 3, R + 4, R + 5, R + 6, R + 7); -} -)_SRC"; - - return body; -} - -static std::string inv_len64() -{ - std::string body = R"_SRC( - template -__device__ void inverse_length64_SBRR_device_pp(scalar_type *R, - real_type_t *__restrict__ lds_real, - scalar_type *__restrict__ lds_complex, - const scalar_type *__restrict__ twiddles, - unsigned int stride_lds, - unsigned int offset_lds, - unsigned int thread, - bool write) -{ - scalar_type W; - scalar_type t; - const unsigned int lstride = (sb == SB_UNIT) ? (1) : (stride_lds); - unsigned int l_offset; - - // pass 0, width 8 - // using 8 threads we need to do 8 radix-8 butterflies - // therefore each thread will do 1.000000 butterflies - InvRad8B1(R + 0, R + 1, R + 2, R + 3, R + 4, R + 5, R + 6, R + 7); - - __syncthreads(); - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 0) * lstride; - lds_complex[l_offset] = R[0]; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 1) * lstride; - lds_complex[l_offset] = R[1]; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 2) * lstride; - lds_complex[l_offset] = R[2]; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 3) * lstride; - lds_complex[l_offset] = R[3]; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 4) * lstride; - lds_complex[l_offset] = R[4]; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 5) * lstride; - lds_complex[l_offset] = R[5]; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 6) * lstride; - lds_complex[l_offset] = R[6]; - l_offset = offset_lds + (((thread + 0 + 0) / 1) * 8 + (thread + 0 + 0) % 1 + 7) * lstride; - lds_complex[l_offset] = R[7]; - - // pass 1, width 8 - // using 8 threads we need to do 8 radix-8 butterflies - // therefore each thread will do 1.000000 butterflies - __syncthreads(); - l_offset = offset_lds + ((thread + 0 + 0) + 0) * lstride; - R[0] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 8) * lstride; - R[1] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 16) * lstride; - R[2] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 24) * lstride; - R[3] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 32) * lstride; - R[4] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 40) * lstride; - R[5] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 48) * lstride; - R[6] = lds_complex[l_offset]; - l_offset = offset_lds + ((thread + 0 + 0) + 56) * lstride; - R[7] = lds_complex[l_offset]; - - W = twiddles[0 + 7 * ((thread + 0 + 0) % 8)]; - t = {R[1].x * W.x + R[1].y * W.y, R[1].y * W.x - R[1].x * W.y}; - R[1] = t; - W = twiddles[1 + 7 * ((thread + 0 + 0) % 8)]; - t = {R[2].x * W.x + R[2].y * W.y, R[2].y * W.x - R[2].x * W.y}; - R[2] = t; - W = twiddles[2 + 7 * ((thread + 0 + 0) % 8)]; - t = {R[3].x * W.x + R[3].y * W.y, R[3].y * W.x - R[3].x * W.y}; - R[3] = t; - W = twiddles[3 + 7 * ((thread + 0 + 0) % 8)]; - t = {R[4].x * W.x + R[4].y * W.y, R[4].y * W.x - R[4].x * W.y}; - R[4] = t; - W = twiddles[4 + 7 * ((thread + 0 + 0) % 8)]; - t = {R[5].x * W.x + R[5].y * W.y, R[5].y * W.x - R[5].x * W.y}; - R[5] = t; - W = twiddles[5 + 7 * ((thread + 0 + 0) % 8)]; - t = {R[6].x * W.x + R[6].y * W.y, R[6].y * W.x - R[6].x * W.y}; - R[6] = t; - W = twiddles[6 + 7 * ((thread + 0 + 0) % 8)]; - t = {R[7].x * W.x + R[7].y * W.y, R[7].y * W.x - R[7].x * W.y}; - R[7] = t; - InvRad8B1(R + 0, R + 1, R + 2, R + 3, R + 4, R + 5, R + 6, R + 7); -} -)_SRC"; - - return body; -} - -static std::string pp_fwd_step_1_2_radix16() -{ - std::string body = R"_SRC( - // Partial pass step 1: length-16 DFT on off-dimension - - // Radix-16 pass - FwdRad16B1(R + 0, R + 1, R + 2, R + 3, R + 4, R + 5, R + 6, R + 7, R + 8, R + 9, R + 10, R + 11, R + 12, R + 13, R + 14, R + 15); - - // Partial pass step 2: Hadamard product with twiddle factors - twiddle_multiple_pp_fwd(R, blockIdx.x % 4, twiddles_pp); -)_SRC"; - - return body; -} - -static std::string pp_inv_step_1_2_radix16() -{ - std::string body = R"_SRC( - // Partial pass step 1: length-16 DFT on off-dimension - - // Radix-16 pass - InvRad16B1(R + 0, R + 1, R + 2, R + 3, R + 4, R + 5, R + 6, R + 7, R + 8, R + 9, R + 10, R + 11, R + 12, R + 13, R + 14, R + 15); - - // Partial pass step 2: Hadamard product with twiddle factors - twiddle_multiple_pp_inv(R, blockIdx.x % 4, twiddles_pp); -)_SRC"; - - return body; -} - -// TODO global function name is hardcoded -static std::string partial_pass_sbrr_64_64_64_rtc_body(const std::string& kernel_name, - int direction) -{ - std::string body; - - body += R"_SRC( - extern "C" __global__ - __launch_bounds__(128) void - )_SRC"; - - body += kernel_name; - - body += R"_SRC( - ( - const scalar_type *__restrict__ twiddles_pp, - const scalar_type *__restrict__ twiddles, - const size_t *__restrict__ lengths, - const size_t *__restrict__ stride, - const size_t nbatch, - const unsigned int lds_padding, - void *__restrict__ load_cb_fn, - void *__restrict__ load_cb_data, - unsigned int load_cb_lds_bytes, - void *__restrict__ store_cb_fn, - void *__restrict__ store_cb_data, - scalar_type *__restrict__ ibuf, - scalar_type *__restrict__ obuf) -{ - auto const sb = SB_UNIT; - auto const ebtype = EmbeddedType::NONE; - auto const sbrc_type = SBRC_2D; - auto const transpose_type = NONE; - auto const drtype = DirectRegType::FORCE_OFF_OR_NOT_SUPPORT; - auto const apply_large_twiddle = false; - auto const intrinsic_mode = IntrinsicAccessType::DISABLE_BOTH; - const size_t large_twiddle_base = 8; - const size_t large_twiddle_steps = 0; - - // this kernel: - // uses 8 threads per transform - // does 16 transforms per thread block - // therefore it should be called with 128 threads per thread block - scalar_type R[16]; - extern __shared__ unsigned char __attribute__((aligned(sizeof(scalar_type)))) lds_uchar[]; - real_type_t *__restrict__ lds_real = reinterpret_cast *>(lds_uchar); - scalar_type *__restrict__ lds_complex = reinterpret_cast(lds_uchar); - size_t offset = 0; - unsigned int offset_lds; - unsigned int stride_lds; - size_t batch; - size_t transform; - const bool direct_load_to_reg = false; - const bool direct_store_from_reg = false; - const bool lds_linear = true; - const bool lds_is_real = false; - auto load_cb = get_load_cb(load_cb_fn); - auto store_cb = get_store_cb(store_cb_fn); - - size_t global_mem_idx = 0, offset_pp = 0, remaining_pp = 0; - - const StrideBin SB_1ST = (ebtype == EmbeddedType::C2Real_PRE) ? SB_NONUNIT : SB_UNIT; - const StrideBin SB_2ND = (ebtype == EmbeddedType::C2Real_PRE) ? SB_UNIT : SB_NONUNIT; - - // large twiddles - // - no large twiddles - - // offsets - const size_t dim = 3; - // const size_t stride0 = (sb == SB_UNIT) ? (1) : (stride[0]); - const size_t stride0 = (stride[0]); - unsigned int thread; - size_t remaining; - size_t index_along_d; - transform = blockIdx.x * 16 + threadIdx.x / 8; - remaining = transform; - remaining_pp = 64 * (transform / 64) + (transform % 64) / 16 + (transform * 4) % 64; - for (int d = 1; d < dim; ++d) - { - // index_along_d = remaining % lengths[d]; - remaining = remaining / lengths[d]; - // offset = offset + index_along_d * stride[d]; - - index_along_d = remaining_pp % lengths[d]; - remaining_pp = remaining_pp / lengths[d]; - offset_pp = offset_pp + index_along_d * stride[d]; - } - batch = remaining; - // offset = offset + batch * stride[dim]; - offset_pp = offset_pp + batch * stride[dim]; - stride_lds = 64 + (ebtype == EmbeddedType::NONE ? 0 : lds_padding); - offset_lds = stride_lds * (transform % 16); - - bool inbound = batch < nbatch; - - // load global into lds - if (inbound) - { - thread = threadIdx.x % 8; - - global_mem_idx = offset_pp + (thread + 0) * stride0; - lds_complex[offset_lds + (thread + 0)] = load_cb(ibuf, global_mem_idx, load_cb_data, nullptr); - - global_mem_idx = offset_pp + (thread + 8) * stride0; - lds_complex[offset_lds + (thread + 8)] = load_cb(ibuf, global_mem_idx, load_cb_data, nullptr); - - global_mem_idx = offset_pp + (thread + 16) * stride0; - lds_complex[offset_lds + (thread + 16)] = load_cb(ibuf, global_mem_idx, load_cb_data, nullptr); - - global_mem_idx = offset_pp + (thread + 24) * stride0; - lds_complex[offset_lds + (thread + 24)] = load_cb(ibuf, global_mem_idx, load_cb_data, nullptr); - - global_mem_idx = offset_pp + (thread + 32) * stride0; - lds_complex[offset_lds + (thread + 32)] = load_cb(ibuf, global_mem_idx, load_cb_data, nullptr); - - global_mem_idx = offset_pp + (thread + 40) * stride0; - lds_complex[offset_lds + (thread + 40)] = load_cb(ibuf, global_mem_idx, load_cb_data, nullptr); - - global_mem_idx = offset_pp + (thread + 48) * stride0; - lds_complex[offset_lds + (thread + 48)] = load_cb(ibuf, global_mem_idx, load_cb_data, nullptr); - - global_mem_idx = offset_pp + (thread + 56) * stride0; - lds_complex[offset_lds + (thread + 56)] = load_cb(ibuf, global_mem_idx, load_cb_data, nullptr); - } - - // calc the thread_in_device value once and for all device funcs - unsigned int thread_in_device = lds_linear ? threadIdx.x % 8 : threadIdx.x / 16; - - // call a pre-load from lds to registers (if necessary) - lds_to_reg_input_length64_device_sbrr( - R, lds_complex, stride_lds, offset_lds, thread_in_device, true); - )_SRC"; - - if(direction == -1) - { - body += R"_SRC( - // transform - forward_length64_SBRR_device_pp( - R, lds_real, lds_complex, twiddles, stride_lds, offset_lds, thread_in_device, true); - )_SRC"; - } - else if(direction == 1) - { - body += R"_SRC( - // transform - inverse_length64_SBRR_device_pp( - R, lds_real, lds_complex, twiddles, stride_lds, offset_lds, thread_in_device, true); - )_SRC"; - } - - body += R"_SRC( - // call a post-store from registers to lds (if necessary) - lds_from_reg_output_length64_device_sbrr( - R, lds_complex, stride_lds, offset_lds, thread_in_device, true); - - auto stride_lds_pp = 64; - auto offset_lds_pp = (blockIdx.x * 16 + threadIdx.x) % 64; - - // call a pre-load from lds to registers (if necessary) - lds_to_reg_16_input_length64_device_pp(R, lds_complex, stride_lds_pp, offset_lds_pp); - )_SRC"; - - if(direction == -1) - body += pp_fwd_step_1_2_radix16(); - else if(direction == 1) - body += pp_inv_step_1_2_radix16(); - - body += R"_SRC( - // call a post-store from registers to lds (if necessary) - lds_from_reg_16_output_length64_device_pp(R, lds_complex, stride_lds_pp, offset_lds_pp); - - // store global - __syncthreads(); - if (inbound) - { - global_mem_idx = offset_pp + (thread + 0) * stride0; - store_cb(obuf, - global_mem_idx, - lds_complex[offset_lds + (thread + 0)], - store_cb_data, - nullptr); - - global_mem_idx = offset_pp + (thread + 8) * stride0; - store_cb(obuf, - global_mem_idx, - lds_complex[offset_lds + (thread + 8)], - store_cb_data, - nullptr); - - global_mem_idx = offset_pp + (thread + 16) * stride0; - store_cb(obuf, - global_mem_idx, - lds_complex[offset_lds + (thread + 16)], - store_cb_data, - nullptr); - - global_mem_idx = offset_pp + (thread + 24) * stride0; - store_cb(obuf, - global_mem_idx, - lds_complex[offset_lds + (thread + 24)], - store_cb_data, - nullptr); - - global_mem_idx = offset_pp + (thread + 32) * stride0; - store_cb(obuf, - global_mem_idx, - lds_complex[offset_lds + (thread + 32)], - store_cb_data, - nullptr); - - global_mem_idx = offset_pp + (thread + 40) * stride0; - store_cb(obuf, - global_mem_idx, - lds_complex[offset_lds + (thread + 40)], - store_cb_data, - nullptr); - - global_mem_idx = offset_pp + (thread + 48) * stride0; - store_cb(obuf, - global_mem_idx, - lds_complex[offset_lds + (thread + 48)], - store_cb_data, - nullptr); - - global_mem_idx = offset_pp + (thread + 56) * stride0; - store_cb(obuf, - global_mem_idx, - lds_complex[offset_lds + (thread + 56)], - store_cb_data, - nullptr); - } -} - )_SRC"; - - return body; -} - -std::string partial_pass_sbrr_64_64_64_rtc(const std::string& kernel_name, - const std::vector all_factors, - int direction, - rocfft_precision precision, - CallbackType cbtype) -{ - std::string src; - - // start off with includes - src += rocfft_complex_h; - src += common_h; - src += memory_gfx_h; - src += callback_h; - src += butterfly_constant_h; - - append_radix_h(src, all_factors); - - src += rtc_precision_type_decl(precision); - src += lds_to_reg(); - src += reg_to_lds(); - src += lds_to_reg_pp(); - src += reg_to_lds_pp(); - if(direction == -1) - { - src += fwd_len64(); - src += twiddle_multiply_pp_fwd(); - } - else if(direction == 1) - { - src += inv_len64(); - src += twiddle_multiply_pp_inv(); - } - src += rtc_const_cbtype_decl(cbtype); - src += partial_pass_sbrr_64_64_64_rtc_body(kernel_name, direction); - - return src; -} - -RTCKernel::RTCGenerator RTCKernelPartialPassSBRR64Cubed::generate_from_node( - const LeafNode& node, const std::string& gpu_arch, bool enable_callbacks) -{ - RTCGenerator generator; - - auto scheme = node.scheme; - if(!(scheme == CS_KERNEL_STOCKHAM && node.applyPartialPass)) - return generator; - - size_t batch_accum = node.batch; - for(size_t j = 1; j < node.length.size(); j++) - batch_accum *= node.length[j]; - - auto workgroup_size = PARTIAL_PASS_SBRR_64_64_64_THREADS; - auto threads_per_transform = 8; - auto transforms_per_block = workgroup_size / threads_per_transform; - - auto bwd = transforms_per_block; - - auto b_x = (batch_accum + bwd - 1) / bwd; - auto wgs_x = workgroup_size; - - generator.gridDim = {static_cast(b_x), 1, 1}; - generator.blockDim = {wgs_x, 1, 1}; - - auto precision = node.precision; - auto direction = node.direction; - auto placement = node.placement; - auto inArrayType = node.inArrayType; - auto outArrayType = node.outArrayType; - auto cbType = node.GetCallbackType(enable_callbacks); - - auto kernelFactors = node.kernelFactors; - auto kernelFactorsPP = node.kernelFactorsPP; - - std::vector all_factors(kernelFactors.begin(), kernelFactors.end()); - all_factors.insert(all_factors.end(), kernelFactorsPP.begin(), kernelFactorsPP.end()); - - generator.generate_name = [=]() { - return partial_pass_64_64_64_sbrr_rtc_kernel_name( - precision, direction, placement, inArrayType, outArrayType, cbType); - }; - - generator.generate_src = [=](const std::string& kernel_name) { - return partial_pass_sbrr_64_64_64_rtc( - kernel_name, all_factors, direction, precision, cbType); - }; - - generator.construct_rtckernel = [=](const std::string& kernel_name, - const std::vector& code, - dim3 gridDim, - dim3 blockDim) { - return std::unique_ptr( - new RTCKernelPartialPassSBRR64Cubed(kernel_name, code, gridDim, blockDim)); - }; - - return generator; -} - -RTCKernelArgs RTCKernelPartialPassSBRR64Cubed::get_launch_args(DeviceCallIn& data) -{ - RTCKernelArgs kargs; - - kargs.append_ptr(data.node->twiddles_pp); - kargs.append_ptr(data.node->twiddles); - - kargs.append_ptr(kargs_lengths(data.node->devKernArg)); - kargs.append_ptr(kargs_stride_in(data.node->devKernArg)); - kargs.append_size_t(data.node->batch); - kargs.append_size_t(data.node->lds_padding); - - // callback params - kargs.append_ptr(data.callbacks.load_cb_fn); - kargs.append_ptr(data.callbacks.load_cb_data); - kargs.append_unsigned_int(data.callbacks.load_cb_lds_bytes); - kargs.append_ptr(data.callbacks.store_cb_fn); - kargs.append_ptr(data.callbacks.store_cb_data); - append_load_store_args(kargs, *data.node); - - kargs.append_ptr(data.bufIn[0]); - kargs.append_ptr(data.bufOut[0]); - - return kargs; -} \ No newline at end of file diff --git a/library/src/rtc_stockham_kernel.cpp b/library/src/rtc_stockham_kernel.cpp index 39dbe8e97ab..f8a0e7d1203 100644 --- a/library/src/rtc_stockham_kernel.cpp +++ b/library/src/rtc_stockham_kernel.cpp @@ -179,7 +179,7 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& else if(node.scheme == CS_KERNEL_STOCKHAM) ppType = PartialPassType::PPT_SBRR; else - throw std::runtime_error("Invalid scheme for partial pass"); + throw std::runtime_error("Invalid scheme for partial pass"); } generator.generate_name = [=, &node]() { @@ -249,6 +249,8 @@ RTCKernelArgs RTCKernelStockham::get_launch_args(DeviceCallIn& data) RTCKernelArgs kargs; // twiddles + if(data.node->applyPartialPass && data.node->scheme == CS_KERNEL_STOCKHAM) + kargs.append_ptr(data.node->twiddles_pp); kargs.append_ptr(data.node->twiddles); // large 1D twiddles if(data.node->scheme == CS_KERNEL_STOCKHAM_BLOCK_CC) diff --git a/library/src/tree_node_1D.cpp b/library/src/tree_node_1D.cpp index e600bad0a4e..a865d1ae7fa 100644 --- a/library/src/tree_node_1D.cpp +++ b/library/src/tree_node_1D.cpp @@ -908,19 +908,6 @@ void Stockham1DNode::SetupGPAndFnPtr_internal(DevFnCall& fnPtr, GridParam& gp) if(ebtype != EmbeddedType::NONE) lds_padding = 1; - if(applyPartialPass) - { - // Special case for partial pass 64 x 64 x 64. - // Kernel configuration is hardcoded for now. - // TODO: Once the partial-pass kernels are properly - // integrated into the Stockham kernel generators, - // this configuration will come from the usual location - // in kernel-generator.py. - kernel.threads_per_transform[0] = 8; - kernel.workgroup_size = 128; - kernel.transforms_per_block = kernel.workgroup_size / kernel.threads_per_transform[0]; - } - bwd = kernel.transforms_per_block; wgs = kernel.workgroup_size; gp.b_x = (batch_accum + bwd - 1) / bwd; @@ -944,6 +931,10 @@ void Stockham1DNode::SetupGPAndFnPtr_internal(DevFnCall& fnPtr, GridParam& gp) else lds = (length[0] + lds_padding) * bwd; } + + std::cout << "gp.b_x: " << gp.b_x << std::endl; + std::cout << "gp.wgs_x: " << gp.wgs_x << std::endl; + std::cout << "lds: " << lds << std::endl; } bool Stockham1DNode::CreateDeviceResources() @@ -1135,35 +1126,15 @@ void SBCCNode::SetupGPAndFnPtr_internal(DevFnCall& fnPtr, GridParam& gp) bwd = kernel.transforms_per_block; wgs = kernel.workgroup_size; - if(applyPartialPass) - { - // Special case for partial pass 64 x 64 x 64. - // Kernel configration is hardcoded for now. - // TODO: Once the partial-pass kernels are integrated - // into the stockham kernel generators, change this - // configuration to use kernel-generator.py data. - auto tpt = 8; - wgs = 64; - bwd = wgs / tpt; - } - lds = length[0] * bwd; gp.b_x = ((length[1]) - 1) / bwd + 1; gp.b_x *= product(length.begin() + 2, length.end()) * batch; gp.wgs_x = wgs; - if(applyPartialPass) - { - // grid and thread organization is different - // on partial pass sbcc kernels (for improved - // global memory access patterns). - auto factor = *std::max_element(kernelFactorsPP.begin(), kernelFactorsPP.end()); - - gp.b_x /= factor; - gp.wgs_x *= factor; - lds *= factor; - } + std::cout << "gp.b_x: " << gp.b_x << std::endl; + std::cout << "gp.wgs_x: " << gp.wgs_x << std::endl; + std::cout << "lds: " << lds << std::endl; } std::vector SBCCNode::CollapsibleDims() From 3356f9fb457a80610d98acdcfca4a3ed22be5a8e Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 14 Feb 2025 14:49:29 -0700 Subject: [PATCH 04/69] - Remove no longer needed code. --- library/src/tree_node.cpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/library/src/tree_node.cpp b/library/src/tree_node.cpp index c24feadc734..f69c7f1ef0f 100644 --- a/library/src/tree_node.cpp +++ b/library/src/tree_node.cpp @@ -106,14 +106,6 @@ void LeafNode::GetKernelFactors() { FMKey key = GetKernelKey(); kernelFactors = function_pool::get_kernel(key).factors; - - // Hard-coded kernel factors for len 64x64x64 partial-pass - // TODO: Remove this hard-coded logic once - // partial-pass is integrated into the stockham generators. - if(scheme == CS_KERNEL_STOCKHAM && applyPartialPass) - kernelFactors = {8, 8}; - if(scheme == CS_KERNEL_STOCKHAM_BLOCK_CC && applyPartialPass) - kernelFactors = {8, 8}; } void LeafNode::GetKernelPartialPassFactors() From 4e56a6d70b8ef1df853bedc344c634625403fd4c Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 14 Feb 2025 17:01:24 -0700 Subject: [PATCH 05/69] Fix register size in partial pass sbrr --- .../src/device/generator/stockham_pp_gen_rr.h | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index 340975e141a..1cb54ba8776 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -30,6 +30,8 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR { // TODO: revisit this. Test with factors_pp.size() > 1 max_factor_pp = *std::max_element(specs.factors_pp.begin(), specs.factors_pp.end()); + + R.size = Expression{std::max(nregisters, max_factor_pp)}; } unsigned int max_factor_pp; @@ -274,8 +276,15 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR StatementList& body = f.body; + body += Declaration{t}; + body += Declaration{W}; + for(unsigned int w = 0; w < max_factor_pp; ++w) - body += Assign(R[w], twiddles_pp[thread * length + w] * R[w]); + { + body += Assign{W, twiddles_pp[thread * length + w]}; + body += Assign{t, TwiddleMultiply{R[w], W}}; + body += Assign{R[w], t}; + } return f; } @@ -348,6 +357,15 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR return arguments; } + void collect_length_stride(StatementList& body) + { + if(static_dim) + { + body += Declaration{dim, static_dim}; + } + body += Declaration{stride0, Parens{stride[0]}}; + } + Function generate_global_function() override { Function f("forward_length" + std::to_string(length) + "_" + tiling_name()); From bc281f91fdf2d142257dc1bd7da15eff9aae66a0 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Tue, 18 Feb 2025 11:03:47 -0700 Subject: [PATCH 06/69] - Fix offset initialization. --- library/src/device/generator/stockham_pp_gen_rr.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index 1cb54ba8776..56c7429a975 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -53,8 +53,8 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR stmts += Declaration{thread}; stmts += Declaration(remaining); stmts += Declaration(index_along_d); - stmts += Declaration(remaining_pp); - stmts += Declaration(offset_pp); + stmts += Declaration(remaining_pp, Literal{0}); + stmts += Declaration(offset_pp, Literal{0}); stmts += Assign{transform, block_id * transforms_per_block + thread_id / threads_per_transform}; stmts += Assign{remaining, transform}; From ed2afe3ead8976feeebbab1c04e9ac73443a19e7 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Tue, 18 Feb 2025 11:13:44 -0700 Subject: [PATCH 07/69] - remove debug code. --- library/src/tree_node_1D.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/library/src/tree_node_1D.cpp b/library/src/tree_node_1D.cpp index a865d1ae7fa..9bfbd9ed5c3 100644 --- a/library/src/tree_node_1D.cpp +++ b/library/src/tree_node_1D.cpp @@ -931,10 +931,6 @@ void Stockham1DNode::SetupGPAndFnPtr_internal(DevFnCall& fnPtr, GridParam& gp) else lds = (length[0] + lds_padding) * bwd; } - - std::cout << "gp.b_x: " << gp.b_x << std::endl; - std::cout << "gp.wgs_x: " << gp.wgs_x << std::endl; - std::cout << "lds: " << lds << std::endl; } bool Stockham1DNode::CreateDeviceResources() From 97e6379c7f9e7a5599b079603f30c9a554941a1a Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Tue, 18 Feb 2025 11:18:25 -0700 Subject: [PATCH 08/69] - Clean up debug code. --- library/src/tree_node_1D.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/library/src/tree_node_1D.cpp b/library/src/tree_node_1D.cpp index 9bfbd9ed5c3..e4e953ba541 100644 --- a/library/src/tree_node_1D.cpp +++ b/library/src/tree_node_1D.cpp @@ -1127,10 +1127,6 @@ void SBCCNode::SetupGPAndFnPtr_internal(DevFnCall& fnPtr, GridParam& gp) gp.b_x = ((length[1]) - 1) / bwd + 1; gp.b_x *= product(length.begin() + 2, length.end()) * batch; gp.wgs_x = wgs; - - std::cout << "gp.b_x: " << gp.b_x << std::endl; - std::cout << "gp.wgs_x: " << gp.wgs_x << std::endl; - std::cout << "lds: " << lds << std::endl; } std::vector SBCCNode::CollapsibleDims() From e7dceae465b5edd4eba8832cd73ea389a38de54e Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Tue, 18 Feb 2025 11:39:37 -0700 Subject: [PATCH 09/69] - Handle direction in partial-pass twiddle multiply step. --- library/src/device/generator/stockham_pp_gen_rr.h | 11 +++++++++-- library/src/rtc_stockham_gen.cpp | 4 ++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index 56c7429a975..69397dd122f 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -259,7 +259,7 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR return f; } - Function generate_twiddle_multiply_pp_function() + Function generate_twiddle_multiply_pp_function(int direction) { std::string function_name = "twiddle_multiply_pp_length" + std::to_string(length) + "_device"; @@ -282,7 +282,14 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR for(unsigned int w = 0; w < max_factor_pp; ++w) { body += Assign{W, twiddles_pp[thread * length + w]}; - body += Assign{t, TwiddleMultiply{R[w], W}}; + + if(direction == -1) + body += Assign{t, TwiddleMultiply{R[w], W}}; + else if(direction == 1) + body += Assign{t, TwiddleMultiplyConjugate{R[w], W}}; + else + throw std::runtime_error("Invalid FFT direction for twiddle multiply"); + body += Assign{R[w], t}; } diff --git a/library/src/rtc_stockham_gen.cpp b/library/src/rtc_stockham_gen.cpp index b6c75ac3147..f22ecbc0182 100644 --- a/library/src/rtc_stockham_gen.cpp +++ b/library/src/rtc_stockham_gen.cpp @@ -362,8 +362,8 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, kernel_pp->generate_lds_to_reg_input_step_1_2_function()); reg2lds_pp_steps = std::make_unique( kernel_pp->generate_lds_from_reg_output_pp_step_1_2_function()); - twiddle_multiply_pp - = std::make_unique(kernel_pp->generate_twiddle_multiply_pp_function()); + twiddle_multiply_pp = std::make_unique( + kernel_pp->generate_twiddle_multiply_pp_function(direction)); device = std::make_unique(kernel_pp->generate_device_function()); break; } From bd3af1bd5c4017ea727796bba0e3b7f36d2a63c5 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 20 Feb 2025 16:15:37 -0700 Subject: [PATCH 10/69] - Fix partial pass radix include in rtc_stockham_gen. - Add support for direct from reg loads in partial pass SBRR kernels. --- .../src/device/generator/stockham_pp_gen_rr.h | 66 ++++++++++--------- library/src/device/kernel-generator.py | 2 +- library/src/rtc_stockham_gen.cpp | 6 +- 3 files changed, 38 insertions(+), 36 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index 69397dd122f..a256ede1c4d 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -29,12 +29,14 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR : StockhamKernelRR(specs) { // TODO: revisit this. Test with factors_pp.size() > 1 - max_factor_pp = *std::max_element(specs.factors_pp.begin(), specs.factors_pp.end()); + max_factor_pp = *std::max_element(specs.factors_pp.begin(), specs.factors_pp.end()); + prod_factors_pp = std::accumulate( + specs.factors_pp.begin(), specs.factors_pp.end(), 1, std::multiplies()); R.size = Expression{std::max(nregisters, max_factor_pp)}; } - unsigned int max_factor_pp; + unsigned int max_factor_pp, prod_factors_pp; Variable offset_pp{"offset_pp", "size_t"}; Variable stride_lds_pp{"stride_lds_pp", "size_t"}; Variable offset_lds_pp{"offset_lds_pp", "size_t"}; @@ -84,6 +86,25 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR return stmts; } + StatementList load_global_generator(unsigned int h, + unsigned int hr, + unsigned int width, + unsigned int dt, + Expression guard) const + { + if(hr == 0) + hr = h; + StatementList load; + for(unsigned int w = 0; w < width; ++w) + { + auto tid = Parens{thread + dt + h * threads_per_transform}; + auto idx = Parens{tid + w * length / width}; + load += Assign{R[hr * width + w], + LoadGlobal{buf, offset_pp + Parens{Expression{idx}} * stride0}}; + } + return load; + } + StatementList load_from_global(bool load_registers) override { StatementList stmts; @@ -117,7 +138,7 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR unsigned int width = factors[0]; auto height = static_cast(length) / width / threads_per_transform; - auto load_global = std::mem_fn(&StockhamKernel::load_global_generator); + auto load_global = std::mem_fn(&StockhamPartialPassKernelRR::load_global_generator); stmts += add_work(std::bind(load_global, this, _1, _2, _3, _4, _5), width, height, @@ -154,17 +175,8 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR stmts += If{Equal{embedded_type, "EmbeddedType::Real2C_POST"}, stmts_real2c_post}; } else - { - auto width = factors.back(); - auto cumheight = product(factors.begin(), factors.begin() + (factors.size() - 1)); - auto height = static_cast(length) / width / threads_per_transform; - - auto store_global = std::mem_fn(&StockhamKernel::store_global_generator); - stmts += add_work(std::bind(store_global, this, _1, _2, _3, _4, _5, cumheight), - width, - height, - ThreadGuardMode::GUARD_BY_IF); - } + throw std::runtime_error( + "Direct store from registers not allowed in partial pass SBRR kernels"); return {If{inbound, stmts}}; } @@ -479,14 +491,13 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR // If we're doing direct-from-reg, this function simply returns. body += LineBreak{}; body += CommentLines{"call a post-store from registers to lds (if necessary)"}; + + // Post stores must be in LDS since partial steps 1/2 are always in LDS. StatementList postStore; postStore += Call{"lds_from_reg_output_length" + std::to_string(length) + "_device", pre_post_lds_tmpl, pre_post_lds_args}; - if(!direct_to_from_reg) - body += postStore; - else - body += If{!direct_store_from_reg, postStore}; + body += postStore; // partial pass here body += perform_partial_pass_step_1_2(); @@ -499,21 +510,12 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR storelds += LineBreak{}; storelds += CommentLines{"store global"}; storelds += SyncThreads{}; - storelds += store_to_global(false); - if(!direct_to_from_reg) - { - body += storelds; - } - else - { - StatementList storer; - storer += CommentLines{"store registers into global"}; - storer += store_to_global(true); - - body += If{direct_store_from_reg, storer}; - body += Else{storelds}; - } + // Cannot have direct from register stores + // to global mem since partial steps 1/2 + // are always in LDS + storelds += store_to_global(false); + body += storelds; f.templates = global_templates(); f.arguments = global_arguments(); diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index 729e6aced98..5517d2eda82 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -309,7 +309,7 @@ def list_small_kernels(): NS(length= 56, workgroup_size=128, threads_per_transform= 8, factors=(7, 8)), NS(length= 60, workgroup_size= 64, threads_per_transform= 10, factors=(6, 10)), NS(length= 63, workgroup_size=256, threads_per_transform= 21, factors=(3, 3, 7), half_lds=False, runtime_compile=True), - NS(length= 64, workgroup_size=128, threads_per_transform= 8, factors=(8, 8), half_lds=False, direct_to_from_reg=False), + NS(length= 64, workgroup_size=128, threads_per_transform= 8, factors=(4, 4, 4), half_lds=False, direct_to_from_reg=True), NS(length= 65, workgroup_size=256, threads_per_transform= 13, factors=(13, 5), runtime_compile=True), NS(length= 66, workgroup_size=256, threads_per_transform= 11, factors=(6, 11), half_lds=False, runtime_compile=True), NS(length= 68, workgroup_size=256, threads_per_transform= 17, factors=(17, 4), runtime_compile=True), diff --git a/library/src/rtc_stockham_gen.cpp b/library/src/rtc_stockham_gen.cpp index f22ecbc0182..2e274eb6f05 100644 --- a/library/src/rtc_stockham_gen.cpp +++ b/library/src/rtc_stockham_gen.cpp @@ -407,6 +407,9 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, // get factors vector all_factors = kernel->factors; + + if(ppType != PPT_NONE) + all_factors.insert(all_factors.end(), specs.factors_pp.begin(), specs.factors_pp.end()); } // generated functions default to forward in-place interleaved. @@ -452,9 +455,6 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, // append the neccessary functions only append_radix_h(src, all_factors); - if(ppType != PPT_NONE) - append_radix_h(src, specs.factors_pp); - // SBCCs don't need this if(scheme != CS_KERNEL_STOCKHAM_BLOCK_CC) src += real2complex_device_h; From 9c021bf24d004589f0977edd34429e6cb263058c Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 20 Feb 2025 16:20:45 -0700 Subject: [PATCH 11/69] - Clean up comment --- library/src/device/generator/stockham_pp_gen_rr.h | 1 - 1 file changed, 1 deletion(-) diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index a256ede1c4d..4d9f99faeb2 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -322,7 +322,6 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR auto pre_post_lds_args = device_lds_reg_inout_pp_device_call_arguments(); pre_post_lds_tmpl.set_value(stride_type.name, "lds_linear ? SB_UNIT : SB_NONUNIT"); - // TODO: handle direct_to_from_reg StatementList preLoad; preLoad += Call{"lds_to_reg_input_pp_step_1_2_length" + std::to_string(length) + "_device", pre_post_lds_tmpl, From 860cc78050d837851c9645394db174e701651e68 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 21 Feb 2025 16:49:09 -0700 Subject: [PATCH 12/69] - Disable half-lds support in partial-pass SBRR kernels. - Remove not needed lds_linear branch in partial-pass SBRR generator. --- .../src/device/generator/stockham_pp_gen_rr.h | 37 ++++++++++++++----- library/src/tree_node.cpp | 3 +- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index 4d9f99faeb2..26e3f24c6c4 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -206,13 +206,20 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR return {R, lds_complex, stride_lds_pp, offset_lds_pp}; } + TemplateList device_lds_reg_inout_pp_templates() + { + TemplateList tpls; + tpls.append(scalar_type); + return tpls; + } + Function generate_lds_to_reg_input_step_1_2_function() { std::string function_name = "lds_to_reg_input_pp_step_1_2_length" + std::to_string(length) + "_device"; Function f{function_name}; - f.templates = device_lds_reg_inout_templates(); + f.templates = device_lds_reg_inout_pp_templates(); f.arguments = device_lds_reg_inout_pp_arguments(); f.qualifier = "__device__"; @@ -252,7 +259,7 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR = "lds_from_reg_output_pp_step_1_2_length" + std::to_string(length) + "_device"; Function f{function_name}; - f.templates = device_lds_reg_inout_templates(); + f.templates = device_lds_reg_inout_pp_templates(); f.arguments = device_lds_reg_inout_pp_arguments(); f.qualifier = "__device__"; @@ -308,6 +315,11 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR return f; } + TemplateList device_lds_reg_inout_pp_step_1_2_device_call_templates() + { + return {scalar_type}; + } + // TODO: Move this to a device function StatementList perform_partial_pass_step_1_2() { @@ -318,9 +330,8 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR stmts += Declaration{offset_lds_pp, Parens(block_id * transforms_per_block + thread_id) % length}; - auto pre_post_lds_tmpl = device_lds_reg_inout_device_call_templates(); + auto pre_post_lds_tmpl = device_lds_reg_inout_pp_step_1_2_device_call_templates(); auto pre_post_lds_args = device_lds_reg_inout_pp_device_call_arguments(); - pre_post_lds_tmpl.set_value(stride_type.name, "lds_linear ? SB_UNIT : SB_NONUNIT"); StatementList preLoad; preLoad += Call{"lds_to_reg_input_pp_step_1_2_length" + std::to_string(length) + "_device", @@ -384,6 +395,15 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR body += Declaration{stride0, Parens{stride[0]}}; } + StatementList set_lds_is_real() override + { + // Half-LDS always disabled in partial-pass. + // To make this option work, step_1_2 here + // would need to implement half LDS usage in + // the off-direction pass. + return {Declaration{lds_is_real, Literal{"false"}}}; + } + Function generate_global_function() override { Function f("forward_length" + std::to_string(length) + "_" + tiling_name()); @@ -447,10 +467,7 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR body += LineBreak{}; body += CommentLines{"calc the thread_in_device value once and for all device funcs"}; - body += Declaration{thread_in_device, - Ternary{lds_linear, - thread_id % threads_per_transform, - thread_id / transforms_per_block}}; + body += Declaration{thread_in_device, thread_id % threads_per_transform}; // before starting the transform job (core device function) // we call a re-load lds-to-reg function here, but it's not always doing things. @@ -459,7 +476,7 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR body += CommentLines{"call a pre-load from lds to registers (if necessary)"}; auto pre_post_lds_tmpl = device_lds_reg_inout_device_call_templates(); auto pre_post_lds_args = device_lds_reg_inout_device_call_arguments(); - pre_post_lds_tmpl.set_value(stride_type.name, "lds_linear ? SB_UNIT : SB_NONUNIT"); + pre_post_lds_tmpl.set_value(stride_type.name, "SB_UNIT"); StatementList preLoad; preLoad += Call{"lds_to_reg_input_length" + std::to_string(length) + "_device", pre_post_lds_tmpl, @@ -476,7 +493,7 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR auto templates = device_call_templates(); auto arguments = device_call_arguments(c); - templates.set_value(stride_type.name, "lds_linear ? SB_UNIT : SB_NONUNIT"); + templates.set_value(stride_type.name, "SB_UNIT"); body += Call{"forward_length" + std::to_string(length) + "_" + tiling_name() + "_device", diff --git a/library/src/tree_node.cpp b/library/src/tree_node.cpp index f69c7f1ef0f..c61c60bbddd 100644 --- a/library/src/tree_node.cpp +++ b/library/src/tree_node.cpp @@ -296,7 +296,8 @@ void LeafNode::SetupGridParamAndFuncPtr(DevFnCall& fnPtr, GridParam& gp) double_half_lds_alloc = true; } - if(kernel.half_lds && (!double_half_lds_alloc)) + // no support for half-lds in partial-pass mode + if(kernel.half_lds && (!double_half_lds_alloc) && (!applyPartialPass)) gp.lds_bytes /= 2; } } From ec4eb381b8e03914139c277d55731fb1a77eef59 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 21 Feb 2025 17:00:51 -0700 Subject: [PATCH 13/69] - Fix partial-pass kernel name --- library/src/device/generator/stockham_pp_gen_rr.h | 1 - library/src/rtc_stockham_gen.cpp | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index 26e3f24c6c4..3c54f42d3f2 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -41,7 +41,6 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR Variable stride_lds_pp{"stride_lds_pp", "size_t"}; Variable offset_lds_pp{"offset_lds_pp", "size_t"}; - // TODO: this should be __restrict__ Variable twiddles_pp{"twiddles_pp", "const scalar_type", true, true}; StatementList calculate_offsets() override diff --git a/library/src/rtc_stockham_gen.cpp b/library/src/rtc_stockham_gen.cpp index 2e274eb6f05..13665be62a4 100644 --- a/library/src/rtc_stockham_gen.cpp +++ b/library/src/rtc_stockham_gen.cpp @@ -109,7 +109,7 @@ std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, if(scheme == CS_KERNEL_2D_SINGLE) kernel_name += "x" + std::to_string(specs2d.threads_per_transform); - if(specs.half_lds) + if(specs.half_lds && ppType == PPT_NONE) kernel_name += "_halfLds"; if(specs.static_dim) From d50ae306dcd41858bb95d5183c772a93e92b3af4 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Mon, 24 Feb 2025 10:50:31 -0700 Subject: [PATCH 14/69] Fix formatting. --- library/src/device/generator/stockham_pp_gen_rr.h | 2 +- library/src/rtc_stockham_kernel.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index 3c54f42d3f2..50526ea4c5f 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -1,4 +1,4 @@ -// Copyright (C) 2021 - 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/library/src/rtc_stockham_kernel.cpp b/library/src/rtc_stockham_kernel.cpp index f8a0e7d1203..8274d42f259 100644 --- a/library/src/rtc_stockham_kernel.cpp +++ b/library/src/rtc_stockham_kernel.cpp @@ -179,7 +179,7 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& else if(node.scheme == CS_KERNEL_STOCKHAM) ppType = PartialPassType::PPT_SBRR; else - throw std::runtime_error("Invalid scheme for partial pass"); + throw std::runtime_error("Invalid scheme for partial pass"); } generator.generate_name = [=, &node]() { From c6c02d0d03985862dafd280e3d34f051c8e0f3dd Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 7 Mar 2025 11:26:12 -0700 Subject: [PATCH 15/69] - Add partial pass test scripts. --- scripts/partial-pass/convert_1d_to_3d.m | 17 ++ scripts/partial-pass/convert_3d_to_1d.m | 14 ++ scripts/partial-pass/partial_pass_3d.m | 211 +++++++++++++++++++++++ scripts/partial-pass/rocfft_to_octave.sh | 74 ++++++++ scripts/partial-pass/run_test.m | 23 +++ scripts/partial-pass/run_test.sh | 72 ++++++++ shared/printbuffer.h | 3 +- 7 files changed, 413 insertions(+), 1 deletion(-) create mode 100644 scripts/partial-pass/convert_1d_to_3d.m create mode 100644 scripts/partial-pass/convert_3d_to_1d.m create mode 100644 scripts/partial-pass/partial_pass_3d.m create mode 100755 scripts/partial-pass/rocfft_to_octave.sh create mode 100644 scripts/partial-pass/run_test.m create mode 100755 scripts/partial-pass/run_test.sh diff --git a/scripts/partial-pass/convert_1d_to_3d.m b/scripts/partial-pass/convert_1d_to_3d.m new file mode 100644 index 00000000000..d5997a48656 --- /dev/null +++ b/scripts/partial-pass/convert_1d_to_3d.m @@ -0,0 +1,17 @@ +function out = convert_1d_to_3d(in, n1, n2, n3, batch, ordering) + +if ~isvector(in) + error('Invalid input'); +endif + +if strcmp(ordering,'column-major') + out=reshape(in, n3, n2, n1, batch); +elseif strcmp(ordering,'row-major') + out=reshape(in, n1, n2, n3, batch)'; +else + error('Invalid option'); +endif + + + + diff --git a/scripts/partial-pass/convert_3d_to_1d.m b/scripts/partial-pass/convert_3d_to_1d.m new file mode 100644 index 00000000000..9b05008afa0 --- /dev/null +++ b/scripts/partial-pass/convert_3d_to_1d.m @@ -0,0 +1,14 @@ +function out = convert_3d_to_1d(in, ordering) + +if isvector(in) + error('Invalid input'); +endif + +if strcmp(ordering,'column-major') + out = reshape(in, 1, []); + out=conj(out); +elseif strcmp(ordering,'row-major') + out = reshape(in', 1, []); +else + error('Invalid option'); +endif diff --git a/scripts/partial-pass/partial_pass_3d.m b/scripts/partial-pass/partial_pass_3d.m new file mode 100644 index 00000000000..8bfe90851e1 --- /dev/null +++ b/scripts/partial-pass/partial_pass_3d.m @@ -0,0 +1,211 @@ +function partial_pass_3d(in_length, nbatch, pp_dim, pp_radices, in_batched, out_batched, test_mode) + + test_mode_1 = 'full-3d'; + test_mode_2 = 'direction_1'; + test_mode_3 = 'direction_2'; + test_mode_4 = 'step_1_2'; + test_mode_5 = 'step_3_4'; + test_mode_6 = 'direction_1_step_1_2'; + test_mode_7 = 'direction_2_step_3_4'; + + if ~(strcmp(test_mode,test_mode_1) || strcmp(test_mode,test_mode_2) || strcmp(test_mode,test_mode_3)... + || strcmp(test_mode,test_mode_4) || strcmp(test_mode,test_mode_5) || strcmp(test_mode,test_mode_6)... + || strcmp(test_mode,test_mode_7)) + display(test_mode); + error('Invalid test mode'); + endif + + format longG; + ordering='column-major'; + data_empty_value = -123456789; + + N = prod(in_length); + + pp_mode = 'four-step'; + + in_batched = convert_1d_to_3d(in_batched, in_length(1), in_length(2), in_length(3), nbatch, ordering); + out_batched = convert_1d_to_3d(out_batched, in_length(1), in_length(2), in_length(3), nbatch, ordering); + + for ibatch=1:nbatch + in = in_batched(:,:,:,ibatch); + out_ = out_batched(:,:,:,ibatch); + + % Validate output + idx_data=find(real(out_)~=data_empty_value); + if ( length(idx_data) != N ) + error('Error: incomplete data'); + endif + + % 3D-FFT (MATLAB built-in) + if (strcmp(test_mode,test_mode_1)) + out = fftn(in); + out = convert_3d_to_1d(out, ordering); + out_ = convert_3d_to_1d(out_, ordering); + linf_rocfft_vs_octave_built_in = norm(out-out_,'inf'); + disp(['l-inf norm: ' num2str(linf_rocfft_vs_octave_built_in)]); + return; + endif + + % CS_3D_RC from rocFFT (with partial pass) + [out_3d_rc, out_3d_rc_pp_1, out_3d_rc_pp_2] = run_CS_3D_RC(in_length, in, pp_dim, pp_radices, pp_mode); + out_3d_rc = convert_3d_to_1d(out_3d_rc, ordering); + out_3d_rc_pp_1 = convert_3d_to_1d(out_3d_rc_pp_1, ordering); + out_3d_rc_pp_2 = convert_3d_to_1d(out_3d_rc_pp_2, ordering); + out_ = convert_3d_to_1d(out_, ordering); + + linf_rocfft_vs_octave_3d_rc_pp = norm(out_3d_rc-out_,'inf'); + disp(['l-inf norm: ' num2str(linf_rocfft_vs_octave_3d_rc_pp)]); + endfor + + display(''); + + function [out, out_pp_1, out_pp_2] = run_CS_3D_RC(in_length, in, pp_dim, pp_radices, pp_mode) + n = in_length(pp_dim); + n1 = pp_radices(1); + n2 = pp_radices(2); + + F_n = dft_matrix(n); + F_n1 = dft_matrix(n1); + F_n2 = dft_matrix(n2); + + out = in; + + if (pp_dim == 1) + % 1st kernel (2nd dimension) + out = fft(out,[], 2); + out = partial_pass_step_1_2(out, 1, n1, n2, F_n1, F_n2, F_n, pp_mode); + out_pp_1 = out; + + % 2nd kernel (3rd dimension) + out = partial_pass_step_3_4(out, 1, n1, n2, F_n1, F_n2, F_n, pp_mode); + out_pp_2 = out; + out = fft(out,[], 3); + endif + + if (pp_dim == 2) + % 1st kernel (1st dimension) + out = fft(out,[], 1); + out = partial_pass_step_1_2(out, 2, n1, n2, F_n1, F_n2, F_n, pp_mode); + out_pp_1 = out; + + % 2nd kernel (3rd dimension) + transp_order = [3 2 1]; + out = permute(out, transp_order); + + out = partial_pass_step_3_4(out, 2, n1, n2, F_n1, F_n2, F_n, pp_mode); + out_pp_2 = out; + + out = fft(out,[], 1); + + out = permute(out, transp_order); + endif + + if (pp_dim == 3) + % 1st kernel (1st dimension) + out = fft(out,[], 1); + out = partial_pass_step_1_2(out, 3, n1, n2, F_n1, F_n2, F_n, pp_mode); + out_pp_1 = out; + + % 2nd kernel (2nd dimension) + out = partial_pass_step_3_4(out, 3, n1, n2, F_n1, F_n2, F_n, pp_mode); + out_pp_2 = out; + out = fft(out,[], 2); + endif + endfunction + + function [dim1, dim2] = get_data_dim_partial_pass(input, pp_dim) + if (pp_dim==1) + dim1 = size(input,2); + dim2 = size(input,3); + elseif (pp_dim==2) + dim1 = size(input,1); + dim2 = size(input,3); + elseif (pp_dim==3) + dim1 = size(input,1); + dim2 = size(input,2); + endif + endfunction + + function input_data_decomp = get_pp_decomposed_data(input_data, pp_dim, idx1, idx2, n1, n2) + if (pp_dim==1) + input_data_decomp = reshape(input_data(:,idx1,idx2), n1, n2); + elseif (pp_dim==2) + input_data_decomp = reshape(input_data(idx1,:,idx2), n1, n2); + elseif (pp_dim==3) + input_data_decomp = reshape(input_data(idx1,idx2,:), n1, n2); + endif + endfunction + + function output = set_pp_data(input, input_decomp, pp_dim, idx1, idx2) + output = input; + + if (pp_dim==1) + output(:,idx1,idx2) = reshape(input_decomp, [], 1); + elseif (pp_dim==2) + output(idx1,:,idx2) = reshape(input_decomp, [], 1); + elseif (pp_dim==3) + output(idx1,idx2,:) = reshape(input_decomp, [], 1); + endif + endfunction + + function output = partial_pass_step_1_2(input, pp_dim, n1, n2, F_n1, F_n2, F_n, mode) + output = input; + + [dim1, dim2] = get_data_dim_partial_pass(input, pp_dim); + + for idx2=1:dim2 + for idx1=1:dim1 + in_decomp = get_pp_decomposed_data(output, pp_dim, idx1, idx2, n1, n2); + + if strcmp(mode, 'four-step') + % Length-n2 FFT along rows of in_decomp + out_decomp = fft(in_decomp, n2, 2); + % Twiddle multiply + out_decomp = F_n(1:n1, 1:n2).*out_decomp; + elseif strcmp(mode, 'six-step') + % Local transpose + out_decomp = in_decomp.'; + % Length-n1 FFT along columns of out_decomp + out_decomp = fft(out_decomp, n1, 1); + % Twiddle multiply + out_decomp = F_n(1:n1, 1:n2).*out_decomp; + else + error('invalid partial-pass mode'); + endif + + output = set_pp_data(output, out_decomp, pp_dim, idx1, idx2); + endfor + endfor + endfunction + + function output = partial_pass_step_3_4(input, pp_dim, n1, n2, F_n1, F_n2, F_n, mode) + output = input; + + [dim1, dim2] = get_data_dim_partial_pass(input, pp_dim); + + for idx1=1:dim1 + for idx2=1:dim2 + in_decomp = get_pp_decomposed_data(output, pp_dim, idx1, idx2, n1, n2); + + if strcmp(mode, 'four-step') + % Local transpose + out_decomp = in_decomp.'; + % Length-n1 FFT along rows of out_decomp + out_decomp = fft(out_decomp, n1, 2); + elseif strcmp(mode, 'six-step') + % Local transpose + out_decomp = in_decomp.'; + % Length-n2 FFT along columns of out_decomp + out_decomp = fft(out_decomp, n2, 1); + % Local transpose + out_decomp = out_decomp.'; + else + error('invalid partial-pass mode'); + endif + + output = set_pp_data(output, out_decomp, pp_dim, idx1, idx2); + endfor + endfor + endfunction + +endfunction diff --git a/scripts/partial-pass/rocfft_to_octave.sh b/scripts/partial-pass/rocfft_to_octave.sh new file mode 100755 index 00000000000..0f9164652f5 --- /dev/null +++ b/scripts/partial-pass/rocfft_to_octave.sh @@ -0,0 +1,74 @@ +#! /bin/bash + +# usage /.rocfft_to_octave.sh $arg1 #arg2 $file +# arg1=1 (input) arg1=0 (output) +# arg=2 buffer id + +if [ $1 -eq 1 ]; then + filename="rocfft_input_data.m" +elif [ $1 -eq 0 ]; then + filename="rocfft_output_data.m" +else + echo "error" +fi + +# put input file in variable filename +sed '' $3 | sponge $filename + +# Get buffer description lines in filename and append +# line number to them (the lines starting with +# '--- --- or final output') +cat -n $filename | sed -n '/--- ---\|final output/p' | sponge $filename + +# remove lines with buffer hash +sed '/hash/d' $filename | sponge $filename + +# store result in temp variable +tmp_var=`cat $filename` + +# get line of buffer passed as argument +tmp_var=$(sed -n "/kernel $2/p; /kernel $2/q" <<< "$tmp_var") + +# if no lines found, use line number of 'final output' buffer +if [[ -z "${tmp_var// }" ]] ; then + sed -n "/final output/p; /final output/q" $filename | sponge $filename +else + sed -n "/kernel $2/p; /kernel $2/q" $filename | sponge $filename +fi + +# get line number from this line +sed 's/ .*//' $filename | sponge $filename + +# store line number in variable tmp_var +tmp_var=`cat $filename` + +# put input file in variable filename +sed '' $3 | sponge $filename + +# get buffer from line1 line number to the next '--- ---' line +sed -n "1,$tmp_var b;/--- ---\|final output/ q;p" $3 | sponge $filename + +# +sed '1i data=[' $filename | sponge $filename + +# Remove character '(' from complex number +sed 's/(//g' $filename | sponge $filename + +# Replace character ',' with '+' in complex number +sed 's/,/+/g' $filename | sponge $filename + +# Remove new lines +tr '\n' ' ' < $filename | sponge $filename + +# Replace character ')' with 'i;' +sed 's/)/i;\n/g' $filename | sponge $filename + +# Append '];' to the end of the file +sed '$a];' $filename | sponge $filename + +# +if [ $1 -eq 1 ]; then + sed -i "1s/^/function data = rocfft_input_data()\n/" $filename +elif [ $1 -eq 0 ]; then + sed -i "1s/^/function data = rocfft_output_data()\n/" $filename +fi diff --git a/scripts/partial-pass/run_test.m b/scripts/partial-pass/run_test.m new file mode 100644 index 00000000000..532b316e836 --- /dev/null +++ b/scripts/partial-pass/run_test.m @@ -0,0 +1,23 @@ +function run_test() + +length = load("-ascii", "in_len.txt"); + +batch = load("-ascii", "in_batch.txt"); + +pp_dim = load("-ascii", "in_pp_dim.txt"); + +pp_radices = load("-ascii", "in_pp_radices.txt"); + +fid = fopen("in_test_mode.txt", 'r'); +test_mode = textscan(fid, '%s', 'delimiter', '\n'); +test_mode = cellstr(test_mode); +fclose(fid); + +in_batched = rocfft_input_data(); + +out_batched = rocfft_output_data(); + +partial_pass_3d(length, batch, pp_dim, pp_radices, in_batched, out_batched, test_mode); + +delete('rocfft_input_data.m'); +delete('rocfft_output_data.m'); \ No newline at end of file diff --git a/scripts/partial-pass/run_test.sh b/scripts/partial-pass/run_test.sh new file mode 100755 index 00000000000..49ae997fc20 --- /dev/null +++ b/scripts/partial-pass/run_test.sh @@ -0,0 +1,72 @@ +#!/bin/bash + + test_mode_1="full-3d" + test_mode_2="direction_1" + test_mode_3="direction_2" + test_mode_4="step_1_2" + test_mode_5="step_3_4" + test_mode_6="direction_1_step_1_2" + test_mode_7="direction_2_step_3_4" + +# ------------------------------------------------------------------- +# input parameters +# ------------------------------------------------------------------- + +length=( 64 64 64 ) +batch=( 1 ) +pp_dim=( 2 ) +pp_radices=( 4 16 ) +test_mode=$test_mode_1 +# ------------------------------------------------------------------- + +in_len_file="in_len.txt" +in_batch_file="in_batch.txt" +in_pp_dim_file="in_pp_dim.txt" +in_pp_radices_file="in_pp_radices.txt" +in_test_mode_file="in_test_mode.txt" +rocfft_input_data_file="rocfft_input_data.m" +rocfft_output_data_file="rocfft_output_data.m" +# ------------------------------------------------------------------- + +echo ${length[@]} > $in_len_file +echo ${batch[@]} > $in_batch_file +echo ${pp_dim[@]} > $in_pp_dim_file +echo ${pp_radices[@]} > $in_pp_radices_file +echo ${test_mode} > $in_test_mode_file + +# =================================================================== +rocfft_script_dir=$(pwd) +rofft_dir=$(pwd)/../.. +rocfft_exec_dir=${rofft_dir}/build/clients/staging/ + +cd $rocfft_exec_dir + +ROCFFT_LAYER=16 ./rocfft-bench --precision double --length ${length[0]} ${length[1]} ${length[2]} -b ${batch[0]} &> out.txt + +cd $rocfft_script_dir + +if [ $test_mode = $test_mode_1 ]; then + buffer_arg_1=0 + buffer_arg_2=4 +elif [ $test_mode = $test_mode_6 ]; then + buffer_arg_1=0 + buffer_arg_2=1 +fi + +./rocfft_to_octave.sh 1 $buffer_arg_1 ${rocfft_exec_dir}out.txt +./rocfft_to_octave.sh 0 $buffer_arg_2 ${rocfft_exec_dir}out.txt + +rm $rocfft_exec_dir/out.txt + +octave -W run_test.m + +# =================================================================== + +rm $in_len_file +rm $in_batch_file +rm $in_pp_dim_file +rm $in_pp_radices_file +rm $in_test_mode_file + + + diff --git a/shared/printbuffer.h b/shared/printbuffer.h index 5ae0b64fbb4..2d6d0679787 100644 --- a/shared/printbuffer.h +++ b/shared/printbuffer.h @@ -24,6 +24,7 @@ #include "hostbuf.h" #include "increment.h" #include +#include #include // Output a formatted general-dimensional array with given length and stride in batches @@ -46,7 +47,7 @@ inline void printbuffer(const Toutput* output, { const int i = std::inner_product(index.begin(), index.end(), stride.begin(), i_base + offset); - stream << output[i] << " "; + stream << std::fixed << std::setprecision(14) << output[i] << " "; for(int li = index.size(); li-- > 0;) { if(index[li] == (length[li] - 1)) From f6a26a0687f58f19372d2cdfbb22b171505c1723 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 7 Mar 2025 17:15:21 -0700 Subject: [PATCH 16/69] - Fixes to test scripts --- scripts/partial-pass/dft_matrix.m | 12 ++++ scripts/partial-pass/partial_pass_3d.m | 77 ++++++++++++++++---------- scripts/partial-pass/run_test.sh | 31 +++++++---- 3 files changed, 80 insertions(+), 40 deletions(-) create mode 100644 scripts/partial-pass/dft_matrix.m diff --git a/scripts/partial-pass/dft_matrix.m b/scripts/partial-pass/dft_matrix.m new file mode 100644 index 00000000000..bbea0c56294 --- /dev/null +++ b/scripts/partial-pass/dft_matrix.m @@ -0,0 +1,12 @@ +function F = dft_matrix(n) + +F = zeros(n,n); +omega_n = exp(-2*pi*j/n); +for i=1:n + for j=1:n + F(i,j) = (omega_n^((i-1)*(j-1))); + endfor +endfor + + + diff --git a/scripts/partial-pass/partial_pass_3d.m b/scripts/partial-pass/partial_pass_3d.m index 8bfe90851e1..84031f7ed53 100644 --- a/scripts/partial-pass/partial_pass_3d.m +++ b/scripts/partial-pass/partial_pass_3d.m @@ -2,15 +2,11 @@ function partial_pass_3d(in_length, nbatch, pp_dim, pp_radices, in_batched, out_ test_mode_1 = 'full-3d'; test_mode_2 = 'direction_1'; - test_mode_3 = 'direction_2'; - test_mode_4 = 'step_1_2'; - test_mode_5 = 'step_3_4'; - test_mode_6 = 'direction_1_step_1_2'; - test_mode_7 = 'direction_2_step_3_4'; - - if ~(strcmp(test_mode,test_mode_1) || strcmp(test_mode,test_mode_2) || strcmp(test_mode,test_mode_3)... - || strcmp(test_mode,test_mode_4) || strcmp(test_mode,test_mode_5) || strcmp(test_mode,test_mode_6)... - || strcmp(test_mode,test_mode_7)) + test_mode_3 = 'direction_1_step_1_2'; + test_mode_4 = 'direction_1_step_1_2_3_4'; + + if ~(strcmp(test_mode,test_mode_1) || strcmp(test_mode,test_mode_2) || ... + strcmp(test_mode,test_mode_3) || strcmp(test_mode,test_mode_4)) display(test_mode); error('Invalid test mode'); endif @@ -18,7 +14,7 @@ function partial_pass_3d(in_length, nbatch, pp_dim, pp_radices, in_batched, out_ format longG; ordering='column-major'; data_empty_value = -123456789; - + N = prod(in_length); pp_mode = 'four-step'; @@ -37,30 +33,49 @@ function partial_pass_3d(in_length, nbatch, pp_dim, pp_radices, in_batched, out_ endif % 3D-FFT (MATLAB built-in) + out = fftn(in); + out = convert_3d_to_1d(out, ordering); + out_ = convert_3d_to_1d(out_, ordering); + if (strcmp(test_mode,test_mode_1)) - out = fftn(in); - out = convert_3d_to_1d(out, ordering); - out_ = convert_3d_to_1d(out_, ordering); linf_rocfft_vs_octave_built_in = norm(out-out_,'inf'); disp(['l-inf norm: ' num2str(linf_rocfft_vs_octave_built_in)]); - return; + else + % CS_3D_RC from rocFFT (with partial pass) + [out_3d_rc, out_3d_rc_1, out_3d_rc_pp_1, out_3d_rc_pp_2] = run_CS_3D_RC(in_length, in, pp_dim, pp_radices, pp_mode); + + out_3d_rc = convert_3d_to_1d(out_3d_rc, ordering); + linf_test = norm(out_3d_rc-out,'inf'); + if (linf_test > 1E-8) + error("Error: partial-pass 3D-RC failed accuracy test"); + endif + + if (strcmp(test_mode,test_mode_2)) + out_3d_rc_1 = convert_3d_to_1d(out_3d_rc_1, ordering); + linf_test = norm(out_3d_rc_1-out_,'inf'); + disp(['l-inf norm: ' num2str(linf_test)]); + endif + + if (strcmp(test_mode,test_mode_3)) + out_3d_rc_pp_1 = convert_3d_to_1d(out_3d_rc_pp_1, ordering); + linf_test = norm(out_3d_rc_pp_1-out_,'inf'); + disp(['l-inf norm: ' num2str(linf_test)]); + endif + + if (strcmp(test_mode,test_mode_4)) + out_3d_rc_pp_2 = convert_3d_to_1d(out_3d_rc_pp_2, ordering); + linf_test = norm(out_3d_rc_pp_2-out_,'inf'); + disp(['l-inf norm: ' num2str(linf_test)]); + endif endif - - % CS_3D_RC from rocFFT (with partial pass) - [out_3d_rc, out_3d_rc_pp_1, out_3d_rc_pp_2] = run_CS_3D_RC(in_length, in, pp_dim, pp_radices, pp_mode); - out_3d_rc = convert_3d_to_1d(out_3d_rc, ordering); - out_3d_rc_pp_1 = convert_3d_to_1d(out_3d_rc_pp_1, ordering); - out_3d_rc_pp_2 = convert_3d_to_1d(out_3d_rc_pp_2, ordering); - out_ = convert_3d_to_1d(out_, ordering); - - linf_rocfft_vs_octave_3d_rc_pp = norm(out_3d_rc-out_,'inf'); - disp(['l-inf norm: ' num2str(linf_rocfft_vs_octave_3d_rc_pp)]); endfor - display(''); - - function [out, out_pp_1, out_pp_2] = run_CS_3D_RC(in_length, in, pp_dim, pp_radices, pp_mode) + function [out, out_1, out_pp_1, out_pp_2] = run_CS_3D_RC(in_length, in, pp_dim, pp_radices, pp_mode) n = in_length(pp_dim); + + % Flip radices, as the radix order is reversed in steps 1-2 and 3-4 + pp_radices = flip(pp_radices); + n1 = pp_radices(1); n2 = pp_radices(2); @@ -73,6 +88,7 @@ function partial_pass_3d(in_length, nbatch, pp_dim, pp_radices, in_batched, out_ if (pp_dim == 1) % 1st kernel (2nd dimension) out = fft(out,[], 2); + out_1 = out; out = partial_pass_step_1_2(out, 1, n1, n2, F_n1, F_n2, F_n, pp_mode); out_pp_1 = out; @@ -83,10 +99,14 @@ function partial_pass_3d(in_length, nbatch, pp_dim, pp_radices, in_batched, out_ endif if (pp_dim == 2) + % Correct ordering for intermediate results comparison + transp_order_comp = [3 1 2]; + % 1st kernel (1st dimension) out = fft(out,[], 1); + out_1 = permute(out, transp_order_comp); out = partial_pass_step_1_2(out, 2, n1, n2, F_n1, F_n2, F_n, pp_mode); - out_pp_1 = out; + out_pp_1 = permute(out, transp_order_comp); % 2nd kernel (3rd dimension) transp_order = [3 2 1]; @@ -103,6 +123,7 @@ function partial_pass_3d(in_length, nbatch, pp_dim, pp_radices, in_batched, out_ if (pp_dim == 3) % 1st kernel (1st dimension) out = fft(out,[], 1); + out_1 = out; out = partial_pass_step_1_2(out, 3, n1, n2, F_n1, F_n2, F_n, pp_mode); out_pp_1 = out; diff --git a/scripts/partial-pass/run_test.sh b/scripts/partial-pass/run_test.sh index 49ae997fc20..99621886343 100755 --- a/scripts/partial-pass/run_test.sh +++ b/scripts/partial-pass/run_test.sh @@ -1,22 +1,23 @@ #!/bin/bash - test_mode_1="full-3d" - test_mode_2="direction_1" - test_mode_3="direction_2" - test_mode_4="step_1_2" - test_mode_5="step_3_4" - test_mode_6="direction_1_step_1_2" - test_mode_7="direction_2_step_3_4" +test_mode_1='full-3d' + +# requires changing code generator to skip steps 1-2 +test_mode_2='direction_1' + +test_mode_3='direction_1_step_1_2' + +test_mode_4='direction_1_step_1_2_3_4' # ------------------------------------------------------------------- # input parameters # ------------------------------------------------------------------- length=( 64 64 64 ) -batch=( 1 ) +batch=( 5 ) pp_dim=( 2 ) -pp_radices=( 4 16 ) -test_mode=$test_mode_1 +pp_radices=( 16 4 ) +test_mode=$test_mode_2 # ------------------------------------------------------------------- in_len_file="in_len.txt" @@ -47,10 +48,16 @@ cd $rocfft_script_dir if [ $test_mode = $test_mode_1 ]; then buffer_arg_1=0 - buffer_arg_2=4 -elif [ $test_mode = $test_mode_6 ]; then + buffer_arg_2=2 +elif [ $test_mode = $test_mode_2 ]; then buffer_arg_1=0 buffer_arg_2=1 +elif [ $test_mode = $test_mode_3 ]; then + buffer_arg_1=0 + buffer_arg_2=1 +elif [ $test_mode = $test_mode_4 ]; then + buffer_arg_1=0 + buffer_arg_2=2 fi ./rocfft_to_octave.sh 1 $buffer_arg_1 ${rocfft_exec_dir}out.txt From 64710e78d4f0a177903ddd5bdd4aea72bb740fa0 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 10 Apr 2025 16:49:24 -0600 Subject: [PATCH 17/69] WIP --- library/src/device/generator/generator.h | 29 +++++++++++++++- library/src/device/generator/stockham_gen.cpp | 4 +-- library/src/device/generator/stockham_gen.h | 3 -- .../src/device/generator/stockham_pp_gen_cc.h | 12 ++++--- .../src/device/generator/stockham_pp_gen_rr.h | 33 ++++++++++--------- library/src/include/rtc_stockham_gen.h | 2 ++ library/src/include/tree_node.h | 3 ++ library/src/rocfft_aot_helper.cpp | 19 +++++++---- library/src/rocfft_kernel_config_search.cpp | 3 +- library/src/rtc_stockham_gen.cpp | 8 +++-- library/src/rtc_stockham_kernel.cpp | 9 ++--- library/src/tree_node_1D.cpp | 10 ++---- library/src/tree_node_3D.cpp | 2 ++ scripts/partial-pass/partial_pass_3d.m | 20 +++++++++-- scripts/partial-pass/run_test.sh | 12 +++++-- 15 files changed, 112 insertions(+), 57 deletions(-) mode change 100755 => 100644 scripts/partial-pass/run_test.sh diff --git a/library/src/device/generator/generator.h b/library/src/device/generator/generator.h index a4d80ee83fd..eb2fd15399a 100644 --- a/library/src/device/generator/generator.h +++ b/library/src/device/generator/generator.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -628,6 +629,7 @@ class Butterfly; class IntrinsicStore; class IntrinsicStorePlanar; class IntrinsicLoadToDest; +class Printf; struct LineBreak { @@ -704,7 +706,8 @@ using Statement = std::variant; + IntrinsicLoadToDest, + Printf>; class Assign { @@ -1096,6 +1099,28 @@ class IntrinsicLoadToDest Expression rw_flag; }; +class Printf +{ +public: + const char* fmt; + std::vector args; + + Printf(const char* format, const std::vector& arguments) + : fmt(format) + , args(arguments){}; + + std::string render() const + { + auto fmt_render = std::string(fmt); + fmt_render = "\"" + std::regex_replace(fmt_render, std::regex(R"(\n)"), "\\n") + "\""; + + auto args_render = args; + args_render.insert(args_render.begin(), Literal(fmt_render)); + + return Call{"printf", args_render}.render(); + } +}; + // end of Statement class declarations static void operator+=(StatementList& stmts, const Statement& s) @@ -1225,6 +1250,7 @@ struct BaseVisitor MAKE_VISITOR_OPERATOR(StatementList, IntrinsicStore); MAKE_VISITOR_OPERATOR(StatementList, IntrinsicStorePlanar); MAKE_VISITOR_OPERATOR(StatementList, IntrinsicLoadToDest); + MAKE_VISITOR_OPERATOR(StatementList, Printf); MAKE_VISITOR_OPERATOR(ArgumentList, ArgumentList); @@ -1327,6 +1353,7 @@ struct BaseVisitor MAKE_TRIVIAL_STATEMENT_VISIT(SyncThreads) MAKE_TRIVIAL_STATEMENT_VISIT(Butterfly); MAKE_TRIVIAL_STATEMENT_VISIT(IntrinsicLoadToDest); + MAKE_TRIVIAL_STATEMENT_VISIT(Printf); MAKE_TRIVIAL_VISIT(Expression, Variable) diff --git a/library/src/device/generator/stockham_gen.cpp b/library/src/device/generator/stockham_gen.cpp index 1ea6da29f09..3cba76fc9eb 100644 --- a/library/src/device/generator/stockham_gen.cpp +++ b/library/src/device/generator/stockham_gen.cpp @@ -325,14 +325,14 @@ int main() ++arg; factors = parse_uints_csv(*arg); - StockhamGeneratorSpecs specs(factors, {}, factors2d, precisions, workgroup_size, scheme); + StockhamGeneratorSpecs specs(factors, factors2d, precisions, workgroup_size, scheme); specs.half_lds = half_lds; specs.direct_to_from_reg = direct_to_from_reg; specs.threads_per_transform = threads_per_transform.front(); // second dimension for 2D_SINGLE - StockhamGeneratorSpecs specs2d(factors2d, {}, factors, precisions, workgroup_size, scheme); + StockhamGeneratorSpecs specs2d(factors2d, factors, precisions, workgroup_size, scheme); if(!threads_per_transform.empty()) specs2d.threads_per_transform = threads_per_transform.back(); diff --git a/library/src/device/generator/stockham_gen.h b/library/src/device/generator/stockham_gen.h index 62497ec002d..a0fb18204b1 100644 --- a/library/src/device/generator/stockham_gen.h +++ b/library/src/device/generator/stockham_gen.h @@ -30,13 +30,11 @@ struct StockhamGeneratorSpecs { StockhamGeneratorSpecs(const std::vector& factors, - const std::vector& factors_pp, const std::vector& factors2d, const std::vector& precisions, unsigned int workgroup_size, const std::string& scheme) : factors(factors) - , factors_pp(factors_pp) , factors2d(factors2d) , precisions(precisions) , length(product(factors.begin(), factors.end())) @@ -47,7 +45,6 @@ struct StockhamGeneratorSpecs } std::vector factors; - std::vector factors_pp; std::vector factors2d; std::vector precisions; // mapped from rocfft_precision unsigned int length; diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index 947cdf77a84..e9fa37c3391 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -31,8 +31,11 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC { explicit StockhamPartialPassKernelCC(const StockhamGeneratorSpecs& specs, - bool largeTwdBatchIsTransformCount) + bool largeTwdBatchIsTransformCount, + const std::vector& ppFactors) : StockhamKernelCC(specs, largeTwdBatchIsTransformCount, false) + , factors_pp(ppFactors) + { large_twiddle_steps.decl_default = 3; large_twiddle_base.decl_default = 8; @@ -40,14 +43,15 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC // TODO: Address and test all "lds_linear=false" cases // TODO: revisit this. Test with factors_pp.size() > 1 - max_factor_pp = *std::max_element(specs.factors_pp.begin(), specs.factors_pp.end()); + max_factor_pp = *std::max_element(factors_pp.begin(), factors_pp.end()); // TODO: transforms_per_block_pp or threads_per_transform? Revisit all usages transforms_per_block_pp = transforms_per_block / max_factor_pp; } - unsigned int transforms_per_block_pp; - unsigned int max_factor_pp; + unsigned int transforms_per_block_pp; + unsigned int max_factor_pp; + std::vector factors_pp; Variable thread_lds{"thread_lds", "unsigned int"}; Variable idx_lds{"idx_lds", "unsigned int"}; diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index 50526ea4c5f..cb964f83a6d 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -25,22 +25,26 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR { - explicit StockhamPartialPassKernelRR(const StockhamGeneratorSpecs& specs) + explicit StockhamPartialPassKernelRR(const StockhamGeneratorSpecs& specs, + const std::vector& ppFactors, + const size_t ppLength) : StockhamKernelRR(specs) + , factors_pp(ppFactors) + , length_pp(ppLength) { // TODO: revisit this. Test with factors_pp.size() > 1 - max_factor_pp = *std::max_element(specs.factors_pp.begin(), specs.factors_pp.end()); - prod_factors_pp = std::accumulate( - specs.factors_pp.begin(), specs.factors_pp.end(), 1, std::multiplies()); + max_factor_pp = *std::max_element(factors_pp.begin(), factors_pp.end()); R.size = Expression{std::max(nregisters, max_factor_pp)}; } - unsigned int max_factor_pp, prod_factors_pp; - Variable offset_pp{"offset_pp", "size_t"}; - Variable stride_lds_pp{"stride_lds_pp", "size_t"}; - Variable offset_lds_pp{"offset_lds_pp", "size_t"}; + unsigned int max_factor_pp; + std::vector factors_pp; + unsigned int length_pp; + Variable offset_pp{"offset_pp", "size_t"}; + Variable stride_lds_pp{"stride_lds_pp", "size_t"}; + Variable offset_lds_pp{"offset_lds_pp", "size_t"}; Variable twiddles_pp{"twiddles_pp", "const scalar_type", true, true}; StatementList calculate_offsets() override @@ -60,9 +64,8 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR block_id * transforms_per_block + thread_id / threads_per_transform}; stmts += Assign{remaining, transform}; stmts += Assign{remaining_pp, - length * Parens(transform / length) - + Parens(transform % length) / max_factor_pp - + Parens(transform * (length / max_factor_pp)) % length}; + 64 * Parens(transform / 64) + Parens(transform % 64) / transforms_per_block + + Parens(transform * (64 / transforms_per_block)) % 64}; stmts += For{d, 1, @@ -280,7 +283,7 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR Function generate_twiddle_multiply_pp_function(int direction) { std::string function_name - = "twiddle_multiply_pp_length" + std::to_string(length) + "_device"; + = "twiddle_multiply_pp_length" + std::to_string(length_pp) + "_device"; Function f{function_name}; @@ -299,7 +302,7 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR for(unsigned int w = 0; w < max_factor_pp; ++w) { - body += Assign{W, twiddles_pp[thread * length + w]}; + body += Assign{W, twiddles_pp[thread * Literal(length_pp) + w]}; if(direction == -1) body += Assign{t, TwiddleMultiply{R[w], W}}; @@ -353,9 +356,9 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR TemplateList pre_twd_mul_tmpl = TemplateList{scalar_type}; std::vector pre_twd_mul_args - = {R, block_id % (length / max_factor_pp), twiddles_pp}; + = {R, block_id % (length_pp / max_factor_pp), twiddles_pp}; StatementList twdMul; - twdMul += Call{"twiddle_multiply_pp_length" + std::to_string(length) + "_device", + twdMul += Call{"twiddle_multiply_pp_length" + std::to_string(length_pp) + "_device", pre_twd_mul_tmpl, pre_twd_mul_args}; diff --git a/library/src/include/rtc_stockham_gen.h b/library/src/include/rtc_stockham_gen.h index 5b82ca17f4b..9febc778aae 100644 --- a/library/src/include/rtc_stockham_gen.h +++ b/library/src/include/rtc_stockham_gen.h @@ -78,6 +78,8 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, CallbackType cbtype, const BluesteinFuseType& fuseBlue, const PartialPassType& ppType, + const std::vector& ppFactors, + const size_t ppLength, const LoadOps& loadOps, const StoreOps& storeOps); diff --git a/library/src/include/tree_node.h b/library/src/include/tree_node.h index d26552e1f5f..264505e2a59 100644 --- a/library/src/include/tree_node.h +++ b/library/src/include/tree_node.h @@ -379,6 +379,9 @@ class TreeNode // enables partial pass for this node bool applyPartialPass = false; + // Dimension of the FFT where partial-pass is applied + size_t ppDim = 0; + // BluesteinType typeBlue = BluesteinType::BT_NONE; BluesteinFuseType fuseBlue = BluesteinFuseType::BFT_NONE; diff --git a/library/src/rocfft_aot_helper.cpp b/library/src/rocfft_aot_helper.cpp index 0824685495e..da041f1a8b1 100644 --- a/library/src/rocfft_aot_helper.cpp +++ b/library/src/rocfft_aot_helper.cpp @@ -226,8 +226,10 @@ void build_stockham_function_pool(CompileQueue& queue) function_pool& fp = function_pool::get_function_pool(); // fused Bluestein and partial-pass kernels are always built at runtime - auto fuseBlue = BluesteinFuseType::BFT_NONE; - auto ppType = PartialPassType::PPT_NONE; + auto fuseBlue = BluesteinFuseType::BFT_NONE; + auto ppType = PartialPassType::PPT_NONE; + auto ppFactors = std::vector{}; + auto ppLength = 0; for(const auto& i : fp.get_map()) { @@ -243,7 +245,6 @@ void build_stockham_function_pool(CompileQueue& queue) std::copy(i.second.factors.begin(), i.second.factors.end(), std::back_inserter(factors)); StockhamGeneratorSpecs specs{factors, - {}, {}, {static_cast(precision)}, static_cast(i.second.workgroup_size), @@ -306,7 +307,6 @@ void build_stockham_function_pool(CompileQueue& queue) std::function generate_src = [=](const std::string& kernel_name) -> std::string { StockhamGeneratorSpecs specs{factors, - {}, {}, {static_cast(precision)}, static_cast(i.second.workgroup_size), @@ -335,6 +335,8 @@ void build_stockham_function_pool(CompileQueue& queue) cbtype, fuseBlue, ppType, + ppFactors, + ppLength, {}, {}); }; @@ -622,8 +624,10 @@ void build_solution_kernels(CompileQueue& queue) solmap.get_all_kernels(kernel_nodes, true); // fused Bluestein and partial-pass kernels are always built at runtime - auto fuseBlue = BluesteinFuseType::BFT_NONE; - auto ppType = PartialPassType::PPT_NONE; + auto fuseBlue = BluesteinFuseType::BFT_NONE; + auto ppType = PartialPassType::PPT_NONE; + auto ppFactors = std::vector{}; + auto ppLength = 0; for(const SolutionNode& kernel_sol : kernel_nodes) { @@ -669,7 +673,6 @@ void build_solution_kernels(CompileQueue& queue) } StockhamGeneratorSpecs specs{factors, - {}, {}, {static_cast(precision)}, static_cast(config.workgroup_size), @@ -727,6 +730,8 @@ void build_solution_kernels(CompileQueue& queue) cbtype, fuseBlue, ppType, + ppFactors, + ppLength, {}, {}); }; diff --git a/library/src/rocfft_kernel_config_search.cpp b/library/src/rocfft_kernel_config_search.cpp index 5fb5935db09..64f6349fe0f 100644 --- a/library/src/rocfft_kernel_config_search.cpp +++ b/library/src/rocfft_kernel_config_search.cpp @@ -157,7 +157,6 @@ std::string test_kernel_src(const std::string& kernel_name, bool direct_to_from_reg) { StockhamGeneratorSpecs specs{factorization, - {}, {}, {static_cast(rocfft_precision_single)}, wgs, @@ -189,6 +188,8 @@ std::string test_kernel_src(const std::string& kernel_name, BluesteinFuseType::BFT_NONE, PartialPassType::PPT_NONE, {}, + 0, + {}, {}); } diff --git a/library/src/rtc_stockham_gen.cpp b/library/src/rtc_stockham_gen.cpp index 13665be62a4..f82106d5816 100644 --- a/library/src/rtc_stockham_gen.cpp +++ b/library/src/rtc_stockham_gen.cpp @@ -269,6 +269,8 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, CallbackType cbtype, const BluesteinFuseType& fuseBlue, const PartialPassType& ppType, + const std::vector& ppFactors, + const size_t ppLength, const LoadOps& loadOps, const StoreOps& storeOps) { @@ -314,7 +316,7 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, if(scheme == CS_KERNEL_STOCKHAM) { if(ppType == PartialPassType::PPT_SBRR) - kernel = std::make_unique(specs); + kernel = std::make_unique(specs, ppFactors, ppLength); else kernel = std::make_unique(specs); } @@ -322,7 +324,7 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, { if(ppType == PartialPassType::PPT_SBCC) kernel = std::make_unique( - specs, largeTwdBatchIsTransformCount); + specs, largeTwdBatchIsTransformCount, ppFactors); else kernel = std::make_unique( specs, largeTwdBatchIsTransformCount, fuseBluestein); @@ -409,7 +411,7 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, all_factors = kernel->factors; if(ppType != PPT_NONE) - all_factors.insert(all_factors.end(), specs.factors_pp.begin(), specs.factors_pp.end()); + all_factors.insert(all_factors.end(), ppFactors.begin(), ppFactors.end()); } // generated functions default to forward in-place interleaved. diff --git a/library/src/rtc_stockham_kernel.cpp b/library/src/rtc_stockham_kernel.cpp index 8274d42f259..8eaf3be366b 100644 --- a/library/src/rtc_stockham_kernel.cpp +++ b/library/src/rtc_stockham_kernel.cpp @@ -38,10 +38,6 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& std::optional specs; std::optional specs2d; - std::vector factors_pp; - std::copy( - node.kernelFactorsPP.begin(), node.kernelFactorsPP.end(), std::back_inserter(factors_pp)); - // SBRC variants look in the function pool for plain BLOCK_RC to // learn the block width, then decide on the transpose type once // that's known. @@ -91,7 +87,6 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& std::vector precisions = {static_cast(node.precision)}; specs.emplace(factors, - factors_pp, std::vector(), precisions, static_cast(kernel->workgroup_size), @@ -130,7 +125,6 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& } specs.emplace(factors1d, - factors_pp, factors2d, precisions, static_cast(kernel->workgroup_size), @@ -139,7 +133,6 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& specs->half_lds = kernel->half_lds; specs2d.emplace(factors2d, - factors_pp, factors1d, precisions, static_cast(kernel->workgroup_size), @@ -232,6 +225,8 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& node.GetCallbackType(enable_callbacks), node.fuseBlue, ppType, + node.kernelFactorsPP, + node.length[node.ppDim], node.loadOps, node.storeOps); }; diff --git a/library/src/tree_node_1D.cpp b/library/src/tree_node_1D.cpp index e4e953ba541..de7aea7d438 100644 --- a/library/src/tree_node_1D.cpp +++ b/library/src/tree_node_1D.cpp @@ -937,15 +937,9 @@ bool Stockham1DNode::CreateDeviceResources() { if(applyPartialPass) { - // handles partial pass 64 x 64 x 64 case. - // current dimension y is the dimension to split - // into x and z. - - // Create twiddle table for partial pass along y - size_t pp_dim = 1; - + // Create twiddle table for partial pass along ppDim std::tie(twiddles_pp, twiddles_pp_size) - = Repo::GetTwiddlesPP(length[pp_dim], precision, deviceProp); + = Repo::GetTwiddlesPP(length[ppDim], precision, deviceProp); } twd_attach_halfN = (ebtype != EmbeddedType::NONE); diff --git a/library/src/tree_node_3D.cpp b/library/src/tree_node_3D.cpp index 767bc5e2308..dfddb0e5548 100644 --- a/library/src/tree_node_3D.cpp +++ b/library/src/tree_node_3D.cpp @@ -666,6 +666,7 @@ void RC3DNode::BuildTree_internal(SchemeTreeVec& child_scheme_trees) xPartialPassPlan = NodeFactory::CreateNodeFromScheme(CS_KERNEL_STOCKHAM, this); xPartialPassPlan->length = xPartialPassPlanData.length; xPartialPassPlan->dimension = 1; + xPartialPassPlan->ppDim = 1; xPartialPassPlan->allowInplace = true; xPartialPassPlan->comments.push_back("partial-pass enabled for second dimension."); @@ -689,6 +690,7 @@ void RC3DNode::BuildTree_internal(SchemeTreeVec& child_scheme_trees) zPartialPassPlan = NodeFactory::CreateNodeFromScheme(CS_KERNEL_STOCKHAM_BLOCK_CC, this); zPartialPassPlan->length = zPartialPassPlanData.length; zPartialPassPlan->dimension = 1; + zPartialPassPlan->ppDim = 1; zPartialPassPlan->allowInplace = false; zPartialPassPlan->comments.push_back("partial-pass enabled for second dimension."); diff --git a/scripts/partial-pass/partial_pass_3d.m b/scripts/partial-pass/partial_pass_3d.m index 84031f7ed53..f08d46cbea6 100644 --- a/scripts/partial-pass/partial_pass_3d.m +++ b/scripts/partial-pass/partial_pass_3d.m @@ -1,12 +1,14 @@ function partial_pass_3d(in_length, nbatch, pp_dim, pp_radices, in_batched, out_batched, test_mode) + test_mode_0 = 'input'; test_mode_1 = 'full-3d'; test_mode_2 = 'direction_1'; test_mode_3 = 'direction_1_step_1_2'; test_mode_4 = 'direction_1_step_1_2_3_4'; - if ~(strcmp(test_mode,test_mode_1) || strcmp(test_mode,test_mode_2) || ... - strcmp(test_mode,test_mode_3) || strcmp(test_mode,test_mode_4)) + if ~(strcmp(test_mode,test_mode_0) || strcmp(test_mode,test_mode_1) || ... + strcmp(test_mode,test_mode_2) || strcmp(test_mode,test_mode_3) || ... + strcmp(test_mode,test_mode_4)) display(test_mode); error('Invalid test mode'); endif @@ -32,6 +34,18 @@ function partial_pass_3d(in_length, nbatch, pp_dim, pp_radices, in_batched, out_ error('Error: incomplete data'); endif + if (strcmp(test_mode,test_mode_0)) + in = convert_3d_to_1d(in, ordering); + out_ = convert_3d_to_1d(out_, ordering); + + in = sort(in); + out_ = sort(out_); + + linf_rocfft_vs_octave_built_in = norm(in-out_,'inf'); + disp(['l-inf norm: ' num2str(linf_rocfft_vs_octave_built_in)]); + return; + endif + % 3D-FFT (MATLAB built-in) out = fftn(in); out = convert_3d_to_1d(out, ordering); @@ -46,7 +60,7 @@ function partial_pass_3d(in_length, nbatch, pp_dim, pp_radices, in_batched, out_ out_3d_rc = convert_3d_to_1d(out_3d_rc, ordering); linf_test = norm(out_3d_rc-out,'inf'); - if (linf_test > 1E-8) + if (linf_test > 1E-5) error("Error: partial-pass 3D-RC failed accuracy test"); endif diff --git a/scripts/partial-pass/run_test.sh b/scripts/partial-pass/run_test.sh old mode 100755 new mode 100644 index 99621886343..ec9a4fb0529 --- a/scripts/partial-pass/run_test.sh +++ b/scripts/partial-pass/run_test.sh @@ -1,5 +1,8 @@ #!/bin/bash +# requires changing code generator to skip full pass and steps 1-2 +test_mode_0='input' + test_mode_1='full-3d' # requires changing code generator to skip steps 1-2 @@ -13,11 +16,11 @@ test_mode_4='direction_1_step_1_2_3_4' # input parameters # ------------------------------------------------------------------- -length=( 64 64 64 ) +length=( 64 64 128 ) batch=( 5 ) pp_dim=( 2 ) pp_radices=( 16 4 ) -test_mode=$test_mode_2 +test_mode=$test_mode_3 # ------------------------------------------------------------------- in_len_file="in_len.txt" @@ -46,7 +49,10 @@ ROCFFT_LAYER=16 ./rocfft-bench --precision double --length ${length[0]} ${length cd $rocfft_script_dir -if [ $test_mode = $test_mode_1 ]; then +if [ $test_mode = $test_mode_0 ]; then + buffer_arg_1=0 + buffer_arg_2=1 +elif [ $test_mode = $test_mode_1 ]; then buffer_arg_1=0 buffer_arg_2=2 elif [ $test_mode = $test_mode_2 ]; then From de36c9fdd1110a580985195b389f7b0b08fb959b Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Wed, 23 Apr 2025 10:53:18 -0600 Subject: [PATCH 18/69] - Resolve merge conflict. --- library/src/device/generator/stockham_pp_gen_rr.h | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index cb964f83a6d..fdb212d92f3 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -378,10 +378,9 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR ArgumentList global_arguments() override { - auto arguments - = static_dim - ? ArgumentList{twiddles_pp, twiddles, lengths, stride, nbatch, lds_padding} - : ArgumentList{twiddles_pp, twiddles, dim, lengths, stride, nbatch, lds_padding}; + auto arguments = static_dim + ? ArgumentList{twiddles_pp, twiddles, lengths, stride, nbatch} + : ArgumentList{twiddles_pp, twiddles, dim, lengths, stride, nbatch}; for(const auto& arg : get_callback_args().arguments) arguments.append(arg); arguments.append(buf); From b368f0bd1abcc85ef6a39cad88aca7e2536d9108 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Wed, 23 Apr 2025 14:40:30 -0600 Subject: [PATCH 19/69] - Clean up. --- .../src/device/generator/stockham_pp_gen_cc.h | 59 ++++--------------- .../src/device/generator/stockham_pp_gen_rr.h | 18 ++---- library/src/tree_node.cpp | 12 +--- library/src/tree_node_3D.cpp | 4 +- 4 files changed, 21 insertions(+), 72 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index e9fa37c3391..5c04f3ba89d 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -21,13 +21,14 @@ #pragma once #include "stockham_gen_cc.h" -// TODO: - Kernel is not getting launched, found out why -// - Check launch bounds. -// - Implementation here used a kernel with work_group_size = 256, however, the prototype was using 64. -// Change kernel_generator.py to use 64 and fix all the issues, comparing again with the prototype. -// - Start testing with different threads_per_transform once the original configuration works. -// - Then test with other lengths and direct_from_reg=true, half_lds=true, etc. - +// TODO: Once partial pass is fully configurable in kernel-generator.py: +// - Test all "lds_linear=false" cases. +// - Test with factors_pp.size() > 1. +// - Revisit all usages of transforms_per_block_pp and threads_per_transform. +// - Different input/output strides. +// - Revisit mod 128 usage in calculate_offsets() with different input lengths, +// (logic is required to work with nbatch > 1) +// - Revisit factor 192 logic in calculate_offsets() with different input lengths struct StockhamPartialPassKernelCC : public StockhamKernelCC { explicit StockhamPartialPassKernelCC(const StockhamGeneratorSpecs& specs, @@ -40,12 +41,8 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC large_twiddle_steps.decl_default = 3; large_twiddle_base.decl_default = 8; - // TODO: Address and test all "lds_linear=false" cases - - // TODO: revisit this. Test with factors_pp.size() > 1 max_factor_pp = *std::max_element(factors_pp.begin(), factors_pp.end()); - // TODO: transforms_per_block_pp or threads_per_transform? Revisit all usages transforms_per_block_pp = transforms_per_block / max_factor_pp; } @@ -121,7 +118,6 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC StatementList work; for(unsigned int w = 0; w < width; ++w) - //TODO: lstride not used here, address to have input/output strides working work += Assign(lds_complex[offset_lds + (w * stride_lds)], R[w]); return work; @@ -161,7 +157,6 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC StatementList work; for(unsigned int w = 0; w < width; ++w) - //TODO: lstride not used here, address to have input/output strides working work += Assign(R[w], lds_complex[offset_lds + (w * stride_lds)]); return work; @@ -226,8 +221,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += Assign{num_of_tiles, (lengths[1] - 1) / transforms_per_block_pp + 1}; stmts += Assign{plength, num_of_tiles}; stmts += Assign{tile_index, block_id % num_of_tiles}; - //TODO figure out mod 128 for other lengths - // mod 128 required to work with nbatch > 1 + stmts += Assign{remaining, (block_id % 128) / num_of_tiles}; stmts += Assign{offset, tile_index * transforms_per_block_pp * stride[1]}; stmts += For{d, @@ -242,19 +236,14 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += LineBreak{}; stmts += Assign{batch, block_id / plength}; - //stmts += Assign{offset, offset + batch * stride[dim]}; if(!direct_to_from_reg) { - // TODO: figure out this branch stmts += Assign{transform, tile_index * transforms_per_block_pp + thread_id / threads_per_transform}; stmts += Assign{stride_lds, (length + get_lds_padding())}; - // TODO: figure out factor 4 for other lengths stmts += MultiplyAssign(stride_lds, Literal{max_factor_pp}); - - //stmts += Assign{offset_lds, stride_lds * (transform % transforms_per_block_pp)}; } else { @@ -268,13 +257,8 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC Ternary{lds_linear, length + get_lds_padding(), transforms_per_block_pp + get_lds_padding()}}; - // TODO: figure out factor 4 for other lengths - stmts += MultiplyAssign(stride_lds, Literal{max_factor_pp}); - // stmts += Assign{offset_lds, - // Ternary{lds_linear, - // stride_lds * (transform % transforms_per_block_pp), - // thread_id % transforms_per_block_pp}}; + stmts += MultiplyAssign(stride_lds, Literal{max_factor_pp}); } stmts += Declaration{ @@ -282,16 +266,12 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC Ternary{ Parens((tile_index + 1) * transforms_per_block_pp > lengths[1]), "false", "true"}}; - // [dim0, dim1] = [tid_ver, tid_hor] : - // each thread reads position [tid_ver, tid_hor], [tid_ver+step_height*1, tid_hor] , [tid_ver+step_height*2, tid_hor]... - // tid_ver walks the columns; tid_hor walks the rows stmts += Declaration{thread, thread_id / transforms_per_block_pp}; stmts += Declaration{tid_hor, thread_id % transforms_per_block_pp}; stmts += Declaration{thread_lds, thread_id / transforms_per_block_pp}; stmts += Declaration{tid_hor_lds, thread_id % transforms_per_block_pp}; - // TODO: figure out factor 4 here for other lengths stmts += Declaration( tid_hor_pp, thread_id % transforms_per_block_pp + length * (thread % max_factor_pp)); stmts += Declaration(thread_new, thread_id / (transforms_per_block_pp * max_factor_pp)); @@ -300,12 +280,10 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += Declaration(thread_idx, thread_id); stmts += Declaration(block_idx, block_id); - // TODO: figure out factor 192 here for other lengths stmts += Declaration( offset_pp, offset + Parens(offset / length) * Literal{192} + batch_new * stride[dim]); stmts += Declaration(offset_tid_hor, offset_pp + tid_hor_pp * stride[1]); - // TODO: figure out factor 4 here for other lengths if(!direct_to_from_reg) stmts += Assign{transform, tile_index * transforms_per_block_pp @@ -352,11 +330,9 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC "no intrinsic when load to lds. FIXME- check why use nested branch is better"}; stmts += If{in_bound, tmp_stmts}; stmts += If{Not{in_bound}, {If{pred, tmp_stmts}}}; - // stmts += Else{{If{pred, tmp_stmts}}}; // FIXME: Need to check with compiler team. } else { - // TODO: Figure out this branch StatementList intrinsic_stmts; StatementList non_intrinsic_stmts; @@ -392,7 +368,6 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += If{intrinsic_mode != "IntrinsicAccessType::DISABLE_BOTH", intrinsic_stmts}; stmts += Else{non_intrinsic_stmts}; - // stmts += Else{{If{in_bound, tmp_stmts}}}; } return stmts; @@ -448,7 +423,6 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC auto stripmine_w = transforms_per_block; auto stripmine_h = workgroup_size / stripmine_w; - // TODO: stride[1] not being handle here, address this to have output strides working auto offset_tile_wbuf = [&](unsigned int i) { return offset_tid_hor + (thread_new + i * stripmine_h) * stride0; }; @@ -468,11 +442,9 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC "no intrinsic when store from lds. FIXME- check why use nested branch is better"}; stmts += If{in_bound, tmp_stmts}; stmts += If{Not{in_bound}, {If{pred, tmp_stmts}}}; - // stmts += Else{{If{pred, tmp_stmts}}}; // FIXME: Need to check with compiler team. } else { - // TODO: figure out this branch StatementList intrinsic_stmts; StatementList non_intrinsic_stmts; @@ -508,7 +480,6 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += If{intrinsic_mode == "IntrinsicAccessType::ENABLE_BOTH", intrinsic_stmts}; stmts += Else{non_intrinsic_stmts}; - // stmts += Else{{If{in_bound, {If{pred, tmp_stmts}}}}}; } return stmts; @@ -729,7 +700,6 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC StatementList& body = f.body; - // TODO: figure out these factors for other lengths auto factor_transpose_1 = (length * length) / max_factor_pp; auto factor_transpose_2 = length * max_factor_pp; auto factor_transpose_3 = length * length; @@ -746,7 +716,6 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC + Parens(transpose_idx / factor_transpose_3) * (factor_transpose_3 - factor_transpose_1)}; - // TODO: clean-up this expression: global_idx / factor_transpose_5 * factor_transpose_5 body += Assign{transpose_idx, transpose_idx + global_idx / factor_transpose_5 * factor_transpose_5}; @@ -755,12 +724,10 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC return f; } - // TODO: Move this to a device function StatementList perform_partial_pass_step_3_4() { StatementList stmts; - // TODO: figure out factor 1 here (what happens with different in/out strides and lengths) stmts += Declaration{stride_lds_pp, Literal{1}}; stmts += Declaration{offset_lds_pp, thread_id * transforms_per_block_pp}; @@ -768,7 +735,6 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC auto pre_post_lds_args = device_lds_reg_inout_pp_device_call_arguments(); pre_post_lds_tmpl.set_value(stride_type.name, "lds_linear ? SB_UNIT : SB_NONUNIT"); - // TODO: handle direct_to_from_reg StatementList preLoad; preLoad += Call{"lds_to_reg_input_pp_step_3_4_length" + std::to_string(length) + "_device", pre_post_lds_tmpl, @@ -777,8 +743,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC for(unsigned int npass = 0; npass < factors_pp.size(); ++npass) { - unsigned int width = factors_pp[npass]; - // TODO: revisit this. Different from same function in stockham_pp_gen_rr.h + unsigned int width = factors_pp[npass]; unsigned int height = threads_per_transform / max_factor_pp; auto butterfly = std::mem_fn(&StockhamKernel::butterfly_generator); @@ -821,7 +786,6 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC auto tidx = cumheight - firstFactor + w - 1 + (width - 1) * (tid % cumheight); auto ridx = hr * width + w; - // TODO- Can try IntrinsicLoadToDest, but should not be a bottleneck work += Assign(W, twiddles[tidx]); work += Assign(t, TwiddleMultiply(R[ridx], W)); work += Assign(R[ridx], t); @@ -1031,7 +995,6 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC body += Else{loadlds}; } - // partial pass here body += perform_partial_pass_step_3_4(); body += LineBreak{}; diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index fdb212d92f3..7f60ab074cb 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -21,7 +21,11 @@ #pragma once #include "stockham_gen_rr.h" -// TODO: transform_per_block or max_factor_pp? Revisit all usages +// TODO: Once partial pass is fully configurable in kernel-generator.py: +// - Revisit all usages of transform_per_block and max_factor_pp. +// - Test with factors_pp.size() > 1 +// - Revisit lstride usage and input/output strides +// - Revisit factor 64 logic in calculate_offsets() with different input lengths struct StockhamPartialPassKernelRR : public StockhamKernelRR { @@ -32,7 +36,6 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR , factors_pp(ppFactors) , length_pp(ppLength) { - // TODO: revisit this. Test with factors_pp.size() > 1 max_factor_pp = *std::max_element(factors_pp.begin(), factors_pp.end()); R.size = Expression{std::max(nregisters, max_factor_pp)}; @@ -191,7 +194,6 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR StatementList work; for(unsigned int w = 0; w < width; ++w) - //TODO: lstride not used here, address to have input/output strides working work += Assign(R[w], lds_complex[offset_lds + (w * stride_lds)]); return work; @@ -229,7 +231,6 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR auto load_lds = std::mem_fn(&StockhamPartialPassKernelRR::load_lds_step_1_2_generator); // first pass of load (full) - // TODO: revisit width. it used to be factors[0] unsigned int width = max_factor_pp; float height = static_cast(length) / width / threads_per_transform; body += SyncThreads(); @@ -249,7 +250,6 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR StatementList work; for(unsigned int w = 0; w < width; ++w) - //TODO: lstride not used here, address to have input/output strides working work += Assign(lds_complex[offset_lds + (w * stride_lds)], R[w]); return work; @@ -269,7 +269,6 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR auto store_lds = std::mem_fn(&StockhamPartialPassKernelRR::store_pp_step_1_2_lds_generator); // last pass of store (full) - // TODO: revisit width. it used to be factors.back() unsigned int width = max_factor_pp; float height = static_cast(length) / width / threads_per_transform; body += SyncThreads(); @@ -322,12 +321,10 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR return {scalar_type}; } - // TODO: Move this to a device function StatementList perform_partial_pass_step_1_2() { StatementList stmts; - // TODO: figure out factor 1 here (what happens with different in/out strides and lengths) stmts += Declaration{stride_lds_pp, length}; stmts += Declaration{offset_lds_pp, Parens(block_id * transforms_per_block + thread_id) % length}; @@ -343,8 +340,7 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR for(unsigned int npass = 0; npass < factors_pp.size(); ++npass) { - unsigned int width = factors_pp[npass]; - // TODO: revisit this. Different from same function in stockham_pp_gen_cc.h + unsigned int width = factors_pp[npass]; unsigned int height = transforms_per_block / max_factor_pp; auto butterfly = std::mem_fn(&StockhamKernel::butterfly_generator); @@ -426,7 +422,6 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR body += Declaration{batch}; body += Declaration{transform}; - // TODO- don't override, unify them body += set_direct_to_from_registers(); // half-lds @@ -516,7 +511,6 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR pre_post_lds_args}; body += postStore; - // partial pass here body += perform_partial_pass_step_1_2(); body += LineBreak{}; diff --git a/library/src/tree_node.cpp b/library/src/tree_node.cpp index 9b748304dff..7e9cfa14f6f 100644 --- a/library/src/tree_node.cpp +++ b/library/src/tree_node.cpp @@ -106,21 +106,13 @@ void LeafNode::GetKernelFactors() { FMKey key = GetKernelKey(); kernelFactors = pool.get_kernel(key).factors; - - // Hard-coded kernel factors for len 64x64x64 partial-pass - // TODO: Remove this hard-coded logic once - // partial-pass is integrated into the stockham generators. - if(scheme == CS_KERNEL_STOCKHAM && applyPartialPass) - kernelFactors = {8, 8}; - if(scheme == CS_KERNEL_STOCKHAM_BLOCK_CC && applyPartialPass) - kernelFactors = {8, 8}; } void LeafNode::GetKernelPartialPassFactors() { // Hard-coded kernel partial-pass factors for len 64x64x64. - // TODO: Remove this hard-coded logic once - // partial-pass is integrated into the Stockham generators. + // TODO: Remove this hard-coded logic once partial-pass + // kernels are configurable in kernel-generator.py. if(scheme == CS_KERNEL_STOCKHAM && applyPartialPass) { kernelFactorsPP = {16}; diff --git a/library/src/tree_node_3D.cpp b/library/src/tree_node_3D.cpp index 9bca8eb2bf8..98b73fea01f 100644 --- a/library/src/tree_node_3D.cpp +++ b/library/src/tree_node_3D.cpp @@ -609,8 +609,8 @@ bool RC3DNode::CheckPartialPassSupport() } } - // TODO: Once partial pass is property integrated into - // the Stockham generators, revisit these restrictions. + // TODO: Revisit these restrictions once partial pass is + // fully configurable in kernel-generator.py. bool batchCondition = (batch >= 5); size_t checkDist = product(length.begin(), length.end()); From e89612c13ae8ec273df33000293aa077bccd8490 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Wed, 23 Apr 2025 14:52:46 -0600 Subject: [PATCH 20/69] - Delete partial-pass test scripts. --- scripts/partial-pass/convert_1d_to_3d.m | 17 -- scripts/partial-pass/convert_3d_to_1d.m | 14 -- scripts/partial-pass/dft_matrix.m | 12 -- scripts/partial-pass/partial_pass_3d.m | 246 ----------------------- scripts/partial-pass/rocfft_to_octave.sh | 74 ------- scripts/partial-pass/run_test.m | 23 --- scripts/partial-pass/run_test.sh | 85 -------- 7 files changed, 471 deletions(-) delete mode 100644 scripts/partial-pass/convert_1d_to_3d.m delete mode 100644 scripts/partial-pass/convert_3d_to_1d.m delete mode 100644 scripts/partial-pass/dft_matrix.m delete mode 100644 scripts/partial-pass/partial_pass_3d.m delete mode 100755 scripts/partial-pass/rocfft_to_octave.sh delete mode 100644 scripts/partial-pass/run_test.m delete mode 100644 scripts/partial-pass/run_test.sh diff --git a/scripts/partial-pass/convert_1d_to_3d.m b/scripts/partial-pass/convert_1d_to_3d.m deleted file mode 100644 index d5997a48656..00000000000 --- a/scripts/partial-pass/convert_1d_to_3d.m +++ /dev/null @@ -1,17 +0,0 @@ -function out = convert_1d_to_3d(in, n1, n2, n3, batch, ordering) - -if ~isvector(in) - error('Invalid input'); -endif - -if strcmp(ordering,'column-major') - out=reshape(in, n3, n2, n1, batch); -elseif strcmp(ordering,'row-major') - out=reshape(in, n1, n2, n3, batch)'; -else - error('Invalid option'); -endif - - - - diff --git a/scripts/partial-pass/convert_3d_to_1d.m b/scripts/partial-pass/convert_3d_to_1d.m deleted file mode 100644 index 9b05008afa0..00000000000 --- a/scripts/partial-pass/convert_3d_to_1d.m +++ /dev/null @@ -1,14 +0,0 @@ -function out = convert_3d_to_1d(in, ordering) - -if isvector(in) - error('Invalid input'); -endif - -if strcmp(ordering,'column-major') - out = reshape(in, 1, []); - out=conj(out); -elseif strcmp(ordering,'row-major') - out = reshape(in', 1, []); -else - error('Invalid option'); -endif diff --git a/scripts/partial-pass/dft_matrix.m b/scripts/partial-pass/dft_matrix.m deleted file mode 100644 index bbea0c56294..00000000000 --- a/scripts/partial-pass/dft_matrix.m +++ /dev/null @@ -1,12 +0,0 @@ -function F = dft_matrix(n) - -F = zeros(n,n); -omega_n = exp(-2*pi*j/n); -for i=1:n - for j=1:n - F(i,j) = (omega_n^((i-1)*(j-1))); - endfor -endfor - - - diff --git a/scripts/partial-pass/partial_pass_3d.m b/scripts/partial-pass/partial_pass_3d.m deleted file mode 100644 index f08d46cbea6..00000000000 --- a/scripts/partial-pass/partial_pass_3d.m +++ /dev/null @@ -1,246 +0,0 @@ -function partial_pass_3d(in_length, nbatch, pp_dim, pp_radices, in_batched, out_batched, test_mode) - - test_mode_0 = 'input'; - test_mode_1 = 'full-3d'; - test_mode_2 = 'direction_1'; - test_mode_3 = 'direction_1_step_1_2'; - test_mode_4 = 'direction_1_step_1_2_3_4'; - - if ~(strcmp(test_mode,test_mode_0) || strcmp(test_mode,test_mode_1) || ... - strcmp(test_mode,test_mode_2) || strcmp(test_mode,test_mode_3) || ... - strcmp(test_mode,test_mode_4)) - display(test_mode); - error('Invalid test mode'); - endif - - format longG; - ordering='column-major'; - data_empty_value = -123456789; - - N = prod(in_length); - - pp_mode = 'four-step'; - - in_batched = convert_1d_to_3d(in_batched, in_length(1), in_length(2), in_length(3), nbatch, ordering); - out_batched = convert_1d_to_3d(out_batched, in_length(1), in_length(2), in_length(3), nbatch, ordering); - - for ibatch=1:nbatch - in = in_batched(:,:,:,ibatch); - out_ = out_batched(:,:,:,ibatch); - - % Validate output - idx_data=find(real(out_)~=data_empty_value); - if ( length(idx_data) != N ) - error('Error: incomplete data'); - endif - - if (strcmp(test_mode,test_mode_0)) - in = convert_3d_to_1d(in, ordering); - out_ = convert_3d_to_1d(out_, ordering); - - in = sort(in); - out_ = sort(out_); - - linf_rocfft_vs_octave_built_in = norm(in-out_,'inf'); - disp(['l-inf norm: ' num2str(linf_rocfft_vs_octave_built_in)]); - return; - endif - - % 3D-FFT (MATLAB built-in) - out = fftn(in); - out = convert_3d_to_1d(out, ordering); - out_ = convert_3d_to_1d(out_, ordering); - - if (strcmp(test_mode,test_mode_1)) - linf_rocfft_vs_octave_built_in = norm(out-out_,'inf'); - disp(['l-inf norm: ' num2str(linf_rocfft_vs_octave_built_in)]); - else - % CS_3D_RC from rocFFT (with partial pass) - [out_3d_rc, out_3d_rc_1, out_3d_rc_pp_1, out_3d_rc_pp_2] = run_CS_3D_RC(in_length, in, pp_dim, pp_radices, pp_mode); - - out_3d_rc = convert_3d_to_1d(out_3d_rc, ordering); - linf_test = norm(out_3d_rc-out,'inf'); - if (linf_test > 1E-5) - error("Error: partial-pass 3D-RC failed accuracy test"); - endif - - if (strcmp(test_mode,test_mode_2)) - out_3d_rc_1 = convert_3d_to_1d(out_3d_rc_1, ordering); - linf_test = norm(out_3d_rc_1-out_,'inf'); - disp(['l-inf norm: ' num2str(linf_test)]); - endif - - if (strcmp(test_mode,test_mode_3)) - out_3d_rc_pp_1 = convert_3d_to_1d(out_3d_rc_pp_1, ordering); - linf_test = norm(out_3d_rc_pp_1-out_,'inf'); - disp(['l-inf norm: ' num2str(linf_test)]); - endif - - if (strcmp(test_mode,test_mode_4)) - out_3d_rc_pp_2 = convert_3d_to_1d(out_3d_rc_pp_2, ordering); - linf_test = norm(out_3d_rc_pp_2-out_,'inf'); - disp(['l-inf norm: ' num2str(linf_test)]); - endif - endif - endfor - - function [out, out_1, out_pp_1, out_pp_2] = run_CS_3D_RC(in_length, in, pp_dim, pp_radices, pp_mode) - n = in_length(pp_dim); - - % Flip radices, as the radix order is reversed in steps 1-2 and 3-4 - pp_radices = flip(pp_radices); - - n1 = pp_radices(1); - n2 = pp_radices(2); - - F_n = dft_matrix(n); - F_n1 = dft_matrix(n1); - F_n2 = dft_matrix(n2); - - out = in; - - if (pp_dim == 1) - % 1st kernel (2nd dimension) - out = fft(out,[], 2); - out_1 = out; - out = partial_pass_step_1_2(out, 1, n1, n2, F_n1, F_n2, F_n, pp_mode); - out_pp_1 = out; - - % 2nd kernel (3rd dimension) - out = partial_pass_step_3_4(out, 1, n1, n2, F_n1, F_n2, F_n, pp_mode); - out_pp_2 = out; - out = fft(out,[], 3); - endif - - if (pp_dim == 2) - % Correct ordering for intermediate results comparison - transp_order_comp = [3 1 2]; - - % 1st kernel (1st dimension) - out = fft(out,[], 1); - out_1 = permute(out, transp_order_comp); - out = partial_pass_step_1_2(out, 2, n1, n2, F_n1, F_n2, F_n, pp_mode); - out_pp_1 = permute(out, transp_order_comp); - - % 2nd kernel (3rd dimension) - transp_order = [3 2 1]; - out = permute(out, transp_order); - - out = partial_pass_step_3_4(out, 2, n1, n2, F_n1, F_n2, F_n, pp_mode); - out_pp_2 = out; - - out = fft(out,[], 1); - - out = permute(out, transp_order); - endif - - if (pp_dim == 3) - % 1st kernel (1st dimension) - out = fft(out,[], 1); - out_1 = out; - out = partial_pass_step_1_2(out, 3, n1, n2, F_n1, F_n2, F_n, pp_mode); - out_pp_1 = out; - - % 2nd kernel (2nd dimension) - out = partial_pass_step_3_4(out, 3, n1, n2, F_n1, F_n2, F_n, pp_mode); - out_pp_2 = out; - out = fft(out,[], 2); - endif - endfunction - - function [dim1, dim2] = get_data_dim_partial_pass(input, pp_dim) - if (pp_dim==1) - dim1 = size(input,2); - dim2 = size(input,3); - elseif (pp_dim==2) - dim1 = size(input,1); - dim2 = size(input,3); - elseif (pp_dim==3) - dim1 = size(input,1); - dim2 = size(input,2); - endif - endfunction - - function input_data_decomp = get_pp_decomposed_data(input_data, pp_dim, idx1, idx2, n1, n2) - if (pp_dim==1) - input_data_decomp = reshape(input_data(:,idx1,idx2), n1, n2); - elseif (pp_dim==2) - input_data_decomp = reshape(input_data(idx1,:,idx2), n1, n2); - elseif (pp_dim==3) - input_data_decomp = reshape(input_data(idx1,idx2,:), n1, n2); - endif - endfunction - - function output = set_pp_data(input, input_decomp, pp_dim, idx1, idx2) - output = input; - - if (pp_dim==1) - output(:,idx1,idx2) = reshape(input_decomp, [], 1); - elseif (pp_dim==2) - output(idx1,:,idx2) = reshape(input_decomp, [], 1); - elseif (pp_dim==3) - output(idx1,idx2,:) = reshape(input_decomp, [], 1); - endif - endfunction - - function output = partial_pass_step_1_2(input, pp_dim, n1, n2, F_n1, F_n2, F_n, mode) - output = input; - - [dim1, dim2] = get_data_dim_partial_pass(input, pp_dim); - - for idx2=1:dim2 - for idx1=1:dim1 - in_decomp = get_pp_decomposed_data(output, pp_dim, idx1, idx2, n1, n2); - - if strcmp(mode, 'four-step') - % Length-n2 FFT along rows of in_decomp - out_decomp = fft(in_decomp, n2, 2); - % Twiddle multiply - out_decomp = F_n(1:n1, 1:n2).*out_decomp; - elseif strcmp(mode, 'six-step') - % Local transpose - out_decomp = in_decomp.'; - % Length-n1 FFT along columns of out_decomp - out_decomp = fft(out_decomp, n1, 1); - % Twiddle multiply - out_decomp = F_n(1:n1, 1:n2).*out_decomp; - else - error('invalid partial-pass mode'); - endif - - output = set_pp_data(output, out_decomp, pp_dim, idx1, idx2); - endfor - endfor - endfunction - - function output = partial_pass_step_3_4(input, pp_dim, n1, n2, F_n1, F_n2, F_n, mode) - output = input; - - [dim1, dim2] = get_data_dim_partial_pass(input, pp_dim); - - for idx1=1:dim1 - for idx2=1:dim2 - in_decomp = get_pp_decomposed_data(output, pp_dim, idx1, idx2, n1, n2); - - if strcmp(mode, 'four-step') - % Local transpose - out_decomp = in_decomp.'; - % Length-n1 FFT along rows of out_decomp - out_decomp = fft(out_decomp, n1, 2); - elseif strcmp(mode, 'six-step') - % Local transpose - out_decomp = in_decomp.'; - % Length-n2 FFT along columns of out_decomp - out_decomp = fft(out_decomp, n2, 1); - % Local transpose - out_decomp = out_decomp.'; - else - error('invalid partial-pass mode'); - endif - - output = set_pp_data(output, out_decomp, pp_dim, idx1, idx2); - endfor - endfor - endfunction - -endfunction diff --git a/scripts/partial-pass/rocfft_to_octave.sh b/scripts/partial-pass/rocfft_to_octave.sh deleted file mode 100755 index 0f9164652f5..00000000000 --- a/scripts/partial-pass/rocfft_to_octave.sh +++ /dev/null @@ -1,74 +0,0 @@ -#! /bin/bash - -# usage /.rocfft_to_octave.sh $arg1 #arg2 $file -# arg1=1 (input) arg1=0 (output) -# arg=2 buffer id - -if [ $1 -eq 1 ]; then - filename="rocfft_input_data.m" -elif [ $1 -eq 0 ]; then - filename="rocfft_output_data.m" -else - echo "error" -fi - -# put input file in variable filename -sed '' $3 | sponge $filename - -# Get buffer description lines in filename and append -# line number to them (the lines starting with -# '--- --- or final output') -cat -n $filename | sed -n '/--- ---\|final output/p' | sponge $filename - -# remove lines with buffer hash -sed '/hash/d' $filename | sponge $filename - -# store result in temp variable -tmp_var=`cat $filename` - -# get line of buffer passed as argument -tmp_var=$(sed -n "/kernel $2/p; /kernel $2/q" <<< "$tmp_var") - -# if no lines found, use line number of 'final output' buffer -if [[ -z "${tmp_var// }" ]] ; then - sed -n "/final output/p; /final output/q" $filename | sponge $filename -else - sed -n "/kernel $2/p; /kernel $2/q" $filename | sponge $filename -fi - -# get line number from this line -sed 's/ .*//' $filename | sponge $filename - -# store line number in variable tmp_var -tmp_var=`cat $filename` - -# put input file in variable filename -sed '' $3 | sponge $filename - -# get buffer from line1 line number to the next '--- ---' line -sed -n "1,$tmp_var b;/--- ---\|final output/ q;p" $3 | sponge $filename - -# -sed '1i data=[' $filename | sponge $filename - -# Remove character '(' from complex number -sed 's/(//g' $filename | sponge $filename - -# Replace character ',' with '+' in complex number -sed 's/,/+/g' $filename | sponge $filename - -# Remove new lines -tr '\n' ' ' < $filename | sponge $filename - -# Replace character ')' with 'i;' -sed 's/)/i;\n/g' $filename | sponge $filename - -# Append '];' to the end of the file -sed '$a];' $filename | sponge $filename - -# -if [ $1 -eq 1 ]; then - sed -i "1s/^/function data = rocfft_input_data()\n/" $filename -elif [ $1 -eq 0 ]; then - sed -i "1s/^/function data = rocfft_output_data()\n/" $filename -fi diff --git a/scripts/partial-pass/run_test.m b/scripts/partial-pass/run_test.m deleted file mode 100644 index 532b316e836..00000000000 --- a/scripts/partial-pass/run_test.m +++ /dev/null @@ -1,23 +0,0 @@ -function run_test() - -length = load("-ascii", "in_len.txt"); - -batch = load("-ascii", "in_batch.txt"); - -pp_dim = load("-ascii", "in_pp_dim.txt"); - -pp_radices = load("-ascii", "in_pp_radices.txt"); - -fid = fopen("in_test_mode.txt", 'r'); -test_mode = textscan(fid, '%s', 'delimiter', '\n'); -test_mode = cellstr(test_mode); -fclose(fid); - -in_batched = rocfft_input_data(); - -out_batched = rocfft_output_data(); - -partial_pass_3d(length, batch, pp_dim, pp_radices, in_batched, out_batched, test_mode); - -delete('rocfft_input_data.m'); -delete('rocfft_output_data.m'); \ No newline at end of file diff --git a/scripts/partial-pass/run_test.sh b/scripts/partial-pass/run_test.sh deleted file mode 100644 index ec9a4fb0529..00000000000 --- a/scripts/partial-pass/run_test.sh +++ /dev/null @@ -1,85 +0,0 @@ -#!/bin/bash - -# requires changing code generator to skip full pass and steps 1-2 -test_mode_0='input' - -test_mode_1='full-3d' - -# requires changing code generator to skip steps 1-2 -test_mode_2='direction_1' - -test_mode_3='direction_1_step_1_2' - -test_mode_4='direction_1_step_1_2_3_4' - -# ------------------------------------------------------------------- -# input parameters -# ------------------------------------------------------------------- - -length=( 64 64 128 ) -batch=( 5 ) -pp_dim=( 2 ) -pp_radices=( 16 4 ) -test_mode=$test_mode_3 -# ------------------------------------------------------------------- - -in_len_file="in_len.txt" -in_batch_file="in_batch.txt" -in_pp_dim_file="in_pp_dim.txt" -in_pp_radices_file="in_pp_radices.txt" -in_test_mode_file="in_test_mode.txt" -rocfft_input_data_file="rocfft_input_data.m" -rocfft_output_data_file="rocfft_output_data.m" -# ------------------------------------------------------------------- - -echo ${length[@]} > $in_len_file -echo ${batch[@]} > $in_batch_file -echo ${pp_dim[@]} > $in_pp_dim_file -echo ${pp_radices[@]} > $in_pp_radices_file -echo ${test_mode} > $in_test_mode_file - -# =================================================================== -rocfft_script_dir=$(pwd) -rofft_dir=$(pwd)/../.. -rocfft_exec_dir=${rofft_dir}/build/clients/staging/ - -cd $rocfft_exec_dir - -ROCFFT_LAYER=16 ./rocfft-bench --precision double --length ${length[0]} ${length[1]} ${length[2]} -b ${batch[0]} &> out.txt - -cd $rocfft_script_dir - -if [ $test_mode = $test_mode_0 ]; then - buffer_arg_1=0 - buffer_arg_2=1 -elif [ $test_mode = $test_mode_1 ]; then - buffer_arg_1=0 - buffer_arg_2=2 -elif [ $test_mode = $test_mode_2 ]; then - buffer_arg_1=0 - buffer_arg_2=1 -elif [ $test_mode = $test_mode_3 ]; then - buffer_arg_1=0 - buffer_arg_2=1 -elif [ $test_mode = $test_mode_4 ]; then - buffer_arg_1=0 - buffer_arg_2=2 -fi - -./rocfft_to_octave.sh 1 $buffer_arg_1 ${rocfft_exec_dir}out.txt -./rocfft_to_octave.sh 0 $buffer_arg_2 ${rocfft_exec_dir}out.txt - -rm $rocfft_exec_dir/out.txt - -octave -W run_test.m - -# =================================================================== - -rm $in_len_file -rm $in_batch_file -rm $in_pp_dim_file -rm $in_pp_radices_file -rm $in_test_mode_file - - - From 6f286b033debf73f60ae998b8f4449914a0f9495 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Wed, 23 Apr 2025 14:55:14 -0600 Subject: [PATCH 21/69] - undo test changes. --- shared/printbuffer.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shared/printbuffer.h b/shared/printbuffer.h index 2d6d0679787..1c9b3895a29 100644 --- a/shared/printbuffer.h +++ b/shared/printbuffer.h @@ -47,7 +47,7 @@ inline void printbuffer(const Toutput* output, { const int i = std::inner_product(index.begin(), index.end(), stride.begin(), i_base + offset); - stream << std::fixed << std::setprecision(14) << output[i] << " "; + stream << output[i] << " "; for(int li = index.size(); li-- > 0;) { if(index[li] == (length[li] - 1)) From 987a7f92e03b3f6dafddb7a952711f69fea9e9dd Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Wed, 23 Apr 2025 14:58:46 -0600 Subject: [PATCH 22/69] - Clang formatting. --- library/src/tree_node.cpp | 2 +- library/src/tree_node_3D.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/library/src/tree_node.cpp b/library/src/tree_node.cpp index 7e9cfa14f6f..14b322caf95 100644 --- a/library/src/tree_node.cpp +++ b/library/src/tree_node.cpp @@ -111,7 +111,7 @@ void LeafNode::GetKernelFactors() void LeafNode::GetKernelPartialPassFactors() { // Hard-coded kernel partial-pass factors for len 64x64x64. - // TODO: Remove this hard-coded logic once partial-pass + // TODO: Remove this hard-coded logic once partial-pass // kernels are configurable in kernel-generator.py. if(scheme == CS_KERNEL_STOCKHAM && applyPartialPass) { diff --git a/library/src/tree_node_3D.cpp b/library/src/tree_node_3D.cpp index 98b73fea01f..075af58ab89 100644 --- a/library/src/tree_node_3D.cpp +++ b/library/src/tree_node_3D.cpp @@ -610,7 +610,7 @@ bool RC3DNode::CheckPartialPassSupport() } // TODO: Revisit these restrictions once partial pass is - // fully configurable in kernel-generator.py. + // fully configurable in kernel-generator.py. bool batchCondition = (batch >= 5); size_t checkDist = product(length.begin(), length.end()); From 7b092a1743e5dc82f8d2bc1ac73b0b97d9e18c0b Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 24 Apr 2025 11:28:40 -0600 Subject: [PATCH 23/69] - Further clean up --- .../src/device/generator/stockham_pp_gen_cc.h | 180 +++++------------- library/src/device/kernel-generator.py | 4 +- library/src/rtc_stockham_kernel.cpp | 17 ++ library/src/tree_node_1D.cpp | 32 ++++ 4 files changed, 99 insertions(+), 134 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index 5c04f3ba89d..a0dccd97f54 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -43,7 +43,10 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC max_factor_pp = *std::max_element(factors_pp.begin(), factors_pp.end()); - transforms_per_block_pp = transforms_per_block / max_factor_pp; + transforms_per_block_pp = transforms_per_block; + + transforms_per_block *= max_factor_pp; + workgroup_size *= max_factor_pp; } unsigned int transforms_per_block_pp; @@ -236,30 +239,12 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += LineBreak{}; stmts += Assign{batch, block_id / plength}; - if(!direct_to_from_reg) - { - stmts - += Assign{transform, - tile_index * transforms_per_block_pp + thread_id / threads_per_transform}; - stmts += Assign{stride_lds, (length + get_lds_padding())}; - stmts += MultiplyAssign(stride_lds, Literal{max_factor_pp}); - } - else - { - stmts += Assign{ - transform, - Ternary{lds_linear, - tile_index * transforms_per_block_pp + thread_id / threads_per_transform, - tile_index * transforms_per_block_pp - + thread_id % transforms_per_block_pp}}; - stmts += Assign{stride_lds, - Ternary{lds_linear, - length + get_lds_padding(), - transforms_per_block_pp + get_lds_padding()}}; + stmts += Assign{transform, + tile_index * transforms_per_block_pp + thread_id / threads_per_transform}; + stmts += Assign{stride_lds, (length + get_lds_padding())}; - stmts += MultiplyAssign(stride_lds, Literal{max_factor_pp}); - } + stmts += MultiplyAssign(stride_lds, Literal{max_factor_pp}); stmts += Declaration{ in_bound, @@ -284,17 +269,9 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC offset_pp, offset + Parens(offset / length) * Literal{192} + batch_new * stride[dim]); stmts += Declaration(offset_tid_hor, offset_pp + tid_hor_pp * stride[1]); - if(!direct_to_from_reg) - stmts += Assign{transform, - tile_index * transforms_per_block_pp - + thread_id / (threads_per_transform * max_factor_pp)}; - else - stmts += Assign{transform, - Ternary{lds_linear, - tile_index * transforms_per_block_pp - + thread_id / (threads_per_transform * max_factor_pp), - tile_index * transforms_per_block_pp - + thread_id % transforms_per_block_pp}}; + stmts += Assign{transform, + tile_index * transforms_per_block_pp + + thread_id / (threads_per_transform * max_factor_pp)}; stmts += Assign{offset_lds, Ternary{lds_linear, @@ -412,75 +389,33 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC return work; } - StatementList store_to_global(bool store_registers) override + StatementList store_to_global(bool store_registers = false) override { StatementList stmts; StatementList tmp_stmts; Expression pred{tile_index * transforms_per_block_pp + tid_hor < lengths[1]}; - if(!store_registers) - { - auto stripmine_w = transforms_per_block; - auto stripmine_h = workgroup_size / stripmine_w; - - auto offset_tile_wbuf = [&](unsigned int i) { - return offset_tid_hor + (thread_new + i * stripmine_h) * stride0; - }; - auto offset_tile_rlds = [&](unsigned int i) { - return tid_hor_lds * stride_lds - + (thread_lds + i * stripmine_h * max_factor_pp) * 1; - }; + auto stripmine_w = transforms_per_block; + auto stripmine_h = workgroup_size / stripmine_w; - for(unsigned int i = 0; i < length / stripmine_h; ++i) - tmp_stmts += StoreGlobal{ - buf, - CallExpr{"local_transpose_pp_length" + std::to_string(length) + "_device", - {offset_tile_wbuf(i)}}, - lds_complex[offset_tile_rlds(i)]}; + auto offset_tile_wbuf = [&](unsigned int i) { + return offset_tid_hor + (thread_new + i * stripmine_h) * stride0; + }; + auto offset_tile_rlds = [&](unsigned int i) { + return tid_hor_lds * stride_lds + (thread_lds + i * stripmine_h * max_factor_pp) * 1; + }; - stmts += CommentLines{ - "no intrinsic when store from lds. FIXME- check why use nested branch is better"}; - stmts += If{in_bound, tmp_stmts}; - stmts += If{Not{in_bound}, {If{pred, tmp_stmts}}}; - } - else - { - StatementList intrinsic_stmts; - StatementList non_intrinsic_stmts; + for(unsigned int i = 0; i < length / stripmine_h; ++i) + tmp_stmts += StoreGlobal{ + buf, + CallExpr{"local_transpose_pp_length" + std::to_string(length) + "_device", + {offset_tile_wbuf(i)}}, + lds_complex[offset_tile_rlds(i)]}; - auto width = factors.back(); - auto cumheight = product(factors.begin(), factors.begin() + (factors.size() - 1)); - auto height = static_cast(length) / width / threads_per_transform; - - auto store_global = std::mem_fn(&StockhamKernelCC::store_global_generator); - intrinsic_stmts += CommentLines{"use intrinsic store"}; - intrinsic_stmts += add_work(std::bind(store_global, - this, - _1, - _2, - _3, - _4, - _5, - cumheight, - true, - Expression{Parens(in_bound || pred)}), - width, - height, - ThreadGuardMode::GURAD_BY_FUNC_ARG); - - tmp_stmts += add_work( - std::bind( - store_global, this, _1, _2, _3, _4, _5, cumheight, false, Expression{in_bound}), - width, - height, - ThreadGuardMode::GUARD_BY_IF); - non_intrinsic_stmts += CommentLines{"can't use intrinsic store"}; - non_intrinsic_stmts += If{in_bound, tmp_stmts}; - non_intrinsic_stmts += If{!in_bound, {If{pred, tmp_stmts}}}; - - stmts += If{intrinsic_mode == "IntrinsicAccessType::ENABLE_BOTH", intrinsic_stmts}; - stmts += Else{non_intrinsic_stmts}; - } + stmts += CommentLines{ + "no intrinsic when store from lds. FIXME- check why use nested branch is better"}; + stmts += If{in_bound, tmp_stmts}; + stmts += If{Not{in_bound}, {If{pred, tmp_stmts}}}; return stmts; } @@ -793,6 +728,16 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC return work; } + // Partial-pass steps 3/4 right after load_from_global + // and local transposition just before store_to_global + // do not allow direct_from_reg + StatementList set_direct_to_from_registers() override + { + return {Declaration{direct_load_to_reg, Literal{"false"}}, + Declaration{direct_store_from_reg, Literal{"false"}}, + Declaration{lds_linear, Literal{"true"}}}; + } + ArgumentList device_arguments() override { ArgumentList args = StockhamKernel::device_arguments(); @@ -981,19 +926,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC loadlds += load_from_global(false); loadlds += LineBreak{}; - if(!direct_to_from_reg) - { - body += loadlds; - } - else - { - StatementList loadr; - loadr += CommentLines{"load global into registers"}; - loadr += load_from_global(true); - - body += If{direct_load_to_reg, loadr}; - body += Else{loadlds}; - } + body += loadlds; body += perform_partial_pass_step_3_4(); @@ -1018,10 +951,8 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC preLoad += Call{"lds_to_reg_input_pp_length" + std::to_string(length) + "_device", pre_post_lds_tmpl, pre_post_lds_args}; - if(!direct_to_from_reg) - body += preLoad; - else - body += If{!direct_load_to_reg, preLoad}; + + body += preLoad; body += LineBreak{}; body += CommentLines{"transform"}; @@ -1040,18 +971,15 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC } // after finishing the transform job (core device function) - // we call a post-store reg-to-lds function here, but it's not always doing things. - // If we're doing direct-from-reg, this function simply returns. + // we call a post-store reg-to-lds function here body += LineBreak{}; body += CommentLines{"call a post-store from registers to lds (if necessary)"}; StatementList postStore; postStore += Call{"lds_from_reg_output_pp_length" + std::to_string(length) + "_device", pre_post_lds_tmpl, pre_post_lds_args}; - if(!direct_to_from_reg) - body += postStore; - else - body += If{!direct_store_from_reg, postStore}; + + body += postStore; body += LineBreak{}; StatementList storelds; @@ -1060,21 +988,9 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC storelds += LineBreak{}; storelds += CommentLines{"store global"}; storelds += SyncThreads{}; - storelds += store_to_global(false); - - if(!direct_to_from_reg) - { - body += storelds; - } - else - { - StatementList storer; - storer += CommentLines{"store registers into global"}; - storer += store_to_global(true); + storelds += store_to_global(); - body += If{direct_store_from_reg, storer}; - body += Else{storelds}; - } + body += storelds; f.templates = global_templates(); f.arguments = global_arguments(); diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index 63c546a4f3b..8ee89e1a987 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -292,7 +292,7 @@ def list_small_kernels(): NS(length= 56, workgroup_size=128, threads_per_transform= 8, factors=(7, 8)), NS(length= 60, workgroup_size= 64, threads_per_transform= 10, factors=(6, 10)), NS(length= 63, workgroup_size=256, threads_per_transform= 21, factors=(3, 3, 7), half_lds=False, runtime_compile=True), - NS(length= 64, workgroup_size=128, threads_per_transform= 8, factors=(4, 4, 4), half_lds=False, direct_to_from_reg=True), + NS(length= 64, workgroup_size= 64, threads_per_transform= 16, factors=(4, 4, 4), half_lds=False, direct_to_from_reg=True), NS(length= 65, workgroup_size=256, threads_per_transform= 13, factors=(13, 5), runtime_compile=True), NS(length= 66, workgroup_size=256, threads_per_transform= 11, factors=(6, 11), half_lds=False, runtime_compile=True), NS(length= 68, workgroup_size=256, threads_per_transform= 17, factors=(17, 4), runtime_compile=True), @@ -870,7 +870,7 @@ def list_large_kernels(): NS(length=60, factors=[6, 10], use_3steps_large_twd={ 'sp': 'false', 'dp': 'false'}), NS(length=64, factors=[8, 8], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'false'}, workgroup_size=256, direct_to_from_reg=False), + 'sp': 'true', 'dp': 'false'}, workgroup_size=256), NS(length=72, factors=[8, 3, 3], use_3steps_large_twd={ 'sp': 'true', 'dp': 'false'}), NS(length=80, factors=[10, 8], use_3steps_large_twd={ diff --git a/library/src/rtc_stockham_kernel.cpp b/library/src/rtc_stockham_kernel.cpp index 134d4e622d5..70d7087ec94 100644 --- a/library/src/rtc_stockham_kernel.cpp +++ b/library/src/rtc_stockham_kernel.cpp @@ -72,6 +72,23 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& // the generator as-is kernel = node.pool.get_kernel(key); + if(node.applyPartialPass) + { + if(node.scheme == CS_KERNEL_STOCKHAM_BLOCK_CC) + { + kernel->threads_per_transform[0] = 8; + kernel->workgroup_size = 64; + } + else if(node.scheme == CS_KERNEL_STOCKHAM) + { + kernel->threads_per_transform[0] = 8; + kernel->workgroup_size = 128; + } + + kernel->transforms_per_block + = kernel->workgroup_size / kernel->threads_per_transform[0]; + } + std::vector factors; std::copy(kernel->factors.begin(), kernel->factors.end(), std::back_inserter(factors)); std::vector precisions = {static_cast(node.precision)}; diff --git a/library/src/tree_node_1D.cpp b/library/src/tree_node_1D.cpp index 040be5950f4..60cffd087af 100644 --- a/library/src/tree_node_1D.cpp +++ b/library/src/tree_node_1D.cpp @@ -903,6 +903,16 @@ void Stockham1DNode::SetupGridParam_internal(GridParam& gp) auto key = GetKernelKey(); auto kernel = pool.get_kernel(key); + if(applyPartialPass) + { + // TODO: Hardcoded configuration for 64 x 64 x 64. + // Remove this once the partial-pass kernels are + // fully configurable in kernel-generator.py. + kernel.threads_per_transform[0] = 8; + kernel.workgroup_size = 128; + kernel.transforms_per_block = kernel.workgroup_size / kernel.threads_per_transform[0]; + } + bwd = kernel.transforms_per_block; wgs = kernel.workgroup_size; gp.b_x = (batch_accum + bwd - 1) / bwd; @@ -1104,11 +1114,33 @@ void SBCCNode::SetupGridParam_internal(GridParam& gp) bwd = kernel.transforms_per_block; wgs = kernel.workgroup_size; + if(applyPartialPass) + { + // TODO: Hardcoded configuration for 64 x 64 x 64. + // Remove this once the partial-pass kernels are + // fully configurable in kernel-generator.py. + auto tpt = 8; + wgs = 64; + bwd = wgs / tpt; + } + lds = length[0] * bwd; gp.b_x = ((length[1]) - 1) / bwd + 1; gp.b_x *= product(length.begin() + 2, length.end()) * batch; gp.wgs_x = wgs; + + if(applyPartialPass) + { + // Grid arrangement is different for partial + // pass SBCC kernels. This arrangement leads + // to improved global memory access patterns. + auto factor = *std::max_element(kernelFactorsPP.begin(), kernelFactorsPP.end()); + + gp.b_x /= factor; + gp.wgs_x *= factor; + lds *= factor; + } } std::vector SBCCNode::CollapsibleDims() From 5c111725cf621c8919bc1f0bdcc76b1d2ba05551 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 24 Apr 2025 11:32:10 -0600 Subject: [PATCH 24/69] - Remove no longer needed include --- shared/printbuffer.h | 1 - 1 file changed, 1 deletion(-) diff --git a/shared/printbuffer.h b/shared/printbuffer.h index 1c9b3895a29..5ae0b64fbb4 100644 --- a/shared/printbuffer.h +++ b/shared/printbuffer.h @@ -24,7 +24,6 @@ #include "hostbuf.h" #include "increment.h" #include -#include #include // Output a formatted general-dimensional array with given length and stride in batches From 708a98cd8d4649e281d87705ef5f425fc1b8b340 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 24 Apr 2025 14:53:13 -0600 Subject: [PATCH 25/69] - More clean-up. --- library/src/device/generator/stockham_pp_gen_cc.h | 5 ++--- library/src/rtc_stockham_gen.cpp | 1 - library/src/rtc_stockham_kernel.cpp | 3 +++ library/src/tree_node.cpp | 2 +- library/src/tree_node_3D.cpp | 5 +++++ 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index a0dccd97f54..633acc803fd 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -22,13 +22,14 @@ #include "stockham_gen_cc.h" // TODO: Once partial pass is fully configurable in kernel-generator.py: -// - Test all "lds_linear=false" cases. // - Test with factors_pp.size() > 1. // - Revisit all usages of transforms_per_block_pp and threads_per_transform. // - Different input/output strides. // - Revisit mod 128 usage in calculate_offsets() with different input lengths, // (logic is required to work with nbatch > 1) // - Revisit factor 192 logic in calculate_offsets() with different input lengths +// - Revisit and test local transpose logic for different input lengths + struct StockhamPartialPassKernelCC : public StockhamKernelCC { explicit StockhamPartialPassKernelCC(const StockhamGeneratorSpecs& specs, @@ -54,12 +55,10 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC std::vector factors_pp; Variable thread_lds{"thread_lds", "unsigned int"}; - Variable idx_lds{"idx_lds", "unsigned int"}; Variable stride_lds_pp{"stride_lds_pp", "unsigned int"}; Variable offset_lds_pp{"offset_lds_pp", "unsigned int"}; Variable tid_hor_lds{"tid_hor_lds", "unsigned int"}; - Variable offfset_unbatched{"offfset_unbatched", "unsigned int"}; Variable tid_hor_pp{"tid_hor_pp", "unsigned int"}; Variable offset_tid_hor{"offset_tid_hor", "unsigned int"}; Variable offset_pp{"offset_pp", "unsigned int"}; diff --git a/library/src/rtc_stockham_gen.cpp b/library/src/rtc_stockham_gen.cpp index f82106d5816..c7d38b96374 100644 --- a/library/src/rtc_stockham_gen.cpp +++ b/library/src/rtc_stockham_gen.cpp @@ -456,7 +456,6 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, src += large_twiddles_h; // append the neccessary functions only append_radix_h(src, all_factors); - // SBCCs don't need this if(scheme != CS_KERNEL_STOCKHAM_BLOCK_CC) src += real2complex_device_h; diff --git a/library/src/rtc_stockham_kernel.cpp b/library/src/rtc_stockham_kernel.cpp index 70d7087ec94..10fad9f86de 100644 --- a/library/src/rtc_stockham_kernel.cpp +++ b/library/src/rtc_stockham_kernel.cpp @@ -74,6 +74,9 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& if(node.applyPartialPass) { + // TODO: Hardcoded configuration for 64 x 64 x 64. + // Remove this once the partial-pass kernels are + // fully configurable in kernel-generator.py. if(node.scheme == CS_KERNEL_STOCKHAM_BLOCK_CC) { kernel->threads_per_transform[0] = 8; diff --git a/library/src/tree_node.cpp b/library/src/tree_node.cpp index 14b322caf95..70706d9e769 100644 --- a/library/src/tree_node.cpp +++ b/library/src/tree_node.cpp @@ -110,7 +110,7 @@ void LeafNode::GetKernelFactors() void LeafNode::GetKernelPartialPassFactors() { - // Hard-coded kernel partial-pass factors for len 64x64x64. + // Hard-coded partial-pass kernel factors for len 64x64x64. // TODO: Remove this hard-coded logic once partial-pass // kernels are configurable in kernel-generator.py. if(scheme == CS_KERNEL_STOCKHAM && applyPartialPass) diff --git a/library/src/tree_node_3D.cpp b/library/src/tree_node_3D.cpp index 075af58ab89..479280afd06 100644 --- a/library/src/tree_node_3D.cpp +++ b/library/src/tree_node_3D.cpp @@ -644,6 +644,11 @@ void RC3DNode::BuildTree_internal(SchemeTreeVec& child_scheme_trees) if(CheckPartialPassSupport()) { + // TODO: Child nodes currently hardcoded to a x+z configuration + // in 3D partial-pass. Add support for other configurations, + // e.g., x+y, y+z, once partial pass is fully configurable + // in kernel-generator.py. + // work along y will be split between x and z applyPartialPass = true; From 47cbc4eb992ab322a2955baa3694bb72d7cf656f Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 25 Apr 2025 09:24:37 -0600 Subject: [PATCH 26/69] - match SBRR configuration. --- library/src/rtc_stockham_kernel.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/library/src/rtc_stockham_kernel.cpp b/library/src/rtc_stockham_kernel.cpp index 10fad9f86de..163533f8f63 100644 --- a/library/src/rtc_stockham_kernel.cpp +++ b/library/src/rtc_stockham_kernel.cpp @@ -86,6 +86,7 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& { kernel->threads_per_transform[0] = 8; kernel->workgroup_size = 128; + kernel->direct_to_from_reg = false; } kernel->transforms_per_block From eecba0ac09071772e9068d93fc70f0a8bd6243ef Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 25 Apr 2025 14:39:08 -0600 Subject: [PATCH 27/69] - Changes to kernel-generator.py and function pool to support partial-pass kernels. --- library/src/compute_scheme.cpp | 3 +- library/src/device/generator/stockham_gen.cpp | 207 ++++++++++++++---- library/src/device/generator/stockham_gen.h | 2 + library/src/device/kernel-generator.py | 68 ++++-- library/src/include/compute_scheme.h | 3 +- library/src/include/function_map_key.h | 77 +++++++ library/src/include/function_pool.h | 27 ++- 7 files changed, 322 insertions(+), 65 deletions(-) diff --git a/library/src/compute_scheme.cpp b/library/src/compute_scheme.cpp index cbc2abfba85..26a6073f405 100644 --- a/library/src/compute_scheme.cpp +++ b/library/src/compute_scheme.cpp @@ -128,7 +128,8 @@ static const std::set& ProblemScheme() (CS_3D_RTRT), (CS_3D_BLOCK_RC), (CS_3D_BLOCK_CR), - (CS_3D_RC)}; + (CS_3D_RC), + (CS_3D_PP)}; return ProblemSchemeSet; } diff --git a/library/src/device/generator/stockham_gen.cpp b/library/src/device/generator/stockham_gen.cpp index 3cba76fc9eb..9b2aae5596a 100644 --- a/library/src/device/generator/stockham_gen.cpp +++ b/library/src/device/generator/stockham_gen.cpp @@ -144,6 +144,26 @@ void make_launcher(const std::vector& precision_types, } } +// parse comma-separated string booleans +std::vector parse_bool_csv(const std::string& arg) +{ + std::vector bools; + + size_t prev_pos = 0; + for(;;) + { + auto pos = arg.find(',', prev_pos); + if(pos == std::string::npos) + { + bools.push_back(arg.substr(prev_pos) == "1"); + break; + } + bools.push_back(arg.substr(prev_pos, pos - prev_pos) == "1"); + prev_pos = pos + 1; + } + return bools; +} + // parse comma-separated string uints std::vector parse_uints_csv(const std::string& arg) { @@ -166,6 +186,73 @@ std::vector parse_uints_csv(const std::string& arg) const char* COMMA = ","; +void output_json(const std::vector& launchers, + const std::string& kernel_name, + std::ostream& output) +{ + // output json (via stdout) describing the launchers that were generated, so + // kernel-generator can generate the function pool + + const char* LIST_DELIM = ""; + + // store all variants of one kernel in a list, and store with kernel name as key + output << "\"" << kernel_name << "\" : "; + output << "["; + for(auto& launcher : launchers) + { + output << LIST_DELIM; + output << launcher.to_string() << "\n"; + LIST_DELIM = COMMA; + } + output << "]"; +} + +void stockham_partial_pass_variants(const std::string& kernel_name, + const size_t& pp_length, + const StockhamGeneratorSpecs& specs1, + const std::vector& pp_factors1, + const StockhamGeneratorSpecs& specs2, + const std::vector& pp_factors2, + std::ostream& output) +{ + std::vector launchers; + + if(specs1.scheme == "CS_3D_PP" && specs2.scheme == "CS_3D_PP") + { + // SBRR_PP + SBCC_PP + if(specs1.static_dim == 0 && specs2.static_dim == 2) + { + std::vector factors1(pp_factors1.begin(), pp_factors1.end()); + StockhamPartialPassKernelRR kernelRR(specs1, factors1, pp_length); + make_launcher( + specs1.precisions, {{"pp_stoc", specs1.scheme, "", ""}}, kernelRR, launchers); + + std::vector factors2(pp_factors2.begin(), pp_factors2.end()); + StockhamPartialPassKernelCC kernelCC(specs2, false, factors2); + make_launcher( + specs2.precisions, {{"pp_sbcc", specs2.scheme, "", ""}}, kernelCC, launchers); + } + // SBRR_PP + SBCC_PP + else if(specs1.static_dim == 1 && specs2.static_dim == 2) + { + } + // SBRR_PP + SBRR_PP + else if(specs1.static_dim == 0 && specs2.static_dim == 1) + { + } + else + { + throw std::runtime_error("invalid dimensions for CS_3D_PP"); + } + } + else + { + throw std::runtime_error("unhandled scheme"); + } + + output_json(launchers, kernel_name, output); +} + void stockham_variants(const std::string& kernel_name, const StockhamGeneratorSpecs& specs, const StockhamGeneratorSpecs& specs2d, @@ -247,32 +334,16 @@ void stockham_variants(const std::string& kernel_name, else throw std::runtime_error("unhandled scheme"); - // output json (via stdout) describing the launchers that were generated, so - // kernel-generator can generate the function pool - - const char* LIST_DELIM = ""; - - // store all variants of one kernel in a list, and store with kernel name as key - output << "\"" << kernel_name << "\" : "; - output << "["; - for(auto& launcher : launchers) - { - output << LIST_DELIM; - output << launcher.to_string() << "\n"; - LIST_DELIM = COMMA; - } - output << "]"; + output_json(launchers, kernel_name, output); } int main() { std::string line; - std::string kernel_name; - std::string scheme; - bool direct_to_from_reg; - bool half_lds; - unsigned int workgroup_size; + std::string kernel_name; + std::string scheme; + bool half_lds; const char* DELIM = ""; std::cout << "{"; @@ -297,14 +368,33 @@ int main() ++arg; scheme = *arg; + unsigned int pp_length; + std::vector dims, pp_factors_1, pp_factors_2; + if(scheme == "CS_3D_PP") + { + ++arg; + pp_length = std::stoul(*arg); + + ++arg; + pp_factors_2 = parse_uints_csv(*arg); + + ++arg; + pp_factors_1 = parse_uints_csv(*arg); + + ++arg; + dims = parse_uints_csv(*arg); + } + ++arg; - direct_to_from_reg = *arg == "1"; + std::vector direct_to_from_reg; + direct_to_from_reg = parse_bool_csv(*arg); ++arg; half_lds = *arg == "1"; ++arg; - workgroup_size = std::stoul(*arg); + std::vector workgroup_size; + workgroup_size = parse_uints_csv(*arg); ++arg; std::vector threads_per_transform; @@ -314,31 +404,70 @@ int main() std::vector precisions; precisions = parse_uints_csv(*arg); - std::vector factors; - std::vector factors2d; - if(scheme == "CS_KERNEL_2D_SINGLE") + // create spec and pass to stockham_variants, writes partial output to stdout + std::cout << DELIM; + + if(scheme == "CS_3D_PP") { + std::vector factors1, factors2; + + ++arg; + factors2 = parse_uints_csv(*arg); + ++arg; - factors2d = parse_uints_csv(*arg); + factors1 = parse_uints_csv(*arg); + + if(dims.size() != 2) + throw std::runtime_error("CS_3D_PP requires two dimensions configuration"); + + if(threads_per_transform.size() != 2) + throw std::runtime_error( + "CS_3D_PP requires two threads_per_transform configuration"); + + if(direct_to_from_reg.size() != 2) + throw std::runtime_error("CS_3D_PP requires two direct_to_from_reg configuration"); + + StockhamGeneratorSpecs specs1(factors1, {}, precisions, workgroup_size[0], scheme); + specs1.static_dim = dims[0]; + specs1.direct_to_from_reg = direct_to_from_reg[0]; + specs1.threads_per_transform = threads_per_transform[0]; + + StockhamGeneratorSpecs specs2(factors2, {}, precisions, workgroup_size[1], scheme); + specs2.static_dim = dims[1]; + specs2.direct_to_from_reg = direct_to_from_reg[1]; + specs2.threads_per_transform = threads_per_transform[1]; + + stockham_partial_pass_variants( + kernel_name, pp_length, specs1, pp_factors_1, specs2, pp_factors_2, std::cout); } + else + { + std::vector factors; + std::vector factors2d; + if(scheme == "CS_KERNEL_2D_SINGLE") + { + ++arg; + factors2d = parse_uints_csv(*arg); + } - ++arg; - factors = parse_uints_csv(*arg); + ++arg; + factors = parse_uints_csv(*arg); - StockhamGeneratorSpecs specs(factors, factors2d, precisions, workgroup_size, scheme); - specs.half_lds = half_lds; - specs.direct_to_from_reg = direct_to_from_reg; + StockhamGeneratorSpecs specs(factors, factors2d, precisions, workgroup_size[0], scheme); + specs.half_lds = half_lds; + specs.direct_to_from_reg = direct_to_from_reg[0]; - specs.threads_per_transform = threads_per_transform.front(); + specs.threads_per_transform = threads_per_transform.front(); - // second dimension for 2D_SINGLE - StockhamGeneratorSpecs specs2d(factors2d, factors, precisions, workgroup_size, scheme); - if(!threads_per_transform.empty()) - specs2d.threads_per_transform = threads_per_transform.back(); + // second dimension for 2D_SINGLE + StockhamGeneratorSpecs specs2d( + factors2d, factors, precisions, workgroup_size[0], scheme); + if(!threads_per_transform.empty()) + specs2d.threads_per_transform = threads_per_transform.back(); + + stockham_variants(kernel_name, specs, specs2d, std::cout); + } - // create spec and pass to stockham_variants, writes partial output to stdout - std::cout << DELIM; - stockham_variants(kernel_name, specs, specs2d, std::cout); DELIM = COMMA; std::cout << std::flush; } diff --git a/library/src/device/generator/stockham_gen.h b/library/src/device/generator/stockham_gen.h index 1b2bff3f5a0..6d44c67a49d 100644 --- a/library/src/device/generator/stockham_gen.h +++ b/library/src/device/generator/stockham_gen.h @@ -60,6 +60,8 @@ struct StockhamGeneratorSpecs unsigned int static_dim = 0; std::string scheme; + unsigned int pp_dim = 0; + // this value indicating if the wgs, tpt are excatly what we want // (i.e. were already derived somewhere) // to tell StockhamKernel not to do its auto-derivation again. diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index 8ee89e1a987..9238e797d9a 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -152,7 +152,7 @@ def generate_cpu_function_pool_main(num_files): type='void', name=f'function_pool_init_{i}', value= - 'std::unordered_map& def_key_pool, std::unordered_map& function_map' + 'FPKeyMap& def_key_pool, FPMap& function_map' ) call_list = StatementList() @@ -176,8 +176,6 @@ def generate_cpu_function_pool_pieces(functions, num_files): 'half': 'rocfft_precision_half', } var_kernel = Variable('kernel', 'FFTKernel') - initial_statement = StatementList() - initial_statement += var_kernel.declaration() # Init list to store contents of function_pool_init function per file being generated piece_contents = [ @@ -204,8 +202,8 @@ def generate_cpu_function_pool_pieces(functions, num_files): # Assemble contents of each file to return in a list pieces = [None] * num_files piece_args = ArgumentList( - 'std::unordered_map& def_key_pool', - 'std::unordered_map& function_map') + 'FPKeyMap& def_key_pool', + 'FPMap& function_map') for i in range(num_files): pieces[i] = StatementList( Include('"../include/function_pool.h"'), @@ -1004,6 +1002,20 @@ def list_large_kernels(): k.length = functools.reduce(lambda a, b: a * b, k.factors) return sbcc_kernels + sbcr_kernels + sbrc_kernels + +def list_3d_partial_pass_kernels(): + """Return list of to generate.""" + + pp_3d_kernels = [ + NS(length=[64,64,64], dims=[0, 2], factors=[[4, 4, 4],[8, 8]], factors_pp=[[16], [4]], threads_per_transform=[8, 16], workgroup_size=[128, 256], direct_to_from_reg=[False, False]) + ] + + expanded = [] + expanded.extend(NS(**kernel.__dict__, + scheme='CS_3D_PP', runtime_compile=True) for kernel in pp_3d_kernels) + + return expanded + # yapf: enable @@ -1032,9 +1044,14 @@ def generate_kernel_functions(kernels, precisions, launchers_json): launcher = NS(**launcher_dict) factors = launcher.factors - length = launcher.lengths[0] if len( - launcher.lengths) == 1 else (launcher.lengths[0], - launcher.lengths[1]) + + if len(launcher.lengths) == 1: + length = launcher.lengths[0] + elif len(launcher.lengths) == 2: + length = (launcher.lengths[0], launcher.lengths[1]) + elif len(launcher.lengths) == 3: + length = (launcher.lengths[0], launcher.lengths[1], launcher.lengths[2]) + transforms_per_block = launcher.transforms_per_block workgroup_size = launcher.workgroup_size threads_per_transform = workgroup_size // transforms_per_block @@ -1149,13 +1166,33 @@ def generate_kernels(kernels, precisions, stockham_gen): if len(k.factors) == 1: half_lds = False - # for unspecified direct_to_from_reg, default is True only for CS_KERNEL_STOCKHAM and SBCC - direct_to_from_reg = getattr(k, 'direct_to_from_reg', True) - # Send data over to subprocess - proc.stdin.write(f' {str(k.workgroup_size)}') + if isinstance(k.workgroup_size, list): + proc.stdin.write(" " + ','.join([str(f) for f in k.workgroup_size])) + else: + proc.stdin.write(f' {str(k.workgroup_size)}') + proc.stdin.write(' 1' if half_lds else ' 0') - proc.stdin.write(' 1' if direct_to_from_reg else ' 0') + + direct_to_from_reg = getattr(k, 'direct_to_from_reg', True) + + if isinstance(direct_to_from_reg, list): + proc.stdin.write(" " + ','.join(['1' if f else '0' for f in direct_to_from_reg])) + else: + # for unspecified direct_to_from_reg, default is True only for CS_KERNEL_STOCKHAM and SBCC + direct_to_from_reg = getattr(k, 'direct_to_from_reg', True) + proc.stdin.write(' 1' if direct_to_from_reg else ' 0') + + if hasattr(k, 'dims'): + proc.stdin.write(" " + ','.join([str(f) for f in k.dims])) + + if hasattr(k, 'factors_pp'): + proc.stdin.write(" " + ','.join([str(f) + for f in k.factors_pp[0]]) + " ") + proc.stdin.write(','.join([str(f) + for f in k.factors_pp[1]]) + " ") + proc.stdin.write(str(k.length[1])) + proc.stdin.write(f' {k.scheme}') proc.stdin.write(f' {kernel_name(k)}\n') @@ -1211,11 +1248,12 @@ def cli(): # kernels = [] - # move 2d out from all, no need to iterate the 2d-kernels for non-2d patterns + # move 2d out from all, no need to iterate the 2d-kernels for non-2d patterns kernels_2d = list_2d_kernels() + kernel_3d_pp = list_3d_partial_pass_kernels() all_kernels = list_small_kernels() + list_large_kernels() - kernels += all_kernels + kernels_2d + kernels += all_kernels + kernels_2d + kernel_3d_pp kernels = unique(kernels) diff --git a/library/src/include/compute_scheme.h b/library/src/include/compute_scheme.h index 13593e58531..fc70cb22b00 100644 --- a/library/src/include/compute_scheme.h +++ b/library/src/include/compute_scheme.h @@ -79,7 +79,8 @@ enum ComputeScheme CS_3D_BLOCK_CR, CS_3D_RC, CS_KERNEL_3D_STOCKHAM_BLOCK_CC, // not implemented yet - CS_KERNEL_3D_SINGLE // not implemented yet + CS_KERNEL_3D_SINGLE, // not implemented yet + CS_3D_PP }; // print abbreviation for kernel scheme diff --git a/library/src/include/function_map_key.h b/library/src/include/function_map_key.h index 7a72a8fe7bf..1da075d4820 100644 --- a/library/src/include/function_map_key.h +++ b/library/src/include/function_map_key.h @@ -478,4 +478,81 @@ struct SimpleHash } }; +struct FMKeyPP +{ + std::array lengths; + rocfft_precision precision; + ComputeScheme scheme = CS_3D_PP; + KernelConfig kernel_config_1 = KernelConfig::EmptyConfig(); + KernelConfig kernel_config_2 = KernelConfig::EmptyConfig(); + + FMKeyPP() = default; + FMKeyPP(const FMKeyPP&) = default; + + // with every data + FMKeyPP(size_t length0, + size_t length1, + size_t length2, + rocfft_precision precision, + ComputeScheme scheme = CS_3D_PP, + KernelConfig kernel_config_1 = KernelConfig::EmptyConfig(), + KernelConfig kernel_config_2 = KernelConfig::EmptyConfig()) + : lengths({length0, length1, length2}) + , precision(precision) + , scheme(scheme) + , kernel_config_1(kernel_config_1) + , kernel_config_2(kernel_config_2) + { + } + + FMKeyPP& operator=(const FMKeyPP&) = default; + + bool operator==(const FMKeyPP& rhs) const + { + return std::tie(lengths, precision, scheme, kernel_config_1, kernel_config_2) + == std::tie(rhs.lengths, + rhs.precision, + rhs.scheme, + rhs.kernel_config_1, + rhs.kernel_config_2); + } + + bool operator!=(const FMKeyPP& rhs) const + { + return !((*this) == rhs); + } + + bool operator<(const FMKeyPP& rhs) const + { + return std::tie(lengths, precision, scheme, kernel_config_1, kernel_config_2) + < std::tie(rhs.lengths, + rhs.precision, + rhs.scheme, + rhs.kernel_config_1, + rhs.kernel_config_2); + } + + static FMKeyPP EmptyFMKeyPP() + { + static FMKeyPP empty; + return empty; + } +}; + +struct SimpleHashPP +{ + size_t operator()(const FMKeyPP& p) const noexcept + { + size_t h = 0; + for(auto& v : p.lengths) + h ^= std::hash{}(v); + h ^= std::hash{}(p.precision); + h ^= std::hash{}(p.scheme); + h ^= std::hash{}(p.kernel_config_1); + h ^= std::hash{}(p.kernel_config_2); + + return h; + } +}; + #endif diff --git a/library/src/include/function_pool.h b/library/src/include/function_pool.h index 5a179ed57a2..fd4fd8d2671 100644 --- a/library/src/include/function_pool.h +++ b/library/src/include/function_pool.h @@ -111,13 +111,22 @@ struct FFTKernel } }; +typedef std::unordered_map FPKeyMap; +typedef std::unordered_map FPKeyMapPP; + +typedef std::unordered_map FPMap; +typedef std::unordered_map, SimpleHashPP> FPMapPP; + struct function_pool_data { // when AOT generator adds a default key-kernel, // we get the keys of two version: empty-config vs full-config // make the pair as an entry in a map so that we know they are the same things - std::unordered_map def_key_pool; - std::unordered_map function_map; + FPKeyMap def_key_pool; + FPKeyMapPP def_key_pool_pp; + + FPMap function_map; + FPMapPP function_map_pp; function_pool_data(); @@ -130,9 +139,9 @@ struct function_pool_data class function_pool { - unsigned int max_lds_bytes; - std::unordered_map& def_key_pool; - std::unordered_map& function_map; + unsigned int max_lds_bytes; + FPKeyMap& def_key_pool; + FPMap& function_map; const FMKey& get_actual_key(const FMKey& key) const { @@ -250,10 +259,10 @@ class function_pool // That is, the default kernel-config we set in the kernel-generator.py we save a pair as // that allows us to use // the empty-config key to get the default kernel -static bool insert_default_entry(const FMKey& def_key, - const FFTKernel& kernel, - std::unordered_map& def_key_pool, - std::unordered_map& function_map) +static bool insert_default_entry(const FMKey& def_key, + const FFTKernel& kernel, + FPKeyMap& def_key_pool, + FPMap& function_map) { // simple_key means the same thing as def_key, but we just remove kernel-config // so we don't need to know the exact config when we're lookin' for the default kernel From acfe158bf5c23cad5be65634ca004b0429371136 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 8 May 2025 15:10:31 -0600 Subject: [PATCH 28/69] WIP --- library/src/assignment_policy.cpp | 2 +- library/src/compute_scheme.cpp | 4 + library/src/device/generator.py | 6 + library/src/device/generator/stockham_gen.cpp | 197 ++++++++++-- library/src/device/generator/stockham_gen.h | 30 +- .../src/device/generator/stockham_gen_base.h | 10 + .../src/device/generator/stockham_pp_gen_cc.h | 35 ++- .../src/device/generator/stockham_pp_gen_rr.h | 24 +- library/src/device/kernel-generator.py | 116 +++++-- library/src/include/compute_scheme.h | 2 + library/src/include/function_map_key.h | 89 ++++-- library/src/include/function_pool.h | 119 ++++++- library/src/include/node_factory.h | 1 + library/src/include/rtc_stockham_gen.h | 49 ++- library/src/include/tree_node.h | 50 ++- library/src/include/tree_node_1D.h | 61 ++++ library/src/include/tree_node_3D.h | 20 +- library/src/node_factory.cpp | 39 ++- library/src/plan.cpp | 79 ++--- library/src/rocfft_aot_helper.cpp | 20 +- library/src/rocfft_kernel_config_search.cpp | 6 +- library/src/rtc_stockham_gen.cpp | 81 +++-- library/src/rtc_stockham_kernel.cpp | 92 +++--- library/src/tree_node.cpp | 75 +++-- library/src/tree_node_1D.cpp | 116 ++++--- library/src/tree_node_3D.cpp | 290 +++++++++--------- 26 files changed, 1094 insertions(+), 519 deletions(-) diff --git a/library/src/assignment_policy.cpp b/library/src/assignment_policy.cpp index f081662e0db..5cb9c834aeb 100644 --- a/library/src/assignment_policy.cpp +++ b/library/src/assignment_policy.cpp @@ -1235,7 +1235,7 @@ void AssignmentPolicy::PadPlan(ExecPlan& execPlan) RecursiveTraverse(execPlan.rootPlan.get(), [&execPlan](TreeNode* n) { // Skip nodes that have partial passes - if(n->applyPartialPass) + if(n->isPartialPassEnabled()) return; // Look for nodes that begin writing to a new temp buffer diff --git a/library/src/compute_scheme.cpp b/library/src/compute_scheme.cpp index 26a6073f405..b801012f499 100644 --- a/library/src/compute_scheme.cpp +++ b/library/src/compute_scheme.cpp @@ -34,7 +34,9 @@ static const std::map& ComputeSchemetoStringMap() static const std::map ComputeSchemetoString = {{ENUMSTR(CS_NONE)}, {ENUMSTR(CS_KERNEL_STOCKHAM)}, + {ENUMSTR(CS_KERNEL_STOCKHAM_PP)}, {ENUMSTR(CS_KERNEL_STOCKHAM_BLOCK_CC)}, + {ENUMSTR(CS_KERNEL_STOCKHAM_PP_BLOCK_CC)}, {ENUMSTR(CS_KERNEL_STOCKHAM_BLOCK_RC)}, {ENUMSTR(CS_KERNEL_STOCKHAM_BLOCK_CR)}, {ENUMSTR(CS_KERNEL_TRANSPOSE)}, @@ -144,8 +146,10 @@ std::string PrintKernelSchemeAbbr(ComputeScheme cs) switch(cs) { case CS_KERNEL_STOCKHAM: + case CS_KERNEL_STOCKHAM_PP: return "sbrr"; case CS_KERNEL_STOCKHAM_BLOCK_CC: + case CS_KERNEL_STOCKHAM_PP_BLOCK_CC: return "sbcc"; case CS_KERNEL_STOCKHAM_BLOCK_CR: return "sbcr"; diff --git a/library/src/device/generator.py b/library/src/device/generator.py index 2f1ed57ef6f..ccf96907e15 100644 --- a/library/src/device/generator.py +++ b/library/src/device/generator.py @@ -817,6 +817,12 @@ def assert_insert(self, key, value, def_key_pool, function_map): function_map)).inline() throw = StatementList(Throw('std::runtime_error("' + str(key) + '")')) return If(Equal(insert, "false"), throw) + def assert_pp_insert(self, key, value_1, value_2, def_key_pool, function_map): + insert = Call('insert_default_pp_entry', + arguments=ArgumentList(key, value_1, value_2, def_key_pool, + function_map)).inline() + throw = StatementList(Throw('std::runtime_error("' + str(key) + '")')) + return If(Equal(insert, "false"), throw) # def __getitem__(self, idx): # return ArrayElement(self.name, idx) diff --git a/library/src/device/generator/stockham_gen.cpp b/library/src/device/generator/stockham_gen.cpp index 9b2aae5596a..042be0b8914 100644 --- a/library/src/device/generator/stockham_gen.cpp +++ b/library/src/device/generator/stockham_gen.cpp @@ -42,16 +42,24 @@ using namespace std::placeholders; // enough to genernate the function pool entry struct GeneratedLauncher { - GeneratedLauncher(StockhamKernel& kernel, - const std::string& scheme, - bool double_precision, - const std::string& sbrc_type, - const std::string& sbrc_transpose_type) + GeneratedLauncher(StockhamKernel& kernel, + const std::string& scheme, + const std::string& pp_child_scheme, + const std::vector& pp_factors, + const unsigned int& pp_current_dim, + const unsigned int& pp_off_dim, + bool double_precision, + const std::string& sbrc_type, + const std::string& sbrc_transpose_type) : scheme(scheme) + , pp_child_scheme(pp_child_scheme) + , pp_factors(pp_factors) + , pp_current_dim(pp_current_dim) + , pp_off_dim(pp_off_dim) , lengths(kernel.launcher_lengths()) , factors(kernel.launcher_factors()) - , transforms_per_block(kernel.transforms_per_block) - , workgroup_size(kernel.workgroup_size) + , transforms_per_block(kernel.launcher_transforms_per_block()) + , workgroup_size(kernel.launcher_workgroup_size()) , half_lds(kernel.half_lds) , direct_to_from_reg(kernel.direct_to_from_reg) , sbrc_type(sbrc_type) @@ -61,6 +69,10 @@ struct GeneratedLauncher } std::string scheme; + std::string pp_child_scheme; + std::vector pp_factors; + unsigned int pp_current_dim; + unsigned int pp_off_dim; std::vector lengths; std::vector factors; @@ -112,6 +124,10 @@ struct GeneratedLauncher add_member("sbrc_type", quote_str(sbrc_type)); add_member("sbrc_transpose_type", quote_str(sbrc_transpose_type)); add_member("double_precision", double_precision ? "true" : "false"); + add_member("pp_child_scheme", quote_str(pp_child_scheme)); + add_member("pp_factors", vec_to_list(pp_factors)); + add_member("pp_current_dim", std::to_string(pp_current_dim)); + add_member("pp_off_dim", std::to_string(pp_off_dim)); output += "}"; return output; @@ -129,6 +145,10 @@ struct LaunchSuffix void make_launcher(const std::vector& precision_types, const std::vector& launcher_suffixes, StockhamKernel& kernel, + const std::string& pp_child_scheme, + const std::vector& pp_factors, + const unsigned int& pp_current_dim, + const unsigned int& pp_off_dim, std::vector& generated_launchers) { for(auto precision_type : precision_types) @@ -137,6 +157,10 @@ void make_launcher(const std::vector& precision_types, { generated_launchers.emplace_back(kernel, launcher.scheme, + pp_child_scheme, + pp_factors, + pp_current_dim, + pp_off_dim, precision_type == rocfft_precision_double, launcher.sbrc_type, launcher.sbrc_transpose_type); @@ -208,11 +232,10 @@ void output_json(const std::vector& launchers, } void stockham_partial_pass_variants(const std::string& kernel_name, - const size_t& pp_length, const StockhamGeneratorSpecs& specs1, - const std::vector& pp_factors1, const StockhamGeneratorSpecs& specs2, - const std::vector& pp_factors2, + const StockhamPartialPassParams& params_1, + const StockhamPartialPassParams& params_2, std::ostream& output) { std::vector launchers; @@ -220,25 +243,73 @@ void stockham_partial_pass_variants(const std::string& kernel_name if(specs1.scheme == "CS_3D_PP" && specs2.scheme == "CS_3D_PP") { // SBRR_PP + SBCC_PP - if(specs1.static_dim == 0 && specs2.static_dim == 2) + if(params_1.current_dim == 0 && params_2.current_dim == 2) { - std::vector factors1(pp_factors1.begin(), pp_factors1.end()); - StockhamPartialPassKernelRR kernelRR(specs1, factors1, pp_length); - make_launcher( - specs1.precisions, {{"pp_stoc", specs1.scheme, "", ""}}, kernelRR, launchers); - - std::vector factors2(pp_factors2.begin(), pp_factors2.end()); - StockhamPartialPassKernelCC kernelCC(specs2, false, factors2); - make_launcher( - specs2.precisions, {{"pp_sbcc", specs2.scheme, "", ""}}, kernelCC, launchers); + StockhamPartialPassKernelRR kernelRR(specs1, params_1); + make_launcher(specs1.precisions, + {{"pp_stoc", specs1.scheme, "", ""}}, + kernelRR, + "CS_KERNEL_STOCKHAM_PP", + params_1.factors_off_dim, + params_1.current_dim, + params_1.off_dim, + launchers); + + StockhamPartialPassKernelCC kernelCC(specs2, params_2, false); + make_launcher(specs2.precisions, + {{"pp_sbcc", specs2.scheme, "", ""}}, + kernelCC, + "CS_KERNEL_STOCKHAM_PP_BLOCK_CC", + params_2.factors_off_dim, + params_2.current_dim, + params_2.off_dim, + launchers); } // SBRR_PP + SBCC_PP - else if(specs1.static_dim == 1 && specs2.static_dim == 2) + else if(params_1.current_dim == 1 && params_2.current_dim == 2) { + StockhamPartialPassKernelRR kernelRR(specs1, params_1); + make_launcher(specs1.precisions, + {{"pp_stoc", specs1.scheme, "", ""}}, + kernelRR, + "CS_KERNEL_STOCKHAM_PP", + params_1.factors_off_dim, + params_1.current_dim, + params_1.off_dim, + launchers); + + StockhamPartialPassKernelCC kernelCC(specs2, params_2, false); + make_launcher(specs2.precisions, + {{"pp_sbcc", specs2.scheme, "", ""}}, + kernelCC, + "CS_KERNEL_STOCKHAM_PP_BLOCK_CC", + params_2.factors_off_dim, + params_2.current_dim, + params_2.off_dim, + launchers); } // SBRR_PP + SBRR_PP - else if(specs1.static_dim == 0 && specs2.static_dim == 1) + else if(params_1.current_dim == 0 && params_2.current_dim == 1) { + StockhamPartialPassKernelRR kernelRR1(specs1, params_1); + make_launcher(specs1.precisions, + {{"pp_stoc", specs1.scheme, "", ""}}, + kernelRR1, + "CS_KERNEL_STOCKHAM_PP", + params_1.factors_off_dim, + params_1.current_dim, + params_1.off_dim, + launchers); + + StockhamPartialPassKernelRR kernelRR2(specs2, params_2); + make_launcher(specs2.precisions, + {{"pp_stoc", specs2.scheme, "", ""}}, + kernelRR2, + "CS_KERNEL_STOCKHAM_PP", + params_2.factors_off_dim, + params_2.current_dim, + params_2.off_dim, + launchers); } else { @@ -264,12 +335,26 @@ void stockham_variants(const std::string& kernel_name, if(specs.scheme == "CS_KERNEL_STOCKHAM") { StockhamKernelRR kernel(specs); - make_launcher(specs.precisions, {{"stoc", specs.scheme, "", ""}}, kernel, launchers); + make_launcher(specs.precisions, + {{"stoc", specs.scheme, "", ""}}, + kernel, + "CS_NONE", + std::vector(), + 0, + 0, + launchers); } else if(specs.scheme == "CS_KERNEL_STOCKHAM_BLOCK_CC") { StockhamKernelCC kernel(specs, false, false); - make_launcher(specs.precisions, {{"sbcc", specs.scheme, "", ""}}, kernel, launchers); + make_launcher(specs.precisions, + {{"sbcc", specs.scheme, "", ""}}, + kernel, + "CS_NONE", + std::vector(), + 0, + 0, + launchers); } else if(specs.scheme == "CS_KERNEL_STOCKHAM_BLOCK_RC") { @@ -312,13 +397,27 @@ void stockham_variants(const std::string& kernel_name, "SBRC_3D_FFT_ERC_TRANS_Z_XY", "TILE_UNALIGNED"}); - make_launcher(specs.precisions, suffixes, kernel, launchers); + make_launcher(specs.precisions, + suffixes, + kernel, + "CS_NONE", + std::vector(), + 0, + 0, + launchers); } else if(specs.scheme == "CS_KERNEL_STOCKHAM_BLOCK_CR") { StockhamKernelCR kernel(specs); - make_launcher(specs.precisions, {{"sbcr", specs.scheme, "", ""}}, kernel, launchers); + make_launcher(specs.precisions, + {{"sbcr", specs.scheme, "", ""}}, + kernel, + "CS_NONE", + std::vector(), + 0, + 0, + launchers); } else if(specs.scheme == "CS_KERNEL_2D_SINGLE") { @@ -327,8 +426,15 @@ void stockham_variants(const std::string& kernel_name, // output 2D launchers for(auto prec_type : specs.precisions) { - launchers.emplace_back( - fused2d, specs.scheme, (prec_type == rocfft_precision_double), "", ""); + launchers.emplace_back(fused2d, + specs.scheme, + "CS_NONE", + std::vector(), + 0, + 0, + (prec_type == rocfft_precision_double), + "", + ""); } } else @@ -368,12 +474,11 @@ int main() ++arg; scheme = *arg; - unsigned int pp_length; - std::vector dims, pp_factors_1, pp_factors_2; + std::vector parent_length, dims, pp_factors_1, pp_factors_2; if(scheme == "CS_3D_PP") { ++arg; - pp_length = std::stoul(*arg); + parent_length = parse_uints_csv(*arg); ++arg; pp_factors_2 = parse_uints_csv(*arg); @@ -420,6 +525,29 @@ int main() if(dims.size() != 2) throw std::runtime_error("CS_3D_PP requires two dimensions configuration"); + unsigned int dims_sum = 0, off_dim = 0; + for(const auto& dim : dims) + { + if(dim < 0 || dim > 2) + throw std::runtime_error("Invalid dimensions configuration for CS_3D_PP"); + + dims_sum += dim; + } + switch(dims_sum) + { + case 1: + off_dim = 2; + break; + case 2: + off_dim = 1; + break; + case 3: + off_dim = 0; + break; + default: + throw std::runtime_error("Invalid dimensions configuration for CS_3D_PP"); + } + if(threads_per_transform.size() != 2) throw std::runtime_error( "CS_3D_PP requires two threads_per_transform configuration"); @@ -428,17 +556,18 @@ int main() throw std::runtime_error("CS_3D_PP requires two direct_to_from_reg configuration"); StockhamGeneratorSpecs specs1(factors1, {}, precisions, workgroup_size[0], scheme); - specs1.static_dim = dims[0]; specs1.direct_to_from_reg = direct_to_from_reg[0]; specs1.threads_per_transform = threads_per_transform[0]; StockhamGeneratorSpecs specs2(factors2, {}, precisions, workgroup_size[1], scheme); - specs2.static_dim = dims[1]; specs2.direct_to_from_reg = direct_to_from_reg[1]; specs2.threads_per_transform = threads_per_transform[1]; + StockhamPartialPassParams pp_params_1(parent_length, dims[0], off_dim, pp_factors_1); + StockhamPartialPassParams pp_params_2(parent_length, dims[1], off_dim, pp_factors_2); + stockham_partial_pass_variants( - kernel_name, pp_length, specs1, pp_factors_1, specs2, pp_factors_2, std::cout); + kernel_name, specs1, specs2, pp_params_1, pp_params_2, std::cout); } else { diff --git a/library/src/device/generator/stockham_gen.h b/library/src/device/generator/stockham_gen.h index 6d44c67a49d..4a400b20aef 100644 --- a/library/src/device/generator/stockham_gen.h +++ b/library/src/device/generator/stockham_gen.h @@ -60,8 +60,6 @@ struct StockhamGeneratorSpecs unsigned int static_dim = 0; std::string scheme; - unsigned int pp_dim = 0; - // this value indicating if the wgs, tpt are excatly what we want // (i.e. were already derived somewhere) // to tell StockhamKernel not to do its auto-derivation again. @@ -81,3 +79,31 @@ void stockham_variants(const std::vector& kernel_name, const StockhamGeneratorSpecs& specs, const StockhamGeneratorSpecs& specs2d, std::ostream& output); + +struct StockhamPartialPassParams +{ + StockhamPartialPassParams() = default; + + StockhamPartialPassParams(const std::vector& parent_length, + const unsigned int current_dim, + const unsigned int off_dim, + const std::vector& factors_off_dim) + : parent_length(parent_length) + , current_dim(current_dim) + , off_dim(off_dim) + , factors_off_dim(factors_off_dim) + { + } + + std::vector parent_length; + unsigned int current_dim = 0; + unsigned int off_dim = 0; + std::vector factors_off_dim; +}; + +void stockham_partial_pass_variants(const std::string& kernel_name, + const StockhamGeneratorSpecs& specs1, + const StockhamGeneratorSpecs& specs2, + const StockhamPartialPassParams& params_1, + const StockhamPartialPassParams& params_2, + std::ostream& output); diff --git a/library/src/device/generator/stockham_gen_base.h b/library/src/device/generator/stockham_gen_base.h index 803d9511cda..afe26ec4131 100644 --- a/library/src/device/generator/stockham_gen_base.h +++ b/library/src/device/generator/stockham_gen_base.h @@ -211,6 +211,16 @@ struct StockhamKernel : public StockhamGeneratorSpecs // butterfly registers Variable R{"R", "scalar_type", false, false}; + virtual unsigned int launcher_workgroup_size() + { + return workgroup_size; + } + + virtual unsigned int launcher_transforms_per_block() + { + return transforms_per_block; + } + virtual std::vector launcher_lengths() { return {length}; diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index 633acc803fd..6c1e2fbf144 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -32,16 +32,18 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC { - explicit StockhamPartialPassKernelCC(const StockhamGeneratorSpecs& specs, - bool largeTwdBatchIsTransformCount, - const std::vector& ppFactors) + explicit StockhamPartialPassKernelCC(const StockhamGeneratorSpecs& specs, + const StockhamPartialPassParams& params, + bool largeTwdBatchIsTransformCount) : StockhamKernelCC(specs, largeTwdBatchIsTransformCount, false) - , factors_pp(ppFactors) + , params(params) { large_twiddle_steps.decl_default = 3; large_twiddle_base.decl_default = 8; + factors_pp = params.factors_off_dim; + max_factor_pp = *std::max_element(factors_pp.begin(), factors_pp.end()); transforms_per_block_pp = transforms_per_block; @@ -50,9 +52,12 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC workgroup_size *= max_factor_pp; } - unsigned int transforms_per_block_pp; - unsigned int max_factor_pp; - std::vector factors_pp; + StockhamPartialPassParams params; + + unsigned int transforms_per_block_pp; + unsigned int max_factor_pp; + + std::vector factors_pp; Variable thread_lds{"thread_lds", "unsigned int"}; Variable stride_lds_pp{"stride_lds_pp", "unsigned int"}; @@ -73,6 +78,22 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC Variable global_idx{"global_idx", "unsigned int"}; Variable transpose_idx{"transpose_idx", "unsigned int"}; + std::vector launcher_lengths() override + { + return params.parent_length; + } + + unsigned int launcher_workgroup_size() override + { + return workgroup_size / max_factor_pp; + } + + unsigned int launcher_transforms_per_block() override + { + return transforms_per_block / max_factor_pp; + } + + StatementList load_global_generator(unsigned int h, unsigned int hr, unsigned int width, diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index 7f60ab074cb..d6bad12e88c 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -29,21 +29,24 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR { - explicit StockhamPartialPassKernelRR(const StockhamGeneratorSpecs& specs, - const std::vector& ppFactors, - const size_t ppLength) + explicit StockhamPartialPassKernelRR(const StockhamGeneratorSpecs& specs, + const StockhamPartialPassParams& params) : StockhamKernelRR(specs) - , factors_pp(ppFactors) - , length_pp(ppLength) + , params(params) { + length_pp = params.parent_length[params.off_dim]; + factors_pp = params.factors_off_dim; + max_factor_pp = *std::max_element(factors_pp.begin(), factors_pp.end()); R.size = Expression{std::max(nregisters, max_factor_pp)}; } - unsigned int max_factor_pp; - std::vector factors_pp; - unsigned int length_pp; + StockhamPartialPassParams params; + + unsigned int max_factor_pp; + std::vector factors_pp; + unsigned int length_pp; Variable offset_pp{"offset_pp", "size_t"}; Variable stride_lds_pp{"stride_lds_pp", "size_t"}; @@ -91,6 +94,11 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR return stmts; } + std::vector launcher_lengths() override + { + return params.parent_length; + } + StatementList load_global_generator(unsigned int h, unsigned int hr, unsigned int width, diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index 9238e797d9a..0005c537d55 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -140,6 +140,12 @@ def __str__(self): f += ', ' f += 'true' if aot_rtc else 'false' + f += ', ' + str(self.function.meta.pp_child_scheme) + f += ', ' + str(self.function.meta.pp_current_dim) + f += ', ' + str(self.function.meta.pp_off_dim) + pp_factors = getattr(self.function.meta, 'pp_factors', None) + if pp_factors is not None: + f += ', {' + cjoin(pp_factors) + '}' f += ')' return f @@ -152,11 +158,11 @@ def generate_cpu_function_pool_main(num_files): type='void', name=f'function_pool_init_{i}', value= - 'FPKeyMap& def_key_pool, FPMap& function_map' + 'std::tuple& def_keys, std::tuple& function_maps' ) call_list = StatementList() - call_args = ArgumentList('def_key_pool', 'function_map') + call_args = ArgumentList('def_keys', 'function_maps') for i in range(num_files): call_list += Call(name=f'function_pool_init_{i}', arguments=call_args) return StatementList( @@ -167,8 +173,11 @@ def generate_cpu_function_pool_main(num_files): body=call_list)) -def generate_cpu_function_pool_pieces(functions, num_files): +def generate_cpu_function_pool_pieces(functions, pp_functions, num_files): """Generate function(s) to populate the kernel function pool.""" + + all_functions = functions + pp_functions + function_map = Map('function_map') precisions = { 'sp': 'rocfft_precision_single', @@ -176,41 +185,71 @@ def generate_cpu_function_pool_pieces(functions, num_files): 'half': 'rocfft_precision_half', } var_kernel = Variable('kernel', 'FFTKernel') + var_pp_kernel_1 = Variable('pp_kernel_1', 'FFTKernel') + var_pp_kernel_2 = Variable('pp_kernel_2', 'FFTKernel') # Init list to store contents of function_pool_init function per file being generated piece_contents = [ - StatementList() + var_kernel.declaration() for _ in range(num_files) + StatementList() + var_kernel.declaration() + var_pp_kernel_1.declaration() + + var_pp_kernel_2.declaration() for _ in range(num_files) ] # Cycles through each file per loop execution to distribute work amongst N files + i = 0 + j = 0 + i_offset = 0 if len(pp_functions) == 0 else len(precisions) curr_file = 0 - - for i, f in enumerate(functions): + while i < len(all_functions) - i_offset: + f = all_functions[i] length, precision, scheme, transpose = f.meta.length, f.meta.precision, f.meta.scheme, f.meta.transpose - if isinstance(length, (int, str)): - length = [length, 0] - piece_contents[curr_file] += Assign(var_kernel, FFTKernel(f)) - key = Call( - name='FMKey', - arguments=ArgumentList(length[0], length[1], precisions[precision], - scheme, transpose or 'NONE', - 'kernel.get_kernel_config()')).inline() - piece_contents[curr_file] += function_map.assert_insert( - key, var_kernel, 'def_key_pool', 'function_map') - curr_file = (curr_file + 1) % num_files + + if scheme == 'CS_3D_PP': + piece_contents[curr_file] += Assign(var_pp_kernel_1, FFTKernel(f)) + f = all_functions[i + i_offset] + piece_contents[curr_file] += Assign(var_pp_kernel_2, FFTKernel(f)) + + key = Call( + name='FMKeyPP', + arguments=ArgumentList(length[0], length[1], length[2], + precisions[precision], + scheme, 'pp_kernel_1.get_kernel_config()', + 'pp_kernel_2.get_kernel_config()')).inline() + piece_contents[curr_file] += function_map.assert_pp_insert( + key, var_pp_kernel_1, var_pp_kernel_2, 'std::get<1>(def_keys)', 'std::get<1>(function_maps)') + + j = j + 1 + else: + if isinstance(length, (int, str)): + length = [length, 0] + piece_contents[curr_file] += Assign(var_kernel, FFTKernel(f)) + key = Call( + name='FMKey', + arguments=ArgumentList(length[0], length[1], precisions[precision], + scheme, transpose or 'NONE', + 'kernel.get_kernel_config()')).inline() + piece_contents[curr_file] += function_map.assert_insert( + key, var_kernel, 'std::get<0>(def_keys)', 'std::get<0>(function_maps)') + + if j == len(precisions): + j = 0 + i = i + len(precisions) + 1 + else: + i = i + 1 + + curr_file = (curr_file + 1) % num_files # Assemble contents of each file to return in a list pieces = [None] * num_files piece_args = ArgumentList( - 'FPKeyMap& def_key_pool', - 'FPMap& function_map') - for i in range(num_files): - pieces[i] = StatementList( + 'std::tuple& def_keys', + 'std::tuple& function_maps') + for k in range(num_files): + pieces[k] = StatementList( Include('"../include/function_pool.h"'), - Function(name=f'void function_pool_init_{i}', + Function(name=f'void function_pool_init_{k}', value=False, arguments=piece_args, - body=piece_contents[i])) + body=piece_contents[k])) return pieces @@ -1007,7 +1046,7 @@ def list_3d_partial_pass_kernels(): """Return list of to generate.""" pp_3d_kernels = [ - NS(length=[64,64,64], dims=[0, 2], factors=[[4, 4, 4],[8, 8]], factors_pp=[[16], [4]], threads_per_transform=[8, 16], workgroup_size=[128, 256], direct_to_from_reg=[False, False]) + NS(length=[64,64,64], dims=[0, 2], factors=[[4, 4, 4],[8, 8]], factors_pp=[[16], [4]], threads_per_transform=[8, 8], workgroup_size=[128, 64], direct_to_from_reg=[False, False]), ] expanded = [] @@ -1034,7 +1073,8 @@ def generate_kernel_functions(kernels, precisions, launchers_json): each kernel in `kernels`, and its variations. """ - cpu_functions = [] + kernel_functions = [] + pp_kernel_functions = [] data = Variable('data_p', 'const void *') back = Variable('back_p', 'void *') # launchers_json has kernel names as keys to a list of launchers for each kernel variant @@ -1058,6 +1098,10 @@ def generate_kernel_functions(kernels, precisions, launchers_json): half_lds = launcher.half_lds direct_to_from_reg = launcher.direct_to_from_reg scheme = launcher.scheme + pp_child_scheme = launcher.pp_child_scheme + pp_factors = launcher.pp_factors + pp_current_dim = launcher.pp_current_dim + pp_off_dim = launcher.pp_off_dim sbrc_transpose_type = launcher.sbrc_transpose_type precision = 'dp' if launcher.double_precision else 'sp' runtime_compile = kernel.runtime_compile @@ -1084,17 +1128,24 @@ def generate_kernel_functions(kernels, precisions, launchers_json): params=params, precision=p, runtime_compile=runtime_compile, - scheme=scheme, + scheme=scheme, workgroup_size=workgroup_size, transforms_per_block=transforms_per_block, threads_per_transform=tpt_list, transpose=sbrc_transpose_type, use_3steps_large_twd=use_3steps_large_twd, + pp_child_scheme=pp_child_scheme, + pp_factors = pp_factors, + pp_current_dim = pp_current_dim, + pp_off_dim = pp_off_dim )) - cpu_functions.append(f) + if (scheme == 'CS_3D_PP'): + pp_kernel_functions.append(f) + else: + kernel_functions.append(f) - return cpu_functions + return kernel_functions, pp_kernel_functions def read_subprocess(proc_output, output): @@ -1185,13 +1236,11 @@ def generate_kernels(kernels, precisions, stockham_gen): if hasattr(k, 'dims'): proc.stdin.write(" " + ','.join([str(f) for f in k.dims])) - - if hasattr(k, 'factors_pp'): proc.stdin.write(" " + ','.join([str(f) for f in k.factors_pp[0]]) + " ") proc.stdin.write(','.join([str(f) for f in k.factors_pp[1]]) + " ") - proc.stdin.write(str(k.length[1])) + proc.stdin.write(','.join([str(f) for f in k.length])) proc.stdin.write(f' {k.scheme}') proc.stdin.write(f' {kernel_name(k)}\n') @@ -1269,9 +1318,10 @@ def cli(): # if args.command == 'generate': - cpu_functions = generate_kernels(kernels, precisions, + functions, pp_functions = generate_kernels(kernels, precisions, args.stockham_gen) - func_files = generate_cpu_function_pool_pieces(cpu_functions, + func_files = generate_cpu_function_pool_pieces(functions, + pp_functions, args.num_files) for i in range(args.num_files): write(f'function_pool_init_{i}.cpp', func_files[i], format=False) diff --git a/library/src/include/compute_scheme.h b/library/src/include/compute_scheme.h index fc70cb22b00..20ae9dcec16 100644 --- a/library/src/include/compute_scheme.h +++ b/library/src/include/compute_scheme.h @@ -30,7 +30,9 @@ enum ComputeScheme { CS_NONE, CS_KERNEL_STOCKHAM, + CS_KERNEL_STOCKHAM_PP, CS_KERNEL_STOCKHAM_BLOCK_CC, + CS_KERNEL_STOCKHAM_PP_BLOCK_CC, CS_KERNEL_STOCKHAM_BLOCK_RC, CS_KERNEL_STOCKHAM_BLOCK_CR, CS_KERNEL_TRANSPOSE, diff --git a/library/src/include/function_map_key.h b/library/src/include/function_map_key.h index 1da075d4820..51f20cfb446 100644 --- a/library/src/include/function_map_key.h +++ b/library/src/include/function_map_key.h @@ -41,6 +41,7 @@ struct KernelConfig int workgroup_size = 0; std::array threads_per_transform = {0, 0}; std::vector factors = {0}; + std::vector factors_pp = {0}; // above data is what we can tune // // the followings are other information of this kernel. @@ -69,6 +70,7 @@ struct KernelConfig KernelConfig(bool use_3steps, std::vector&& factors, + std::vector&& factors_pp, int tpb, int wgs, std::array&& tpt, @@ -89,6 +91,7 @@ struct KernelConfig , workgroup_size(wgs) , threads_per_transform(tpt) , factors(factors) + , factors_pp(factors_pp) , ebType(ebType) , direction(direction) , static_dim(static_dim) @@ -109,7 +112,8 @@ struct KernelConfig transforms_per_block, workgroup_size, threads_per_transform, - factors) + factors, + factors_pp) == std::tie(rhs.use_3steps_large_twd, rhs.half_lds, rhs.direct_to_from_reg, @@ -117,7 +121,8 @@ struct KernelConfig rhs.transforms_per_block, rhs.workgroup_size, rhs.threads_per_transform, - rhs.factors); + rhs.factors, + rhs.factors_pp); } bool operator<(const KernelConfig& rhs) const @@ -129,7 +134,8 @@ struct KernelConfig transforms_per_block, workgroup_size, threads_per_transform, - factors) + factors, + factors_pp) < std::tie(rhs.use_3steps_large_twd, rhs.half_lds, rhs.direct_to_from_reg, @@ -137,7 +143,8 @@ struct KernelConfig rhs.transforms_per_block, rhs.workgroup_size, rhs.threads_per_transform, - rhs.factors); + rhs.factors, + rhs.factors_pp); } std::string Print() const @@ -160,6 +167,15 @@ struct KernelConfig } ss << "]"; + ss << ", factors_pp: ["; + COMMA = ""; + for(auto factor : factors_pp) + { + ss << COMMA << factor; + COMMA = ", "; + } + ss << "]"; + ss << "}"; return ss.str(); @@ -194,9 +210,13 @@ namespace std // which means the maximal factorization pass is 8 auto factors_max_len = config.factors; factors_max_len.resize(TWIDDLES_MAX_RADICES); - for(auto& v : factors_max_len) h ^= std::hash{}(v); + + auto factors_pp_max_len = config.factors_pp; + factors_pp_max_len.resize(TWIDDLES_MAX_RADICES); + for(auto& v : factors_pp_max_len) + h ^= std::hash{}(v); return h; } }; @@ -220,6 +240,7 @@ struct ToString str += FieldDescriptor().describe("wgs", value.workgroup_size) + ","; str += VectorFieldDescriptor().describe("tpt", tpt) + ","; str += VectorFieldDescriptor().describe("factors", value.factors) + ","; + str += VectorFieldDescriptor().describe("factors_pp", value.factors_pp) + ","; // below: not tunable data, for AOT cache str += FieldDescriptor().describe("ebtype", PrintEBType(value.ebType)) + ","; str += FieldDescriptor().describe("direction", value.direction) + ","; @@ -254,6 +275,7 @@ struct FromString FieldParser().parse("wgs", ret.workgroup_size, current); VectorFieldParser().parse("tpt", tpt, current); VectorFieldParser().parse("factors", ret.factors, current); + VectorFieldParser().parse("factors_pp", ret.factors_pp, current); if(DescriptorFormatVersion::UsingVersion < 2) { @@ -284,6 +306,23 @@ struct FromString } }; +struct FMKeyBase +{ + FMKeyBase(rocfft_precision precision, ComputeScheme scheme = CS_NONE) + : precision(precision) + , scheme(scheme) + { + } + + FMKeyBase() = default; + FMKeyBase(const FMKeyBase&) = default; + + virtual ~FMKeyBase() {}; + + rocfft_precision precision; + ComputeScheme scheme; +}; + // length, precision, scheme are theose fundemantal information of a kernel; // SBRC_TRANS is also neccessary for SBRC or SBRC_3D, but for non-SBRC, it is just NONE // And the newly added KernerlConfig is the key to supporting the "multi-configurations". @@ -297,13 +336,12 @@ struct FromString // (And that is what exactly "fuction_pool::insert_default_entry()" and // "function_pool::get_actual_key()"" is doing // -struct FMKey +struct FMKey : public FMKeyBase { std::array lengths; - rocfft_precision precision; - ComputeScheme scheme = CS_KERNEL_STOCKHAM; - SBRC_TRANSPOSE_TYPE sbrcTrans = NONE; - KernelConfig kernel_config = KernelConfig::EmptyConfig(); + + SBRC_TRANSPOSE_TYPE sbrcTrans = NONE; + KernelConfig kernel_config = KernelConfig::EmptyConfig(); FMKey() = default; FMKey(const FMKey&) = default; @@ -314,11 +352,11 @@ struct FMKey ComputeScheme scheme = CS_KERNEL_STOCKHAM, SBRC_TRANSPOSE_TYPE transpose = NONE, KernelConfig kernel_config = KernelConfig::EmptyConfig()) - : lengths({length0, 0}) - , precision(precision) - , scheme(scheme) + : FMKeyBase(precision, scheme) + , lengths({length0, 0}) , sbrcTrans(transpose) , kernel_config(kernel_config) + { } @@ -329,9 +367,8 @@ struct FMKey ComputeScheme scheme = CS_KERNEL_2D_SINGLE, SBRC_TRANSPOSE_TYPE transpose = NONE, KernelConfig kernel_config = KernelConfig::EmptyConfig()) - : lengths({length0, length1}) - , precision(precision) - , scheme(scheme) + : FMKeyBase(precision, scheme) + , lengths({length0, length1}) , sbrcTrans(transpose) , kernel_config(kernel_config) { @@ -478,11 +515,14 @@ struct SimpleHash } }; -struct FMKeyPP +// TODO: Check if FMKeyPP and FMKey can derive from a common base class +// This should simplfy operations in the function_pool, e.g., +// FMKey key = GetKernelKey(); +// if(!pool.has_function(key)) +// in LeafNode::KernelCheck(std::vector& kernel_keys) +struct FMKeyPP : public FMKeyBase { std::array lengths; - rocfft_precision precision; - ComputeScheme scheme = CS_3D_PP; KernelConfig kernel_config_1 = KernelConfig::EmptyConfig(); KernelConfig kernel_config_2 = KernelConfig::EmptyConfig(); @@ -497,9 +537,8 @@ struct FMKeyPP ComputeScheme scheme = CS_3D_PP, KernelConfig kernel_config_1 = KernelConfig::EmptyConfig(), KernelConfig kernel_config_2 = KernelConfig::EmptyConfig()) - : lengths({length0, length1, length2}) - , precision(precision) - , scheme(scheme) + : FMKeyBase(precision, scheme) + , lengths({length0, length1, length2}) , kernel_config_1(kernel_config_1) , kernel_config_2(kernel_config_2) { @@ -537,6 +576,12 @@ struct FMKeyPP static FMKeyPP empty; return empty; } + + // TODO: Implement this + bool base_lds_usage_fits(unsigned int lds_size) const + { + return true; + } }; struct SimpleHashPP diff --git a/library/src/include/function_pool.h b/library/src/include/function_pool.h index fd4fd8d2671..889355ee341 100644 --- a/library/src/include/function_pool.h +++ b/library/src/include/function_pool.h @@ -41,6 +41,27 @@ inline std::string PrintMissingKernelInfo(const FMKey& key) return msg.str(); } +struct PartialPassParams +{ + PartialPassParams() = default; + + PartialPassParams(ComputeScheme scheme, + unsigned int current_dim, + unsigned int off_dim, + std::vector factors_off_dim) + : scheme(scheme) + , current_dim(current_dim) + , off_dim(off_dim) + , factors_off_dim(factors_off_dim) + { + } + + ComputeScheme scheme; + unsigned int current_dim; + unsigned int off_dim; + std::vector factors_off_dim; +}; + struct FFTKernel { std::vector factors; @@ -61,19 +82,25 @@ struct FFTKernel // build time), using runtime compilation. bool aot_rtc = false; + PartialPassParams pp_params; + FFTKernel() = default; FFTKernel(const FFTKernel&) = default; FFTKernel& operator=(const FFTKernel&) = default; - FFTKernel(bool use_3steps, - std::vector&& factors, - int tpb, - int wgs, - std::array&& tpt, - bool half_lds = false, - bool direct_to_from_reg = false, - bool aot_rtc = false) + FFTKernel(bool use_3steps, + std::vector&& factors, + int tpb, + int wgs, + std::array&& tpt, + bool half_lds = false, + bool direct_to_from_reg = false, + bool aot_rtc = false, + ComputeScheme pp_scheme = CS_NONE, + unsigned int pp_current_dim = 0, + unsigned int pp_off_dim = 0, + std::vector&& pp_factors_off_dim = std::vector()) : factors(factors) , transforms_per_block(tpb) , workgroup_size(wgs) @@ -82,6 +109,7 @@ struct FFTKernel , half_lds(half_lds) , direct_to_from_reg(direct_to_from_reg) , aot_rtc(aot_rtc) + , pp_params(pp_scheme, pp_current_dim, pp_off_dim, pp_factors_off_dim) { } @@ -122,11 +150,8 @@ struct function_pool_data // when AOT generator adds a default key-kernel, // we get the keys of two version: empty-config vs full-config // make the pair as an entry in a map so that we know they are the same things - FPKeyMap def_key_pool; - FPKeyMapPP def_key_pool_pp; - - FPMap function_map; - FPMapPP function_map_pp; + std::tuple def_keys; + std::tuple function_maps; function_pool_data(); @@ -141,7 +166,10 @@ class function_pool { unsigned int max_lds_bytes; FPKeyMap& def_key_pool; - FPMap& function_map; + FPKeyMapPP& def_pp_key_pool; + + FPMap& function_map; + FPMapPP& function_pp_map; const FMKey& get_actual_key(const FMKey& key) const { @@ -155,11 +183,21 @@ class function_pool return key; } + const FMKeyPP& get_actual_pp_key(const FMKeyPP& key) const + { + if(def_pp_key_pool.count(key) > 0) + return def_pp_key_pool.at(key); + else + return key; + } + public: function_pool(unsigned int max_lds_bytes) : max_lds_bytes(max_lds_bytes) - , def_key_pool(function_pool_data::get_function_pool_data().def_key_pool) - , function_map(function_pool_data::get_function_pool_data().function_map) + , def_key_pool(std::get<0>(function_pool_data::get_function_pool_data().def_keys)) + , def_pp_key_pool(std::get<1>(function_pool_data::get_function_pool_data().def_keys)) + , function_map(std::get<0>(function_pool_data::get_function_pool_data().function_maps)) + , function_pp_map(std::get<1>(function_pool_data::get_function_pool_data().function_maps)) { // We would only see zero if we received a // default-constructed device prop struct, which means @@ -173,7 +211,7 @@ class function_pool { } - function_pool(function_pool& p) = delete; + function_pool(function_pool& p) = delete; function_pool& operator=(const function_pool&) = delete; ~function_pool() = default; @@ -196,6 +234,14 @@ class function_pool return function_map.count(real_key) > 0; } + bool has_pp_function(const FMKeyPP& key) const + { + auto real_key = get_actual_pp_key(key); + if(!real_key.base_lds_usage_fits(max_lds_bytes)) + return false; + return function_pp_map.count(real_key) > 0; + } + size_t get_largest_length(rocfft_precision precision) const { auto supported = get_lengths(precision, CS_KERNEL_STOCKHAM); @@ -230,6 +276,25 @@ class function_pool return function_map.at(real_key); } + FFTKernel get_pp_kernel(const FMKeyPP& key, ComputeScheme scheme) const + { + auto real_key = get_actual_pp_key(key); + if(!real_key.base_lds_usage_fits(max_lds_bytes)) + throw std::out_of_range("kernel not found in partial-pass map"); + + auto kernel_list = function_pp_map.at(real_key); + + auto scheme_0 = kernel_list[0].pp_params.scheme; + auto scheme_1 = kernel_list[1].pp_params.scheme; + + if(scheme == scheme_0) + return kernel_list[0]; + else if(scheme == scheme_1) + return kernel_list[1]; + else + throw std::out_of_range("kernel not found in partial-pass map"); + } + // helper for common used bool has_SBCC_kernel(size_t length, rocfft_precision precision) const { @@ -275,4 +340,24 @@ static bool insert_default_entry(const FMKey& def_key, return std::get<1>(function_map.emplace(def_key, kernel)); } +static bool insert_default_pp_entry(const FMKeyPP& def_key, + const FFTKernel& kernel_0, + const FFTKernel& kernel_1, + FPKeyMapPP& def_key_pool, + FPMapPP& function_map) +{ + // simple_key means the same thing as def_key, but we just remove kernel-config + // so we don't need to know the exact config when we're lookin' for the default kernel + FMKeyPP simple_key(def_key); + simple_key.kernel_config_1 = KernelConfig::EmptyConfig(); + simple_key.kernel_config_2 = KernelConfig::EmptyConfig(); + + def_key_pool.emplace(simple_key, def_key); + + std::array kernels = {kernel_0, kernel_1}; + + // still use the detailed key with config to maintain the function map + return std::get<1>(function_map.emplace(def_key, kernels)); +} + #endif // FUNCTION_POOL_H diff --git a/library/src/include/node_factory.h b/library/src/include/node_factory.h index d2d2d4329a6..fd3f395b092 100644 --- a/library/src/include/node_factory.h +++ b/library/src/include/node_factory.h @@ -72,6 +72,7 @@ class NodeFactory NodeMetaData& nodeData); // using scheme CS_2D_RC or not static bool use_CS_3D_BLOCK_RC(const function_pool& pool, NodeMetaData& nodeData); static bool use_CS_3D_RC(const function_pool& pool, NodeMetaData& nodeData); + static bool use_CS_3D_PP(const function_pool& pool, NodeMetaData& nodeData); // how many SBRC kernels can we put into a 3D transform? static size_t count_3D_SBRC_nodes(const function_pool& pool, NodeMetaData& nodeData); diff --git a/library/src/include/rtc_stockham_gen.h b/library/src/include/rtc_stockham_gen.h index 9febc778aae..cd5753362f2 100644 --- a/library/src/include/rtc_stockham_gen.h +++ b/library/src/include/rtc_stockham_gen.h @@ -57,30 +57,29 @@ std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, // generate source for RTC stockham kernel. transforms_per_block may // be nullptr, but if non-null, stockham_rtc stores the number of // transforms each threadblock will do -std::string stockham_rtc(const StockhamGeneratorSpecs& specs, - const StockhamGeneratorSpecs& specs2d, - unsigned int* transforms_per_block, - const std::string& kernel_name, - ComputeScheme scheme, - int direction, - rocfft_precision precision, - rocfft_result_placement placement, - rocfft_array_type inArrayType, - rocfft_array_type outArrayType, - bool unit_stride, - size_t largeTwdBase, - size_t largeTwdSteps, - bool largeTwdBatchIsTransformCount, - EmbeddedType ebtype, - DirectRegType dir2regMode, - IntrinsicAccessType intrinsicMode, - SBRC_TRANSPOSE_TYPE transpose_type, - CallbackType cbtype, - const BluesteinFuseType& fuseBlue, - const PartialPassType& ppType, - const std::vector& ppFactors, - const size_t ppLength, - const LoadOps& loadOps, - const StoreOps& storeOps); +std::string stockham_rtc(const StockhamGeneratorSpecs& specs, + const StockhamGeneratorSpecs& specs2d, + const StockhamPartialPassParams& params_pp, + unsigned int* transforms_per_block, + const std::string& kernel_name, + ComputeScheme scheme, + int direction, + rocfft_precision precision, + rocfft_result_placement placement, + rocfft_array_type inArrayType, + rocfft_array_type outArrayType, + bool unit_stride, + size_t largeTwdBase, + size_t largeTwdSteps, + bool largeTwdBatchIsTransformCount, + EmbeddedType ebtype, + DirectRegType dir2regMode, + IntrinsicAccessType intrinsicMode, + SBRC_TRANSPOSE_TYPE transpose_type, + CallbackType cbtype, + const BluesteinFuseType& fuseBlue, + const PartialPassType& ppType, + const LoadOps& loadOps, + const StoreOps& storeOps); #endif diff --git a/library/src/include/tree_node.h b/library/src/include/tree_node.h index 1abb90d3731..fe28cf5e384 100644 --- a/library/src/include/tree_node.h +++ b/library/src/include/tree_node.h @@ -178,13 +178,12 @@ struct NodeMetaData size_t iDist = 0, oDist = 0; size_t iDistBlue = 0, oDistBlue = 0; size_t iOffset = 0, oOffset = 0; - bool applyPartialPass = false; - int direction = -1; - rocfft_result_placement placement = rocfft_placement_inplace; - rocfft_precision precision = rocfft_precision_single; - rocfft_array_type inArrayType = rocfft_array_type_unset; - rocfft_array_type outArrayType = rocfft_array_type_unset; - hipDeviceProp_t deviceProp = {}; + int direction = -1; + rocfft_result_placement placement = rocfft_placement_inplace; + rocfft_precision precision = rocfft_precision_single; + rocfft_array_type inArrayType = rocfft_array_type_unset; + rocfft_array_type outArrayType = rocfft_array_type_unset; + hipDeviceProp_t deviceProp = {}; bool rootIsC2C; explicit NodeMetaData(TreeNode* refNode); @@ -354,6 +353,8 @@ class TreeNode // specified kernel key from solution map. (if there is any) std::unique_ptr specified_key; + std::unique_ptr specified_pp_key; + // Tree structure: // non-owning pointer to parent node, may be null TreeNode* parent = nullptr; @@ -373,11 +374,8 @@ class TreeNode size_t lengthBlue = 0; size_t lengthBlueN = 0; - // enables partial pass for this node - bool applyPartialPass = false; - - // Dimension of the FFT where partial-pass is applied - size_t ppDim = 0; + size_t ppOffDim = 0; + size_t ppCurrDim = 0; // BluesteinType typeBlue = BluesteinType::BT_NONE; @@ -490,6 +488,12 @@ class TreeNode return {}; } + bool isPartialPassEnabled() const + { + return (scheme == CS_3D_PP || scheme == CS_KERNEL_STOCKHAM_PP + || scheme == CS_KERNEL_STOCKHAM_PP_BLOCK_CC); + } + // able to fuse CS_KERNEL_STOCKHAM and CS_KERNEL_TRANSPOSE_Z_XY ? bool fuse_CS_KERNEL_TRANSPOSE_Z_XY(); // able to fuse CS_KERNEL_STOCKHAM and CS_KERNEL_TRANSPOSE_XY_Z ? @@ -540,6 +544,8 @@ class TreeNode TreeNode* GetRealEvenAncestor(); bool IsRootPlanC2CTransform(); + TreeNode* GetPartialPassAncestor() const; + // Set length of transpose kernel node, since those are easily // knowable just by looking at the scheme and they're used in // many plans. Throws an exception if this is not a transpose @@ -582,6 +588,18 @@ class TreeNode : FMKey(length[0], length[1], precision, scheme); } + virtual FMKeyPP GetPPKernelsKey() const + { + if(specified_pp_key) + return *specified_pp_key.get(); + + auto pp_parent_node = GetPartialPassAncestor(); + if(pp_parent_node) + return FMKeyPP(length[0], length[1], length[2], precision, pp_parent_node->scheme); + else + throw std::runtime_error("Invalid parent node for partial pass"); + } + // Compute the large twd decomposition base void set_large_twd_base_steps(size_t largeTWDLength); @@ -762,7 +780,7 @@ class InternalTempBuffer : comm_rank(comm_rank) { } - InternalTempBuffer(const InternalTempBuffer&) = delete; + InternalTempBuffer(const InternalTempBuffer&) = delete; InternalTempBuffer& operator=(const InternalTempBuffer&) = delete; ~InternalTempBuffer() = default; @@ -818,8 +836,8 @@ class InternalTempBuffer class BufferPtr { public: - BufferPtr() = default; - BufferPtr(const BufferPtr&) = default; + BufferPtr() = default; + BufferPtr(const BufferPtr&) = default; BufferPtr& operator=(const BufferPtr&) = default; ~BufferPtr() = default; @@ -946,7 +964,7 @@ struct MultiPlanItem { MultiPlanItem(); virtual ~MultiPlanItem(); - MultiPlanItem(const MultiPlanItem&) = delete; + MultiPlanItem(const MultiPlanItem&) = delete; MultiPlanItem& operator=(const MultiPlanItem&) = delete; // multi-process requests diff --git a/library/src/include/tree_node_1D.h b/library/src/include/tree_node_1D.h index c3785cf3a8b..2cebaffee27 100644 --- a/library/src/include/tree_node_1D.h +++ b/library/src/include/tree_node_1D.h @@ -102,6 +102,34 @@ class Stockham1DNode : public LeafNode } }; +/***************************************************** + * CS_KERNEL_STOCKHAM_PP * + *****************************************************/ +class StockhamPP1DNode : public LeafNode +{ + friend class NodeFactory; + +protected: + StockhamPP1DNode(TreeNode* p, ComputeScheme s) + : LeafNode(p, s) + { + externalKernel = true; + need_twd_table = true; + } + + void SetupGridParam_internal(GridParam& gp) override; + +public: + bool CreateDeviceResources() override; + std::vector CollapsibleDims() override; + bool UseOutputLengthForPadding() override + { + // with embedded r2c, stockham nodes will change length, so the + // output length is different from the input length. + return ebtype != EmbeddedType::NONE; + } +}; + /***************************************************** * SBCC * *****************************************************/ @@ -145,6 +173,39 @@ class SBCCNode : public LeafNode std::vector CollapsibleDims() override; }; +/***************************************************** + * SBCC Partial-Pass * + *****************************************************/ +class SBCCPPNode : public LeafNode +{ + friend class NodeFactory; + +protected: + SBCCPPNode(TreeNode* p, ComputeScheme s) + : LeafNode(p, s) + { + externalKernel = true; + need_twd_table = true; + } + + void SetupGridParam_internal(GridParam& gp) override; + + // InitIntrinsicMode is the first step to check if eligible for buffer load/store + void InitIntrinsicMode(); + +public: + // reads + writes are along columns so both may benefit from padding + bool PaddingBenefitsInput() override + { + return true; + } + bool PaddingBenefitsOutput() override + { + return true; + } + std::vector CollapsibleDims() override; +}; + /***************************************************** * SBRC * *****************************************************/ diff --git a/library/src/include/tree_node_3D.h b/library/src/include/tree_node_3D.h index aa216433d95..e8b56086981 100644 --- a/library/src/include/tree_node_3D.h +++ b/library/src/include/tree_node_3D.h @@ -129,9 +129,23 @@ class RC3DNode : public RC2DNode scheme = CS_3D_RC; } - // Determines if the current node suports - // the partial pass optimization. - bool CheckPartialPassSupport(); + void AssignParams_internal() override; + void BuildTree_internal(SchemeTreeVec& child_scheme_trees = EmptySchemeTreeVec) override; +}; + +/***************************************************** + * CS_3D_PP * + *****************************************************/ +class PP3DNode : public InternalNode +{ + friend class NodeFactory; + +protected: + explicit PP3DNode(TreeNode* p) + : InternalNode(p) + { + scheme = CS_3D_PP; + } void AssignParams_internal() override; void BuildTree_internal(SchemeTreeVec& child_scheme_trees = EmptySchemeTreeVec) override; diff --git a/library/src/node_factory.cpp b/library/src/node_factory.cpp index 4f31cc5e307..5b46836734b 100644 --- a/library/src/node_factory.cpp +++ b/library/src/node_factory.cpp @@ -435,12 +435,18 @@ std::unique_ptr NodeFactory::CreateNodeFromScheme(ComputeScheme s, Tre return std::unique_ptr(new BLOCKCR3DNode(parent)); case CS_3D_RC: return std::unique_ptr(new RC3DNode(parent)); + case CS_3D_PP: + return std::unique_ptr(new PP3DNode(parent)); // Leaf Node that need to check external kernel file case CS_KERNEL_STOCKHAM: return std::unique_ptr(new Stockham1DNode(parent, s)); + case CS_KERNEL_STOCKHAM_PP: + return std::unique_ptr(new StockhamPP1DNode(parent, s)); case CS_KERNEL_STOCKHAM_BLOCK_CC: return std::unique_ptr(new SBCCNode(parent, s)); + case CS_KERNEL_STOCKHAM_PP_BLOCK_CC: + return std::unique_ptr(new SBCCPPNode(parent, s)); case CS_KERNEL_STOCKHAM_BLOCK_RC: return std::unique_ptr(new SBRCNode(parent, s)); case CS_KERNEL_STOCKHAM_BLOCK_CR: @@ -790,8 +796,11 @@ ComputeScheme NodeFactory::Decide3DScheme(const function_pool& pool, NodeMetaDat // multi-dimension cases and small 2d, 3d within one kernel bool MultiDimFuseKernelsAvailable = false; - // try 3 SBCR kernels first - if(Apply_SBCR(pool, nodeData)) + if(use_CS_3D_PP(pool, nodeData)) // try 2 partial-pass kernels first + { + return CS_3D_PP; + } + else if(Apply_SBCR(pool, nodeData)) // try 3 kernels next { return CS_3D_BLOCK_CR; } @@ -960,3 +969,29 @@ bool NodeFactory::use_CS_3D_RC(const function_pool& pool, NodeMetaData& nodeData return false; } + +bool NodeFactory::use_CS_3D_PP(const function_pool& pool, NodeMetaData& nodeData) +{ + if(!pool.has_pp_function( + + FMKeyPP(nodeData.length[0], + nodeData.length[1], + nodeData.length[2], + nodeData.precision, + CS_3D_PP))) + return false; + + // Partial pass is currently restricted large enough batch sizes, + // unite stride, interleaved FFTs. + bool batchCondition = (nodeData.batch >= 5); + + size_t checkDist = product(nodeData.length.begin(), nodeData.length.end()); + bool distCondition = (nodeData.iDist == checkDist && nodeData.oDist == checkDist); + + bool strideCondition = (nodeData.inStride[0] == 1 && nodeData.outStride[0] == 1); + + bool arrayTypeCondition = (nodeData.inArrayType != rocfft_array_type_complex_planar + && nodeData.outArrayType != rocfft_array_type_complex_planar); + + return (batchCondition && distCondition && strideCondition && arrayTypeCondition); +} diff --git a/library/src/plan.cpp b/library/src/plan.cpp index d800b4793f0..710a1d83e04 100644 --- a/library/src/plan.cpp +++ b/library/src/plan.cpp @@ -1199,7 +1199,7 @@ struct TempBufferLease buf = std::move(other.buf); return *this; } - TempBufferLease(const TempBufferLease& other) = delete; + TempBufferLease(const TempBufferLease& other) = delete; TempBufferLease& operator=(const TempBufferLease& other) = delete; std::shared_ptr data() @@ -3312,24 +3312,23 @@ void TreeNode::CopyNodeData(const TreeNode& srcNode) length = srcNode.length; if(!srcNode.outputLength.empty()) outputLength = srcNode.outputLength; - inStride = srcNode.inStride; - inStrideBlue = srcNode.inStrideBlue; - outStride = srcNode.outStride; - outStrideBlue = srcNode.outStrideBlue; - iDist = srcNode.iDist; - iDistBlue = srcNode.iDistBlue; - oDist = srcNode.oDist; - oDistBlue = srcNode.oDistBlue; - iOffset = srcNode.iOffset; - oOffset = srcNode.oOffset; - placement = srcNode.placement; - precision = srcNode.precision; - applyPartialPass = srcNode.applyPartialPass; - direction = srcNode.direction; - inArrayType = srcNode.inArrayType; - outArrayType = srcNode.outArrayType; - allowInplace = srcNode.allowInplace; - allowOutofplace = srcNode.allowOutofplace; + inStride = srcNode.inStride; + inStrideBlue = srcNode.inStrideBlue; + outStride = srcNode.outStride; + outStrideBlue = srcNode.outStrideBlue; + iDist = srcNode.iDist; + iDistBlue = srcNode.iDistBlue; + oDist = srcNode.oDist; + oDistBlue = srcNode.oDistBlue; + iOffset = srcNode.iOffset; + oOffset = srcNode.oOffset; + placement = srcNode.placement; + precision = srcNode.precision; + direction = srcNode.direction; + inArrayType = srcNode.inArrayType; + outArrayType = srcNode.outArrayType; + allowInplace = srcNode.allowInplace; + allowOutofplace = srcNode.allowOutofplace; // conditional large1D = srcNode.large1D; @@ -3360,22 +3359,21 @@ void TreeNode::CopyNodeData(const NodeMetaData& data) length = data.length; if(!data.outputLength.empty()) outputLength = data.outputLength; - inStride = data.inStride; - inStrideBlue = data.inStrideBlue; - outStride = data.outStride; - outStrideBlue = data.outStrideBlue; - iDist = data.iDist; - iDistBlue = data.iDistBlue; - oDist = data.oDist; - oDistBlue = data.oDistBlue; - iOffset = data.iOffset; - oOffset = data.oOffset; - placement = data.placement; - precision = data.precision; - applyPartialPass = data.applyPartialPass; - direction = data.direction; - inArrayType = data.inArrayType; - outArrayType = data.outArrayType; + inStride = data.inStride; + inStrideBlue = data.inStrideBlue; + outStride = data.outStride; + outStrideBlue = data.outStrideBlue; + iDist = data.iDist; + iDistBlue = data.iDistBlue; + oDist = data.oDist; + oDistBlue = data.oDistBlue; + iOffset = data.iOffset; + oOffset = data.oOffset; + placement = data.placement; + precision = data.precision; + direction = data.direction; + inArrayType = data.inArrayType; + outArrayType = data.outArrayType; } bool TreeNode::isPlacementAllowed(rocfft_result_placement test_placement) const @@ -3971,6 +3969,17 @@ TreeNode* TreeNode::GetRealEvenAncestor() return parent->GetRealEvenAncestor(); } +TreeNode* TreeNode::GetPartialPassAncestor() const +{ + if(!parent) + return nullptr; + + if(parent->scheme == CS_3D_PP) + return parent; + + return parent->GetPartialPassAncestor(); +} + bool TreeNode::IsRootPlanC2CTransform() { auto root = GetPlanRoot(); diff --git a/library/src/rocfft_aot_helper.cpp b/library/src/rocfft_aot_helper.cpp index ad7b1abe524..7d7acff5881 100644 --- a/library/src/rocfft_aot_helper.cpp +++ b/library/src/rocfft_aot_helper.cpp @@ -226,10 +226,9 @@ void build_stockham_function_pool(CompileQueue& queue) function_pool fp(65536); // fused Bluestein and partial-pass kernels are always built at runtime - auto fuseBlue = BluesteinFuseType::BFT_NONE; - auto ppType = PartialPassType::PPT_NONE; - auto ppFactors = std::vector{}; - auto ppLength = 0; + auto fuseBlue = BluesteinFuseType::BFT_NONE; + auto ppType = PartialPassType::PPT_NONE; + auto ppParams = StockhamPartialPassParams(); for(const auto& i : fp.get_map()) { @@ -316,6 +315,7 @@ void build_stockham_function_pool(CompileQueue& queue) specs.direct_to_from_reg = i.second.direct_to_from_reg; return stockham_rtc(specs, specs, + ppParams, nullptr, kernel_name, scheme, @@ -335,8 +335,6 @@ void build_stockham_function_pool(CompileQueue& queue) cbtype, fuseBlue, ppType, - ppFactors, - ppLength, {}, {}); }; @@ -624,10 +622,9 @@ void build_solution_kernels(CompileQueue& queue) solmap.get_all_kernels(kernel_nodes, true); // fused Bluestein and partial-pass kernels are always built at runtime - auto fuseBlue = BluesteinFuseType::BFT_NONE; - auto ppType = PartialPassType::PPT_NONE; - auto ppFactors = std::vector{}; - auto ppLength = 0; + auto fuseBlue = BluesteinFuseType::BFT_NONE; + auto ppType = PartialPassType::PPT_NONE; + auto ppParams = StockhamPartialPassParams(); for(const SolutionNode& kernel_sol : kernel_nodes) { @@ -711,6 +708,7 @@ void build_solution_kernels(CompileQueue& queue) = [=](const std::string& kernel_name) -> std::string { return stockham_rtc(specs, specs, + ppParams, nullptr, kernel_name, scheme, @@ -730,8 +728,6 @@ void build_solution_kernels(CompileQueue& queue) cbtype, fuseBlue, ppType, - ppFactors, - ppLength, {}, {}); }; diff --git a/library/src/rocfft_kernel_config_search.cpp b/library/src/rocfft_kernel_config_search.cpp index 525196200e2..2fbcdae90a8 100644 --- a/library/src/rocfft_kernel_config_search.cpp +++ b/library/src/rocfft_kernel_config_search.cpp @@ -161,12 +161,16 @@ std::string test_kernel_src(const std::string& kernel_name, {static_cast(rocfft_precision_single)}, wgs, PrintScheme(compute_scheme)}; + + auto ppParams = StockhamPartialPassParams(); + specs.threads_per_transform = tpt; specs.half_lds = half_lds; specs.direct_to_from_reg = direct_to_from_reg; return stockham_rtc(specs, specs, + ppParams, &transforms_per_block, kernel_name, compute_scheme, @@ -188,8 +192,6 @@ std::string test_kernel_src(const std::string& kernel_name, BluesteinFuseType::BFT_NONE, PartialPassType::PPT_NONE, {}, - 0, - {}, {}); } diff --git a/library/src/rtc_stockham_gen.cpp b/library/src/rtc_stockham_gen.cpp index c7d38b96374..945d6795e83 100644 --- a/library/src/rtc_stockham_gen.cpp +++ b/library/src/rtc_stockham_gen.cpp @@ -138,9 +138,11 @@ std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, switch(scheme) { case CS_KERNEL_STOCKHAM: + case CS_KERNEL_STOCKHAM_PP: kernel_name += "_sbrr"; break; case CS_KERNEL_STOCKHAM_BLOCK_CC: + case CS_KERNEL_STOCKHAM_PP_BLOCK_CC: kernel_name += "_sbcc"; break; case CS_KERNEL_STOCKHAM_BLOCK_CR: @@ -248,31 +250,30 @@ std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, return kernel_name; } -std::string stockham_rtc(const StockhamGeneratorSpecs& specs, - const StockhamGeneratorSpecs& specs2d, - unsigned int* transforms_per_block, - const std::string& kernel_name, - ComputeScheme scheme, - int direction, - rocfft_precision precision, - rocfft_result_placement placement, - rocfft_array_type inArrayType, - rocfft_array_type outArrayType, - bool unit_stride, - size_t largeTwdBase, - size_t largeTwdSteps, - bool largeTwdBatchIsTransformCount, - EmbeddedType ebtype, - DirectRegType dir2regMode, - IntrinsicAccessType intrinsicMode, - SBRC_TRANSPOSE_TYPE transpose_type, - CallbackType cbtype, - const BluesteinFuseType& fuseBlue, - const PartialPassType& ppType, - const std::vector& ppFactors, - const size_t ppLength, - const LoadOps& loadOps, - const StoreOps& storeOps) +std::string stockham_rtc(const StockhamGeneratorSpecs& specs, + const StockhamGeneratorSpecs& specs2d, + const StockhamPartialPassParams& params_pp, + unsigned int* transforms_per_block, + const std::string& kernel_name, + ComputeScheme scheme, + int direction, + rocfft_precision precision, + rocfft_result_placement placement, + rocfft_array_type inArrayType, + rocfft_array_type outArrayType, + bool unit_stride, + size_t largeTwdBase, + size_t largeTwdSteps, + bool largeTwdBatchIsTransformCount, + EmbeddedType ebtype, + DirectRegType dir2regMode, + IntrinsicAccessType intrinsicMode, + SBRC_TRANSPOSE_TYPE transpose_type, + CallbackType cbtype, + const BluesteinFuseType& fuseBlue, + const PartialPassType& ppType, + const LoadOps& loadOps, + const StoreOps& storeOps) { std::unique_ptr lds2reg, reg2lds, device; std::unique_ptr lds2reg_pp_steps, reg2lds_pp_steps; @@ -314,21 +315,15 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, { std::unique_ptr kernel; if(scheme == CS_KERNEL_STOCKHAM) - { - if(ppType == PartialPassType::PPT_SBRR) - kernel = std::make_unique(specs, ppFactors, ppLength); - else - kernel = std::make_unique(specs); - } + kernel = std::make_unique(specs); + else if(scheme == CS_KERNEL_STOCKHAM_PP) + kernel = std::make_unique(specs, params_pp); else if(scheme == CS_KERNEL_STOCKHAM_BLOCK_CC) - { - if(ppType == PartialPassType::PPT_SBCC) - kernel = std::make_unique( - specs, largeTwdBatchIsTransformCount, ppFactors); - else - kernel = std::make_unique( - specs, largeTwdBatchIsTransformCount, fuseBluestein); - } + kernel = std::make_unique( + specs, largeTwdBatchIsTransformCount, fuseBluestein); + else if(scheme == CS_KERNEL_STOCKHAM_PP_BLOCK_CC) + kernel = std::make_unique( + specs, params_pp, largeTwdBatchIsTransformCount); else if(scheme == CS_KERNEL_STOCKHAM_BLOCK_CR) kernel = std::make_unique(specs); else if(scheme == CS_KERNEL_STOCKHAM_BLOCK_RC) @@ -411,7 +406,9 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, all_factors = kernel->factors; if(ppType != PPT_NONE) - all_factors.insert(all_factors.end(), ppFactors.begin(), ppFactors.end()); + all_factors.insert(all_factors.end(), + params_pp.factors_off_dim.begin(), + params_pp.factors_off_dim.end()); } // generated functions default to forward in-place interleaved. @@ -452,12 +449,12 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, src += butterfly_constant_h; // only SBCCs need this - if(scheme == CS_KERNEL_STOCKHAM_BLOCK_CC) + if(scheme == CS_KERNEL_STOCKHAM_BLOCK_CC || scheme == CS_KERNEL_STOCKHAM_PP_BLOCK_CC) src += large_twiddles_h; // append the neccessary functions only append_radix_h(src, all_factors); // SBCCs don't need this - if(scheme != CS_KERNEL_STOCKHAM_BLOCK_CC) + if(scheme != CS_KERNEL_STOCKHAM_BLOCK_CC && scheme != CS_KERNEL_STOCKHAM_PP_BLOCK_CC) src += real2complex_device_h; src += lds2reg->render(); diff --git a/library/src/rtc_stockham_kernel.cpp b/library/src/rtc_stockham_kernel.cpp index 163533f8f63..fb1dbc0a2fb 100644 --- a/library/src/rtc_stockham_kernel.cpp +++ b/library/src/rtc_stockham_kernel.cpp @@ -37,6 +37,7 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& std::optional specs; std::optional specs2d; + StockhamPartialPassParams pp_params; // SBRC variants look in the function pool for plain BLOCK_RC to // learn the block width, then decide on the transpose type once @@ -59,7 +60,9 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& switch(pool_scheme) { case CS_KERNEL_STOCKHAM: + case CS_KERNEL_STOCKHAM_PP: case CS_KERNEL_STOCKHAM_BLOCK_CC: + case CS_KERNEL_STOCKHAM_PP_BLOCK_CC: case CS_KERNEL_STOCKHAM_BLOCK_CR: case CS_KERNEL_STOCKHAM_BLOCK_RC: { @@ -68,43 +71,54 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& if((pool_scheme == CS_KERNEL_STOCKHAM_BLOCK_RC) && (node.sbrcTranstype == NONE)) throw std::runtime_error("Invalid SBRC_TRANS_TYPE for SBRC kernel"); - // these go into the function pool normally and are passed to - // the generator as-is - kernel = node.pool.get_kernel(key); + std::vector factors; - if(node.applyPartialPass) + if(node.isPartialPassEnabled()) { - // TODO: Hardcoded configuration for 64 x 64 x 64. - // Remove this once the partial-pass kernels are - // fully configurable in kernel-generator.py. - if(node.scheme == CS_KERNEL_STOCKHAM_BLOCK_CC) - { - kernel->threads_per_transform[0] = 8; - kernel->workgroup_size = 64; - } - else if(node.scheme == CS_KERNEL_STOCKHAM) - { - kernel->threads_per_transform[0] = 8; - kernel->workgroup_size = 128; - kernel->direct_to_from_reg = false; - } + auto pp_key = node.GetPPKernelsKey(); - kernel->transforms_per_block - = kernel->workgroup_size / kernel->threads_per_transform[0]; - } + kernel = node.pool.get_pp_kernel(pp_key, node.scheme); - std::vector factors; - std::copy(kernel->factors.begin(), kernel->factors.end(), std::back_inserter(factors)); - std::vector precisions = {static_cast(node.precision)}; + std::copy(kernel->factors.begin(), kernel->factors.end(), std::back_inserter(factors)); + std::vector precisions = {static_cast(node.precision)}; + + pp_params.off_dim = node.ppOffDim; + pp_params.current_dim = node.ppCurrDim; + pp_params.factors_off_dim = std::vector( + kernel->pp_params.factors_off_dim.begin(), kernel->pp_params.factors_off_dim.end()); + pp_params.parent_length + = std::vector(node.length.begin(), node.length.end()); + + specs.emplace(factors, + std::vector(), + precisions, + static_cast(kernel->workgroup_size), + PrintScheme(node.scheme)); + + specs->threads_per_transform = kernel->threads_per_transform[0]; + specs->half_lds = kernel->half_lds; + specs->direct_to_from_reg = kernel->direct_to_from_reg; + } + else + { + // these go into the function pool normally and are passed to + // the generator as-is + kernel = node.pool.get_kernel(key); + + std::copy(kernel->factors.begin(), kernel->factors.end(), std::back_inserter(factors)); + std::vector precisions = {static_cast(node.precision)}; + + specs.emplace(factors, + std::vector(), + precisions, + static_cast(kernel->workgroup_size), + PrintScheme(node.scheme)); + + specs->threads_per_transform = kernel->threads_per_transform[0]; + specs->half_lds = kernel->half_lds; + specs->direct_to_from_reg = kernel->direct_to_from_reg; + } - specs.emplace(factors, - std::vector(), - precisions, - static_cast(kernel->workgroup_size), - PrintScheme(node.scheme)); - specs->threads_per_transform = kernel->threads_per_transform[0]; - specs->half_lds = kernel->half_lds; - specs->direct_to_from_reg = kernel->direct_to_from_reg; break; } case CS_KERNEL_2D_SINGLE: @@ -171,11 +185,11 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& bool unit_stride = node.inStride.front() == 1 && node.outStride.front() == 1; auto ppType = PartialPassType::PPT_NONE; - if(node.applyPartialPass) + if(node.isPartialPassEnabled()) { - if(node.scheme == CS_KERNEL_STOCKHAM_BLOCK_CC) + if(node.scheme == CS_KERNEL_STOCKHAM_PP_BLOCK_CC) ppType = PartialPassType::PPT_SBCC; - else if(node.scheme == CS_KERNEL_STOCKHAM) + else if(node.scheme == CS_KERNEL_STOCKHAM_PP) ppType = PartialPassType::PPT_SBRR; else throw std::runtime_error("Invalid scheme for partial pass"); @@ -208,6 +222,7 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& generator.generate_src = [=, &node](const std::string& kernel_name) { return stockham_rtc(*specs, specs2d ? *specs2d : *specs, + pp_params, nullptr, kernel_name, node.scheme, @@ -227,8 +242,6 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& node.GetCallbackType(enable_callbacks), node.fuseBlue, ppType, - node.kernelFactorsPP, - node.length[node.ppDim], node.loadOps, node.storeOps); }; @@ -246,11 +259,12 @@ RTCKernelArgs RTCKernelStockham::get_launch_args(DeviceCallIn& data) RTCKernelArgs kargs; // twiddles - if(data.node->applyPartialPass && data.node->scheme == CS_KERNEL_STOCKHAM) + if(data.node->scheme == CS_KERNEL_STOCKHAM_PP) kargs.append_ptr(data.node->twiddles_pp); kargs.append_ptr(data.node->twiddles); // large 1D twiddles - if(data.node->scheme == CS_KERNEL_STOCKHAM_BLOCK_CC) + if(data.node->scheme == CS_KERNEL_STOCKHAM_BLOCK_CC + || data.node->scheme == CS_KERNEL_STOCKHAM_PP_BLOCK_CC) kargs.append_ptr(data.node->twiddles_large); if(!hardcoded_dim) kargs.append_size_t(data.node->length.size()); diff --git a/library/src/tree_node.cpp b/library/src/tree_node.cpp index 70706d9e769..678464e59a4 100644 --- a/library/src/tree_node.cpp +++ b/library/src/tree_node.cpp @@ -104,18 +104,27 @@ FMKey LeafNode::GetKernelKey() const void LeafNode::GetKernelFactors() { - FMKey key = GetKernelKey(); - kernelFactors = pool.get_kernel(key).factors; + if(isPartialPassEnabled()) + { + FMKeyPP key = GetPPKernelsKey(); + kernelFactors = pool.get_pp_kernel(key, scheme).factors; + } + else + { + FMKey key = GetKernelKey(); + kernelFactors = pool.get_kernel(key).factors; + } } void LeafNode::GetKernelPartialPassFactors() { - // Hard-coded partial-pass kernel factors for len 64x64x64. - // TODO: Remove this hard-coded logic once partial-pass - // kernels are configurable in kernel-generator.py. - if(scheme == CS_KERNEL_STOCKHAM && applyPartialPass) + FMKeyPP key = GetPPKernelsKey(); + auto kernel = pool.get_pp_kernel(key, scheme); + kernelFactorsPP = std::vector(kernel.pp_params.factors_off_dim.begin(), + kernel.pp_params.factors_off_dim.end()); + + if(scheme == CS_KERNEL_STOCKHAM_PP) { - kernelFactorsPP = {16}; std::stringstream msg; msg << "work in the off-dimension:" << std::endl; msg << "\t radix: ["; @@ -124,9 +133,8 @@ void LeafNode::GetKernelPartialPassFactors() msg << " ] pass(es) + Hadamard product with twiddle factors. \n"; comments.push_back(msg.str()); } - if(scheme == CS_KERNEL_STOCKHAM_BLOCK_CC && applyPartialPass) + if(scheme == CS_KERNEL_STOCKHAM_PP_BLOCK_CC) { - kernelFactorsPP = {4}; std::stringstream msg; msg << "work in the off-dimension:" << std::endl; msg << "\t local data transposition + radix: ["; @@ -186,26 +194,38 @@ bool LeafNode::KernelCheck(std::vector& kernel_keys) } } - // get the final key and check if we have the kernel. - // Note that the check is trivial if we are using "specified_key" - // since we definitly have the kernel, but not trivial if it's the auto-gen key - FMKey key = GetKernelKey(); - if(!pool.has_function(key)) + if(isPartialPassEnabled()) { - if(LOG_TRACE_ENABLED()) - (*LogSingleton::GetInstance().GetTraceOS()) << PrintMissingKernelInfo(key); + FMKeyPP key = GetPPKernelsKey(); + if(!pool.has_pp_function(key)) + return false; - return false; + auto kernel = pool.get_pp_kernel(key, scheme); + dir2regMode = (kernel.direct_to_from_reg) ? DirectRegType::TRY_ENABLE_IF_SUPPORT + : DirectRegType::FORCE_OFF_OR_NOT_SUPPORT; + + GetKernelPartialPassFactors(); } + else + { + // get the final key and check if we have the kernel. + // Note that the check is trivial if we are using "specified_key" + // since we definitly have the kernel, but not trivial if it's the auto-gen key + FMKey key = GetKernelKey(); + if(!pool.has_function(key)) + { + if(LOG_TRACE_ENABLED()) + (*LogSingleton::GetInstance().GetTraceOS()) << PrintMissingKernelInfo(key); - dir2regMode = (pool.get_kernel(key).direct_to_from_reg) - ? DirectRegType::TRY_ENABLE_IF_SUPPORT - : DirectRegType::FORCE_OFF_OR_NOT_SUPPORT; + return false; + } - GetKernelFactors(); + dir2regMode = (pool.get_kernel(key).direct_to_from_reg) + ? DirectRegType::TRY_ENABLE_IF_SUPPORT + : DirectRegType::FORCE_OFF_OR_NOT_SUPPORT; - if(applyPartialPass) - GetKernelPartialPassFactors(); + GetKernelFactors(); + } return true; } @@ -281,8 +301,6 @@ void LeafNode::SetupGridParam(GridParam& gp) { if(pool.has_function(key)) { - auto kernel = pool.get_kernel(key); - // NB: // Special case on specific arch: // For some cases using hald_lds, finer tuning(enlarge) dynamic @@ -297,7 +315,8 @@ void LeafNode::SetupGridParam(GridParam& gp) } // no support for half-lds in partial-pass mode - if(kernel.half_lds && (!double_half_lds_alloc) && (!applyPartialPass)) + auto kernel = pool.get_kernel(key); + if(kernel.half_lds && (!double_half_lds_alloc)) gp.lds_bytes /= 2; } } @@ -311,12 +330,14 @@ void LeafNode::SetupGridParam(GridParam& gp) if(kernel.half_lds) gp.lds_bytes /= 2; } + } + if(scheme == CS_KERNEL_STOCKHAM_BLOCK_CC || scheme == CS_KERNEL_STOCKHAM_PP_BLOCK_CC) + { auto apply_large_twd = (largeTwdBase > 0 && ltwdSteps > 0); if(apply_large_twd && largeTwdBase < 8) { // append twiddle table to dynamic lds - auto kernel = pool.get_kernel(key); gp.lds_bytes += twiddles_large_size; } } diff --git a/library/src/tree_node_1D.cpp b/library/src/tree_node_1D.cpp index 60cffd087af..440aae5eb9b 100644 --- a/library/src/tree_node_1D.cpp +++ b/library/src/tree_node_1D.cpp @@ -903,16 +903,6 @@ void Stockham1DNode::SetupGridParam_internal(GridParam& gp) auto key = GetKernelKey(); auto kernel = pool.get_kernel(key); - if(applyPartialPass) - { - // TODO: Hardcoded configuration for 64 x 64 x 64. - // Remove this once the partial-pass kernels are - // fully configurable in kernel-generator.py. - kernel.threads_per_transform[0] = 8; - kernel.workgroup_size = 128; - kernel.transforms_per_block = kernel.workgroup_size / kernel.threads_per_transform[0]; - } - bwd = kernel.transforms_per_block; wgs = kernel.workgroup_size; gp.b_x = (batch_accum + bwd - 1) / bwd; @@ -934,13 +924,6 @@ void Stockham1DNode::SetupGridParam_internal(GridParam& gp) bool Stockham1DNode::CreateDeviceResources() { - if(applyPartialPass) - { - // Create twiddle table for partial pass along ppDim - std::tie(twiddles_pp, twiddles_pp_size) - = Repo::GetTwiddlesPP(length[ppDim], precision, deviceProp); - } - twd_attach_halfN = (ebtype != EmbeddedType::NONE); return LeafNode::CreateDeviceResources(); } @@ -951,16 +934,51 @@ std::vector Stockham1DNode::CollapsibleDims() if(typeBlue == BT_MULTI_KERNEL_FUSED) return {}; - // do not collapse on partial-pass nodes - if(applyPartialPass) - return {}; - // fastest dim is FFT, the rest is collapsible std::vector ret(length.size() - 1); std::iota(ret.begin(), ret.end(), 1); return ret; } +/***************************************************** + * CS_KERNEL_STOCKHAM_PP * + *****************************************************/ +void StockhamPP1DNode::SetupGridParam_internal(GridParam& gp) +{ + // get working group size and number of transforms + size_t batch_accum = batch; + for(size_t j = 1; j < length.size(); j++) + batch_accum *= length[j]; + + FMKeyPP key = GetPPKernelsKey(); + auto kernel = pool.get_pp_kernel(key, scheme); + + bwd = kernel.transforms_per_block; + wgs = kernel.workgroup_size; + gp.b_x = (batch_accum + bwd - 1) / bwd; + gp.wgs_x = wgs; + + const auto lds_padding = ebtype != EmbeddedType::NONE ? 1 : 0; + + lds = (length[0] + lds_padding) * bwd; +} + +bool StockhamPP1DNode::CreateDeviceResources() +{ + // Create twiddle table for partial pass along ppOffDim + std::tie(twiddles_pp, twiddles_pp_size) + = Repo::GetTwiddlesPP(length[ppOffDim], precision, deviceProp); + + twd_attach_halfN = (ebtype != EmbeddedType::NONE); + return LeafNode::CreateDeviceResources(); +} + +std::vector StockhamPP1DNode::CollapsibleDims() +{ + // do not collapse on partial-pass nodes + return {}; +} + /***************************************************** * SBCC * *****************************************************/ @@ -1114,33 +1132,11 @@ void SBCCNode::SetupGridParam_internal(GridParam& gp) bwd = kernel.transforms_per_block; wgs = kernel.workgroup_size; - if(applyPartialPass) - { - // TODO: Hardcoded configuration for 64 x 64 x 64. - // Remove this once the partial-pass kernels are - // fully configurable in kernel-generator.py. - auto tpt = 8; - wgs = 64; - bwd = wgs / tpt; - } - lds = length[0] * bwd; gp.b_x = ((length[1]) - 1) / bwd + 1; gp.b_x *= product(length.begin() + 2, length.end()) * batch; gp.wgs_x = wgs; - - if(applyPartialPass) - { - // Grid arrangement is different for partial - // pass SBCC kernels. This arrangement leads - // to improved global memory access patterns. - auto factor = *std::max_element(kernelFactorsPP.begin(), kernelFactorsPP.end()); - - gp.b_x /= factor; - gp.wgs_x *= factor; - lds *= factor; - } } std::vector SBCCNode::CollapsibleDims() @@ -1155,6 +1151,40 @@ std::vector SBCCNode::CollapsibleDims() return ret; } +/***************************************************** + * SBCC Partial-Pass * + *****************************************************/ + +void SBCCPPNode::SetupGridParam_internal(GridParam& gp) +{ + FMKeyPP key = GetPPKernelsKey(); + auto kernel = pool.get_pp_kernel(key, scheme); + + bwd = kernel.transforms_per_block; + wgs = kernel.workgroup_size; + + lds = length[0] * bwd; + + gp.b_x = ((length[1]) - 1) / bwd + 1; + gp.b_x *= product(length.begin() + 2, length.end()) * batch; + gp.wgs_x = wgs; + + // Grid arrangement is different for partial + // pass SBCC kernels for improved global memory + // access patterns. + auto factor = *std::max_element(kernelFactorsPP.begin(), kernelFactorsPP.end()); + + gp.b_x /= factor; + gp.wgs_x *= factor; + lds *= factor; +} + +std::vector SBCCPPNode::CollapsibleDims() +{ + // do not collapse on partial-pass nodes + return {}; +} + /***************************************************** * SBRC * *****************************************************/ diff --git a/library/src/tree_node_3D.cpp b/library/src/tree_node_3D.cpp index 479280afd06..f8d60e5d7f4 100644 --- a/library/src/tree_node_3D.cpp +++ b/library/src/tree_node_3D.cpp @@ -581,51 +581,6 @@ void BLOCKCR3DNode::AssignParams_internal() /***************************************************** * CS_3D_RC * *****************************************************/ -bool RC3DNode::CheckPartialPassSupport() -{ - if(parent != nullptr) - { - // 3D_RC node is child of a real tranform node - // skip partial pass for now - return false; - } - - // Partial pass is currently restricted to length 64x64x64, - // large enough batch sizes, unite stride, interleaved FFTs - // without result scaling. - - // List of supported 3D sizes for partial pass. - typedef std::tuple len3D_item_t; - const std::vector supportedLengthsPP = {len3D_item_t{64, 64, 64}}; - - bool lengthCondition = false; - for(const auto& lenItem : supportedLengthsPP) - { - if(std::get<0>(lenItem) == length[0] && std::get<1>(lenItem) == length[1] - && std::get<2>(lenItem) == length[2]) - { - lengthCondition = true; - break; - } - } - - // TODO: Revisit these restrictions once partial pass is - // fully configurable in kernel-generator.py. - bool batchCondition = (batch >= 5); - - size_t checkDist = product(length.begin(), length.end()); - bool distCondition = (iDist == checkDist && oDist == checkDist); - - bool strideCondition = (inStride[0] == 1 && outStride[0] == 1); - - bool arrayTypeCondition = (inArrayType != rocfft_array_type_complex_planar - && outArrayType != rocfft_array_type_complex_planar); - - bool loadStoreOpsCondition = (!loadOps.enabled() && !storeOps.enabled()); - - return (lengthCondition && batchCondition && distCondition && strideCondition - && arrayTypeCondition && loadStoreOpsCondition); -} void RC3DNode::BuildTree_internal(SchemeTreeVec& child_scheme_trees) { @@ -642,120 +597,57 @@ void RC3DNode::BuildTree_internal(SchemeTreeVec& child_scheme_trees) determined_scheme_node1 = child_scheme_trees[1]->curScheme; } - if(CheckPartialPassSupport()) + // 2d fft + NodeMetaData xyPlanData(this); + xyPlanData.length.push_back(length[0]); + xyPlanData.length.push_back(length[1]); + xyPlanData.dimension = 2; + xyPlanData.length.push_back(length[2]); + for(size_t index = 3; index < length.size(); index++) { - // TODO: Child nodes currently hardcoded to a x+z configuration - // in 3D partial-pass. Add support for other configurations, - // e.g., x+y, y+z, once partial pass is fully configurable - // in kernel-generator.py. - - // work along y will be split between x and z - applyPartialPass = true; - - // x row fft + partial pass(es) along y - NodeMetaData xPartialPassPlanData(this); - xPartialPassPlanData.length.push_back(length[0]); - xPartialPassPlanData.length.push_back(length[1]); - // technically 1 < dimension < 2 for x node. - xPartialPassPlanData.dimension = 1; - xPartialPassPlanData.applyPartialPass = true; - xPartialPassPlanData.length.push_back(length[2]); - for(size_t index = 3; index < length.size(); index++) - { - xPartialPassPlanData.length.push_back(length[index]); - } - - // use explicit (modified) SBRR kernel - std::unique_ptr xPartialPassPlan; - - xPartialPassPlan = NodeFactory::CreateNodeFromScheme(CS_KERNEL_STOCKHAM, this); - xPartialPassPlan->length = xPartialPassPlanData.length; - xPartialPassPlan->dimension = 1; - xPartialPassPlan->ppDim = 1; - xPartialPassPlan->allowInplace = true; - xPartialPassPlan->comments.push_back("partial-pass enabled for second dimension."); - - // partial pass(es) along y + z col fft - NodeMetaData zPartialPassPlanData(this); - zPartialPassPlanData.length.push_back(length[2]); - zPartialPassPlanData.applyPartialPass = true; - // technically 1 < dimension < 2 for z node. - zPartialPassPlanData.dimension = 1; - zPartialPassPlanData.length.push_back(length[0]); - zPartialPassPlanData.length.push_back(length[1]); - for(size_t index = 3; index < length.size(); index++) - { - zPartialPassPlanData.length.push_back(length[index]); - } - zPartialPassPlanData.outputLength = length; + xyPlanData.length.push_back(length[index]); + } + auto xyPlan = NodeFactory::CreateExplicitNode(xyPlanData, this, determined_scheme_node0); + xyPlan->RecursiveBuildTree((noSolution) ? nullptr : child_scheme_trees[0].get()); - // use explicit (modified) SBCC kernel - std::unique_ptr zPartialPassPlan; + // z col fft + NodeMetaData zPlanData(this); + zPlanData.length.push_back(length[2]); + zPlanData.dimension = 1; + zPlanData.length.push_back(length[0]); + zPlanData.length.push_back(length[1]); + for(size_t index = 3; index < length.size(); index++) + { + zPlanData.length.push_back(length[index]); + } + zPlanData.outputLength = length; - zPartialPassPlan = NodeFactory::CreateNodeFromScheme(CS_KERNEL_STOCKHAM_BLOCK_CC, this); - zPartialPassPlan->length = zPartialPassPlanData.length; - zPartialPassPlan->dimension = 1; - zPartialPassPlan->ppDim = 1; - zPartialPassPlan->allowInplace = false; - zPartialPassPlan->comments.push_back("partial-pass enabled for second dimension."); + // use explicit SBCC kernel if available + std::unique_ptr zPlan; - childNodes.emplace_back(std::move(xPartialPassPlan)); - childNodes.emplace_back(std::move(zPartialPassPlan)); + if(determined_scheme_node1 != CS_NONE) + { + zPlan = NodeFactory::CreateExplicitNode(zPlanData, this, determined_scheme_node1); + zPlan->RecursiveBuildTree((noSolution) ? nullptr : child_scheme_trees[1].get()); } else { - // 2d fft - NodeMetaData xyPlanData(this); - xyPlanData.length.push_back(length[0]); - xyPlanData.length.push_back(length[1]); - xyPlanData.dimension = 2; - xyPlanData.length.push_back(length[2]); - for(size_t index = 3; index < length.size(); index++) - { - xyPlanData.length.push_back(length[index]); - } - auto xyPlan = NodeFactory::CreateExplicitNode(xyPlanData, this, determined_scheme_node0); - xyPlan->RecursiveBuildTree((noSolution) ? nullptr : child_scheme_trees[0].get()); - - // z col fft - NodeMetaData zPlanData(this); - zPlanData.length.push_back(length[2]); - zPlanData.dimension = 1; - zPlanData.length.push_back(length[0]); - zPlanData.length.push_back(length[1]); - for(size_t index = 3; index < length.size(); index++) + if(pool.has_SBCC_kernel(length[2], precision)) { - zPlanData.length.push_back(length[index]); - } - zPlanData.outputLength = length; - - // use explicit SBCC kernel if available - std::unique_ptr zPlan; - - if(determined_scheme_node1 != CS_NONE) - { - zPlan = NodeFactory::CreateExplicitNode(zPlanData, this, determined_scheme_node1); - zPlan->RecursiveBuildTree((noSolution) ? nullptr : child_scheme_trees[1].get()); + zPlan = NodeFactory::CreateNodeFromScheme(CS_KERNEL_STOCKHAM_BLOCK_CC, this); + zPlan->length = zPlanData.length; + zPlan->dimension = 1; } else { - if(pool.has_SBCC_kernel(length[2], precision)) - { - zPlan = NodeFactory::CreateNodeFromScheme(CS_KERNEL_STOCKHAM_BLOCK_CC, this); - zPlan->length = zPlanData.length; - zPlan->dimension = 1; - } - else - { - zPlan = NodeFactory::CreateExplicitNode(zPlanData, this); - zPlan->RecursiveBuildTree(nullptr); - } + zPlan = NodeFactory::CreateExplicitNode(zPlanData, this); + zPlan->RecursiveBuildTree(nullptr); } - - // RC - childNodes.emplace_back(std::move(xyPlan)); - childNodes.emplace_back(std::move(zPlan)); } + + // RC + childNodes.emplace_back(std::move(xyPlan)); + childNodes.emplace_back(std::move(zPlan)); } void RC3DNode::AssignParams_internal() @@ -772,8 +664,6 @@ void RC3DNode::AssignParams_internal() xyPlan->outStride = outStride; xyPlan->oDist = oDist; - xyPlan->applyPartialPass = applyPartialPass; - xyPlan->AssignParams(); zPlan->inStride.push_back(outStride[2]); @@ -787,7 +677,109 @@ void RC3DNode::AssignParams_internal() zPlan->outStride = zPlan->inStride; zPlan->oDist = zPlan->iDist; - zPlan->applyPartialPass = applyPartialPass; + zPlan->AssignParams(); +} + +/***************************************************** + * CS_3D_PP * + *****************************************************/ +void PP3DNode::BuildTree_internal(SchemeTreeVec& child_scheme_trees) +{ + // bool noSolution = child_scheme_trees.empty(); + + // // check schemes from solution map + // ComputeScheme determined_scheme_node0 = CS_NONE; + // ComputeScheme determined_scheme_node1 = CS_NONE; + // if(!noSolution) + // { + // if((child_scheme_trees.size() != 2)) + // throw std::runtime_error("RC3DNode: Unexpected child scheme from solution map"); + // determined_scheme_node0 = child_scheme_trees[0]->curScheme; + // determined_scheme_node1 = child_scheme_trees[1]->curScheme; + // } + + // TODO: Child nodes currently hardcoded to a x+z configuration + // in 3D partial-pass. Add support for other configurations, + // e.g., x+y, y+z, once partial pass is fully configurable + // in kernel-generator.py. + + // work along y will be split between x and z + + // x row fft + partial pass(es) along y + NodeMetaData xPartialPassPlanData(this); + xPartialPassPlanData.length.push_back(length[0]); + xPartialPassPlanData.length.push_back(length[1]); + // technically 1 < dimension < 2 for x node. + xPartialPassPlanData.dimension = 1; + xPartialPassPlanData.length.push_back(length[2]); + for(size_t index = 3; index < length.size(); index++) + { + xPartialPassPlanData.length.push_back(length[index]); + } + + // use explicit (modified) SBRR kernel + std::unique_ptr xPartialPassPlan; + + xPartialPassPlan = NodeFactory::CreateNodeFromScheme(CS_KERNEL_STOCKHAM_PP, this); + xPartialPassPlan->length = xPartialPassPlanData.length; + xPartialPassPlan->dimension = 1; + xPartialPassPlan->ppOffDim = 1; + xPartialPassPlan->allowInplace = true; + xPartialPassPlan->comments.push_back("partial-pass enabled for second dimension."); + + // partial pass(es) along y + z col fft + NodeMetaData zPartialPassPlanData(this); + zPartialPassPlanData.length.push_back(length[2]); + // technically 1 < dimension < 2 for z node. + zPartialPassPlanData.dimension = 1; + zPartialPassPlanData.length.push_back(length[0]); + zPartialPassPlanData.length.push_back(length[1]); + for(size_t index = 3; index < length.size(); index++) + { + zPartialPassPlanData.length.push_back(length[index]); + } + zPartialPassPlanData.outputLength = length; + + // use explicit (modified) SBCC kernel + std::unique_ptr zPartialPassPlan; + + zPartialPassPlan = NodeFactory::CreateNodeFromScheme(CS_KERNEL_STOCKHAM_PP_BLOCK_CC, this); + zPartialPassPlan->length = zPartialPassPlanData.length; + zPartialPassPlan->dimension = 1; + zPartialPassPlan->ppOffDim = 1; + zPartialPassPlan->allowInplace = false; + zPartialPassPlan->comments.push_back("partial-pass enabled for second dimension."); + + childNodes.emplace_back(std::move(xPartialPassPlan)); + childNodes.emplace_back(std::move(zPartialPassPlan)); +} + +void PP3DNode::AssignParams_internal() +{ + // in partial pass case: + // xy plan is a x row 1D-FFT + plus partial pass(es) along y + // z plan is partial pass(es) along y + z col 1D-FFT. + auto& xyPlan = childNodes[0]; + auto& zPlan = childNodes[1]; + + xyPlan->inStride = inStride; + xyPlan->iDist = iDist; + + xyPlan->outStride = outStride; + xyPlan->oDist = oDist; + + xyPlan->AssignParams(); + + zPlan->inStride.push_back(outStride[2]); + zPlan->inStride.push_back(outStride[0]); + zPlan->inStride.push_back(outStride[1]); + for(size_t index = 3; index < length.size(); index++) + zPlan->inStride.push_back(outStride[index]); + + zPlan->iDist = xyPlan->oDist; + + zPlan->outStride = zPlan->inStride; + zPlan->oDist = zPlan->iDist; zPlan->AssignParams(); } From ce4e4d4ae4cd7ac7638d0f4100e987e4e6637416 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 8 May 2025 16:10:32 -0600 Subject: [PATCH 29/69] formatting. --- library/src/device/generator/stockham_gen.h | 2 +- library/src/device/generator/stockham_pp_gen_cc.h | 3 +-- library/src/include/function_map_key.h | 2 +- library/src/include/function_pool.h | 2 +- library/src/include/tree_node.h | 8 ++++---- library/src/plan.cpp | 2 +- 6 files changed, 9 insertions(+), 10 deletions(-) diff --git a/library/src/device/generator/stockham_gen.h b/library/src/device/generator/stockham_gen.h index 4a400b20aef..feb4b1acf4f 100644 --- a/library/src/device/generator/stockham_gen.h +++ b/library/src/device/generator/stockham_gen.h @@ -97,7 +97,7 @@ struct StockhamPartialPassParams std::vector parent_length; unsigned int current_dim = 0; - unsigned int off_dim = 0; + unsigned int off_dim = 0; std::vector factors_off_dim; }; diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index 6c1e2fbf144..5d8b12def5b 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -91,8 +91,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC unsigned int launcher_transforms_per_block() override { return transforms_per_block / max_factor_pp; - } - + } StatementList load_global_generator(unsigned int h, unsigned int hr, diff --git a/library/src/include/function_map_key.h b/library/src/include/function_map_key.h index 51f20cfb446..f006929d863 100644 --- a/library/src/include/function_map_key.h +++ b/library/src/include/function_map_key.h @@ -317,7 +317,7 @@ struct FMKeyBase FMKeyBase() = default; FMKeyBase(const FMKeyBase&) = default; - virtual ~FMKeyBase() {}; + virtual ~FMKeyBase(){}; rocfft_precision precision; ComputeScheme scheme; diff --git a/library/src/include/function_pool.h b/library/src/include/function_pool.h index 889355ee341..7b26d36c1a3 100644 --- a/library/src/include/function_pool.h +++ b/library/src/include/function_pool.h @@ -211,7 +211,7 @@ class function_pool { } - function_pool(function_pool& p) = delete; + function_pool(function_pool& p) = delete; function_pool& operator=(const function_pool&) = delete; ~function_pool() = default; diff --git a/library/src/include/tree_node.h b/library/src/include/tree_node.h index fe28cf5e384..412e607fc18 100644 --- a/library/src/include/tree_node.h +++ b/library/src/include/tree_node.h @@ -780,7 +780,7 @@ class InternalTempBuffer : comm_rank(comm_rank) { } - InternalTempBuffer(const InternalTempBuffer&) = delete; + InternalTempBuffer(const InternalTempBuffer&) = delete; InternalTempBuffer& operator=(const InternalTempBuffer&) = delete; ~InternalTempBuffer() = default; @@ -836,8 +836,8 @@ class InternalTempBuffer class BufferPtr { public: - BufferPtr() = default; - BufferPtr(const BufferPtr&) = default; + BufferPtr() = default; + BufferPtr(const BufferPtr&) = default; BufferPtr& operator=(const BufferPtr&) = default; ~BufferPtr() = default; @@ -964,7 +964,7 @@ struct MultiPlanItem { MultiPlanItem(); virtual ~MultiPlanItem(); - MultiPlanItem(const MultiPlanItem&) = delete; + MultiPlanItem(const MultiPlanItem&) = delete; MultiPlanItem& operator=(const MultiPlanItem&) = delete; // multi-process requests diff --git a/library/src/plan.cpp b/library/src/plan.cpp index 710a1d83e04..63dfb2edf61 100644 --- a/library/src/plan.cpp +++ b/library/src/plan.cpp @@ -1199,7 +1199,7 @@ struct TempBufferLease buf = std::move(other.buf); return *this; } - TempBufferLease(const TempBufferLease& other) = delete; + TempBufferLease(const TempBufferLease& other) = delete; TempBufferLease& operator=(const TempBufferLease& other) = delete; std::shared_ptr data() From 4ae3e0244124d787683db487c4988d1661fa2539 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 8 May 2025 16:30:24 -0600 Subject: [PATCH 30/69] - Fix ROCFFT_LAYER=8 not displaying the new partial-pass nodes. --- library/src/compute_scheme.cpp | 1 + library/src/include/compute_scheme.h | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/library/src/compute_scheme.cpp b/library/src/compute_scheme.cpp index b801012f499..d37e8e61b40 100644 --- a/library/src/compute_scheme.cpp +++ b/library/src/compute_scheme.cpp @@ -84,6 +84,7 @@ static const std::map& ComputeSchemetoStringMap() {ENUMSTR(CS_3D_BLOCK_RC)}, {ENUMSTR(CS_3D_BLOCK_CR)}, {ENUMSTR(CS_3D_RC)}, + {ENUMSTR(CS_3D_PP)}, {ENUMSTR(CS_KERNEL_3D_STOCKHAM_BLOCK_CC)}, {ENUMSTR(CS_KERNEL_3D_SINGLE)}}; return ComputeSchemetoString; diff --git a/library/src/include/compute_scheme.h b/library/src/include/compute_scheme.h index 20ae9dcec16..3c466b9fcda 100644 --- a/library/src/include/compute_scheme.h +++ b/library/src/include/compute_scheme.h @@ -80,9 +80,9 @@ enum ComputeScheme CS_3D_BLOCK_RC, CS_3D_BLOCK_CR, CS_3D_RC, + CS_3D_PP, CS_KERNEL_3D_STOCKHAM_BLOCK_CC, // not implemented yet CS_KERNEL_3D_SINGLE, // not implemented yet - CS_3D_PP }; // print abbreviation for kernel scheme From 96ee9cf42e64f7007342c536e2bd46738f78da81 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 15 May 2025 14:48:07 -0600 Subject: [PATCH 31/69] - WIP: Fixes for dealing with column- to row-major configuration entries. --- library/src/device/generator/stockham_gen.cpp | 78 ++++++++++++------- library/src/device/kernel-generator.py | 3 +- library/src/include/tree_node.h | 6 +- 3 files changed, 56 insertions(+), 31 deletions(-) diff --git a/library/src/device/generator/stockham_gen.cpp b/library/src/device/generator/stockham_gen.cpp index 042be0b8914..d087fbcadd0 100644 --- a/library/src/device/generator/stockham_gen.cpp +++ b/library/src/device/generator/stockham_gen.cpp @@ -23,6 +23,7 @@ using namespace std::placeholders; #include "generator.h" #include "stockham_gen.h" +#include #include #include #include @@ -265,52 +266,43 @@ void stockham_partial_pass_variants(const std::string& kernel_name params_2.off_dim, launchers); } - // SBRR_PP + SBCC_PP - else if(params_1.current_dim == 1 && params_2.current_dim == 2) + else if(params_1.current_dim == 2 && params_2.current_dim == 0) { - StockhamPartialPassKernelRR kernelRR(specs1, params_1); + StockhamPartialPassKernelRR kernelCC(specs1, params_1); make_launcher(specs1.precisions, - {{"pp_stoc", specs1.scheme, "", ""}}, - kernelRR, - "CS_KERNEL_STOCKHAM_PP", - params_1.factors_off_dim, - params_1.current_dim, - params_1.off_dim, - launchers); - - StockhamPartialPassKernelCC kernelCC(specs2, params_2, false); - make_launcher(specs2.precisions, - {{"pp_sbcc", specs2.scheme, "", ""}}, + {{"pp_sbcc", specs1.scheme, "", ""}}, kernelCC, "CS_KERNEL_STOCKHAM_PP_BLOCK_CC", - params_2.factors_off_dim, - params_2.current_dim, - params_2.off_dim, - launchers); - } - // SBRR_PP + SBRR_PP - else if(params_1.current_dim == 0 && params_2.current_dim == 1) - { - StockhamPartialPassKernelRR kernelRR1(specs1, params_1); - make_launcher(specs1.precisions, - {{"pp_stoc", specs1.scheme, "", ""}}, - kernelRR1, - "CS_KERNEL_STOCKHAM_PP", params_1.factors_off_dim, params_1.current_dim, params_1.off_dim, launchers); - StockhamPartialPassKernelRR kernelRR2(specs2, params_2); + StockhamPartialPassKernelCC kernelRR(specs2, params_2, false); make_launcher(specs2.precisions, {{"pp_stoc", specs2.scheme, "", ""}}, - kernelRR2, + kernelRR, "CS_KERNEL_STOCKHAM_PP", params_2.factors_off_dim, params_2.current_dim, params_2.off_dim, launchers); } + // SBRR_PP + SBCC_PP + else if((params_1.current_dim == 1 && params_2.current_dim == 2) + || (params_1.current_dim == 2 && params_2.current_dim == 1)) + { + throw std::runtime_error("CS_KERNEL_STOCKHAM_PP + CS_KERNEL_STOCKHAM_PP_BLOCK_CC not " + "yet implemented for CS_3D_PP"); + } + // SBRR_PP + SBRR_PP + else if((params_1.current_dim == 0 && params_2.current_dim == 1) + || (params_1.current_dim == 1 && params_2.current_dim == 0)) + { + throw std::runtime_error( + "CS_KERNEL_STOCKHAM_PP_BLOCK_CC + CS_KERNEL_STOCKHAM_PP_BLOCK_CC not yet " + "implemented for CS_3D_PP"); + } else { throw std::runtime_error("invalid dimensions for CS_3D_PP"); @@ -443,6 +435,17 @@ void stockham_variants(const std::string& kernel_name, output_json(launchers, kernel_name, output); } +size_t pp_current_dim_rm(const size_t current_dim) +{ + if(current_dim == 0) + return 2; + + if(current_dim == 2) + return 0; + + return current_dim; +} + int main() { std::string line; @@ -548,6 +551,23 @@ int main() throw std::runtime_error("Invalid dimensions configuration for CS_3D_PP"); } + // parent length from column to row-major + std::reverse(parent_length.begin(), parent_length.end()); + + // current dim from column to row-major + auto dims_rm = dims; + for(auto& dim : dims_rm) + dim = pp_current_dim_rm(dim); + + if(dims_rm != dims) + { + factors1.swap(factors2); + pp_factors_1.swap(pp_factors_2); + std::reverse(workgroup_size.begin(), workgroup_size.end()); + std::reverse(threads_per_transform.begin(), threads_per_transform.end()); + std::reverse(direct_to_from_reg.begin(), direct_to_from_reg.end()); + } + if(threads_per_transform.size() != 2) throw std::runtime_error( "CS_3D_PP requires two threads_per_transform configuration"); diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index 0005c537d55..6be67206169 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -1046,7 +1046,8 @@ def list_3d_partial_pass_kernels(): """Return list of to generate.""" pp_3d_kernels = [ - NS(length=[64,64,64], dims=[0, 2], factors=[[4, 4, 4],[8, 8]], factors_pp=[[16], [4]], threads_per_transform=[8, 8], workgroup_size=[128, 64], direct_to_from_reg=[False, False]), + NS(length=[64,64,64], dims=[0, 2], factors=[[8, 8],[4, 4, 4]], factors_pp=[[4],[16]], threads_per_transform=[8, 8], workgroup_size=[64,128], direct_to_from_reg=[False, False]), + NS(length=[64,64,128], dims=[0, 2], factors=[[4, 4, 4],[8, 8, 2]], factors_pp=[[4], [16]], threads_per_transform=[8, 8], workgroup_size=[64,128], direct_to_from_reg=[False, False]), ] expanded = [] diff --git a/library/src/include/tree_node.h b/library/src/include/tree_node.h index 412e607fc18..12a4fdc7159 100644 --- a/library/src/include/tree_node.h +++ b/library/src/include/tree_node.h @@ -595,7 +595,11 @@ class TreeNode auto pp_parent_node = GetPartialPassAncestor(); if(pp_parent_node) - return FMKeyPP(length[0], length[1], length[2], precision, pp_parent_node->scheme); + return FMKeyPP(pp_parent_node->length[0], + pp_parent_node->length[1], + pp_parent_node->length[2], + precision, + pp_parent_node->scheme); else throw std::runtime_error("Invalid parent node for partial pass"); } From eedc32dca031ad581ce7f789a2f4156a3577caec Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 16 May 2025 08:56:37 -0600 Subject: [PATCH 32/69] - Resolve merge conflicts. --- library/src/device/generator/stockham_gen.cpp | 12 ---- library/src/include/function_pool.h | 72 +++---------------- 2 files changed, 9 insertions(+), 75 deletions(-) diff --git a/library/src/device/generator/stockham_gen.cpp b/library/src/device/generator/stockham_gen.cpp index 2a35bb7280d..623f5768d3e 100644 --- a/library/src/device/generator/stockham_gen.cpp +++ b/library/src/device/generator/stockham_gen.cpp @@ -460,18 +460,6 @@ static size_t max_bytes_per_element(const std::vector& precisions) return element_size; } -static size_t max_bytes_per_element(const std::vector& precisions) -{ - // generate for the maximum element size in the available - // precisions - size_t element_size = 0; - for(auto p : precisions) - { - element_size = std::max(element_size, complex_type_size(static_cast(p))); - } - return element_size; -} - int main() { std::string line; diff --git a/library/src/include/function_pool.h b/library/src/include/function_pool.h index 21b87ec0bd5..4c87f86b584 100644 --- a/library/src/include/function_pool.h +++ b/library/src/include/function_pool.h @@ -140,11 +140,11 @@ struct FFTKernel } }; -typedef std::unordered_multimap FPKeyMap; -typedef std::unordered_multimap FPKeyMapPP; +typedef std::unordered_multimap FPKeyMap; +typedef std::unordered_map FPKeyMapPP; -typedef std::unordered_multimap FPMap; -typedef std::unordered_multimap, SimpleHashPP> FPMapPP; +typedef std::unordered_multimap FPMap; +typedef std::unordered_map, SimpleHashPP> FPMapPP; struct function_pool_data { @@ -213,14 +213,6 @@ class function_pool return key; } - const FMKeyPP& get_actual_pp_key(const FMKeyPP& key) const - { - if(def_pp_key_pool.count(key) > 0) - return def_pp_key_pool.at(key); - else - return key; - } - public: function_pool(unsigned int max_lds_bytes) : max_lds_bytes(max_lds_bytes) @@ -238,8 +230,10 @@ class function_pool function_pool(const hipDeviceProp_t& prop) : max_lds_bytes(prop.sharedMemPerBlock) - , def_key_pool(function_pool_data::get_function_pool_data().def_key_pool) - , function_map(function_pool_data::get_function_pool_data().function_map) + , def_key_pool(std::get<0>(function_pool_data::get_function_pool_data().def_keys)) + , def_pp_key_pool(std::get<1>(function_pool_data::get_function_pool_data().def_keys)) + , function_map(std::get<0>(function_pool_data::get_function_pool_data().function_maps)) + , function_pp_map(std::get<1>(function_pool_data::get_function_pool_data().function_maps)) , deviceProp(prop) { // We would only see zero if we received a @@ -281,14 +275,6 @@ class function_pool return function_pp_map.count(real_key) > 0; } - bool has_pp_function(const FMKeyPP& key) const - { - auto real_key = get_actual_pp_key(key); - if(!real_key.base_lds_usage_fits(max_lds_bytes)) - return false; - return function_pp_map.count(real_key) > 0; - } - size_t get_largest_length(rocfft_precision precision) const { auto supported = get_lengths(precision, CS_KERNEL_STOCKHAM); @@ -343,25 +329,6 @@ class function_pool throw std::out_of_range("kernel not found in partial-pass map"); } - FFTKernel get_pp_kernel(const FMKeyPP& key, ComputeScheme scheme) const - { - auto real_key = get_actual_pp_key(key); - if(!real_key.base_lds_usage_fits(max_lds_bytes)) - throw std::out_of_range("kernel not found in partial-pass map"); - - auto kernel_list = function_pp_map.at(real_key); - - auto scheme_0 = kernel_list[0].pp_params.scheme; - auto scheme_1 = kernel_list[1].pp_params.scheme; - - if(scheme == scheme_0) - return kernel_list[0]; - else if(scheme == scheme_1) - return kernel_list[1]; - else - throw std::out_of_range("kernel not found in partial-pass map"); - } - // helper for common used bool has_SBCC_kernel(size_t length, rocfft_precision precision) const { @@ -413,27 +380,7 @@ static void insert_default_entry(const FMKey& def_key, def_key_pool.emplace(simple_key, def_key_with_lds); // still use the detailed key with config to maintain the function map - std::get<1>(function_map.emplace(def_key_with_lds, kernel)) -} - -static bool insert_default_pp_entry(const FMKeyPP& def_key, - const FFTKernel& kernel_0, - const FFTKernel& kernel_1, - FPKeyMapPP& def_key_pool, - FPMapPP& function_map) -{ - // simple_key means the same thing as def_key, but we just remove kernel-config - // so we don't need to know the exact config when we're lookin' for the default kernel - FMKeyPP simple_key(def_key); - simple_key.kernel_config_1 = KernelConfig::EmptyConfig(); - simple_key.kernel_config_2 = KernelConfig::EmptyConfig(); - - def_key_pool.emplace(simple_key, def_key); - - std::array kernels = {kernel_0, kernel_1}; - - // still use the detailed key with config to maintain the function map - return std::get<1>(function_map.emplace(def_key, kernels)); + function_map.emplace(def_key_with_lds, kernel); } static bool insert_default_pp_entry(const FMKeyPP& def_key, @@ -447,7 +394,6 @@ static bool insert_default_pp_entry(const FMKeyPP& def_key, FMKeyPP simple_key(def_key); simple_key.kernel_config_1 = KernelConfig::EmptyConfig(); simple_key.kernel_config_2 = KernelConfig::EmptyConfig(); - def_key_pool.emplace(simple_key, def_key); From 834f888e0c21cd0d6bf779b7b9f14218e99047e4 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 16 May 2025 09:40:37 -0600 Subject: [PATCH 33/69] - Formatting. --- library/src/device/generator/stockham_gen.cpp | 11 +++++------ library/src/include/function_pool.h | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/library/src/device/generator/stockham_gen.cpp b/library/src/device/generator/stockham_gen.cpp index 623f5768d3e..8c72553c127 100644 --- a/library/src/device/generator/stockham_gen.cpp +++ b/library/src/device/generator/stockham_gen.cpp @@ -436,7 +436,6 @@ void stockham_variants(const std::string& kernel_name, output_json(launchers, kernel_name, output); } - size_t pp_current_dim_rm(const size_t current_dim) { if(current_dim == 0) @@ -464,9 +463,9 @@ int main() { std::string line; - std::string kernel_name; - std::string scheme; - bool half_lds; + std::string kernel_name; + std::string scheme; + bool half_lds; unsigned int lds_size_bytes; const char* DELIM = ""; @@ -631,13 +630,13 @@ int main() // second dimension for 2D_SINGLE StockhamGeneratorSpecs specs2d( factors2d, factors, precisions, workgroup_size[0], scheme); - + if(!threads_per_transform.empty()) specs2d.threads_per_transform = threads_per_transform.back(); // aim for occupancy-2 by default specs.lds_byte_limit = lds_size_bytes / 2; - specs2d.lds_byte_limit = lds_size_bytes / 2; + specs2d.lds_byte_limit = lds_size_bytes / 2; stockham_variants(kernel_name, specs, specs2d, std::cout); } diff --git a/library/src/include/function_pool.h b/library/src/include/function_pool.h index 4c87f86b584..4ce1afa17e3 100644 --- a/library/src/include/function_pool.h +++ b/library/src/include/function_pool.h @@ -243,7 +243,7 @@ class function_pool throw std::runtime_error("function_pool: max_lds_bytes not initialized"); } - function_pool(function_pool& p) = delete; + function_pool(function_pool& p) = delete; function_pool& operator=(const function_pool&) = delete; ~function_pool() = default; From 66b16bbab7142059cfc5d3af9dc16601c629af1c Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 16 May 2025 10:45:52 -0600 Subject: [PATCH 34/69] - Refactor kernel-generator.py changes. --- library/src/device/kernel-generator.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index 017ee869a86..828e2911a82 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -195,17 +195,16 @@ def generate_cpu_function_pool_pieces(functions, pp_functions, num_files): ] # Cycles through each file per loop execution to distribute work amongst N files - i = 0 - j = 0 - i_offset = 0 if len(pp_functions) == 0 else len(precisions) + curr_func, curr_pp_func = 0, 0 + curr_func_offset = 0 if len(pp_functions) == 0 else len(precisions) curr_file = 0 - while i < len(all_functions) - i_offset: - f = all_functions[i] + while curr_func < len(all_functions) - curr_func_offset: + f = all_functions[curr_func] length, precision, scheme, transpose = f.meta.length, f.meta.precision, f.meta.scheme, f.meta.transpose if scheme == 'CS_3D_PP': piece_contents[curr_file] += Assign(var_pp_kernel_1, FFTKernel(f)) - f = all_functions[i + i_offset] + f = all_functions[curr_func + curr_func_offset] piece_contents[curr_file] += Assign(var_pp_kernel_2, FFTKernel(f)) key = Call( @@ -218,7 +217,7 @@ def generate_cpu_function_pool_pieces(functions, pp_functions, num_files): key, var_pp_kernel_1, var_pp_kernel_2, 'std::get<1>(def_keys)', 'std::get<1>(function_maps)') - j = j + 1 + curr_pp_func = curr_pp_func + 1 else: if isinstance(length, (int, str)): length = [length, 0] @@ -232,11 +231,10 @@ def generate_cpu_function_pool_pieces(functions, pp_functions, num_files): key, var_kernel, 'std::get<0>(def_keys)', 'std::get<0>(function_maps)', f.meta.lds_size_bytes) - if j == len(precisions): - j = 0 - i = i + len(precisions) + 1 + if curr_pp_func == len(precisions): + curr_func, curr_pp_func = curr_func + len(precisions) + 1, 0 else: - i = i + 1 + curr_func = curr_func + 1 curr_file = (curr_file + 1) % num_files @@ -1160,7 +1158,7 @@ def generate_kernel_functions(kernels, precisions, launchers_json): params=params, precision=p, runtime_compile=runtime_compile, - scheme=scheme, + scheme=scheme, workgroup_size=workgroup_size, transforms_per_block=transforms_per_block, threads_per_transform=tpt_list, @@ -1255,7 +1253,8 @@ def generate_kernels(kernels, precisions, stockham_gen): if len(k.factors) == 1: half_lds = False - # Send data over to subprocess + # Send data over to subprocess + if isinstance(k.workgroup_size, list): proc.stdin.write(" " + ','.join([str(f) for f in k.workgroup_size])) else: @@ -1272,6 +1271,7 @@ def generate_kernels(kernels, precisions, stockham_gen): direct_to_from_reg = getattr(k, 'direct_to_from_reg', True) proc.stdin.write(' 1' if direct_to_from_reg else ' 0') + # check for data specific to partial-pass 3D kernels if hasattr(k, 'dims'): proc.stdin.write(" " + ','.join([str(f) for f in k.dims])) proc.stdin.write(" " + ','.join([str(f) From e2c142bac4a87c84c280768c3673b32324a2dd52 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 16 May 2025 12:00:39 -0600 Subject: [PATCH 35/69] - Refactor and improvements. --- library/src/device/generator/stockham_gen.cpp | 126 +++++++++++------- library/src/device/kernel-generator.py | 6 +- library/src/include/function_pool.h | 24 ++-- 3 files changed, 91 insertions(+), 65 deletions(-) diff --git a/library/src/device/generator/stockham_gen.cpp b/library/src/device/generator/stockham_gen.cpp index 8c72553c127..8db742dbc35 100644 --- a/library/src/device/generator/stockham_gen.cpp +++ b/library/src/device/generator/stockham_gen.cpp @@ -436,15 +436,72 @@ void stockham_variants(const std::string& kernel_name, output_json(launchers, kernel_name, output); } -size_t pp_current_dim_rm(const size_t current_dim) +size_t pp_dim_rm(const size_t dim) { - if(current_dim == 0) + if(dim == 0) return 2; - if(current_dim == 2) + if(dim == 2) return 0; - return current_dim; + return dim; +} + +void pp_params_rm(std::vector& parent_length, + std::vector& dims, + std::vector& factors1, + std::vector& factors2, + std::vector& pp_factors1, + std::vector& pp_factors2, + std::vector& workgroup_size, + std::vector& threads_per_transform, + std::vector& direct_to_from_reg) +{ + std::reverse(parent_length.begin(), parent_length.end()); + + auto dims_rm = dims; + for(auto& dim : dims_rm) + dim = pp_dim_rm(dim); + + if(dims_rm != dims) + { + factors1.swap(factors2); + pp_factors1.swap(pp_factors2); + std::reverse(workgroup_size.begin(), workgroup_size.end()); + std::reverse(threads_per_transform.begin(), threads_per_transform.end()); + std::reverse(direct_to_from_reg.begin(), direct_to_from_reg.end()); + } +} + +unsigned int get_pp_off_dim(const std::vector& dims) +{ + if(dims.size() != 2) + throw std::runtime_error("CS_3D_PP requires two dimensions configuration"); + + unsigned int dims_sum = 0, off_dim = 0; + for(const auto& dim : dims) + { + if(dim < 0 || dim > 2) + throw std::runtime_error("Invalid dimensions configuration for CS_3D_PP"); + + dims_sum += dim; + } + switch(dims_sum) + { + case 1: + off_dim = 2; + break; + case 2: + off_dim = 1; + break; + case 3: + off_dim = 0; + break; + default: + throw std::runtime_error("Invalid dimensions configuration for CS_3D_PP"); + } + + return off_dim; } static size_t max_bytes_per_element(const std::vector& precisions) @@ -494,17 +551,17 @@ int main() ++arg; scheme = *arg; - std::vector parent_length, dims, pp_factors_1, pp_factors_2; + std::vector parent_length, dims, pp_factors1, pp_factors2; if(scheme == "CS_3D_PP") { ++arg; parent_length = parse_uints_csv(*arg); ++arg; - pp_factors_2 = parse_uints_csv(*arg); + pp_factors2 = parse_uints_csv(*arg); ++arg; - pp_factors_1 = parse_uints_csv(*arg); + pp_factors1 = parse_uints_csv(*arg); ++arg; dims = parse_uints_csv(*arg); @@ -542,48 +599,17 @@ int main() ++arg; factors1 = parse_uints_csv(*arg); - if(dims.size() != 2) - throw std::runtime_error("CS_3D_PP requires two dimensions configuration"); + auto off_dim = get_pp_off_dim(dims); - unsigned int dims_sum = 0, off_dim = 0; - for(const auto& dim : dims) - { - if(dim < 0 || dim > 2) - throw std::runtime_error("Invalid dimensions configuration for CS_3D_PP"); - - dims_sum += dim; - } - switch(dims_sum) - { - case 1: - off_dim = 2; - break; - case 2: - off_dim = 1; - break; - case 3: - off_dim = 0; - break; - default: - throw std::runtime_error("Invalid dimensions configuration for CS_3D_PP"); - } - - // parent length from column to row-major - std::reverse(parent_length.begin(), parent_length.end()); - - // current dim from column to row-major - auto dims_rm = dims; - for(auto& dim : dims_rm) - dim = pp_current_dim_rm(dim); - - if(dims_rm != dims) - { - factors1.swap(factors2); - pp_factors_1.swap(pp_factors_2); - std::reverse(workgroup_size.begin(), workgroup_size.end()); - std::reverse(threads_per_transform.begin(), threads_per_transform.end()); - std::reverse(direct_to_from_reg.begin(), direct_to_from_reg.end()); - } + pp_params_rm(parent_length, + dims, + factors1, + factors2, + pp_factors1, + pp_factors2, + workgroup_size, + threads_per_transform, + direct_to_from_reg); if(threads_per_transform.size() != 2) throw std::runtime_error( @@ -600,8 +626,8 @@ int main() specs2.direct_to_from_reg = direct_to_from_reg[1]; specs2.threads_per_transform = threads_per_transform[1]; - StockhamPartialPassParams pp_params_1(parent_length, dims[0], off_dim, pp_factors_1); - StockhamPartialPassParams pp_params_2(parent_length, dims[1], off_dim, pp_factors_2); + StockhamPartialPassParams pp_params_1(parent_length, dims[0], off_dim, pp_factors1); + StockhamPartialPassParams pp_params_2(parent_length, dims[1], off_dim, pp_factors2); stockham_partial_pass_variants( kernel_name, specs1, specs2, pp_params_1, pp_params_2, std::cout); diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index 828e2911a82..c26871c46bb 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -158,7 +158,7 @@ def generate_cpu_function_pool_main(num_files): type='void', name=f'function_pool_init_{i}', value= - 'std::tuple& def_keys, std::tuple& function_maps' + 'std::tuple& def_keys, std::tuple& function_maps' ) call_list = StatementList() @@ -241,8 +241,8 @@ def generate_cpu_function_pool_pieces(functions, pp_functions, num_files): # Assemble contents of each file to return in a list pieces = [None] * num_files piece_args = ArgumentList( - 'std::tuple& def_keys', - 'std::tuple& function_maps') + 'std::tuple& def_keys', + 'std::tuple& function_maps') for k in range(num_files): pieces[k] = StatementList( Include('"../include/function_pool.h"'), diff --git a/library/src/include/function_pool.h b/library/src/include/function_pool.h index 4ce1afa17e3..670f9af0d16 100644 --- a/library/src/include/function_pool.h +++ b/library/src/include/function_pool.h @@ -57,9 +57,9 @@ struct PartialPassParams { } - ComputeScheme scheme; - unsigned int current_dim; - unsigned int off_dim; + ComputeScheme scheme = CS_NONE; + unsigned int current_dim = 0; + unsigned int off_dim = 0; std::vector factors_off_dim; }; @@ -141,18 +141,18 @@ struct FFTKernel }; typedef std::unordered_multimap FPKeyMap; -typedef std::unordered_map FPKeyMapPP; +typedef std::unordered_map PPFPKeyMap; typedef std::unordered_multimap FPMap; -typedef std::unordered_map, SimpleHashPP> FPMapPP; +typedef std::unordered_map, SimpleHashPP> PPFPMap; struct function_pool_data { // when AOT generator adds a default key-kernel, // we get the keys of two version: empty-config vs full-config // make the pair as an entry in a map so that we know they are the same things - std::tuple def_keys; - std::tuple function_maps; + std::tuple def_keys; + std::tuple function_maps; function_pool_data(); @@ -167,10 +167,10 @@ class function_pool { unsigned int max_lds_bytes; FPKeyMap& def_key_pool; - FPKeyMapPP& def_pp_key_pool; + PPFPKeyMap& def_pp_key_pool; FPMap& function_map; - FPMapPP& function_pp_map; + PPFPMap& function_pp_map; // look in the specified map for the specified key, returning an // iterator to the item that fits best into the available LDS @@ -243,7 +243,7 @@ class function_pool throw std::runtime_error("function_pool: max_lds_bytes not initialized"); } - function_pool(function_pool& p) = delete; + function_pool(function_pool& p) = delete; function_pool& operator=(const function_pool&) = delete; ~function_pool() = default; @@ -386,8 +386,8 @@ static void insert_default_entry(const FMKey& def_key, static bool insert_default_pp_entry(const FMKeyPP& def_key, const FFTKernel& kernel_0, const FFTKernel& kernel_1, - FPKeyMapPP& def_key_pool, - FPMapPP& function_map) + PPFPKeyMap& def_key_pool, + PPFPMap& function_map) { // simple_key means the same thing as def_key, but we just remove kernel-config // so we don't need to know the exact config when we're lookin' for the default kernel From 304dd854781afa18bcc82862cf975742ea6b5956 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 16 May 2025 16:05:51 -0600 Subject: [PATCH 36/69] - Refactor function pool. --- library/src/device/generator.py | 11 ++-- library/src/device/kernel-generator.py | 8 +-- library/src/include/function_map_key.h | 52 ++++++++--------- library/src/include/function_pool.h | 77 +++++++++++++------------- library/src/include/tree_node.h | 6 +- library/src/node_factory.cpp | 4 +- library/src/rtc_stockham_kernel.cpp | 2 +- library/src/tree_node.cpp | 14 ++--- library/src/tree_node_1D.cpp | 8 +-- 9 files changed, 89 insertions(+), 93 deletions(-) diff --git a/library/src/device/generator.py b/library/src/device/generator.py index 325a3a0ad97..7e373e73d49 100644 --- a/library/src/device/generator.py +++ b/library/src/device/generator.py @@ -811,17 +811,16 @@ def assert_emplace(self, key, value, what_error): Throw('std::runtime_error("' + str(what_error) + '")')) return If(Equal(status, "false"), throw) - def assert_insert(self, key, value, def_key_pool, function_map, + def insert(self, key, value, def_key_pool, function_map, lds_size_bytes): return Call('insert_default_entry', arguments=ArgumentList(key, value, def_key_pool, function_map, lds_size_bytes)) - def assert_pp_insert(self, key, value_1, value_2, def_key_pool, function_map): - insert = Call('insert_default_pp_entry', + def insert_pp(self, key, value_1, value_2, def_key_pool, function_map, + lds_size_bytes): + return Call('insert_default_entry', arguments=ArgumentList(key, value_1, value_2, def_key_pool, - function_map)).inline() - throw = StatementList(Throw('std::runtime_error("' + str(key) + '")')) - return If(Equal(insert, "false"), throw) + function_map, lds_size_bytes)) # def __getitem__(self, idx): # return ArrayElement(self.name, idx) diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index c26871c46bb..23efe1487d1 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -208,14 +208,14 @@ def generate_cpu_function_pool_pieces(functions, pp_functions, num_files): piece_contents[curr_file] += Assign(var_pp_kernel_2, FFTKernel(f)) key = Call( - name='FMKeyPP', + name='PPFMKey', arguments=ArgumentList(length[0], length[1], length[2], precisions[precision], scheme, 'pp_kernel_1.get_kernel_config()', 'pp_kernel_2.get_kernel_config()')).inline() - piece_contents[curr_file] += function_map.assert_pp_insert( + piece_contents[curr_file] += function_map.insert_pp( key, var_pp_kernel_1, var_pp_kernel_2, 'std::get<1>(def_keys)', - 'std::get<1>(function_maps)') + 'std::get<1>(function_maps)', f.meta.lds_size_bytes) curr_pp_func = curr_pp_func + 1 else: @@ -227,7 +227,7 @@ def generate_cpu_function_pool_pieces(functions, pp_functions, num_files): arguments=ArgumentList(length[0], length[1], precisions[precision], scheme, transpose or 'NONE', 'kernel.get_kernel_config()')).inline() - piece_contents[curr_file] += function_map.assert_insert( + piece_contents[curr_file] += function_map.insert( key, var_kernel, 'std::get<0>(def_keys)', 'std::get<0>(function_maps)', f.meta.lds_size_bytes) diff --git a/library/src/include/function_map_key.h b/library/src/include/function_map_key.h index efdcfc576de..1ccf5bcc173 100644 --- a/library/src/include/function_map_key.h +++ b/library/src/include/function_map_key.h @@ -308,8 +308,11 @@ struct FromString struct FMKeyBase { - FMKeyBase(rocfft_precision precision, ComputeScheme scheme = CS_NONE) - : precision(precision) + FMKeyBase(std::array lengths, + rocfft_precision precision, + ComputeScheme scheme = CS_NONE) + : lengths(lengths) + , precision(precision) , scheme(scheme) { } @@ -324,6 +327,8 @@ struct FMKeyBase // allowed to coexist in the function pool. size_t lds_size_bytes = 0; + std::array lengths; + rocfft_precision precision; ComputeScheme scheme; }; @@ -343,8 +348,6 @@ struct FMKeyBase // struct FMKey : public FMKeyBase { - std::array lengths; - SBRC_TRANSPOSE_TYPE sbrcTrans = NONE; KernelConfig kernel_config = KernelConfig::EmptyConfig(); @@ -357,8 +360,7 @@ struct FMKey : public FMKeyBase ComputeScheme scheme = CS_KERNEL_STOCKHAM, SBRC_TRANSPOSE_TYPE transpose = NONE, KernelConfig kernel_config = KernelConfig::EmptyConfig()) - : FMKeyBase(precision, scheme) - , lengths({length0, 0}) + : FMKeyBase({length0, 0, 0}, precision, scheme) , sbrcTrans(transpose) , kernel_config(kernel_config) @@ -372,8 +374,7 @@ struct FMKey : public FMKeyBase ComputeScheme scheme = CS_KERNEL_2D_SINGLE, SBRC_TRANSPOSE_TYPE transpose = NONE, KernelConfig kernel_config = KernelConfig::EmptyConfig()) - : FMKeyBase(precision, scheme) - , lengths({length0, length1}) + : FMKeyBase({length0, length1, 0}, precision, scheme) , sbrcTrans(transpose) , kernel_config(kernel_config) { @@ -497,38 +498,31 @@ struct SimpleHash } }; -// TODO: Check if FMKeyPP and FMKey can derive from a common base class -// This should simplfy operations in the function_pool, e.g., -// FMKey key = GetKernelKey(); -// if(!pool.has_function(key)) -// in LeafNode::KernelCheck(std::vector& kernel_keys) -struct FMKeyPP : public FMKeyBase +struct PPFMKey : public FMKeyBase { - std::array lengths; - KernelConfig kernel_config_1 = KernelConfig::EmptyConfig(); - KernelConfig kernel_config_2 = KernelConfig::EmptyConfig(); + KernelConfig kernel_config_1 = KernelConfig::EmptyConfig(); + KernelConfig kernel_config_2 = KernelConfig::EmptyConfig(); - FMKeyPP() = default; - FMKeyPP(const FMKeyPP&) = default; + PPFMKey() = default; + PPFMKey(const PPFMKey&) = default; // with every data - FMKeyPP(size_t length0, + PPFMKey(size_t length0, size_t length1, size_t length2, rocfft_precision precision, ComputeScheme scheme = CS_3D_PP, KernelConfig kernel_config_1 = KernelConfig::EmptyConfig(), KernelConfig kernel_config_2 = KernelConfig::EmptyConfig()) - : FMKeyBase(precision, scheme) - , lengths({length0, length1, length2}) + : FMKeyBase({length0, length1, length2}, precision, scheme) , kernel_config_1(kernel_config_1) , kernel_config_2(kernel_config_2) { } - FMKeyPP& operator=(const FMKeyPP&) = default; + PPFMKey& operator=(const PPFMKey&) = default; - bool operator==(const FMKeyPP& rhs) const + bool operator==(const PPFMKey& rhs) const { return std::tie(lengths, precision, scheme, kernel_config_1, kernel_config_2) == std::tie(rhs.lengths, @@ -538,12 +532,12 @@ struct FMKeyPP : public FMKeyBase rhs.kernel_config_2); } - bool operator!=(const FMKeyPP& rhs) const + bool operator!=(const PPFMKey& rhs) const { return !((*this) == rhs); } - bool operator<(const FMKeyPP& rhs) const + bool operator<(const PPFMKey& rhs) const { return std::tie(lengths, precision, scheme, kernel_config_1, kernel_config_2) < std::tie(rhs.lengths, @@ -553,9 +547,9 @@ struct FMKeyPP : public FMKeyBase rhs.kernel_config_2); } - static FMKeyPP EmptyFMKeyPP() + static PPFMKey EmptyPPFMKey() { - static FMKeyPP empty; + static PPFMKey empty; return empty; } @@ -568,7 +562,7 @@ struct FMKeyPP : public FMKeyBase struct SimpleHashPP { - size_t operator()(const FMKeyPP& p) const noexcept + size_t operator()(const PPFMKey& p) const noexcept { size_t h = 0; for(auto& v : p.lengths) diff --git a/library/src/include/function_pool.h b/library/src/include/function_pool.h index 670f9af0d16..c4fbb785122 100644 --- a/library/src/include/function_pool.h +++ b/library/src/include/function_pool.h @@ -98,10 +98,10 @@ struct FFTKernel bool half_lds = false, bool direct_to_from_reg = false, bool aot_rtc = false, - ComputeScheme pp_scheme = CS_NONE, - unsigned int pp_current_dim = 0, - unsigned int pp_off_dim = 0, - std::vector&& pp_factors_off_dim = std::vector()) + ComputeScheme scheme = CS_NONE, + unsigned int current_dim = 0, + unsigned int off_dim = 0, + std::vector&& factors_off_dim = std::vector()) : factors(factors) , transforms_per_block(tpb) , workgroup_size(wgs) @@ -110,7 +110,7 @@ struct FFTKernel , half_lds(half_lds) , direct_to_from_reg(direct_to_from_reg) , aot_rtc(aot_rtc) - , pp_params(pp_scheme, pp_current_dim, pp_off_dim, pp_factors_off_dim) + , pp_params(scheme, current_dim, off_dim, factors_off_dim) { } @@ -140,11 +140,11 @@ struct FFTKernel } }; -typedef std::unordered_multimap FPKeyMap; -typedef std::unordered_map PPFPKeyMap; +typedef std::unordered_multimap FPKeyMap; +typedef std::unordered_multimap PPFPKeyMap; -typedef std::unordered_multimap FPMap; -typedef std::unordered_map, SimpleHashPP> PPFPMap; +typedef std::unordered_multimap FPMap; +typedef std::unordered_multimap, SimpleHashPP> PPFPMap; struct function_pool_data { @@ -170,12 +170,12 @@ class function_pool PPFPKeyMap& def_pp_key_pool; FPMap& function_map; - PPFPMap& function_pp_map; + PPFPMap& pp_function_map; // look in the specified map for the specified key, returning an // iterator to the item that fits best into the available LDS - template - typename Tmap::const_iterator find_key_in_map(const Tmap& fmap, const FMKey& key) const + template + typename Tmap::const_iterator find_key_in_map(const Tmap& fmap, const TKey& key) const { auto range = fmap.equal_range(key); auto best = fmap.end(); @@ -205,10 +205,11 @@ class function_pool return key; } - const FMKeyPP& get_actual_pp_key(const FMKeyPP& key) const + const PPFMKey& get_actual_key(const PPFMKey& key) const { - if(def_pp_key_pool.count(key) > 0) - return def_pp_key_pool.at(key); + auto it = find_key_in_map(def_pp_key_pool, key); + if(it != def_pp_key_pool.end()) + return it->second; else return key; } @@ -219,7 +220,7 @@ class function_pool , def_key_pool(std::get<0>(function_pool_data::get_function_pool_data().def_keys)) , def_pp_key_pool(std::get<1>(function_pool_data::get_function_pool_data().def_keys)) , function_map(std::get<0>(function_pool_data::get_function_pool_data().function_maps)) - , function_pp_map(std::get<1>(function_pool_data::get_function_pool_data().function_maps)) + , pp_function_map(std::get<1>(function_pool_data::get_function_pool_data().function_maps)) { // We would only see zero if we received a // default-constructed device prop struct, which means @@ -233,7 +234,7 @@ class function_pool , def_key_pool(std::get<0>(function_pool_data::get_function_pool_data().def_keys)) , def_pp_key_pool(std::get<1>(function_pool_data::get_function_pool_data().def_keys)) , function_map(std::get<0>(function_pool_data::get_function_pool_data().function_maps)) - , function_pp_map(std::get<1>(function_pool_data::get_function_pool_data().function_maps)) + , pp_function_map(std::get<1>(function_pool_data::get_function_pool_data().function_maps)) , deviceProp(prop) { // We would only see zero if we received a @@ -243,7 +244,7 @@ class function_pool throw std::runtime_error("function_pool: max_lds_bytes not initialized"); } - function_pool(function_pool& p) = delete; + function_pool(function_pool& p) = delete; function_pool& operator=(const function_pool&) = delete; ~function_pool() = default; @@ -267,12 +268,10 @@ class function_pool return find_key_in_map(function_map, real_key) != function_map.end(); } - bool has_pp_function(const FMKeyPP& key) const + bool has_function(const PPFMKey& key) const { - auto real_key = get_actual_pp_key(key); - if(!real_key.base_lds_usage_fits(max_lds_bytes)) - return false; - return function_pp_map.count(real_key) > 0; + auto real_key = get_actual_key(key); + return find_key_in_map(pp_function_map, real_key) != pp_function_map.end(); } size_t get_largest_length(rocfft_precision precision) const @@ -310,13 +309,14 @@ class function_pool return it->second; } - FFTKernel get_pp_kernel(const FMKeyPP& key, ComputeScheme scheme) const + FFTKernel get_kernel(const PPFMKey& key, ComputeScheme scheme) const { - auto real_key = get_actual_pp_key(key); - if(!real_key.base_lds_usage_fits(max_lds_bytes)) + auto real_key = get_actual_key(key); + auto it = find_key_in_map(pp_function_map, real_key); + if(it == pp_function_map.end()) throw std::out_of_range("kernel not found in partial-pass map"); - auto kernel_list = function_pp_map.at(real_key); + auto kernel_list = it->second; auto scheme_0 = kernel_list[0].pp_params.scheme; auto scheme_1 = kernel_list[1].pp_params.scheme; @@ -383,24 +383,27 @@ static void insert_default_entry(const FMKey& def_key, function_map.emplace(def_key_with_lds, kernel); } -static bool insert_default_pp_entry(const FMKeyPP& def_key, - const FFTKernel& kernel_0, - const FFTKernel& kernel_1, - PPFPKeyMap& def_key_pool, - PPFPMap& function_map) +static void insert_default_entry(const PPFMKey& def_key, + const FFTKernel& kernel_0, + const FFTKernel& kernel_1, + PPFPKeyMap& def_key_pool, + PPFPMap& function_map, + size_t lds_size_bytes) { - // simple_key means the same thing as def_key, but we just remove kernel-config - // so we don't need to know the exact config when we're lookin' for the default kernel - FMKeyPP simple_key(def_key); + PPFMKey def_key_with_lds = def_key; + def_key_with_lds.lds_size_bytes = lds_size_bytes; + + PPFMKey simple_key(def_key_with_lds); + simple_key.kernel_config_1 = KernelConfig::EmptyConfig(); simple_key.kernel_config_2 = KernelConfig::EmptyConfig(); - def_key_pool.emplace(simple_key, def_key); + def_key_pool.emplace(simple_key, def_key_with_lds); std::array kernels = {kernel_0, kernel_1}; // still use the detailed key with config to maintain the function map - return std::get<1>(function_map.emplace(def_key, kernels)); + function_map.emplace(def_key_with_lds, kernels); } #endif // FUNCTION_POOL_H diff --git a/library/src/include/tree_node.h b/library/src/include/tree_node.h index 12a4fdc7159..e1c211f1ed9 100644 --- a/library/src/include/tree_node.h +++ b/library/src/include/tree_node.h @@ -353,7 +353,7 @@ class TreeNode // specified kernel key from solution map. (if there is any) std::unique_ptr specified_key; - std::unique_ptr specified_pp_key; + std::unique_ptr specified_pp_key; // Tree structure: // non-owning pointer to parent node, may be null @@ -588,14 +588,14 @@ class TreeNode : FMKey(length[0], length[1], precision, scheme); } - virtual FMKeyPP GetPPKernelsKey() const + virtual PPFMKey GetPPKernelsKey() const { if(specified_pp_key) return *specified_pp_key.get(); auto pp_parent_node = GetPartialPassAncestor(); if(pp_parent_node) - return FMKeyPP(pp_parent_node->length[0], + return PPFMKey(pp_parent_node->length[0], pp_parent_node->length[1], pp_parent_node->length[2], precision, diff --git a/library/src/node_factory.cpp b/library/src/node_factory.cpp index f641e312fbc..f158dfdb308 100644 --- a/library/src/node_factory.cpp +++ b/library/src/node_factory.cpp @@ -1014,9 +1014,9 @@ bool NodeFactory::use_CS_3D_RC(const function_pool& pool, NodeMetaData& nodeData bool NodeFactory::use_CS_3D_PP(const function_pool& pool, NodeMetaData& nodeData) { - if(!pool.has_pp_function( + if(!pool.has_function( - FMKeyPP(nodeData.length[0], + PPFMKey(nodeData.length[0], nodeData.length[1], nodeData.length[2], nodeData.precision, diff --git a/library/src/rtc_stockham_kernel.cpp b/library/src/rtc_stockham_kernel.cpp index fb1dbc0a2fb..ba2d8c15e2b 100644 --- a/library/src/rtc_stockham_kernel.cpp +++ b/library/src/rtc_stockham_kernel.cpp @@ -77,7 +77,7 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& { auto pp_key = node.GetPPKernelsKey(); - kernel = node.pool.get_pp_kernel(pp_key, node.scheme); + kernel = node.pool.get_kernel(pp_key, node.scheme); std::copy(kernel->factors.begin(), kernel->factors.end(), std::back_inserter(factors)); std::vector precisions = {static_cast(node.precision)}; diff --git a/library/src/tree_node.cpp b/library/src/tree_node.cpp index 678464e59a4..2164447052b 100644 --- a/library/src/tree_node.cpp +++ b/library/src/tree_node.cpp @@ -106,8 +106,8 @@ void LeafNode::GetKernelFactors() { if(isPartialPassEnabled()) { - FMKeyPP key = GetPPKernelsKey(); - kernelFactors = pool.get_pp_kernel(key, scheme).factors; + PPFMKey key = GetPPKernelsKey(); + kernelFactors = pool.get_kernel(key, scheme).factors; } else { @@ -118,8 +118,8 @@ void LeafNode::GetKernelFactors() void LeafNode::GetKernelPartialPassFactors() { - FMKeyPP key = GetPPKernelsKey(); - auto kernel = pool.get_pp_kernel(key, scheme); + PPFMKey key = GetPPKernelsKey(); + auto kernel = pool.get_kernel(key, scheme); kernelFactorsPP = std::vector(kernel.pp_params.factors_off_dim.begin(), kernel.pp_params.factors_off_dim.end()); @@ -196,11 +196,11 @@ bool LeafNode::KernelCheck(std::vector& kernel_keys) if(isPartialPassEnabled()) { - FMKeyPP key = GetPPKernelsKey(); - if(!pool.has_pp_function(key)) + PPFMKey key = GetPPKernelsKey(); + if(!pool.has_function(key)) return false; - auto kernel = pool.get_pp_kernel(key, scheme); + auto kernel = pool.get_kernel(key, scheme); dir2regMode = (kernel.direct_to_from_reg) ? DirectRegType::TRY_ENABLE_IF_SUPPORT : DirectRegType::FORCE_OFF_OR_NOT_SUPPORT; diff --git a/library/src/tree_node_1D.cpp b/library/src/tree_node_1D.cpp index 440aae5eb9b..3e976faed52 100644 --- a/library/src/tree_node_1D.cpp +++ b/library/src/tree_node_1D.cpp @@ -950,8 +950,8 @@ void StockhamPP1DNode::SetupGridParam_internal(GridParam& gp) for(size_t j = 1; j < length.size(); j++) batch_accum *= length[j]; - FMKeyPP key = GetPPKernelsKey(); - auto kernel = pool.get_pp_kernel(key, scheme); + PPFMKey key = GetPPKernelsKey(); + auto kernel = pool.get_kernel(key, scheme); bwd = kernel.transforms_per_block; wgs = kernel.workgroup_size; @@ -1157,8 +1157,8 @@ std::vector SBCCNode::CollapsibleDims() void SBCCPPNode::SetupGridParam_internal(GridParam& gp) { - FMKeyPP key = GetPPKernelsKey(); - auto kernel = pool.get_pp_kernel(key, scheme); + PPFMKey key = GetPPKernelsKey(); + auto kernel = pool.get_kernel(key, scheme); bwd = kernel.transforms_per_block; wgs = kernel.workgroup_size; From 0a53c1e4c3cae088832eb915fc15b3988681319f Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Tue, 20 May 2025 16:31:15 -0600 Subject: [PATCH 37/69] - Refactoring. - Clean up. --- library/src/include/function_pool.h | 19 ++++++++-- library/src/include/tree_node.h | 57 ++++++++++++++++++++++++---- library/src/rtc_stockham_kernel.cpp | 58 +++++++++-------------------- library/src/tree_node.cpp | 52 ++++++-------------------- library/src/tree_node_1D.cpp | 14 +++---- library/src/tree_node_2D.cpp | 2 +- library/src/tree_node_3D.cpp | 4 +- 7 files changed, 103 insertions(+), 103 deletions(-) diff --git a/library/src/include/function_pool.h b/library/src/include/function_pool.h index c4fbb785122..cc4a34ca5c5 100644 --- a/library/src/include/function_pool.h +++ b/library/src/include/function_pool.h @@ -30,15 +30,28 @@ #include #include -inline std::string PrintMissingKernelInfo(const FMKey& key) +inline std::string PrintMissingKernelInfoBase(const FMKeyBase& key) { std::stringstream msg; msg << "Kernel not found: \n" << "\tlength: " << key.lengths[0] << "," << key.lengths[1] << "\n" << "\tprecision: " << key.precision << "\n" - << "\tscheme: " << PrintScheme(key.scheme) << "\n" - << "\tSBRC Transpose type: " << PrintSBRCTransposeType(key.sbrcTrans) << std::endl; + << "\tscheme: " << PrintScheme(key.scheme) << "\n"; + return msg.str(); +} +inline std::string PrintMissingKernelInfo(const PPFMKey& key) +{ + std::stringstream msg; + msg << PrintMissingKernelInfoBase(key); + return msg.str(); +} + +inline std::string PrintMissingKernelInfo(const FMKey& key) +{ + std::stringstream msg; + msg << PrintMissingKernelInfoBase(key) + << "\tSBRC Transpose type: " << PrintSBRCTransposeType(key.sbrcTrans) << std::endl; return msg.str(); } diff --git a/library/src/include/tree_node.h b/library/src/include/tree_node.h index e1c211f1ed9..14e504693d8 100644 --- a/library/src/include/tree_node.h +++ b/library/src/include/tree_node.h @@ -41,6 +41,7 @@ #include "function_pool.h" #include "kargs.h" #include "load_store_ops.h" +#include "logging.h" #include "rtc_kernel.h" #include @@ -594,14 +595,56 @@ class TreeNode return *specified_pp_key.get(); auto pp_parent_node = GetPartialPassAncestor(); - if(pp_parent_node) - return PPFMKey(pp_parent_node->length[0], - pp_parent_node->length[1], - pp_parent_node->length[2], - precision, - pp_parent_node->scheme); - else + if(!pp_parent_node) throw std::runtime_error("Invalid parent node for partial pass"); + + return PPFMKey(pp_parent_node->length[0], + pp_parent_node->length[1], + pp_parent_node->length[2], + precision, + pp_parent_node->scheme); + } + + virtual FFTKernel GetKernel() const + { + if(isPartialPassEnabled()) + { + auto key = GetPPKernelsKey(); + return pool.get_kernel(key, scheme); + } + else + { + auto key = GetKernelKey(); + return pool.get_kernel(key); + } + } + + virtual bool HasKernel() const + { + if(isPartialPassEnabled()) + { + auto key = GetPPKernelsKey(); + if(!pool.has_function(key)) + { + if(LOG_TRACE_ENABLED()) + (*LogSingleton::GetInstance().GetTraceOS()) << PrintMissingKernelInfo(key); + + return false; + } + } + else + { + auto key = GetKernelKey(); + if(!pool.has_function(key)) + { + if(LOG_TRACE_ENABLED()) + (*LogSingleton::GetInstance().GetTraceOS()) << PrintMissingKernelInfo(key); + + return false; + } + } + + return true; } // Compute the large twd decomposition base diff --git a/library/src/rtc_stockham_kernel.cpp b/library/src/rtc_stockham_kernel.cpp index ba2d8c15e2b..50524ce432d 100644 --- a/library/src/rtc_stockham_kernel.cpp +++ b/library/src/rtc_stockham_kernel.cpp @@ -53,10 +53,6 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& std::optional kernel; - // find function pool entry so we can construct specs for the generator - // NB: make sure all SBRC-type node have the correct trans_type value - FMKey key; - key = node.GetKernelKey(); switch(pool_scheme) { case CS_KERNEL_STOCKHAM: @@ -71,59 +67,39 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& if((pool_scheme == CS_KERNEL_STOCKHAM_BLOCK_RC) && (node.sbrcTranstype == NONE)) throw std::runtime_error("Invalid SBRC_TRANS_TYPE for SBRC kernel"); - std::vector factors; + // these go into the function pool normally and are passed to + // the generator as-is + kernel = node.GetKernel(); - if(node.isPartialPassEnabled()) - { - auto pp_key = node.GetPPKernelsKey(); + std::vector factors; + std::copy(kernel->factors.begin(), kernel->factors.end(), std::back_inserter(factors)); + std::vector precisions = {static_cast(node.precision)}; - kernel = node.pool.get_kernel(pp_key, node.scheme); + specs.emplace(factors, + std::vector(), + precisions, + static_cast(kernel->workgroup_size), + PrintScheme(node.scheme)); - std::copy(kernel->factors.begin(), kernel->factors.end(), std::back_inserter(factors)); - std::vector precisions = {static_cast(node.precision)}; + specs->threads_per_transform = kernel->threads_per_transform[0]; + specs->half_lds = kernel->half_lds; + specs->direct_to_from_reg = kernel->direct_to_from_reg; + if(node.isPartialPassEnabled()) + { pp_params.off_dim = node.ppOffDim; pp_params.current_dim = node.ppCurrDim; pp_params.factors_off_dim = std::vector( kernel->pp_params.factors_off_dim.begin(), kernel->pp_params.factors_off_dim.end()); pp_params.parent_length = std::vector(node.length.begin(), node.length.end()); - - specs.emplace(factors, - std::vector(), - precisions, - static_cast(kernel->workgroup_size), - PrintScheme(node.scheme)); - - specs->threads_per_transform = kernel->threads_per_transform[0]; - specs->half_lds = kernel->half_lds; - specs->direct_to_from_reg = kernel->direct_to_from_reg; - } - else - { - // these go into the function pool normally and are passed to - // the generator as-is - kernel = node.pool.get_kernel(key); - - std::copy(kernel->factors.begin(), kernel->factors.end(), std::back_inserter(factors)); - std::vector precisions = {static_cast(node.precision)}; - - specs.emplace(factors, - std::vector(), - precisions, - static_cast(kernel->workgroup_size), - PrintScheme(node.scheme)); - - specs->threads_per_transform = kernel->threads_per_transform[0]; - specs->half_lds = kernel->half_lds; - specs->direct_to_from_reg = kernel->direct_to_from_reg; } break; } case CS_KERNEL_2D_SINGLE: { - kernel = node.pool.get_kernel(key); + kernel = node.GetKernel(); std::vector factors1d; std::vector factors2d; diff --git a/library/src/tree_node.cpp b/library/src/tree_node.cpp index 2164447052b..2dc6f8e30cc 100644 --- a/library/src/tree_node.cpp +++ b/library/src/tree_node.cpp @@ -104,22 +104,13 @@ FMKey LeafNode::GetKernelKey() const void LeafNode::GetKernelFactors() { - if(isPartialPassEnabled()) - { - PPFMKey key = GetPPKernelsKey(); - kernelFactors = pool.get_kernel(key, scheme).factors; - } - else - { - FMKey key = GetKernelKey(); - kernelFactors = pool.get_kernel(key).factors; - } + auto kernel = GetKernel(); + kernelFactors = kernel.factors; } void LeafNode::GetKernelPartialPassFactors() { - PPFMKey key = GetPPKernelsKey(); - auto kernel = pool.get_kernel(key, scheme); + auto kernel = GetKernel(); kernelFactorsPP = std::vector(kernel.pp_params.factors_off_dim.begin(), kernel.pp_params.factors_off_dim.end()); @@ -194,38 +185,19 @@ bool LeafNode::KernelCheck(std::vector& kernel_keys) } } - if(isPartialPassEnabled()) - { - PPFMKey key = GetPPKernelsKey(); - if(!pool.has_function(key)) - return false; + // get the final key and check if we have the kernel. + // Note that the check is trivial if we are using "specified_key" + // since we definitly have the kernel, but not trivial if it's the auto-gen key + HasKernel(); - auto kernel = pool.get_kernel(key, scheme); - dir2regMode = (kernel.direct_to_from_reg) ? DirectRegType::TRY_ENABLE_IF_SUPPORT - : DirectRegType::FORCE_OFF_OR_NOT_SUPPORT; + GetKernelFactors(); + if(isPartialPassEnabled()) GetKernelPartialPassFactors(); - } - else - { - // get the final key and check if we have the kernel. - // Note that the check is trivial if we are using "specified_key" - // since we definitly have the kernel, but not trivial if it's the auto-gen key - FMKey key = GetKernelKey(); - if(!pool.has_function(key)) - { - if(LOG_TRACE_ENABLED()) - (*LogSingleton::GetInstance().GetTraceOS()) << PrintMissingKernelInfo(key); - return false; - } - - dir2regMode = (pool.get_kernel(key).direct_to_from_reg) - ? DirectRegType::TRY_ENABLE_IF_SUPPORT - : DirectRegType::FORCE_OFF_OR_NOT_SUPPORT; - - GetKernelFactors(); - } + auto kernel = GetKernel(); + dir2regMode = (kernel.direct_to_from_reg) ? DirectRegType::TRY_ENABLE_IF_SUPPORT + : DirectRegType::FORCE_OFF_OR_NOT_SUPPORT; return true; } diff --git a/library/src/tree_node_1D.cpp b/library/src/tree_node_1D.cpp index 3e976faed52..0187acd0e7e 100644 --- a/library/src/tree_node_1D.cpp +++ b/library/src/tree_node_1D.cpp @@ -900,8 +900,7 @@ void Stockham1DNode::SetupGridParam_internal(GridParam& gp) for(size_t j = 1; j < length.size(); j++) batch_accum *= length[j]; - auto key = GetKernelKey(); - auto kernel = pool.get_kernel(key); + auto kernel = GetKernel(); bwd = kernel.transforms_per_block; wgs = kernel.workgroup_size; @@ -950,8 +949,7 @@ void StockhamPP1DNode::SetupGridParam_internal(GridParam& gp) for(size_t j = 1; j < length.size(); j++) batch_accum *= length[j]; - PPFMKey key = GetPPKernelsKey(); - auto kernel = pool.get_kernel(key, scheme); + auto kernel = GetKernel(); bwd = kernel.transforms_per_block; wgs = kernel.workgroup_size; @@ -990,8 +988,7 @@ bool SBCCNode::KernelCheck(std::vector& kernel_keys) if(large1D > 0) { - FMKey key = GetKernelKey(); - auto kernel = pool.get_kernel(key); + auto kernel = GetKernel(); largeTwd3Steps = kernel.use_3steps_large_twd; get_large_twd_base_steps(large1D, largeTwd3Steps, largeTwdBase, ltwdSteps); } @@ -1128,7 +1125,7 @@ void SBCCNode::TuneIntrinsicMode() void SBCCNode::SetupGridParam_internal(GridParam& gp) { - auto kernel = pool.get_kernel(GetKernelKey()); + auto kernel = GetKernel(); bwd = kernel.transforms_per_block; wgs = kernel.workgroup_size; @@ -1157,8 +1154,7 @@ std::vector SBCCNode::CollapsibleDims() void SBCCPPNode::SetupGridParam_internal(GridParam& gp) { - PPFMKey key = GetPPKernelsKey(); - auto kernel = pool.get_kernel(key, scheme); + auto kernel = GetKernel(); bwd = kernel.transforms_per_block; wgs = kernel.workgroup_size; diff --git a/library/src/tree_node_2D.cpp b/library/src/tree_node_2D.cpp index addb144ccdb..9d554f00d24 100644 --- a/library/src/tree_node_2D.cpp +++ b/library/src/tree_node_2D.cpp @@ -316,7 +316,7 @@ bool Single2DNode::CreateDeviceResources() void Single2DNode::SetupGridParam_internal(GridParam& gp) { - auto kernel = pool.get_kernel(GetKernelKey()); + auto kernel = GetKernel(); bwd = kernel.transforms_per_block; wgs = kernel.workgroup_size; diff --git a/library/src/tree_node_3D.cpp b/library/src/tree_node_3D.cpp index f8d60e5d7f4..c55e2d42dbf 100644 --- a/library/src/tree_node_3D.cpp +++ b/library/src/tree_node_3D.cpp @@ -884,7 +884,7 @@ void SBRCTranspose3DNode::TuneDirectRegType() void SBRCTransXY_ZNode::SetupGridParam_internal(GridParam& gp) { // sbrcTransType has already been assigned in KernelCheck(); - auto kernel = pool.get_kernel(GetKernelKey()); + auto kernel = GetKernel(); bwd = kernel.transforms_per_block; wgs = kernel.workgroup_size; lds = length[0] * bwd; @@ -899,7 +899,7 @@ void SBRCTransXY_ZNode::SetupGridParam_internal(GridParam& gp) void SBRCTransZ_XYNode::SetupGridParam_internal(GridParam& gp) { // sbrcTransType has already been assigned in KernelCheck(); - auto kernel = pool.get_kernel(GetKernelKey()); + auto kernel = GetKernel(); bwd = kernel.transforms_per_block; wgs = kernel.workgroup_size; lds = length[0] * bwd; From a37b90c4d8fd695ac9d84f9d6474133ac4a36fed Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Wed, 21 May 2025 12:38:52 -0600 Subject: [PATCH 38/69] - Get partial pass off-dim from function pool. - Improve comments. - Clean up. --- library/src/include/compute_scheme.h | 2 +- library/src/include/tree_node_3D.h | 5 +- library/src/node_factory.cpp | 3 +- library/src/tree_node_3D.cpp | 212 ++++++++++++++++----------- 4 files changed, 135 insertions(+), 87 deletions(-) diff --git a/library/src/include/compute_scheme.h b/library/src/include/compute_scheme.h index 3c466b9fcda..954ef153cad 100644 --- a/library/src/include/compute_scheme.h +++ b/library/src/include/compute_scheme.h @@ -82,7 +82,7 @@ enum ComputeScheme CS_3D_RC, CS_3D_PP, CS_KERNEL_3D_STOCKHAM_BLOCK_CC, // not implemented yet - CS_KERNEL_3D_SINGLE, // not implemented yet + CS_KERNEL_3D_SINGLE // not implemented yet }; // print abbreviation for kernel scheme diff --git a/library/src/include/tree_node_3D.h b/library/src/include/tree_node_3D.h index e8b56086981..9bcb279b602 100644 --- a/library/src/include/tree_node_3D.h +++ b/library/src/include/tree_node_3D.h @@ -147,8 +147,9 @@ class PP3DNode : public InternalNode scheme = CS_3D_PP; } - void AssignParams_internal() override; - void BuildTree_internal(SchemeTreeVec& child_scheme_trees = EmptySchemeTreeVec) override; + void AssignParams_internal() override; + void BuildTree_internal(SchemeTreeVec& child_scheme_trees = EmptySchemeTreeVec) override; + size_t GetPPOffDim() const; }; /***************************************************** diff --git a/library/src/node_factory.cpp b/library/src/node_factory.cpp index f158dfdb308..1dc18d31258 100644 --- a/library/src/node_factory.cpp +++ b/library/src/node_factory.cpp @@ -983,8 +983,7 @@ bool NodeFactory::use_CS_3D_RC(const function_pool& pool, NodeMetaData& nodeData return false; // Check the C part. - // The first R is built recursively with 2D_FFT, or with a - // 1_D FTT + partial pass(es). leave the check part to themselves + // The first R is built recursively with 2D_FFT, leave the check part to themselves auto kernel = pool.get_kernel(key); // hack for this special case diff --git a/library/src/tree_node_3D.cpp b/library/src/tree_node_3D.cpp index c55e2d42dbf..aaa62b60035 100644 --- a/library/src/tree_node_3D.cpp +++ b/library/src/tree_node_3D.cpp @@ -652,9 +652,6 @@ void RC3DNode::BuildTree_internal(SchemeTreeVec& child_scheme_trees) void RC3DNode::AssignParams_internal() { - // in partial pass case: - // xy plan is a x row 1D-FFT + plus partial pass(es) along y - // z plan is partial pass(es) along y + z col 1D-FFT. auto& xyPlan = childNodes[0]; auto& zPlan = childNodes[1]; @@ -683,105 +680,156 @@ void RC3DNode::AssignParams_internal() /***************************************************** * CS_3D_PP * *****************************************************/ -void PP3DNode::BuildTree_internal(SchemeTreeVec& child_scheme_trees) +size_t PP3DNode::GetPPOffDim() const { - // bool noSolution = child_scheme_trees.empty(); - - // // check schemes from solution map - // ComputeScheme determined_scheme_node0 = CS_NONE; - // ComputeScheme determined_scheme_node1 = CS_NONE; - // if(!noSolution) - // { - // if((child_scheme_trees.size() != 2)) - // throw std::runtime_error("RC3DNode: Unexpected child scheme from solution map"); - // determined_scheme_node0 = child_scheme_trees[0]->curScheme; - // determined_scheme_node1 = child_scheme_trees[1]->curScheme; - // } - - // TODO: Child nodes currently hardcoded to a x+z configuration - // in 3D partial-pass. Add support for other configurations, - // e.g., x+y, y+z, once partial pass is fully configurable - // in kernel-generator.py. - - // work along y will be split between x and z - - // x row fft + partial pass(es) along y - NodeMetaData xPartialPassPlanData(this); - xPartialPassPlanData.length.push_back(length[0]); - xPartialPassPlanData.length.push_back(length[1]); - // technically 1 < dimension < 2 for x node. - xPartialPassPlanData.dimension = 1; - xPartialPassPlanData.length.push_back(length[2]); - for(size_t index = 3; index < length.size(); index++) - { - xPartialPassPlanData.length.push_back(length[index]); - } + // CS_3D_PP will have two corresponding kernels in + // the function pool, both will have the same off-dim + // value, and at least of one them must be an SBCC PP + auto child_scheme = CS_KERNEL_STOCKHAM_PP_BLOCK_CC; - // use explicit (modified) SBRR kernel - std::unique_ptr xPartialPassPlan; + auto key = PPFMKey(length[0], length[1], length[2], precision, scheme); + if(!pool.has_function(key)) + throw std::runtime_error("GetPPOffDim failed to find a valid kernel"); - xPartialPassPlan = NodeFactory::CreateNodeFromScheme(CS_KERNEL_STOCKHAM_PP, this); - xPartialPassPlan->length = xPartialPassPlanData.length; - xPartialPassPlan->dimension = 1; - xPartialPassPlan->ppOffDim = 1; - xPartialPassPlan->allowInplace = true; - xPartialPassPlan->comments.push_back("partial-pass enabled for second dimension."); + auto kernel = pool.get_kernel(key, child_scheme); - // partial pass(es) along y + z col fft - NodeMetaData zPartialPassPlanData(this); - zPartialPassPlanData.length.push_back(length[2]); - // technically 1 < dimension < 2 for z node. - zPartialPassPlanData.dimension = 1; - zPartialPassPlanData.length.push_back(length[0]); - zPartialPassPlanData.length.push_back(length[1]); - for(size_t index = 3; index < length.size(); index++) + return kernel.pp_params.off_dim; +} + +void PP3DNode::BuildTree_internal(SchemeTreeVec& child_scheme_trees) +{ + ppOffDim = GetPPOffDim(); + + switch(ppOffDim) + { + case 0: // work along x will be split between y and z { - zPartialPassPlanData.length.push_back(length[index]); + // y col fft + partial pass along x + // partial pass along x + z col fft + throw std::runtime_error( + "PP3DNode::BuildTree_internal: partial-passes along x not currently supported"); + break; } - zPartialPassPlanData.outputLength = length; + case 1: // work along y will be split between x and z + { + // x row fft + partial pass along y + // partial pass along y + z col fft + + // Create node for x row fft + partial pass(es) along y + NodeMetaData xPartialPassPlanData(this); + xPartialPassPlanData.length.push_back(length[0]); + xPartialPassPlanData.length.push_back(length[1]); + // technically 1 < dimension < 2 for x node. + xPartialPassPlanData.dimension = 1; + xPartialPassPlanData.length.push_back(length[2]); + for(size_t index = 3; index < length.size(); index++) + { + xPartialPassPlanData.length.push_back(length[index]); + } + + // use explicit SBRR partial-pass kernel + std::unique_ptr xPartialPassPlan; + + xPartialPassPlan = NodeFactory::CreateNodeFromScheme(CS_KERNEL_STOCKHAM_PP, this); + xPartialPassPlan->length = xPartialPassPlanData.length; + xPartialPassPlan->dimension = 1; + xPartialPassPlan->ppOffDim = ppOffDim; + xPartialPassPlan->allowInplace = true; + xPartialPassPlan->comments.push_back("partial-pass enabled for second dimension."); + + // Create node for partial pass(es) along y + z col fft + NodeMetaData zPartialPassPlanData(this); + zPartialPassPlanData.length.push_back(length[2]); + // technically 1 < dimension < 2 for z node. + zPartialPassPlanData.dimension = 1; + zPartialPassPlanData.length.push_back(length[0]); + zPartialPassPlanData.length.push_back(length[1]); + for(size_t index = 3; index < length.size(); index++) + { + zPartialPassPlanData.length.push_back(length[index]); + } + zPartialPassPlanData.outputLength = length; + + // use explicit SBCC partial-pass kernel + std::unique_ptr zPartialPassPlan; - // use explicit (modified) SBCC kernel - std::unique_ptr zPartialPassPlan; + zPartialPassPlan = NodeFactory::CreateNodeFromScheme(CS_KERNEL_STOCKHAM_PP_BLOCK_CC, this); + zPartialPassPlan->length = zPartialPassPlanData.length; + zPartialPassPlan->dimension = 1; + zPartialPassPlan->ppOffDim = ppOffDim; + zPartialPassPlan->allowInplace = false; + zPartialPassPlan->comments.push_back("partial-pass enabled for second dimension."); - zPartialPassPlan = NodeFactory::CreateNodeFromScheme(CS_KERNEL_STOCKHAM_PP_BLOCK_CC, this); - zPartialPassPlan->length = zPartialPassPlanData.length; - zPartialPassPlan->dimension = 1; - zPartialPassPlan->ppOffDim = 1; - zPartialPassPlan->allowInplace = false; - zPartialPassPlan->comments.push_back("partial-pass enabled for second dimension."); + childNodes.emplace_back(std::move(xPartialPassPlan)); + childNodes.emplace_back(std::move(zPartialPassPlan)); - childNodes.emplace_back(std::move(xPartialPassPlan)); - childNodes.emplace_back(std::move(zPartialPassPlan)); + break; + } + case 2: // work along z will be split between x and y + { + // x row fft + partial pass along z + // partial pass along z + y col fft + throw std::runtime_error( + "PP3DNode::BuildTree_internal:: partial-passes along z not currently supported"); + break; + } + default: + throw std::runtime_error("PP3DNode::BuildTree_internal:: Unexpected ppOffDim"); + } } void PP3DNode::AssignParams_internal() { - // in partial pass case: - // xy plan is a x row 1D-FFT + plus partial pass(es) along y - // z plan is partial pass(es) along y + z col 1D-FFT. - auto& xyPlan = childNodes[0]; - auto& zPlan = childNodes[1]; + switch(ppOffDim) + { + case 0: // work along x will be split between y and z + { + // y col fft + partial pass along x + // partial pass along x + z col fft + throw std::runtime_error( + "PP3DNode::AssignParams_internal: partial-passes along x not currently supported"); + break; + } + case 1: // work along y will be split between x and z + { + // xy plan is a x row 1D-FFT + plus partial pass(es) along y + // z plan is partial pass(es) along y + z col 1D-FFT. + auto& xyPlan = childNodes[0]; + auto& zPlan = childNodes[1]; - xyPlan->inStride = inStride; - xyPlan->iDist = iDist; + xyPlan->inStride = inStride; + xyPlan->iDist = iDist; - xyPlan->outStride = outStride; - xyPlan->oDist = oDist; + xyPlan->outStride = outStride; + xyPlan->oDist = oDist; - xyPlan->AssignParams(); + xyPlan->AssignParams(); - zPlan->inStride.push_back(outStride[2]); - zPlan->inStride.push_back(outStride[0]); - zPlan->inStride.push_back(outStride[1]); - for(size_t index = 3; index < length.size(); index++) - zPlan->inStride.push_back(outStride[index]); + zPlan->inStride.push_back(outStride[2]); + zPlan->inStride.push_back(outStride[0]); + zPlan->inStride.push_back(outStride[1]); + for(size_t index = 3; index < length.size(); index++) + zPlan->inStride.push_back(outStride[index]); - zPlan->iDist = xyPlan->oDist; + zPlan->iDist = xyPlan->oDist; - zPlan->outStride = zPlan->inStride; - zPlan->oDist = zPlan->iDist; + zPlan->outStride = zPlan->inStride; + zPlan->oDist = zPlan->iDist; - zPlan->AssignParams(); + zPlan->AssignParams(); + break; + } + case 2: // work along z will be split between x and y + { + // x row fft + partial pass along z + // partial pass along z + y col fft + throw std::runtime_error( + "PP3DNode::AssignParams_internal: partial-passes along z not currently supported"); + break; + } + default: + throw std::runtime_error("PP3DNode::AssignParams_internal: Unexpected ppOffDim"); + } } // Leaf Node From bbbe036a450a3a5853ae62e23e544975152502a0 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Wed, 21 May 2025 13:54:21 -0600 Subject: [PATCH 39/69] - Clean up. - Improve comments. --- library/src/device/kernel-generator.py | 5 +-- library/src/include/tree_node.h | 21 ++++++++-- library/src/include/tree_node_1D.h | 30 ++------------ library/src/node_factory.cpp | 2 +- library/src/rtc_stockham_kernel.cpp | 1 - library/src/tree_node.cpp | 56 ++++++++++++++++++-------- library/src/tree_node_1D.cpp | 5 +-- library/src/tree_node_3D.cpp | 2 +- 8 files changed, 67 insertions(+), 55 deletions(-) diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index 23efe1487d1..d16d93cca71 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -1075,8 +1075,7 @@ def list_3d_partial_pass_kernels(): """Return list of to generate.""" pp_3d_kernels = [ - NS(length=[64,64,64], dims=[0, 2], factors=[[8, 8],[4, 4, 4]], factors_pp=[[4],[16]], threads_per_transform=[8, 8], workgroup_size=[64,128], direct_to_from_reg=[False, False]), - NS(length=[64,64,128], dims=[0, 2], factors=[[4, 4, 4],[8, 8, 2]], factors_pp=[[4], [16]], threads_per_transform=[8, 8], workgroup_size=[64,128], direct_to_from_reg=[False, False]), + NS(length=[64,64,64], dims=[0, 2], factors=[[8, 8],[4, 4, 4]], factors_pp=[[4],[16]], threads_per_transform=[8, 8], workgroup_size=[64,128], direct_to_from_reg=[False, False]) ] expanded = [] @@ -1253,7 +1252,7 @@ def generate_kernels(kernels, precisions, stockham_gen): if len(k.factors) == 1: half_lds = False - # Send data over to subprocess + # Send data over to subprocess if isinstance(k.workgroup_size, list): proc.stdin.write(" " + ','.join([str(f) for f in k.workgroup_size])) diff --git a/library/src/include/tree_node.h b/library/src/include/tree_node.h index 14e504693d8..3d7f1f1a5dc 100644 --- a/library/src/include/tree_node.h +++ b/library/src/include/tree_node.h @@ -351,9 +351,8 @@ class TreeNode // sbrc transpose type mutable SBRC_TRANSPOSE_TYPE sbrcTranstype = SBRC_TRANSPOSE_TYPE::NONE; - // specified kernel key from solution map. (if there is any) - std::unique_ptr specified_key; - + // specified kernel keys from solution map. (if there are any) + std::unique_ptr specified_key; std::unique_ptr specified_pp_key; // Tree structure: @@ -375,7 +374,10 @@ class TreeNode size_t lengthBlue = 0; size_t lengthBlueN = 0; - size_t ppOffDim = 0; + // Index of off-dimension in partial-pass nodes + size_t ppOffDim = 0; + + // Index of current dimension (full pass) in partial-pass nodes size_t ppCurrDim = 0; // @@ -489,6 +491,7 @@ class TreeNode return {}; } + // Check node scheme to see if partial pass is enabled bool isPartialPassEnabled() const { return (scheme == CS_3D_PP || scheme == CS_KERNEL_STOCKHAM_PP @@ -545,6 +548,8 @@ class TreeNode TreeNode* GetRealEvenAncestor(); bool IsRootPlanC2CTransform(); + // Return ancestor node of 'this' that is partial-pass, or + // nullptr if there is no such ancestor TreeNode* GetPartialPassAncestor() const; // Set length of transpose kernel node, since those are easily @@ -589,6 +594,10 @@ class TreeNode : FMKey(length[0], length[1], precision, scheme); } + // Partial pass parent nodes, e.g., CS_3D_PP, have + // two kernels associated with them. The key for + // querying the function pool is different from the + // the standard kernel key. virtual PPFMKey GetPPKernelsKey() const { if(specified_pp_key) @@ -605,6 +614,8 @@ class TreeNode pp_parent_node->scheme); } + // Query the function pool with the right key, + // and return the kernel linked to this node. virtual FFTKernel GetKernel() const { if(isPartialPassEnabled()) @@ -619,6 +630,8 @@ class TreeNode } } + // Check if the function pool has a kernel, + // querying it with the key linked to this node. virtual bool HasKernel() const { if(isPartialPassEnabled()) diff --git a/library/src/include/tree_node_1D.h b/library/src/include/tree_node_1D.h index 2cebaffee27..8f33524f2c5 100644 --- a/library/src/include/tree_node_1D.h +++ b/library/src/include/tree_node_1D.h @@ -105,16 +105,14 @@ class Stockham1DNode : public LeafNode /***************************************************** * CS_KERNEL_STOCKHAM_PP * *****************************************************/ -class StockhamPP1DNode : public LeafNode +class StockhamPP1DNode : public Stockham1DNode { friend class NodeFactory; protected: StockhamPP1DNode(TreeNode* p, ComputeScheme s) - : LeafNode(p, s) + : Stockham1DNode(p, s) { - externalKernel = true; - need_twd_table = true; } void SetupGridParam_internal(GridParam& gp) override; @@ -122,12 +120,6 @@ class StockhamPP1DNode : public LeafNode public: bool CreateDeviceResources() override; std::vector CollapsibleDims() override; - bool UseOutputLengthForPadding() override - { - // with embedded r2c, stockham nodes will change length, so the - // output length is different from the input length. - return ebtype != EmbeddedType::NONE; - } }; /***************************************************** @@ -176,33 +168,19 @@ class SBCCNode : public LeafNode /***************************************************** * SBCC Partial-Pass * *****************************************************/ -class SBCCPPNode : public LeafNode +class SBCCPPNode : public SBCCNode { friend class NodeFactory; protected: SBCCPPNode(TreeNode* p, ComputeScheme s) - : LeafNode(p, s) + : SBCCNode(p, s) { - externalKernel = true; - need_twd_table = true; } void SetupGridParam_internal(GridParam& gp) override; - // InitIntrinsicMode is the first step to check if eligible for buffer load/store - void InitIntrinsicMode(); - public: - // reads + writes are along columns so both may benefit from padding - bool PaddingBenefitsInput() override - { - return true; - } - bool PaddingBenefitsOutput() override - { - return true; - } std::vector CollapsibleDims() override; }; diff --git a/library/src/node_factory.cpp b/library/src/node_factory.cpp index 1dc18d31258..d46bbafb8ed 100644 --- a/library/src/node_factory.cpp +++ b/library/src/node_factory.cpp @@ -1022,7 +1022,7 @@ bool NodeFactory::use_CS_3D_PP(const function_pool& pool, NodeMetaData& nodeData CS_3D_PP))) return false; - // Partial pass is currently restricted large enough batch sizes, + // Partial pass is currently restricted to large enough batch sizes, // unite stride, interleaved FFTs. bool batchCondition = (nodeData.batch >= 5); diff --git a/library/src/rtc_stockham_kernel.cpp b/library/src/rtc_stockham_kernel.cpp index 50524ce432d..2f54ecb2a88 100644 --- a/library/src/rtc_stockham_kernel.cpp +++ b/library/src/rtc_stockham_kernel.cpp @@ -80,7 +80,6 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& precisions, static_cast(kernel->workgroup_size), PrintScheme(node.scheme)); - specs->threads_per_transform = kernel->threads_per_transform[0]; specs->half_lds = kernel->half_lds; specs->direct_to_from_reg = kernel->direct_to_from_reg; diff --git a/library/src/tree_node.cpp b/library/src/tree_node.cpp index 2dc6f8e30cc..8a9fa118cd2 100644 --- a/library/src/tree_node.cpp +++ b/library/src/tree_node.cpp @@ -114,25 +114,49 @@ void LeafNode::GetKernelPartialPassFactors() kernelFactorsPP = std::vector(kernel.pp_params.factors_off_dim.begin(), kernel.pp_params.factors_off_dim.end()); - if(scheme == CS_KERNEL_STOCKHAM_PP) + switch(ppOffDim) { - std::stringstream msg; - msg << "work in the off-dimension:" << std::endl; - msg << "\t radix: ["; - for(const auto factor : kernelFactorsPP) - msg << " " << factor; - msg << " ] pass(es) + Hadamard product with twiddle factors. \n"; - comments.push_back(msg.str()); + case 0: // work along x will be split between y and z + { + throw std::runtime_error( + "GetKernelPartialPassFactors: partial-passes along x not currently supported"); + break; } - if(scheme == CS_KERNEL_STOCKHAM_PP_BLOCK_CC) + case 1: // work along y will be split between x and z { - std::stringstream msg; - msg << "work in the off-dimension:" << std::endl; - msg << "\t local data transposition + radix: ["; - for(const auto factor : kernelFactorsPP) - msg << " " << factor; - msg << " ] pass(es). \n"; - comments.push_back(msg.str()); + if(scheme == CS_KERNEL_STOCKHAM_PP) + { + std::stringstream msg; + msg << "work in the off-dimension:" << std::endl; + msg << "\t radix: ["; + for(const auto factor : kernelFactorsPP) + msg << " " << factor; + msg << " ] pass(es) + Hadamard product with twiddle factors. \n"; + comments.push_back(msg.str()); + } + if(scheme == CS_KERNEL_STOCKHAM_PP_BLOCK_CC) + { + std::stringstream msg; + msg << "work in the off-dimension:" << std::endl; + msg << "\t local data transposition + radix: ["; + for(const auto factor : kernelFactorsPP) + msg << " " << factor; + msg << " ] pass(es). \n"; + comments.push_back(msg.str()); + } + + break; + } + case 2: // work along z will be split between x and y + { + // x row fft + partial pass along z + // partial pass along z + y col fft + throw std::runtime_error( + "GetKernelPartialPassFactors: partial-passes along z not currently supported"); + break; + } + default: + throw std::runtime_error("Invalid off-dimension for partial pass"); } } diff --git a/library/src/tree_node_1D.cpp b/library/src/tree_node_1D.cpp index 0187acd0e7e..49917cf4821 100644 --- a/library/src/tree_node_1D.cpp +++ b/library/src/tree_node_1D.cpp @@ -1165,9 +1165,8 @@ void SBCCPPNode::SetupGridParam_internal(GridParam& gp) gp.b_x *= product(length.begin() + 2, length.end()) * batch; gp.wgs_x = wgs; - // Grid arrangement is different for partial - // pass SBCC kernels for improved global memory - // access patterns. + // Grid arrangement is different than regular SBCC + // for improved global memory access patterns. auto factor = *std::max_element(kernelFactorsPP.begin(), kernelFactorsPP.end()); gp.b_x /= factor; diff --git a/library/src/tree_node_3D.cpp b/library/src/tree_node_3D.cpp index aaa62b60035..514e28d46f2 100644 --- a/library/src/tree_node_3D.cpp +++ b/library/src/tree_node_3D.cpp @@ -770,7 +770,7 @@ void PP3DNode::BuildTree_internal(SchemeTreeVec& child_scheme_trees) // x row fft + partial pass along z // partial pass along z + y col fft throw std::runtime_error( - "PP3DNode::BuildTree_internal:: partial-passes along z not currently supported"); + "PP3DNode::BuildTree_internal: partial-passes along z not currently supported"); break; } default: From 6c0dfb2e52e2364463a6f1535b019b435a165e59 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Wed, 21 May 2025 14:38:27 -0600 Subject: [PATCH 40/69] - Further partial-pass kernel config validation. --- library/src/device/generator/stockham_gen.cpp | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/library/src/device/generator/stockham_gen.cpp b/library/src/device/generator/stockham_gen.cpp index 8db742dbc35..e898bdc0b50 100644 --- a/library/src/device/generator/stockham_gen.cpp +++ b/library/src/device/generator/stockham_gen.cpp @@ -504,6 +504,37 @@ unsigned int get_pp_off_dim(const std::vector& dims) return off_dim; } +void validate_pp_length(const StockhamPartialPassParams& pp_params, + const std::vector& factors) + +{ + unsigned int length_curr + = std::accumulate(factors.begin(), factors.end(), 1, std::multiplies()); + + auto curr_dim = pp_params.current_dim; + if(length_curr != pp_params.parent_length[curr_dim]) + throw std::runtime_error("Invalid partial-pass kernel length configuration"); +} + +void validate_pp_off_dim_length(const StockhamPartialPassParams& pp_params_1, + const StockhamPartialPassParams& pp_params_2) +{ + auto off_factors_all = pp_params_1.factors_off_dim; + off_factors_all.insert(off_factors_all.end(), + pp_params_2.factors_off_dim.begin(), + pp_params_2.factors_off_dim.end()); + + unsigned int length_off_dim = std::accumulate( + off_factors_all.begin(), off_factors_all.end(), 1, std::multiplies()); + + if(pp_params_1.parent_length[pp_params_1.off_dim] + != pp_params_2.parent_length[pp_params_2.off_dim]) + throw std::runtime_error("Invalid partial-pass kernel off-dimension length"); + + if(length_off_dim != pp_params_1.parent_length[pp_params_1.off_dim]) + throw std::runtime_error("Invalid partial-pass kernel off-dimension length"); +} + static size_t max_bytes_per_element(const std::vector& precisions) { // generate for the maximum element size in the available @@ -629,6 +660,10 @@ int main() StockhamPartialPassParams pp_params_1(parent_length, dims[0], off_dim, pp_factors1); StockhamPartialPassParams pp_params_2(parent_length, dims[1], off_dim, pp_factors2); + validate_pp_length(pp_params_1, factors1); + validate_pp_length(pp_params_2, factors2); + validate_pp_off_dim_length(pp_params_1, pp_params_2); + stockham_partial_pass_variants( kernel_name, specs1, specs2, pp_params_1, pp_params_2, std::cout); } From aae9b331463877b28785bbee99ff3e7ffe45b826 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Tue, 27 May 2025 17:10:36 -0600 Subject: [PATCH 41/69] - Add further validation for kernel-generator.py partial-pass data. - Remove redundant code from partial-pass SBCC generator's constructor. - Improve comments. --- library/src/device/generator/stockham_gen.cpp | 65 +++++++++++++++++-- .../src/device/generator/stockham_pp_gen_cc.h | 3 - library/src/device/kernel-generator.py | 2 +- 3 files changed, 60 insertions(+), 10 deletions(-) diff --git a/library/src/device/generator/stockham_gen.cpp b/library/src/device/generator/stockham_gen.cpp index e898bdc0b50..65b9c5abc9b 100644 --- a/library/src/device/generator/stockham_gen.cpp +++ b/library/src/device/generator/stockham_gen.cpp @@ -289,20 +289,21 @@ void stockham_partial_pass_variants(const std::string& kernel_name params_2.off_dim, launchers); } - // SBRR_PP + SBCC_PP + // SBCC_PP + SBCC_PP else if((params_1.current_dim == 1 && params_2.current_dim == 2) || (params_1.current_dim == 2 && params_2.current_dim == 1)) { - throw std::runtime_error("CS_KERNEL_STOCKHAM_PP + CS_KERNEL_STOCKHAM_PP_BLOCK_CC not " + throw std::runtime_error("CS_KERNEL_STOCKHAM_PP_BLOCK_CC + " + "CS_KERNEL_STOCKHAM_PP_BLOCK_CC with x as off-dimension not " "yet implemented for CS_3D_PP"); } - // SBRR_PP + SBRR_PP + // SBRR_PP + SBCC_PP else if((params_1.current_dim == 0 && params_2.current_dim == 1) || (params_1.current_dim == 1 && params_2.current_dim == 0)) { - throw std::runtime_error( - "CS_KERNEL_STOCKHAM_PP_BLOCK_CC + CS_KERNEL_STOCKHAM_PP_BLOCK_CC not yet " - "implemented for CS_3D_PP"); + throw std::runtime_error("CS_KERNEL_STOCKHAM_PP + CS_KERNEL_STOCKHAM_PP_BLOCK_CC with " + "with z as off-dimension not yet " + "implemented for CS_3D_PP"); } else { @@ -535,6 +536,57 @@ void validate_pp_off_dim_length(const StockhamPartialPassParams& pp_params_1, throw std::runtime_error("Invalid partial-pass kernel off-dimension length"); } +void validate_pp_grid_params(const StockhamPartialPassParams& params_1, + const StockhamPartialPassParams& params_2, + const StockhamGeneratorSpecs& specs_1, + const StockhamGeneratorSpecs& specs_2) +{ + if(specs_1.scheme == "CS_3D_PP" && specs_2.scheme == "CS_3D_PP") + { + // SBRR_PP + SBCC_PP + if((params_1.current_dim == 0 && params_2.current_dim == 2) + || (params_1.current_dim == 2 && params_2.current_dim == 0)) + { + // SBRR needs tpb to be at least max(pp_factors), + // so that it has the required off-dim data in LDS + // to perform partial pass + auto tpb_sbrr = (params_1.current_dim == 0 && params_2.current_dim == 2) + ? specs_1.workgroup_size / specs_1.threads_per_transform + : specs_2.workgroup_size / specs_2.threads_per_transform; + if(tpb_sbrr < *std::max_element(params_1.factors_off_dim.begin(), + params_1.factors_off_dim.end())) + { + throw std::runtime_error("CS_KERNEL_STOCKHAM_PP requires transform-per-block " + "to be at least max(pp_factors)"); + } + } + // SBCC_PP + SBCC_PP + else if((params_1.current_dim == 1 && params_2.current_dim == 2) + || (params_1.current_dim == 2 && params_2.current_dim == 1)) + { + throw std::runtime_error("CS_KERNEL_STOCKHAM_PP_BLOCK_CC + " + "CS_KERNEL_STOCKHAM_PP_BLOCK_CC with x as off-dimension not " + "yet implemented for CS_3D_PP"); + } + // SBRR_PP + SBCC_PP + else if((params_1.current_dim == 0 && params_2.current_dim == 1) + || (params_1.current_dim == 1 && params_2.current_dim == 0)) + { + throw std::runtime_error("CS_KERNEL_STOCKHAM_PP + CS_KERNEL_STOCKHAM_PP_BLOCK_CC with " + "with z as off-dimension not yet " + "implemented for CS_3D_PP"); + } + else + { + throw std::runtime_error("invalid dimensions for CS_3D_PP"); + } + } + else + { + throw std::runtime_error("unhandled scheme"); + } +} + static size_t max_bytes_per_element(const std::vector& precisions) { // generate for the maximum element size in the available @@ -663,6 +715,7 @@ int main() validate_pp_length(pp_params_1, factors1); validate_pp_length(pp_params_2, factors2); validate_pp_off_dim_length(pp_params_1, pp_params_2); + validate_pp_grid_params(pp_params_1, pp_params_2, specs1, specs2); stockham_partial_pass_variants( kernel_name, specs1, specs2, pp_params_1, pp_params_2, std::cout); diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index 188a16e92ca..1d765dd075f 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -53,9 +53,6 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC , params(params) { - large_twiddle_steps.decl_default = 3; - large_twiddle_base.decl_default = 8; - factors_pp = params.factors_off_dim; max_factor_pp = *std::max_element(factors_pp.begin(), factors_pp.end()); diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index d16d93cca71..0fa4c4b53b7 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -1075,7 +1075,7 @@ def list_3d_partial_pass_kernels(): """Return list of to generate.""" pp_3d_kernels = [ - NS(length=[64,64,64], dims=[0, 2], factors=[[8, 8],[4, 4, 4]], factors_pp=[[4],[16]], threads_per_transform=[8, 8], workgroup_size=[64,128], direct_to_from_reg=[False, False]) + NS(length=[64,64,64], dims=[0, 2], factors=[[8, 8],[4, 4, 4]], factors_pp=[[4],[16]], threads_per_transform=[8, 8], workgroup_size=[64,128], direct_to_from_reg=[False, False]) ] expanded = [] From 20f1d3c5862a39dc5363feafc308adabd676b1a8 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Wed, 28 May 2025 14:41:50 -0600 Subject: [PATCH 42/69] - Remove no longer needed field from KernelConfig. - Minor cosmetic changes. --- library/src/device/kernel-generator.py | 6 ++--- library/src/include/function_map_key.h | 32 ++++---------------------- library/src/rtc_stockham_gen.cpp | 2 ++ library/src/tree_node.cpp | 4 ++-- library/src/tree_node_3D.cpp | 11 ++++----- 5 files changed, 17 insertions(+), 38 deletions(-) diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index 0fa4c4b53b7..50b61a423f7 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -1072,10 +1072,10 @@ def list_large_kernels(): return sbcc_kernels + sbcr_kernels + sbrc_kernels def list_3d_partial_pass_kernels(): - """Return list of to generate.""" + """Return list of partial-pass 3D kernels to generate.""" pp_3d_kernels = [ - NS(length=[64,64,64], dims=[0, 2], factors=[[8, 8],[4, 4, 4]], factors_pp=[[4],[16]], threads_per_transform=[8, 8], workgroup_size=[64,128], direct_to_from_reg=[False, False]) + NS(length=[64,64,64], dims=[0, 2], factors=[[8, 8],[4, 4, 4]], factors_pp=[[4],[16]], threads_per_transform=[8, 8], workgroup_size=[64,128], direct_to_from_reg=[False, False]), ] expanded = [] @@ -1336,7 +1336,7 @@ def cli(): # kernels = [] - # move 2d out from all, no need to iterate the 2d-kernels for non-2d patterns + # move 2d out from all, no need to iterate the 2d-kernels for non-2d patterns kernels_2d = list_2d_kernels() kernel_3d_pp = list_3d_partial_pass_kernels() all_kernels = list_small_kernels() + list_large_kernels() diff --git a/library/src/include/function_map_key.h b/library/src/include/function_map_key.h index 1ccf5bcc173..9774e8555b1 100644 --- a/library/src/include/function_map_key.h +++ b/library/src/include/function_map_key.h @@ -41,7 +41,6 @@ struct KernelConfig int workgroup_size = 0; std::array threads_per_transform = {0, 0}; std::vector factors = {0}; - std::vector factors_pp = {0}; // above data is what we can tune // // the followings are other information of this kernel. @@ -70,7 +69,6 @@ struct KernelConfig KernelConfig(bool use_3steps, std::vector&& factors, - std::vector&& factors_pp, int tpb, int wgs, std::array&& tpt, @@ -91,7 +89,6 @@ struct KernelConfig , workgroup_size(wgs) , threads_per_transform(tpt) , factors(factors) - , factors_pp(factors_pp) , ebType(ebType) , direction(direction) , static_dim(static_dim) @@ -112,8 +109,7 @@ struct KernelConfig transforms_per_block, workgroup_size, threads_per_transform, - factors, - factors_pp) + factors) == std::tie(rhs.use_3steps_large_twd, rhs.half_lds, rhs.direct_to_from_reg, @@ -121,8 +117,7 @@ struct KernelConfig rhs.transforms_per_block, rhs.workgroup_size, rhs.threads_per_transform, - rhs.factors, - rhs.factors_pp); + rhs.factors); } bool operator<(const KernelConfig& rhs) const @@ -134,8 +129,7 @@ struct KernelConfig transforms_per_block, workgroup_size, threads_per_transform, - factors, - factors_pp) + factors) < std::tie(rhs.use_3steps_large_twd, rhs.half_lds, rhs.direct_to_from_reg, @@ -143,8 +137,7 @@ struct KernelConfig rhs.transforms_per_block, rhs.workgroup_size, rhs.threads_per_transform, - rhs.factors, - rhs.factors_pp); + rhs.factors); } std::string Print() const @@ -167,15 +160,6 @@ struct KernelConfig } ss << "]"; - ss << ", factors_pp: ["; - COMMA = ""; - for(auto factor : factors_pp) - { - ss << COMMA << factor; - COMMA = ", "; - } - ss << "]"; - ss << "}"; return ss.str(); @@ -210,12 +194,8 @@ namespace std // which means the maximal factorization pass is 8 auto factors_max_len = config.factors; factors_max_len.resize(TWIDDLES_MAX_RADICES); - for(auto& v : factors_max_len) - h ^= std::hash{}(v); - auto factors_pp_max_len = config.factors_pp; - factors_pp_max_len.resize(TWIDDLES_MAX_RADICES); - for(auto& v : factors_pp_max_len) + for(auto& v : factors_max_len) h ^= std::hash{}(v); return h; } @@ -240,7 +220,6 @@ struct ToString str += FieldDescriptor().describe("wgs", value.workgroup_size) + ","; str += VectorFieldDescriptor().describe("tpt", tpt) + ","; str += VectorFieldDescriptor().describe("factors", value.factors) + ","; - str += VectorFieldDescriptor().describe("factors_pp", value.factors_pp) + ","; // below: not tunable data, for AOT cache str += FieldDescriptor().describe("ebtype", PrintEBType(value.ebType)) + ","; str += FieldDescriptor().describe("direction", value.direction) + ","; @@ -275,7 +254,6 @@ struct FromString FieldParser().parse("wgs", ret.workgroup_size, current); VectorFieldParser().parse("tpt", tpt, current); VectorFieldParser().parse("factors", ret.factors, current); - VectorFieldParser().parse("factors_pp", ret.factors_pp, current); if(DescriptorFormatVersion::UsingVersion < 2) { diff --git a/library/src/rtc_stockham_gen.cpp b/library/src/rtc_stockham_gen.cpp index 945d6795e83..947ffd8ce76 100644 --- a/library/src/rtc_stockham_gen.cpp +++ b/library/src/rtc_stockham_gen.cpp @@ -381,6 +381,8 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, device = std::make_unique(kernel_pp->generate_device_function()); break; } + default: + throw std::runtime_error("unhandled partial pass type"); } if(fuseBluestein) diff --git a/library/src/tree_node.cpp b/library/src/tree_node.cpp index 8fd7d41e5c5..ab9a7f319b1 100644 --- a/library/src/tree_node.cpp +++ b/library/src/tree_node.cpp @@ -212,7 +212,8 @@ bool LeafNode::KernelCheck(std::vector& kernel_keys) // get the final key and check if we have the kernel. // Note that the check is trivial if we are using "specified_key" // since we definitly have the kernel, but not trivial if it's the auto-gen key - HasKernel(); + if(!HasKernel()) + return false; GetKernelFactors(); @@ -310,7 +311,6 @@ void LeafNode::SetupGridParam(GridParam& gp) double_half_lds_alloc = true; } - // no support for half-lds in partial-pass mode auto kernel = pool.get_kernel(key); if(kernel.half_lds && (!double_half_lds_alloc)) gp.lds_bytes /= 2; diff --git a/library/src/tree_node_3D.cpp b/library/src/tree_node_3D.cpp index 13e7d78db17..31ea93bec5b 100644 --- a/library/src/tree_node_3D.cpp +++ b/library/src/tree_node_3D.cpp @@ -681,16 +681,15 @@ void RC3DNode::AssignParams_internal() *****************************************************/ size_t PP3DNode::GetPPOffDim() const { - // CS_3D_PP will have two corresponding kernels in - // the function pool, both will have the same off-dim - // value, and at least of one them must be an SBCC PP - auto child_scheme = CS_KERNEL_STOCKHAM_PP_BLOCK_CC; - auto key = PPFMKey(length[0], length[1], length[2], precision, scheme); if(!pool.has_function(key)) throw std::runtime_error("GetPPOffDim failed to find a valid kernel"); - auto kernel = pool.get_kernel(key, child_scheme); + // CS_3D_PP will have two corresponding kernels in + // the function pool, both will have the same off-dim + // value, and at least of one them must be an SBCC PP + auto child_scheme = CS_KERNEL_STOCKHAM_PP_BLOCK_CC; + auto kernel = pool.get_kernel(key, child_scheme); return kernel.pp_params.off_dim; } From 82d8e3a47bf1349db74208ad805499bea2e217a1 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Mon, 2 Jun 2025 15:18:15 -0600 Subject: [PATCH 43/69] - Refactor kernel configuration lists --- library/src/device/kernel-generator.py | 795 +----------------- .../kernels/configs/config_2d_single.py | 162 ++++ .../src/device/kernels/configs/config_lds.py | 21 + .../device/kernels/configs/config_pp_3d.py | 25 + .../src/device/kernels/configs/config_sbcc.py | 94 +++ .../src/device/kernels/configs/config_sbcr.py | 42 + .../src/device/kernels/configs/config_sbrc.py | 53 ++ .../src/device/kernels/configs/config_sbrr.py | 509 +++++++++++ 8 files changed, 938 insertions(+), 763 deletions(-) create mode 100644 library/src/device/kernels/configs/config_2d_single.py create mode 100644 library/src/device/kernels/configs/config_lds.py create mode 100644 library/src/device/kernels/configs/config_pp_3d.py create mode 100644 library/src/device/kernels/configs/config_sbcc.py create mode 100644 library/src/device/kernels/configs/config_sbcr.py create mode 100644 library/src/device/kernels/configs/config_sbrc.py create mode 100644 library/src/device/kernels/configs/config_sbrr.py diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index 27ad90fea79..11e9647e646 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -33,6 +33,13 @@ import json import threading +import kernels.configs.config_sbrc as config_sbrc +import kernels.configs.config_sbrr as config_sbrr +import kernels.configs.config_sbcc as config_sbcc +import kernels.configs.config_sbcr as config_sbcr +import kernels.configs.config_2d_single as config_2d_single +import kernels.configs.config_pp_3d as config_pp_3d + from copy import deepcopy from pathlib import Path from types import SimpleNamespace as NS @@ -241,729 +248,23 @@ def kernel_name(ns): return f'rocfft_len{length}{postfix}' -LDS_160k = 160 * 1024 - - -# yapf: disable def list_small_kernels(): """Return list of small kernels to generate.""" - - # Note: Default half_lds is True and default direct_to_from_reg is True as well. - # TODO- Currently, if half_lds is True, then direct_to_from_reg must be True - # but if half_lds is False, direct_to_from_reg can be either (still can be True). - kernels1d = [ - NS(length= 1, workgroup_size= 64, threads_per_transform= 1, factors=(1,), runtime_compile=True), - NS(length= 2, workgroup_size= 64, threads_per_transform= 1, factors=(2,), runtime_compile=True), - NS(length= 3, workgroup_size= 64, threads_per_transform= 1, factors=(3,), runtime_compile=True), - NS(length= 4, workgroup_size=128, threads_per_transform= 1, factors=(4,), runtime_compile=True), - NS(length= 5, workgroup_size=128, threads_per_transform= 1, factors=(5,), runtime_compile=True), - NS(length= 6, workgroup_size=128, threads_per_transform= 1, factors=(6,), runtime_compile=True), - NS(length= 7, workgroup_size= 64, threads_per_transform= 1, factors=(7,), runtime_compile=True), - NS(length= 8, workgroup_size= 64, threads_per_transform= 4, factors=(4, 2), runtime_compile=True), - NS(length= 9, workgroup_size= 64, threads_per_transform= 3, factors=(3, 3), runtime_compile=True), - NS(length= 10, workgroup_size= 64, threads_per_transform= 1, factors=(10,), runtime_compile=True), - NS(length= 11, workgroup_size=128, threads_per_transform= 1, factors=(11,), runtime_compile=True), - NS(length= 12, workgroup_size=128, threads_per_transform= 6, factors=(6, 2), runtime_compile=True), - NS(length= 13, workgroup_size= 64, threads_per_transform= 1, factors=(13,), runtime_compile=True), - NS(length= 14, workgroup_size=128, threads_per_transform= 7, factors=(7, 2), runtime_compile=True), - NS(length= 15, workgroup_size=128, threads_per_transform= 5, factors=(3, 5), runtime_compile=True), - NS(length= 16, workgroup_size= 64, threads_per_transform= 4, factors=(4, 4), runtime_compile=True), - NS(length= 17, workgroup_size=256, threads_per_transform= 1, factors=(17,), runtime_compile=True), - NS(length= 18, workgroup_size= 64, threads_per_transform= 6, factors=(3, 6), runtime_compile=True), - NS(length= 20, workgroup_size=256, threads_per_transform= 10, factors=(5, 4), runtime_compile=True), - NS(length= 21, workgroup_size=128, threads_per_transform= 7, factors=(3, 7), runtime_compile=True), - NS(length= 22, workgroup_size= 64, threads_per_transform= 2, factors=(11, 2), runtime_compile=True), - NS(length= 24, workgroup_size=256, threads_per_transform= 8, factors=(8, 3), runtime_compile=True), - NS(length= 25, workgroup_size=256, threads_per_transform= 5, factors=(5, 5), runtime_compile=True), - NS(length= 26, workgroup_size= 64, threads_per_transform= 2, factors=(13, 2), runtime_compile=True), - NS(length= 27, workgroup_size=256, threads_per_transform= 9, factors=(3, 3, 3), runtime_compile=True), - NS(length= 28, workgroup_size= 64, threads_per_transform= 4, factors=(7, 4), runtime_compile=True), - NS(length= 30, workgroup_size=128, threads_per_transform= 10, factors=(10, 3), runtime_compile=True), - NS(length= 32, workgroup_size=128, threads_per_transform= 16, factors=(8, 4)), - NS(length= 33, workgroup_size=256, threads_per_transform= 11, factors=(11, 3), runtime_compile=True), - NS(length= 34, workgroup_size=256, threads_per_transform= 17, factors=(17, 2), runtime_compile=True), - NS(length= 35, workgroup_size=256, threads_per_transform= 7, factors=(5, 7), half_lds=False, runtime_compile=True), - NS(length= 36, workgroup_size= 64, threads_per_transform= 6, factors=(6, 6)), - NS(length= 39, workgroup_size=256, threads_per_transform= 13, factors=(13, 3), runtime_compile=True), - NS(length= 40, workgroup_size=128, threads_per_transform= 10, factors=(10, 4)), - NS(length= 42, workgroup_size=256, threads_per_transform= 7, factors=(7, 6)), - NS(length= 44, workgroup_size= 64, threads_per_transform= 4, factors=(11, 4)), - NS(length= 45, workgroup_size=128, threads_per_transform= 15, factors=(5, 3, 3)), - NS(length= 48, workgroup_size= 64, threads_per_transform= 16, factors=(4, 3, 4)), - NS(length= 49, workgroup_size= 64, threads_per_transform= 7, factors=(7, 7)), - NS(length= 50, workgroup_size=256, threads_per_transform= 10, factors=(10, 5)), - NS(length= 51, workgroup_size=256, threads_per_transform= 17, factors=(17, 3), runtime_compile=True), - NS(length= 52, workgroup_size= 64, threads_per_transform= 4, factors=(13, 4)), - NS(length= 54, workgroup_size=256, threads_per_transform= 18, factors=(6, 3, 3)), - NS(length= 55, workgroup_size=256, threads_per_transform= 11, factors=(5, 11), half_lds=False, runtime_compile=True), - NS(length= 56, workgroup_size=128, threads_per_transform= 8, factors=(7, 8)), - NS(length= 60, workgroup_size= 64, threads_per_transform= 10, factors=(6, 10)), - NS(length= 63, workgroup_size=256, threads_per_transform= 21, factors=(3, 3, 7), half_lds=False, runtime_compile=True), - NS(length= 64, workgroup_size= 64, threads_per_transform= 16, factors=(4, 4, 4), half_lds=False, direct_to_from_reg=True), - NS(length= 65, workgroup_size=256, threads_per_transform= 13, factors=(13, 5), runtime_compile=True), - NS(length= 66, workgroup_size=256, threads_per_transform= 11, factors=(6, 11), half_lds=False, runtime_compile=True), - NS(length= 68, workgroup_size=256, threads_per_transform= 17, factors=(17, 4), runtime_compile=True), - NS(length= 70, workgroup_size=256, threads_per_transform= 14, factors=(2, 5, 7), runtime_compile=True), - NS(length= 72, workgroup_size= 64, threads_per_transform= 9, factors=(8, 3, 3)), - NS(length= 75, workgroup_size=256, threads_per_transform= 25, factors=(5, 5, 3)), - NS(length= 77, workgroup_size=256, threads_per_transform= 11, factors=(7, 11), runtime_compile=True), - NS(length= 78, workgroup_size=256, threads_per_transform= 13, factors=(6, 13), half_lds=False, runtime_compile=True), - NS(length= 80, workgroup_size= 64, threads_per_transform= 10, factors=(5, 2, 8)), - NS(length= 81, workgroup_size=128, threads_per_transform= 27, factors=(3, 3, 3, 3)), - NS(length= 84, workgroup_size=128, threads_per_transform= 12, factors=(7, 2, 6)), - NS(length= 85, workgroup_size=256, threads_per_transform= 17, factors=(17, 5), runtime_compile=True), - NS(length= 88, workgroup_size=128, threads_per_transform= 11, factors=(11, 8)), - NS(length= 90, workgroup_size= 64, threads_per_transform= 9, factors=(3, 3, 10)), - NS(length= 91, workgroup_size=256, threads_per_transform= 13, factors=(7, 13), half_lds=False, runtime_compile=True), - NS(length= 96, workgroup_size=128, threads_per_transform= 16, factors=(6, 16), half_lds=False, direct_to_from_reg=False), - NS(length= 98, workgroup_size= 256, threads_per_transform= 14, factors=(2, 7, 7), half_lds=False, runtime_compile=True), - NS(length= 99, workgroup_size= 256, threads_per_transform= 11, factors=(3, 3, 11), half_lds=False, runtime_compile=True), - NS(length= 100, workgroup_size= 64, threads_per_transform= 10, factors=(10, 10)), - NS(length= 102, workgroup_size=128, threads_per_transform= 17, factors=(17, 6), runtime_compile=True), - NS(length= 104, workgroup_size= 64, threads_per_transform= 8, factors=(13, 8)), - NS(length= 105, workgroup_size=256, threads_per_transform= 21, factors=(7, 3, 5), half_lds=False, runtime_compile=True), - NS(length= 108, workgroup_size=256, threads_per_transform= 36, factors=(6, 6, 3)), - NS(length= 110, workgroup_size=256, threads_per_transform= 11, factors=(2, 5, 11), half_lds=False, runtime_compile=True), - NS(length= 112, workgroup_size=256, threads_per_transform= 16, factors=(16, 7), half_lds=False, direct_to_from_reg=False), - NS(length= 117, workgroup_size= 64, threads_per_transform= 13, factors=(13, 9), runtime_compile=True), - NS(length= 119, workgroup_size=256, threads_per_transform= 17, factors=(17, 7), runtime_compile=True), - NS(length= 120, workgroup_size= 64, threads_per_transform= 12, factors=(6, 10, 2), runtime_compile=True), - NS(length= 121, workgroup_size=128, threads_per_transform= 11, factors=(11, 11), runtime_compile=True), - NS(length= 125, workgroup_size=256, threads_per_transform= 25, factors=(5, 5, 5), half_lds=False, direct_to_from_reg=False), - NS(length= 126, workgroup_size= 256, threads_per_transform= 42, factors=(6, 7, 3), half_lds=False, runtime_compile=True), - NS(length= 128, workgroup_size=256, threads_per_transform= 16, factors=(16, 8)), - NS(length= 130, workgroup_size= 64, threads_per_transform= 13, factors=(13, 10), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 132, workgroup_size=128, threads_per_transform= 22, factors=(11, 6, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 135, workgroup_size=128, threads_per_transform= 9, factors=(5, 3, 3, 3), runtime_compile=True), - NS(length= 136, workgroup_size=128, threads_per_transform=17, factors=(17, 8), runtime_compile=True), - NS(length= 140, workgroup_size= 64, threads_per_transform= 28, factors=(7, 5, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 143, workgroup_size=256, threads_per_transform= 13, factors=(13, 11), half_lds=False, runtime_compile=True), - NS(length= 144, workgroup_size=128, threads_per_transform= 12, factors=(6, 6, 4)), - NS(length= 147, workgroup_size= 64, threads_per_transform= 21, factors=(7, 7, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 150, workgroup_size= 64, threads_per_transform= 5, factors=(10, 5, 3), runtime_compile=True), - NS(length= 153, workgroup_size=128, threads_per_transform= 17, factors=(17, 9), runtime_compile=True), - NS(length= 154, workgroup_size=128, threads_per_transform= 22, factors=(11, 7, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 156, workgroup_size= 128, threads_per_transform=13, factors=(3, 4, 13), half_lds=False, runtime_compile=True), - NS(length= 160, workgroup_size=256, threads_per_transform= 16, factors=(16, 10)), - NS(length= 162, workgroup_size=256, threads_per_transform= 27, factors=(6, 3, 3, 3), runtime_compile=True), - NS(length= 165, workgroup_size= 64, threads_per_transform= 11, factors=(11, 5, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 168, workgroup_size=256, threads_per_transform= 56, factors=(8, 7, 3), half_lds=False, direct_to_from_reg=False), - NS(length= 169, workgroup_size=256, threads_per_transform= 13, factors=(13, 13), runtime_compile=True), - NS(length= 170, workgroup_size=128, threads_per_transform= 17, factors=(17, 10), runtime_compile=True), - NS(length= 175, workgroup_size=256, threads_per_transform= 35, factors=(5, 5, 7), half_lds=False, runtime_compile=True), - NS(length= 176, workgroup_size= 64, threads_per_transform= 16, factors=(11, 16), runtime_compile=True), - NS(length= 180, workgroup_size=256, threads_per_transform= 60, factors=(10, 6, 3), half_lds=False, direct_to_from_reg=False), - NS(length= 182, workgroup_size= 64, threads_per_transform= 13, factors=(13, 2, 7), half_lds=False, runtime_compile=True), - NS(length= 187, workgroup_size=128, threads_per_transform= 17, factors=(17, 11), runtime_compile=True), - NS(length= 189, workgroup_size= 64, threads_per_transform= 21, factors=(7, 3, 3, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 192, workgroup_size=128, threads_per_transform= 16, factors=(6, 4, 4, 2)), - NS(length= 195, workgroup_size= 64, threads_per_transform= 13, factors=(13, 5, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 196, workgroup_size= 64, threads_per_transform= 28, factors=(4, 7, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 198, workgroup_size=128, threads_per_transform= 22, factors=(11, 2, 9), half_lds=False, runtime_compile=True), - NS(length= 200, workgroup_size= 64, threads_per_transform= 20, factors=(10, 10, 2)), - NS(length= 204, workgroup_size=128, threads_per_transform= 17, factors=(17, 4, 3), runtime_compile=True), - NS(length= 208, workgroup_size= 64, threads_per_transform= 16, factors=(13, 16)), - NS(length= 210, workgroup_size= 64, threads_per_transform= 30, factors=(10, 7, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 216, workgroup_size=256, threads_per_transform= 36, factors=(6, 6, 6)), - NS(length= 220, workgroup_size=128, threads_per_transform= 22, factors=(10, 2, 11), half_lds=False, runtime_compile=True), - NS(length= 221, workgroup_size=128, threads_per_transform= 17, factors=(17, 13), runtime_compile=True), - NS(length= 224, workgroup_size= 64, threads_per_transform= 16, factors=(7, 2, 2, 2, 2, 2)), - NS(length= 225, workgroup_size=256, threads_per_transform= 75, factors=(5, 5, 3, 3), runtime_compile=True), - NS(length= 231, workgroup_size=256, threads_per_transform= 33, factors=(11, 7, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 234, workgroup_size= 64, threads_per_transform= 26, factors=(13, 9, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 238, workgroup_size= 64, threads_per_transform= 17, factors=(17, 7, 2), runtime_compile=True), - NS(length= 240, workgroup_size=128, threads_per_transform= 48, factors=(8, 5, 6)), - NS(length= 242, workgroup_size=128, threads_per_transform= 22, factors=(11, 2, 11), half_lds=False, runtime_compile=True), - NS(length= 243, workgroup_size=256, threads_per_transform= 81, factors=(3, 3, 3, 3, 3)), - NS(length= 245, workgroup_size=256, threads_per_transform= 35, factors=(7, 5, 7), half_lds=False, runtime_compile=True), - NS(length= 250, workgroup_size=128, threads_per_transform= 25, factors=(10, 5, 5), runtime_compile=True), - NS(length= 252, workgroup_size= 64, threads_per_transform= 63, factors=(7, 3, 3, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 255, workgroup_size= 64, threads_per_transform= 17, factors=(17, 5, 3), runtime_compile=True), - NS(length= 256, workgroup_size= 64, threads_per_transform= 64, factors=(4, 4, 4, 4)), - NS(length= 260, workgroup_size= 64, threads_per_transform= 26, factors=(13, 10, 2), half_lds=False, runtime_compile=True), - NS(length= 264, workgroup_size=256, threads_per_transform= 33, factors=(8, 3, 11), half_lds=False, runtime_compile=True), - NS(length= 270, workgroup_size=128, threads_per_transform= 27, factors=(10, 3, 3, 3)), - NS(length= 272, workgroup_size=128, threads_per_transform= 17, factors=(16, 17), runtime_compile=True), - NS(length= 273, workgroup_size= 64, threads_per_transform= 13, factors=(13, 3, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 275, workgroup_size= 64, threads_per_transform= 55, factors=(11, 5, 5), half_lds=False, runtime_compile=True), - NS(length= 280, workgroup_size= 64, threads_per_transform= 56, factors=(8, 7, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 286, workgroup_size= 64, threads_per_transform= 26, factors=(13, 11, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 288, workgroup_size=128, threads_per_transform= 24, factors=(6, 6, 4, 2), runtime_compile=True), - NS(length= 289, workgroup_size=128, threads_per_transform= 17, factors=(17, 17), runtime_compile=True), - NS(length= 294, workgroup_size=128, threads_per_transform= 42, factors=(6, 7, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 297, workgroup_size=256, threads_per_transform= 33, factors=(9, 3, 11), runtime_compile=True), - NS(length= 300, workgroup_size= 64, threads_per_transform= 30, factors=(10, 10, 3), runtime_compile=True), - NS(length= 306, workgroup_size=256, threads_per_transform= 34, factors=(17, 2, 9), runtime_compile=True), - NS(length= 308, workgroup_size= 64, threads_per_transform= 44, factors=(11, 7, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 312, workgroup_size= 64, threads_per_transform= 26, factors=(13, 4, 3, 2), half_lds=False, runtime_compile=True), - NS(length= 315, workgroup_size= 64, threads_per_transform= 63, factors=(7, 3, 3, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 320, workgroup_size= 64, threads_per_transform= 16, factors=(10, 4, 4, 2), runtime_compile=True), - NS(length= 324, workgroup_size= 64, threads_per_transform= 54, factors=(3, 6, 6, 3), runtime_compile=True), - NS(length= 325, workgroup_size= 64, threads_per_transform= 13, factors=(13, 5, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 330, workgroup_size=128, threads_per_transform= 33, factors=(11, 10, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 336, workgroup_size=128, threads_per_transform= 56, factors=(8, 7, 6)), - NS(length= 338, workgroup_size= 64, threads_per_transform= 26, factors=(13, 2, 13), runtime_compile=True), - NS(length= 340, workgroup_size=128, threads_per_transform= 34, factors=(17, 2, 10), runtime_compile=True), - NS(length= 343, workgroup_size=256, threads_per_transform= 49, factors=(7, 7, 7), runtime_compile=True), - NS(length= 350, workgroup_size= 64, threads_per_transform= 50, factors=(5, 7, 10), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 351, workgroup_size=128, threads_per_transform= 39, factors=(13, 3, 9), half_lds=False, runtime_compile=True), - NS(length= 352, workgroup_size= 64, threads_per_transform= 32, factors=(11, 2, 16), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 357, workgroup_size=256, threads_per_transform= 17, factors=(17, 3, 7), runtime_compile=True), - NS(length= 360, workgroup_size=256, threads_per_transform= 60, factors=(10, 6, 6), runtime_compile=True), - NS(length= 363, workgroup_size=128, threads_per_transform= 33, factors=(11, 3, 11), runtime_compile=True), - NS(length= 364, workgroup_size= 64, threads_per_transform= 52, factors=(13, 7, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 374, workgroup_size=256, threads_per_transform= 34, factors=(17, 2, 11), runtime_compile=True), - NS(length= 375, workgroup_size=128, threads_per_transform= 25, factors=(5, 5, 5, 3), runtime_compile=True), - NS(length= 378, workgroup_size=128, threads_per_transform=126, factors=(6, 3, 3, 7), half_lds=False, runtime_compile=True), - NS(length= 384, workgroup_size=128, threads_per_transform= 32, factors=(6, 4, 4, 4), runtime_compile=True), - NS(length= 385, workgroup_size= 64, threads_per_transform= 55, factors=(11, 7, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 390, workgroup_size=128, threads_per_transform= 39, factors=(13, 3, 10), half_lds=False, runtime_compile=True), - NS(length= 392, workgroup_size= 64, threads_per_transform= 56, factors=(8, 7, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 396, workgroup_size= 64, threads_per_transform= 44, factors=(11, 9, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 400, workgroup_size=128, threads_per_transform= 40, factors=(4, 10, 10), runtime_compile=True), - NS(length= 405, workgroup_size=128, threads_per_transform= 27, factors=(5, 3, 3, 3, 3), runtime_compile=True), - NS(length= 408, workgroup_size= 64, threads_per_transform= 17, factors=(17, 3, 8), runtime_compile=True), - NS(length= 416, workgroup_size= 64, threads_per_transform= 32, factors=(13, 2, 16), half_lds=False, runtime_compile=True), - NS(length= 420, workgroup_size= 64, threads_per_transform= 60, factors=(10, 7, 6), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 425, workgroup_size= 64, threads_per_transform= 17, factors=(17, 5, 5), runtime_compile=True), - NS(length= 429, workgroup_size=128, threads_per_transform= 39, factors=(13, 3, 11), half_lds=False, runtime_compile=True), - NS(length= 432, workgroup_size= 64, threads_per_transform= 27, factors=(3, 16, 3, 3), runtime_compile=True), - NS(length= 440, workgroup_size= 64, threads_per_transform= 55, factors=(11, 8, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 441, workgroup_size= 64, threads_per_transform= 63, factors=(9, 7, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 442, workgroup_size=256, threads_per_transform= 34, factors=(17, 2, 13), runtime_compile=True), - NS(length= 448, workgroup_size=128, threads_per_transform= 64, factors=(8, 7, 8), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 450, workgroup_size=128, threads_per_transform= 30, factors=(10, 5, 3, 3), runtime_compile=True), - NS(length= 455, workgroup_size=256, threads_per_transform= 65, factors=(13, 5, 7), half_lds=False, runtime_compile=True), - NS(length= 459, workgroup_size=256, threads_per_transform= 51, factors=(17, 3, 9), runtime_compile=True), - NS(length= 462, workgroup_size=256, threads_per_transform= 77, factors=(11, 6, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 468, workgroup_size= 64, threads_per_transform= 52, factors=(13, 9, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 476, workgroup_size=128, threads_per_transform= 34, factors=(17, 2, 7, 2), runtime_compile=True), - NS(length= 480, workgroup_size= 64, threads_per_transform= 16, factors=(10, 8, 6), runtime_compile=True), - NS(length= 484, workgroup_size= 64, threads_per_transform= 44, factors=(4, 11, 11), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 486, workgroup_size=256, threads_per_transform=162, factors=(6, 3, 3, 3, 3), runtime_compile=True), - NS(length= 490, workgroup_size=256, threads_per_transform= 70, factors=(10, 7, 7), half_lds=False, runtime_compile=True), - NS(length= 495, workgroup_size= 64, threads_per_transform= 55, factors=(11, 9, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 500, workgroup_size=128, threads_per_transform=100, factors=(10, 5, 10), runtime_compile=True), - NS(length= 504, workgroup_size= 64, threads_per_transform= 63, factors=(7, 9, 4, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 507, workgroup_size=128, threads_per_transform= 39, factors=(13, 3, 13), runtime_compile=True), - NS(length= 510, workgroup_size=256, threads_per_transform= 34, factors=(17, 2, 3, 5), runtime_compile=True), - NS(length= 512, workgroup_size= 64, threads_per_transform= 64, factors=(8, 8, 8)), - NS(length= 520, workgroup_size= 64, threads_per_transform= 52, factors=(13, 10, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 525, workgroup_size= 128, threads_per_transform=105, factors=(7, 3, 5, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 528, workgroup_size= 64, threads_per_transform= 48, factors=(4, 4, 3, 11), runtime_compile=True), - NS(length= 539, workgroup_size=256, threads_per_transform= 77, factors=(11, 7, 7), runtime_compile=True), - NS(length= 540, workgroup_size=256, threads_per_transform= 54, factors=(3, 10, 6, 3), runtime_compile=True), - NS(length= 544, workgroup_size=128, threads_per_transform= 34, factors=(17, 2, 16), runtime_compile=True), - NS(length= 546, workgroup_size=128, threads_per_transform= 39, factors=(13, 3, 7, 2), runtime_compile=True), - NS(length= 550, workgroup_size= 64, threads_per_transform= 55, factors=(11, 10, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 560, workgroup_size= 64, threads_per_transform= 56, factors=(8, 7, 5, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 561, workgroup_size=256, threads_per_transform= 51, factors=(17, 3, 11), runtime_compile=True), - NS(length= 567, workgroup_size= 64, threads_per_transform= 63, factors=(7, 9, 3, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 572, workgroup_size= 64, threads_per_transform= 52, factors=(13, 11, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 576, workgroup_size=128, threads_per_transform= 96, factors=(16, 6, 6), runtime_compile=True), - NS(length= 578, workgroup_size= 256, threads_per_transform=34, factors=(17, 17, 2), runtime_compile=True), - NS(length= 585, workgroup_size= 256, threads_per_transform=65, factors=(13, 5, 9), half_lds=False, runtime_compile=True), - NS(length= 588, workgroup_size= 256, threads_per_transform=84, factors=(7, 3, 4, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 594, workgroup_size=128, threads_per_transform= 99, factors=(11, 3, 6, 3), half_lds=False, runtime_compile=True), - NS(length= 595, workgroup_size= 64, threads_per_transform= 17, factors=(7, 17, 5), runtime_compile=True), - NS(length= 600, workgroup_size= 64, threads_per_transform= 60, factors=(10, 6, 10), runtime_compile=True), - NS(length= 605, workgroup_size= 64, threads_per_transform= 55, factors=(11, 5, 11), half_lds=False, runtime_compile=True), - NS(length= 612, workgroup_size= 64, threads_per_transform= 51, factors=(17, 3, 6, 2), runtime_compile=True), - NS(length= 616, workgroup_size=128, threads_per_transform= 88, factors=(11, 7, 8), half_lds=False, runtime_compile=True), - NS(length= 624, workgroup_size= 64, threads_per_transform= 52, factors=(13, 4, 6, 2), half_lds=False, runtime_compile=True), - NS(length= 625, workgroup_size=128, threads_per_transform=125, factors=(5, 5, 5, 5), runtime_compile=True), - NS(length= 630, workgroup_size= 64, threads_per_transform= 63, factors=(3, 3, 5, 7, 2), runtime_compile=True), - NS(length= 637, workgroup_size=128, threads_per_transform= 91, factors=(13, 7, 7), runtime_compile=True), - NS(length= 640, workgroup_size=128, threads_per_transform= 64, factors=(8, 10, 8), runtime_compile=True), - NS(length= 648, workgroup_size=256, threads_per_transform=216, factors=(8, 3, 3, 3, 3), runtime_compile=True), - NS(length= 650, workgroup_size= 256, threads_per_transform=65, factors=(10, 5, 13), half_lds=False, runtime_compile=True), - NS(length= 660, workgroup_size=128, threads_per_transform=110, factors=(11, 6, 10), runtime_compile=True), - NS(length= 663, workgroup_size= 64, threads_per_transform= 51, factors=(17, 13, 3), half_lds=False, runtime_compile=True), - NS(length= 672, workgroup_size= 64, threads_per_transform= 56, factors=(2, 2, 2, 2, 2, 3, 7), runtime_compile=True), - NS(length= 675, workgroup_size=256, threads_per_transform=225, factors=(5, 5, 3, 3, 3), runtime_compile=True), - NS(length= 676, workgroup_size= 64, threads_per_transform= 52, factors=(13, 13, 4), half_lds=False, runtime_compile=True), - NS(length= 680, workgroup_size=256, threads_per_transform= 68, factors=(17, 4, 10), runtime_compile=True), - NS(length= 686, workgroup_size= 64, threads_per_transform= 49, factors=(7, 7, 7, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 693, workgroup_size=128, threads_per_transform= 99, factors=(11, 7, 9), runtime_compile=True), - NS(length= 700, workgroup_size= 128, threads_per_transform=100, factors=(10, 7, 10), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 702, workgroup_size= 128, threads_per_transform=117, factors=(13, 3, 6, 3), runtime_compile=True), - NS(length= 704, workgroup_size=256, threads_per_transform=88, factors=(2, 2, 2, 2, 11, 2, 2), runtime_compile=True), - NS(length= 714, workgroup_size=64, threads_per_transform=51, factors=(3, 17, 7, 2), runtime_compile=True), - NS(length= 715, workgroup_size=256, threads_per_transform= 65, factors=(13, 5, 11), runtime_compile=True), - NS(length= 720, workgroup_size=256, threads_per_transform=120, factors=(10, 3, 8, 3), runtime_compile=True), - NS(length= 726, workgroup_size=256, threads_per_transform= 66, factors=(11, 6, 11), half_lds=False, runtime_compile=True), - NS(length= 728, workgroup_size=128, threads_per_transform=104, factors=(13, 7, 8), runtime_compile=True), - NS(length= 729, workgroup_size=256, threads_per_transform=243, factors=(3, 3, 3, 3, 3, 3), runtime_compile=True), - NS(length= 735, workgroup_size= 256, threads_per_transform=147, factors=(7, 3, 5, 7), half_lds=False, runtime_compile=True), - NS(length= 748, workgroup_size= 256, threads_per_transform=68, factors=(17, 4, 11), runtime_compile=True), - NS(length= 750, workgroup_size=256, threads_per_transform=250, factors=(10, 5, 3, 5), runtime_compile=True), - NS(length= 756, workgroup_size= 64, threads_per_transform= 63, factors=(2, 2, 3, 3, 3, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 765, workgroup_size=256, threads_per_transform= 51, factors=(17, 3, 5, 3), runtime_compile=True), - NS(length= 768, workgroup_size= 64, threads_per_transform= 48, factors=(16, 3, 16), runtime_compile=True), - NS(length= 770, workgroup_size=256, threads_per_transform=110, factors=(11, 10, 7), half_lds=False, runtime_compile=True), - NS(length= 780, workgroup_size=256, threads_per_transform= 78, factors=(2, 3, 13, 5, 2), runtime_compile=True), - NS(length= 784, workgroup_size= 64, threads_per_transform= 56, factors=(2, 2, 2, 2, 7, 7), runtime_compile=True), - NS(length= 792, workgroup_size=256, threads_per_transform= 88, factors=(2, 2, 2, 3, 3, 11), half_lds=False, runtime_compile=True), - NS(length= 800, workgroup_size=256, threads_per_transform=160, factors=(16, 5, 10), runtime_compile=True), - NS(length= 810, workgroup_size=128, threads_per_transform= 81, factors=(3, 10, 3, 3, 3), runtime_compile=True), - NS(length= 816, workgroup_size= 64, threads_per_transform= 51, factors=(17, 2, 3, 2, 2, 2), runtime_compile=True), - NS(length= 819, workgroup_size=128, threads_per_transform=117, factors=(9, 7, 13), half_lds=False, runtime_compile=True), - NS(length= 825, workgroup_size= 64, threads_per_transform= 55, factors=(11, 5, 5, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 832, workgroup_size=128, threads_per_transform=104, factors=(13, 2, 2, 2, 2, 2, 2), runtime_compile=True), - NS(length= 833, workgroup_size=128, threads_per_transform=119, factors=(17, 7, 7), runtime_compile=True), - NS(length= 840, workgroup_size= 64, threads_per_transform= 56, factors=(2, 2, 2, 3, 5, 7), runtime_compile=True), - NS(length= 845, workgroup_size= 256, threads_per_transform=65, factors=(13, 5, 13), runtime_compile=True), - NS(length= 847, workgroup_size= 256, threads_per_transform=77, factors=(11, 7, 11), runtime_compile=True), - NS(length= 850, workgroup_size= 128, threads_per_transform=85, factors=(10, 5, 17), half_lds=False, runtime_compile=True), - NS(length= 858, workgroup_size= 256, threads_per_transform=78, factors=(13, 11, 6), runtime_compile=True), - NS(length= 864, workgroup_size= 64, threads_per_transform= 54, factors=(3, 6, 16, 3), runtime_compile=True), - NS(length= 867, workgroup_size= 64, threads_per_transform=51, factors=(17, 17, 3), runtime_compile=True), - NS(length= 875, workgroup_size= 256, threads_per_transform=175, factors=(7, 5, 5, 5), half_lds=False, runtime_compile=True), - NS(length= 880, workgroup_size=256, threads_per_transform= 88, factors=(2, 2, 2, 2, 11, 5), runtime_compile=True), - NS(length= 882, workgroup_size= 64, threads_per_transform=63, factors=(9, 7, 7, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 884, workgroup_size= 256, threads_per_transform=68, factors=(13, 4, 17), runtime_compile=True), - NS(length= 891, workgroup_size= 256, threads_per_transform=99, factors=(9, 11, 3, 3), runtime_compile=True), - NS(length= 896, workgroup_size=128, threads_per_transform=112, factors=(2, 2, 2, 2, 2, 2, 2, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 900, workgroup_size=256, threads_per_transform= 90, factors=(10, 10, 3, 3), runtime_compile=True), - NS(length= 910, workgroup_size=256, threads_per_transform= 91, factors=(13, 2, 7, 5), half_lds=False, runtime_compile=True), - NS(length= 918, workgroup_size=128, threads_per_transform=102, factors=(17, 9, 2, 3), runtime_compile=True), - NS(length= 924, workgroup_size= 64, threads_per_transform= 44, factors=(2, 2, 3, 7, 11), runtime_compile=True), - NS(length= 935, workgroup_size= 256, threads_per_transform= 85, factors=(17, 11, 5), runtime_compile=True), - NS(length= 936, workgroup_size=256, threads_per_transform= 78, factors=(2, 2, 13, 2, 3, 3), runtime_compile=True), - NS(length= 945, workgroup_size= 64, threads_per_transform= 63, factors=(3, 3, 3, 5, 7), runtime_compile=True), - NS(length= 952, workgroup_size=256, threads_per_transform= 68, factors=(17, 4, 2, 7), runtime_compile=True), - NS(length= 960, workgroup_size=256, threads_per_transform=160, factors=(16, 10, 6), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 968, workgroup_size=256, threads_per_transform= 88, factors=(2, 2, 2, 11, 11), half_lds=False, runtime_compile=True), - NS(length= 972, workgroup_size=256, threads_per_transform=162, factors=(3, 6, 3, 6, 3), runtime_compile=True), - NS(length= 975, workgroup_size=128, threads_per_transform= 39, factors=(13, 5, 3, 5), runtime_compile=True), - NS(length= 980, workgroup_size= 256, threads_per_transform=196, factors=(7, 5, 7, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length= 990, workgroup_size=128, threads_per_transform=110, factors=(2, 3, 3, 5, 11), half_lds=False, runtime_compile=True), - NS(length=1000, workgroup_size=128, threads_per_transform=100, factors=(10, 10, 10), runtime_compile=True), - NS(length=1001, workgroup_size=256, threads_per_transform= 91, factors=(13, 7, 11), runtime_compile=True), - NS(length=1008, workgroup_size= 64, threads_per_transform= 56, factors=(2, 2, 2, 2, 3, 3, 7), runtime_compile=True), - NS(length=1014, workgroup_size=256, threads_per_transform= 78, factors=(13, 6, 13), half_lds=False, runtime_compile=True), - NS(length=1020, workgroup_size=256, threads_per_transform= 68, factors=(2, 17, 2, 3, 5), runtime_compile=True), - NS(length=1024, workgroup_size=128, threads_per_transform=128, factors=(8, 8, 4, 4)), - NS(length=1040, workgroup_size=256, threads_per_transform=208, factors=(13, 16, 5), runtime_compile=True), - NS(length=1050, workgroup_size=256, threads_per_transform=210, factors=(2, 3, 5, 5, 7), half_lds=False, runtime_compile=True), - NS(length=1053, workgroup_size=128, threads_per_transform=117, factors=(3, 3, 13, 3, 3), runtime_compile=True), - NS(length=1056, workgroup_size=256, threads_per_transform=176, factors=(2, 2, 2, 2, 11, 6), runtime_compile=True), - NS(length=1071, workgroup_size=128, threads_per_transform=119, factors=(17, 7, 9), runtime_compile=True), - NS(length=1078, workgroup_size=256, threads_per_transform= 77, factors=(2, 11, 7, 7), runtime_compile=True), - NS(length=1080, workgroup_size=256, threads_per_transform=108, factors=(6, 10, 6, 3), runtime_compile=True), - NS(length=1088, workgroup_size=256, threads_per_transform= 68, factors=(17, 4, 4, 2, 2), runtime_compile=True), - NS(length=1089, workgroup_size=128, threads_per_transform=121, factors=(3, 11, 3, 11), half_lds=False, runtime_compile=True), - NS(length=1092, workgroup_size= 64, threads_per_transform= 52, factors=(2, 2, 13, 7, 3), runtime_compile=True), - NS(length=1100, workgroup_size=128, threads_per_transform=110, factors=(2, 2, 11, 5, 5), half_lds=False, runtime_compile=True), - NS(length=1105, workgroup_size=256, threads_per_transform= 85, factors=(17, 13, 5), runtime_compile=True), - NS(length=1120, workgroup_size=256, threads_per_transform=224, factors=(2, 2, 2, 2, 2, 5, 7), runtime_compile=True), - NS(length=1122, workgroup_size=256, threads_per_transform=102, factors=(17, 11, 6), runtime_compile=True), - NS(length=1125, workgroup_size=256, threads_per_transform=225, factors=(5, 5, 3, 3, 5), runtime_compile=True), - NS(length=1134, workgroup_size=128, threads_per_transform=126, factors=(2, 3, 3, 3, 3, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length=1144, workgroup_size=128, threads_per_transform=104, factors=(13, 11, 8), half_lds=False, direct_to_from_reg=False, runtime_compile=True), - NS(length=1152, workgroup_size=256, threads_per_transform=144, factors=(4, 3, 8, 3, 4), runtime_compile=True), - NS(length=1155, workgroup_size= 64, threads_per_transform= 55, factors=(11, 5, 7, 3), runtime_compile=True), - NS(length=1156, workgroup_size=256, threads_per_transform= 68, factors=(17, 2, 17, 2), runtime_compile=True), - NS(length=1170, workgroup_size=256, threads_per_transform=117, factors=(2, 13, 3, 5, 3), half_lds=False, runtime_compile=True), - NS(length=1176, workgroup_size= 64, threads_per_transform= 56, factors=(2, 2, 2, 3, 7, 7), runtime_compile=True), - NS(length=1183, workgroup_size=256, threads_per_transform= 91, factors=(7, 13, 13), runtime_compile=True), - NS(length=1188, workgroup_size=256, threads_per_transform= 66, factors=(6, 11, 2, 3, 3), runtime_compile=True), - NS(length=1190, workgroup_size=256, threads_per_transform= 85, factors=(17, 2, 5, 7), runtime_compile=True), - NS(length=1200, workgroup_size=256, threads_per_transform= 75, factors=(5, 5, 16, 3), runtime_compile=True), - NS(length=1210, workgroup_size=128, threads_per_transform=110, factors=(2, 5, 11, 11), runtime_compile=True), - NS(length=1215, workgroup_size=256, threads_per_transform=243, factors=(5, 3, 3, 3, 3, 3), runtime_compile=True), - NS(length=1224, workgroup_size=256, threads_per_transform=102, factors=(17, 3, 4, 6), runtime_compile=True), - NS(length=1225, workgroup_size=256, threads_per_transform=175, factors=(5, 5, 7, 7), runtime_compile=True), - NS(length=1232, workgroup_size=256, threads_per_transform=176, factors=(2, 2, 2, 2, 11, 7), runtime_compile=True), - NS(length=1248, workgroup_size= 64, threads_per_transform= 52, factors=(2, 2, 13, 2, 3, 2, 2), runtime_compile=True), - NS(length=1250, workgroup_size=256, threads_per_transform=250, factors=(5, 10, 5, 5), runtime_compile=True), - NS(length=1260, workgroup_size= 64, threads_per_transform= 63, factors=(2, 2, 3, 3, 5, 7), runtime_compile=True), - NS(length=1274, workgroup_size=256, threads_per_transform=182, factors=(2, 13, 7, 7), runtime_compile=True), - NS(length=1275, workgroup_size=256, threads_per_transform= 85, factors=(17, 3, 5, 5), runtime_compile=True), - NS(length=1280, workgroup_size=128, threads_per_transform= 80, factors=(16, 5, 16), runtime_compile=True), - NS(length=1287, workgroup_size=128, threads_per_transform=117, factors=(3, 13, 3, 11), half_lds=False, runtime_compile=True), - NS(length=1296, workgroup_size=128, threads_per_transform=108, factors=(6, 6, 6, 6), runtime_compile=True), - NS(length=1300, workgroup_size=256, threads_per_transform=130, factors=(10, 10, 13), half_lds=False, runtime_compile=True), - NS(length=1309, workgroup_size=128, threads_per_transform=119, factors=(17, 7, 11), runtime_compile=True), - NS(length=1320, workgroup_size=256, threads_per_transform=165, factors=(11, 2, 3, 5, 4), half_lds=False, runtime_compile=True), - NS(length=1323, workgroup_size=256, threads_per_transform=189, factors=(3, 3, 3, 7, 7), half_lds=False, runtime_compile=True), - NS(length=1326, workgroup_size=256, threads_per_transform=102, factors=(17, 6, 13), runtime_compile=True), - NS(length=1331, workgroup_size=256, threads_per_transform=121, factors=(11, 11, 11), runtime_compile=True), - NS(length=1344, workgroup_size=256, threads_per_transform=224, factors=(2, 2, 2, 2, 2, 2, 3, 7), runtime_compile=True), - NS(length=1350, workgroup_size=256, threads_per_transform=135, factors=(5, 10, 3, 3, 3), runtime_compile=True), - NS(length=1352, workgroup_size= 64, threads_per_transform= 52, factors=(2, 13, 13, 4), runtime_compile=True), - NS(length=1360, workgroup_size=256, threads_per_transform= 85, factors=(17, 5, 16), runtime_compile=True), - NS(length=1365, workgroup_size=256, threads_per_transform= 91, factors=(13, 7, 5, 3), runtime_compile=True), - NS(length=1372, workgroup_size=256, threads_per_transform= 98, factors=(2, 2, 7, 7, 7), runtime_compile=True), - NS(length=1375, workgroup_size= 64, threads_per_transform= 55, factors=(11, 5, 5, 5), runtime_compile=True), - NS(length=1377, workgroup_size= 64, threads_per_transform= 51, factors=(17, 3, 9, 3), runtime_compile=True), - NS(length=1386, workgroup_size=256, threads_per_transform=231, factors=(2, 7, 3, 11, 3), runtime_compile=True), - NS(length=1400, workgroup_size= 64, threads_per_transform= 56, factors=(2, 2, 2, 5, 7, 5), runtime_compile=True), - NS(length=1404, workgroup_size=128, threads_per_transform=117, factors=(2, 2, 3, 13, 3, 3), runtime_compile=True), - NS(length=1408, workgroup_size=256, threads_per_transform=176, factors=(2, 2, 2, 2, 2, 2, 11, 2), runtime_compile=True), - NS(length=1428, workgroup_size=128, threads_per_transform=119, factors=(17, 2, 7, 6), runtime_compile=True), - NS(length=1430, workgroup_size=256, threads_per_transform=143, factors=(13, 11, 10), half_lds=False, runtime_compile=True), - NS(length=1440, workgroup_size=128, threads_per_transform= 90, factors=(10, 16, 3, 3), runtime_compile=True), - NS(length=1445, workgroup_size=128, threads_per_transform= 85, factors=(17, 5, 17), runtime_compile=True), - NS(length=1452, workgroup_size=256, threads_per_transform=132, factors=(11, 3, 11, 4), runtime_compile=True), - NS(length=1456, workgroup_size=256, threads_per_transform=182, factors=(13, 4, 7, 2, 2), runtime_compile=True), - NS(length=1458, workgroup_size=256, threads_per_transform=243, factors=(6, 3, 3, 3, 3, 3), runtime_compile=True), - NS(length=1470, workgroup_size=256, threads_per_transform=210, factors=(2, 3, 5, 7, 7), runtime_compile=True), - NS(length=1485, workgroup_size=256, threads_per_transform=165, factors=(3, 5, 11, 3, 3), half_lds=False, runtime_compile=True), - NS(length=1496, workgroup_size=256, threads_per_transform=187, factors=(17, 8, 11), runtime_compile=True), - NS(length=1500, workgroup_size=256, threads_per_transform=150, factors=(5, 10, 10, 3), runtime_compile=True), - NS(length=1512, workgroup_size= 64, threads_per_transform= 63, factors=(2, 2, 2, 3, 3, 3, 7), runtime_compile=True), - NS(length=1521, workgroup_size=128, threads_per_transform=117, factors=(13, 3, 3, 13), runtime_compile=True), - NS(length=1530, workgroup_size=128, threads_per_transform=102, factors=(17, 3, 6, 5), runtime_compile=True), - NS(length=1536, workgroup_size=256, threads_per_transform=256, factors=(16, 16, 6), runtime_compile=True), - NS(length=1540, workgroup_size=256, threads_per_transform=154, factors=(11, 2, 7, 5, 2), runtime_compile=True), - NS(length=1547, workgroup_size=128, threads_per_transform=119, factors=(17, 7, 13), runtime_compile=True), - NS(length=1560, workgroup_size=256, threads_per_transform=156, factors=(13, 2, 2, 10, 3), half_lds=False, runtime_compile=True), - NS(length=1568, workgroup_size=256, threads_per_transform=224, factors=(2, 2, 2, 2, 2, 7, 7), runtime_compile=True), - NS(length=1573, workgroup_size=256, threads_per_transform=143, factors=(13, 11, 11), half_lds=False, runtime_compile=True), - NS(length=1575, workgroup_size= 64, threads_per_transform= 63, factors=(3, 3, 5, 7, 5), runtime_compile=True), - NS(length=1584, workgroup_size=256, threads_per_transform=176, factors=(4, 2, 2, 11, 3, 3), runtime_compile=True), - NS(length=1600, workgroup_size=256, threads_per_transform=100, factors=(10, 16, 10), runtime_compile=True), - NS(length=1617, workgroup_size=256, threads_per_transform=231, factors=(3, 7, 7, 11), half_lds=False, runtime_compile=True), - NS(length=1620, workgroup_size=256, threads_per_transform=162, factors=(10, 3, 3, 6, 3), runtime_compile=True), - NS(length=1625, workgroup_size=256, threads_per_transform= 65, factors=(13, 5, 5, 5), runtime_compile=True), - NS(length=1632, workgroup_size=128, threads_per_transform=102, factors=(17, 2, 2, 3, 8), runtime_compile=True), - NS(length=1638, workgroup_size=256, threads_per_transform=182, factors=(13, 2, 3, 7, 3), runtime_compile=True), - NS(length=1650, workgroup_size=128, threads_per_transform=110, factors=(11, 2, 3, 5, 5), runtime_compile=True), - NS(length=1664, workgroup_size=256, threads_per_transform=208, factors=(13, 2, 2, 4, 2, 2, 2), runtime_compile=True), - NS(length=1666, workgroup_size=128, threads_per_transform=119, factors=(17, 2, 7, 7), runtime_compile=True), - NS(length=1680, workgroup_size=128, threads_per_transform=112, factors=(2, 2, 2, 2, 3, 7, 5), runtime_compile=True), - NS(length=1683, workgroup_size= 64, threads_per_transform= 51, factors=(17, 3, 11, 3), runtime_compile=True), - NS(length=1690, workgroup_size=256, threads_per_transform=169, factors=(13, 10, 13), half_lds=False, runtime_compile=True), - NS(length=1694, workgroup_size=256, threads_per_transform=154, factors=(11, 2, 11, 7), runtime_compile=True), - NS(length=1700, workgroup_size=256, threads_per_transform=170, factors=(17, 10, 10), runtime_compile=True), - NS(length=1701, workgroup_size= 64, threads_per_transform= 63, factors=(3, 3, 3, 3, 3, 7), runtime_compile=True), - NS(length=1715, workgroup_size=256, threads_per_transform=245, factors=(5, 7, 7, 7), runtime_compile=True), - NS(length=1716, workgroup_size=256, threads_per_transform=156, factors=(13, 2, 6, 11), half_lds=False, runtime_compile=True), - NS(length=1728, workgroup_size=128, threads_per_transform=108, factors=(3, 6, 6, 16), runtime_compile=True), - NS(length=1734, workgroup_size=128, threads_per_transform=102, factors=(17, 17, 6), runtime_compile=True), - NS(length=1750, workgroup_size=256, threads_per_transform=175, factors=(2, 5, 5, 7, 5), runtime_compile=True), - NS(length=1755, workgroup_size=128, threads_per_transform=117, factors=(13, 3, 3, 3, 5), runtime_compile=True), - NS(length=1760, workgroup_size=256, threads_per_transform=176, factors=(2, 2, 2, 2, 2, 11, 5), runtime_compile=True), - NS(length=1764, workgroup_size=128, threads_per_transform=126, factors=(2, 2, 3, 3, 7, 7), runtime_compile=True), - NS(length=1768, workgroup_size=256, threads_per_transform=136, factors=(17, 13, 8), runtime_compile=True), - NS(length=1782, workgroup_size=128, threads_per_transform= 99, factors=(11, 3, 3, 3, 3, 2), runtime_compile=True), - NS(length=1785, workgroup_size=128, threads_per_transform=119, factors=(17, 3, 5, 7), runtime_compile=True), - NS(length=1792, workgroup_size=256, threads_per_transform=224, factors=(4, 4, 4, 4, 7), runtime_compile=True), - NS(length=1800, workgroup_size=256, threads_per_transform=180, factors=(10, 6, 10, 3), runtime_compile=True), - NS(length=1815, workgroup_size=256, threads_per_transform=165, factors=(11, 3, 5, 11), half_lds=False, runtime_compile=True), - NS(length=1820, workgroup_size=256, threads_per_transform=182, factors=(10, 13, 7, 2), runtime_compile=True), - NS(length=1836, workgroup_size=256, threads_per_transform=153, factors=(17, 3, 3, 2, 6), runtime_compile=True), - NS(length=1848, workgroup_size=256, threads_per_transform=231, factors=(3, 11, 7, 4, 2), runtime_compile=True), - NS(length=1859, workgroup_size=256, threads_per_transform=169, factors=(13, 11, 13), runtime_compile=True), - NS(length=1870, workgroup_size=256, threads_per_transform=187, factors=(17, 10, 11), runtime_compile=True), - NS(length=1872, workgroup_size=256, threads_per_transform=156, factors=(13, 3, 4, 6, 2), runtime_compile=True), - NS(length=1875, workgroup_size=256, threads_per_transform=125, factors=(5, 5, 5, 5, 3), runtime_compile=True), - NS(length=1890, workgroup_size=128, threads_per_transform=126, factors=(2, 3, 3, 3, 7, 5), runtime_compile=True), - NS(length=1904, workgroup_size=128, threads_per_transform=119, factors=(17, 2, 2, 7, 4), runtime_compile=True), - NS(length=1911, workgroup_size=128, threads_per_transform= 91, factors=(13, 7, 7, 3), runtime_compile=True), - NS(length=1920, workgroup_size=256, threads_per_transform=120, factors=(10, 6, 16, 2), runtime_compile=True), - NS(length=1925, workgroup_size= 64, threads_per_transform= 55, factors=(7, 11, 5, 5), runtime_compile=True), - NS(length=1936, workgroup_size=256, threads_per_transform=176, factors=(2, 2, 4, 11, 11), half_lds=False, runtime_compile=True), - NS(length=1944, workgroup_size=256, threads_per_transform=243, factors=(3, 3, 3, 3, 8, 3), runtime_compile=True), - NS(length=1950, workgroup_size=256, threads_per_transform=195, factors=(13, 5, 10, 3), half_lds=False, runtime_compile=True), - NS(length=1960, workgroup_size= 64, threads_per_transform= 56, factors=(4, 7, 2, 7, 5), runtime_compile=True), - NS(length=1980, workgroup_size=256, threads_per_transform=198, factors=(11, 2, 3, 3, 5, 2), runtime_compile=True), - NS(length=1989, workgroup_size=256, threads_per_transform=153, factors=(17, 13, 9), runtime_compile=True), - NS(length=2000, workgroup_size=128, threads_per_transform=125, factors=(5, 5, 5, 16), runtime_compile=True), - NS(length=2002, workgroup_size=256, threads_per_transform=182, factors=(2, 13, 7, 11), runtime_compile=True), - NS(length=2016, workgroup_size=256, threads_per_transform=112, factors=(2, 2, 2, 2, 2, 3, 3, 7), runtime_compile=True), - NS(length=2023, workgroup_size=128, threads_per_transform=119, factors=(17, 7, 17), runtime_compile=True), - NS(length=2025, workgroup_size=256, threads_per_transform=135, factors=(3, 3, 5, 5, 3, 3), runtime_compile=True), - NS(length=2028, workgroup_size=256, threads_per_transform=156, factors=(13, 4, 3, 13), half_lds=False, runtime_compile=True), - NS(length=2040, workgroup_size=256, threads_per_transform=170, factors=(17, 4, 3, 10), runtime_compile=True), - NS(length=2048, workgroup_size=256, threads_per_transform=256, factors=(16, 16, 8), runtime_compile=True), - NS(length=2160, workgroup_size=256, threads_per_transform= 60, factors=(10, 6, 6, 6), runtime_compile=True), - NS(length=2187, workgroup_size=256, threads_per_transform=243, factors=(3, 3, 3, 3, 3, 3, 3), runtime_compile=True), - NS(length=2197, workgroup_size=256, threads_per_transform=169, factors=(13, 13, 13), runtime_compile=True), - NS(length=2250, workgroup_size=256, threads_per_transform= 90, factors=(10, 3, 5, 3, 5), runtime_compile=True), - NS(length=2304, workgroup_size=256, threads_per_transform=192, factors=(6, 6, 4, 4, 4), runtime_compile=True), - NS(length=2400, workgroup_size=256, threads_per_transform=240, factors=(4, 10, 10, 6), runtime_compile=True), - NS(length=2401, workgroup_size=256, threads_per_transform= 49, factors=(7, 7, 7, 7), runtime_compile=True), - NS(length=2430, workgroup_size=256, threads_per_transform= 81, factors=(10, 3, 3, 3, 3, 3), runtime_compile=True), - NS(length=2500, workgroup_size=256, threads_per_transform=250, factors=(10, 5, 10, 5), runtime_compile=True), - NS(length=2560, workgroup_size=128, threads_per_transform=128, factors=(4, 4, 4, 10, 4), runtime_compile=True), - NS(length=2592, workgroup_size=256, threads_per_transform=216, factors=(6, 6, 6, 6, 2), runtime_compile=True), - NS(length=2700, workgroup_size=128, threads_per_transform= 90, factors=(3, 10, 10, 3, 3), runtime_compile=True), - NS(length=2880, workgroup_size=256, threads_per_transform= 96, factors=(10, 6, 6, 2, 2, 2), runtime_compile=True), - NS(length=2916, workgroup_size=256, threads_per_transform=243, factors=(6, 6, 3, 3, 3, 3), runtime_compile=True), - NS(length=3000, workgroup_size=128, threads_per_transform=100, factors=(10, 3, 10, 10), runtime_compile=True), - NS(length=3072, workgroup_size=256, threads_per_transform=256, factors=(6, 4, 4, 4, 4, 2), runtime_compile=True), - NS(length=3125, workgroup_size=128, threads_per_transform=125, factors=(5, 5, 5, 5, 5), runtime_compile=True), - NS(length=3200, workgroup_size=256, threads_per_transform=160, factors=(10, 10, 4, 4, 2), runtime_compile=True), - NS(length=3240, workgroup_size=128, threads_per_transform=108, factors=(3, 3, 10, 6, 6), runtime_compile=True), - NS(length=3375, workgroup_size=256, threads_per_transform=225, factors=(5, 5, 5, 3, 3, 3), runtime_compile=True), - NS(length=3456, workgroup_size=256, threads_per_transform=144, factors=(6, 6, 6, 4, 4), runtime_compile=True), - NS(length=3600, workgroup_size=256, threads_per_transform=120, factors=(10, 10, 6, 6), runtime_compile=True), - NS(length=3645, workgroup_size=256, threads_per_transform=243, factors=(5, 3, 3, 3, 3, 3, 3), runtime_compile=True), - NS(length=3750, workgroup_size=256, threads_per_transform=125, factors=(3, 5, 5, 10, 5), runtime_compile=True), - NS(length=3840, workgroup_size=256, threads_per_transform=128, factors=(10, 6, 2, 2, 2, 2, 2, 2), runtime_compile=True), - NS(length=3888, workgroup_size=512, threads_per_transform=324, factors=(16, 3, 3, 3, 3, 3), runtime_compile=True), - NS(length=4000, workgroup_size=256, threads_per_transform=200, factors=(10, 10, 10, 4), runtime_compile=True), - NS(length=4050, workgroup_size=256, threads_per_transform=135, factors=(10, 5, 3, 3, 3, 3), runtime_compile=True), - NS(length=4096, workgroup_size=256, threads_per_transform=256, factors=(16, 16, 16), runtime_compile=True), - NS(length=4704, workgroup_size=256, threads_per_transform=224, factors=(8, 4, 7, 7, 3), double_precision=False, runtime_compile=True), - NS(length=5488, workgroup_size=256, threads_per_transform=196, factors=(7, 4, 7, 4, 7), double_precision=False, runtime_compile=True), - NS(length=6144, workgroup_size=512, threads_per_transform=512, factors=(16, 4, 8, 3, 4), double_precision=False, runtime_compile=True), - NS(length=6561, workgroup_size=256, threads_per_transform=243, factors=(3, 3, 3, 3, 3, 3, 3, 3), double_precision=False, runtime_compile=True), - NS(length=8192, workgroup_size=512, threads_per_transform=512, factors=(16, 4, 4, 4, 8), double_precision=False, runtime_compile=True), - - # configs for 160kiB LDS - NS(length=4704, workgroup_size=256, threads_per_transform=224, factors=(8, 4, 7, 7, 3), lds_size_bytes=LDS_160k, runtime_compile=True), - NS(length=5488, workgroup_size=256, threads_per_transform=196, factors=(7, 4, 7, 4, 7), lds_size_bytes=LDS_160k, runtime_compile=True), - NS(length=6144, workgroup_size=384, threads_per_transform=256, factors=(4, 8, 8, 8, 3), lds_size_bytes=LDS_160k, runtime_compile=True), - NS(length=6561, workgroup_size=256, threads_per_transform=243, factors=(3, 3, 3, 3, 3, 3, 3, 3), lds_size_bytes=LDS_160k, runtime_compile=True), - NS(length=8192, workgroup_size=512, threads_per_transform=512, factors=(16, 4, 16, 8), lds_size_bytes=LDS_160k, runtime_compile=True), - NS(length=9216, workgroup_size=512, threads_per_transform=512, factors=(4, 8, 4, 4, 3, 6), lds_size_bytes=LDS_160k, runtime_compile=True), - NS(length=10000, workgroup_size=512, threads_per_transform=500, factors=(4, 5, 5, 10, 10), lds_size_bytes=LDS_160k, runtime_compile=True), - NS(length=10240, workgroup_size=512, threads_per_transform=512, factors=(8, 4, 4, 4, 5, 4), lds_size_bytes=LDS_160k, runtime_compile=True), - NS(length=10752, workgroup_size=512, threads_per_transform=512, factors=(4, 16, 8, 7, 3), double_precision=False, lds_size_bytes=LDS_160k, runtime_compile=True), - NS(length=11200, workgroup_size=512, threads_per_transform=448, factors=(4, 7, 5, 16, 5), double_precision=False, lds_size_bytes=LDS_160k, runtime_compile=True), - NS(length=12288, workgroup_size=512, threads_per_transform=512, factors=(8, 8, 4, 6, 8), double_precision=False, lds_size_bytes=LDS_160k, runtime_compile=True), - NS(length=16384, workgroup_size=512, threads_per_transform=512, factors=(8, 16, 4, 8, 4), double_precision=False, lds_size_bytes=LDS_160k, runtime_compile=True), - NS(length=16807, workgroup_size=384, threads_per_transform=343, factors=(7, 7, 7, 7, 7), double_precision=False, lds_size_bytes=LDS_160k, runtime_compile=True), - NS(length=18816, workgroup_size=512, threads_per_transform=448, factors=(8, 8, 7, 7, 6), double_precision=False, lds_size_bytes=LDS_160k, runtime_compile=True), - NS(length=19200, workgroup_size=512, threads_per_transform=480, factors=(8, 10, 8, 5, 6), double_precision=False, lds_size_bytes=LDS_160k, runtime_compile=True), - NS(length=20480, workgroup_size=512, threads_per_transform=512, factors=(4, 4, 16, 10, 8), double_precision=False, lds_size_bytes=LDS_160k, runtime_compile=True), - ] + kernels1d = config_sbrr.sbrr_kernels kernels = [NS(**kernel.__dict__, - scheme='CS_KERNEL_STOCKHAM', - precision=['sp','dp'] if not hasattr(kernel, 'double_precision') or kernel.double_precision else ['sp']) for kernel in kernels1d] - + scheme='CS_KERNEL_STOCKHAM', + precision=['sp','dp'] if not hasattr(kernel, 'double_precision') or kernel.double_precision else ['sp']) for kernel in kernels1d] + return kernels - -def list_2d_kernels(): - """Return list of fused 2D kernels to generate.""" - - fused_kernels = [ - NS(length=[4,4], factors=[[2,2],[2,2]], threads_per_transform=[2,2], workgroup_size=8), - NS(length=[4,8], factors=[[2,2],[4,2]], threads_per_transform=[2,2], workgroup_size=16), - NS(length=[4,9], factors=[[2,2],[3,3]], threads_per_transform=[2,3], workgroup_size=18), - NS(length=[4,16], factors=[[2,2],[4,4]], threads_per_transform=[2,4], workgroup_size=32), - NS(length=[4,25], factors=[[2,2],[5,5]], threads_per_transform=[2,5], workgroup_size=50), - NS(length=[4,27], factors=[[2,2],[3,3,3]], threads_per_transform=[2,9], workgroup_size=54), - NS(length=[4,32], factors=[[2,2],[8,4]], threads_per_transform=[2,4], workgroup_size=64), - NS(length=[4,64], factors=[[2,2],[4,4,4]], threads_per_transform=[2,16], workgroup_size=128), - NS(length=[4,81], factors=[[2,2],[3,3,3,3]], threads_per_transform=[2,27], workgroup_size=162), - NS(length=[4,125], factors=[[2,2],[5,5,5]], threads_per_transform=[2,25], workgroup_size=250), - NS(length=[4,128], factors=[[2,2],[8,4,4]], threads_per_transform=[2,16], workgroup_size=256), - NS(length=[4,243], factors=[[2,2],[3,3,3,3,3]], threads_per_transform=[2,81], workgroup_size=486), - NS(length=[4,256], factors=[[2,2],[4,4,4,4]], threads_per_transform=[2,64], workgroup_size=512), - NS(length=[8,4], factors=[[4,2],[2,2]], threads_per_transform=[2,2], workgroup_size=16), - NS(length=[8,8], factors=[[4,2],[4,2]], threads_per_transform=[2,2], workgroup_size=16), - NS(length=[8,9], factors=[[4,2],[3,3]], threads_per_transform=[2,3], workgroup_size=24), - NS(length=[8,16], factors=[[4,2],[4,4]], threads_per_transform=[2,4], workgroup_size=32), - NS(length=[8,25], factors=[[4,2],[5,5]], threads_per_transform=[2,5], workgroup_size=50), - NS(length=[8,27], factors=[[4,2],[3,3,3]], threads_per_transform=[2,9], workgroup_size=72), - NS(length=[8,32], factors=[[4,2],[8,4]], threads_per_transform=[2,4], workgroup_size=64), - NS(length=[8,64], factors=[[4,2],[4,4,4]], threads_per_transform=[2,16], workgroup_size=128), - NS(length=[8,81], factors=[[4,2],[3,3,3,3]], threads_per_transform=[2,27], workgroup_size=216), - NS(length=[8,125], factors=[[4,2],[5,5,5]], threads_per_transform=[2,25], workgroup_size=250), - NS(length=[8,128], factors=[[4,2],[8,4,4]], threads_per_transform=[2,16], workgroup_size=256), - NS(length=[8,243], factors=[[4,2],[3,3,3,3,3]], threads_per_transform=[2,81], workgroup_size=648), - NS(length=[8,256], factors=[[4,2],[4,4,4,4]], threads_per_transform=[2,64], workgroup_size=512), - NS(length=[9,4], factors=[[3,3],[2,2]], threads_per_transform=[3,2], workgroup_size=18), - NS(length=[9,8], factors=[[3,3],[4,2]], threads_per_transform=[3,2], workgroup_size=24), - NS(length=[9,9], factors=[[3,3],[3,3]], threads_per_transform=[3,3], workgroup_size=27), - NS(length=[9,16], factors=[[3,3],[4,4]], threads_per_transform=[3,4], workgroup_size=48), - NS(length=[9,25], factors=[[3,3],[5,5]], threads_per_transform=[3,5], workgroup_size=75), - NS(length=[9,27], factors=[[3,3],[3,3,3]], threads_per_transform=[3,9], workgroup_size=81), - NS(length=[9,32], factors=[[3,3],[8,4]], threads_per_transform=[3,4], workgroup_size=96), - NS(length=[9,64], factors=[[3,3],[4,4,4]], threads_per_transform=[3,16], workgroup_size=192), - NS(length=[9,81], factors=[[3,3],[3,3,3,3]], threads_per_transform=[3,27], workgroup_size=243), - NS(length=[9,125], factors=[[3,3],[5,5,5]], threads_per_transform=[3,25], workgroup_size=375), - NS(length=[9,128], factors=[[3,3],[8,4,4]], threads_per_transform=[3,16], workgroup_size=384), - NS(length=[9,243], factors=[[3,3],[3,3,3,3,3]], threads_per_transform=[3,81], workgroup_size=729), - NS(length=[9,256], factors=[[3,3],[4,4,4,4]], threads_per_transform=[3,64], workgroup_size=768), - NS(length=[16,4], factors=[[4,4],[2,2]], threads_per_transform=[4,2], workgroup_size=32), - NS(length=[16,8], factors=[[4,4],[4,2]], threads_per_transform=[4,2], workgroup_size=32), - NS(length=[16,9], factors=[[4,4],[3,3]], threads_per_transform=[4,3], workgroup_size=48), - NS(length=[16,16], factors=[[4,4],[4,4]], threads_per_transform=[4,4], workgroup_size=64), - NS(length=[16,25], factors=[[4,4],[5,5]], threads_per_transform=[4,5], workgroup_size=100), - NS(length=[16,27], factors=[[4,4],[3,3,3]], threads_per_transform=[4,9], workgroup_size=144), - NS(length=[16,32], factors=[[4,4],[8,4]], threads_per_transform=[4,4], workgroup_size=128), - NS(length=[16,64], factors=[[4,4],[4,4,4]], threads_per_transform=[4,16], workgroup_size=256), - NS(length=[16,81], factors=[[4,4],[3,3,3,3]], threads_per_transform=[4,27], workgroup_size=432), - NS(length=[16,125], factors=[[4,4],[5,5,5]], threads_per_transform=[4,25], workgroup_size=500), - NS(length=[16,128], factors=[[4,4],[8,4,4]], threads_per_transform=[4,16], workgroup_size=512), - NS(length=[25,4], factors=[[5,5],[2,2]], threads_per_transform=[5,2], workgroup_size=50), - NS(length=[25,8], factors=[[5,5],[4,2]], threads_per_transform=[5,2], workgroup_size=50), - NS(length=[25,9], factors=[[5,5],[3,3]], threads_per_transform=[5,3], workgroup_size=75), - NS(length=[25,16], factors=[[5,5],[4,4]], threads_per_transform=[5,4], workgroup_size=100), - NS(length=[25,25], factors=[[5,5],[5,5]], threads_per_transform=[5,5], workgroup_size=125), - NS(length=[25,27], factors=[[5,5],[3,3,3]], threads_per_transform=[5,9], workgroup_size=225), - NS(length=[25,32], factors=[[5,5],[8,4]], threads_per_transform=[5,4], workgroup_size=160), - NS(length=[25,64], factors=[[5,5],[4,4,4]], threads_per_transform=[5,16], workgroup_size=400), - NS(length=[25,81], factors=[[5,5],[3,3,3,3]], threads_per_transform=[5,27], workgroup_size=675), - NS(length=[25,125], factors=[[5,5],[5,5,5]], threads_per_transform=[5,25], workgroup_size=625), - NS(length=[25,128], factors=[[5,5],[8,4,4]], threads_per_transform=[5,16], workgroup_size=640), - NS(length=[27,4], factors=[[3,3,3],[2,2]], threads_per_transform=[9,2], workgroup_size=54), - NS(length=[27,8], factors=[[3,3,3],[4,2]], threads_per_transform=[9,2], workgroup_size=72), - NS(length=[27,9], factors=[[3,3,3],[3,3]], threads_per_transform=[9,3], workgroup_size=81), - NS(length=[27,16], factors=[[3,3,3],[4,4]], threads_per_transform=[9,4], workgroup_size=144), - NS(length=[27,25], factors=[[3,3,3],[5,5]], threads_per_transform=[9,5], workgroup_size=225), - NS(length=[27,27], factors=[[3,3,3],[3,3,3]], threads_per_transform=[9,9], workgroup_size=243), - NS(length=[27,32], factors=[[3,3,3],[8,4]], threads_per_transform=[9,4], workgroup_size=288), - NS(length=[27,64], factors=[[3,3,3],[4,4,4]], threads_per_transform=[9,16], workgroup_size=576), - NS(length=[27,81], factors=[[3,3,3],[3,3,3,3]], threads_per_transform=[9,27], workgroup_size=729), - NS(length=[32,4], factors=[[8,4],[2,2]], threads_per_transform=[4,2], workgroup_size=64), - NS(length=[32,8], factors=[[8,4],[4,2]], threads_per_transform=[4,2], workgroup_size=64), - NS(length=[32,9], factors=[[8,4],[3,3]], threads_per_transform=[4,3], workgroup_size=96), - NS(length=[32,16], factors=[[8,4],[4,4]], threads_per_transform=[4,4], workgroup_size=128), - NS(length=[32,25], factors=[[8,4],[5,5]], threads_per_transform=[4,5], workgroup_size=160), - NS(length=[32,27], factors=[[8,4],[3,3,3]], threads_per_transform=[4,9], workgroup_size=288), - NS(length=[32,32], factors=[[8,4],[8,4]], threads_per_transform=[4,4], workgroup_size=128), - NS(length=[32,64], factors=[[8,4],[4,4,4]], threads_per_transform=[4,16], workgroup_size=512), - NS(length=[32,81], factors=[[8,4],[3,3,3,3]], threads_per_transform=[4,27], workgroup_size=864), - NS(length=[32,125], factors=[[8,4],[5,5,5]], threads_per_transform=[4,25], workgroup_size=800), - NS(length=[32,128], factors=[[8,4],[8,4,4]], threads_per_transform=[4,16], workgroup_size=512), - NS(length=[64,4], factors=[[4,4,4],[2,2]], threads_per_transform=[16,2], workgroup_size=128), - NS(length=[64,8], factors=[[4,4,4],[4,2]], threads_per_transform=[16,2], workgroup_size=128), - NS(length=[64,9], factors=[[4,4,4],[3,3]], threads_per_transform=[16,3], workgroup_size=192), - NS(length=[64,16], factors=[[4,4,4],[4,4]], threads_per_transform=[16,4], workgroup_size=256), - NS(length=[64,25], factors=[[4,4,4],[5,5]], threads_per_transform=[16,5], workgroup_size=400), - NS(length=[64,27], factors=[[4,4,4],[3,3,3]], threads_per_transform=[16,9], workgroup_size=576), - NS(length=[64,32], factors=[[4,4,4],[8,4]], threads_per_transform=[16,4], workgroup_size=512), - NS(length=[81,4], factors=[[3,3,3,3],[2,2]], threads_per_transform=[27,2], workgroup_size=162), - NS(length=[81,8], factors=[[3,3,3,3],[4,2]], threads_per_transform=[27,2], workgroup_size=216), - NS(length=[81,9], factors=[[3,3,3,3],[3,3]], threads_per_transform=[27,3], workgroup_size=243), - NS(length=[81,16], factors=[[3,3,3,3],[4,4]], threads_per_transform=[27,4], workgroup_size=432), - NS(length=[81,25], factors=[[3,3,3,3],[5,5]], threads_per_transform=[27,5], workgroup_size=675), - NS(length=[81,27], factors=[[3,3,3,3],[3,3,3]], threads_per_transform=[27,9], workgroup_size=729), - NS(length=[81,32], factors=[[3,3,3,3],[8,4]], threads_per_transform=[27,4], workgroup_size=864), - NS(length=[125,4], factors=[[5,5,5],[2,2]], threads_per_transform=[25,2], workgroup_size=250), - NS(length=[125,8], factors=[[5,5,5],[4,2]], threads_per_transform=[25,2], workgroup_size=250), - NS(length=[125,9], factors=[[5,5,5],[3,3]], threads_per_transform=[25,3], workgroup_size=375), - NS(length=[125,16], factors=[[5,5,5],[4,4]], threads_per_transform=[25,4], workgroup_size=500), - NS(length=[125,25], factors=[[5,5,5],[5,5]], threads_per_transform=[25,5], workgroup_size=625), - NS(length=[125,32], factors=[[5,5,5],[8,4]], threads_per_transform=[25,4], workgroup_size=800), - NS(length=[128,4], factors=[[8,4,4],[2,2]], threads_per_transform=[16,2], workgroup_size=256), - NS(length=[128,8], factors=[[8,4,4],[4,2]], threads_per_transform=[16,2], workgroup_size=256), - NS(length=[128,9], factors=[[8,4,4],[3,3]], threads_per_transform=[16,3], workgroup_size=384), - NS(length=[128,16], factors=[[8,4,4],[4,4]], threads_per_transform=[16,4], workgroup_size=512), - NS(length=[128,25], factors=[[8,4,4],[5,5]], threads_per_transform=[16,5], workgroup_size=640), - NS(length=[128,32], factors=[[8,4,4],[8,4]], threads_per_transform=[16,4], workgroup_size=512), - NS(length=[243,4], factors=[[3,3,3,3,3],[2,2]], threads_per_transform=[81,2], workgroup_size=486), - NS(length=[243,8], factors=[[3,3,3,3,3],[4,2]], threads_per_transform=[81,2], workgroup_size=648), - NS(length=[243,9], factors=[[3,3,3,3,3],[3,3]], threads_per_transform=[81,3], workgroup_size=729), - NS(length=[256,4], factors=[[4,4,4,4],[2,2]], threads_per_transform=[64,2], workgroup_size=512), - NS(length=[256,8], factors=[[4,4,4,4],[4,2]], threads_per_transform=[64,2], workgroup_size=512), - NS(length=[256,9], factors=[[4,4,4,4],[3,3]], threads_per_transform=[64,3], workgroup_size=768), - # ----- new for r2c/c2r - NS(length=[7,84], factors=[[7],[7,2,6]], threads_per_transform=[1,12], workgroup_size=84), - NS(length=[84,7], factors=[[7,2,6],[7]], threads_per_transform=[12,1], workgroup_size=84), - NS(length=[10,20], factors=[[10],[5,4]], threads_per_transform=[1,5], workgroup_size=50), - NS(length=[20,10], factors=[[5,4],[10]], threads_per_transform=[5,1], workgroup_size=50), - NS(length=[26,64], factors=[[13,2],[4,4,4]], threads_per_transform=[2,16], workgroup_size=416), - NS(length=[64,26], factors=[[4,4,4],[13,2]], threads_per_transform=[16,2], workgroup_size=416), - NS(length=[26,72], factors=[[13,2],[8,3,3]], threads_per_transform=[2,9], workgroup_size=234), - NS(length=[72,26], factors=[[8,3,3],[13,2]], threads_per_transform=[9,2], workgroup_size=234), - NS(length=[30,60], factors=[[10,3],[6,10]], threads_per_transform=[3,10], workgroup_size=300), - NS(length=[60,30], factors=[[6,10],[10,3]], threads_per_transform=[10,3], workgroup_size=300), - NS(length=[36,72], factors=[[6,6],[8,3,3]], threads_per_transform=[6,9], workgroup_size=432), - NS(length=[72,36], factors=[[8,3,3],[6,6]], threads_per_transform=[9,6], workgroup_size=432), - NS(length=[36,80], factors=[[6,6],[5,2,8]], threads_per_transform=[6,10], workgroup_size=480), - NS(length=[80,36], factors=[[5,2,8],[6,6]], threads_per_transform=[10,6], workgroup_size=480), - NS(length=[36,84], factors=[[6,6],[7,2,6]], threads_per_transform=[6,12], workgroup_size=504), - NS(length=[84,36], factors=[[7,2,6],[6,6]], threads_per_transform=[12,6], workgroup_size=504), - NS(length=[40,80], factors=[[10,4],[5,2,8]], threads_per_transform=[4,10], workgroup_size=400), - NS(length=[80,40], factors=[[5,2,8],[10,4]], threads_per_transform=[10,4], workgroup_size=400), - NS(length=[42,84], factors=[[7,6],[7,2,6]], threads_per_transform=[6,12], workgroup_size=504), - NS(length=[84,42], factors=[[7,2,6],[7,6]], threads_per_transform=[12,6], workgroup_size=504), - NS(length=[42,96], factors=[[7,6],[6,16]], threads_per_transform=[6,6], workgroup_size=576), - NS(length=[96,42], factors=[[6,16],[7,6]], threads_per_transform=[6,6], workgroup_size=576), - ] - - expanded = [] - expanded.extend(NS(**kernel.__dict__, - scheme='CS_KERNEL_2D_SINGLE', runtime_compile=True) for kernel in fused_kernels) - - return expanded - - def list_large_kernels(): """Return list of large kernels to generate.""" - # Note: Default direct_to_from_reg is True - sbcc_kernels = [ - NS(length=50, factors=[10, 5], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'true'}, workgroup_size=256), - NS(length=52, factors=[13, 4], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'true'}), - NS(length=60, factors=[6, 10], use_3steps_large_twd={ - 'sp': 'false', 'dp': 'false'}), - NS(length=64, factors=[8, 8], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'false'}, workgroup_size=256), - NS(length=72, factors=[8, 3, 3], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'false'}), - NS(length=80, factors=[10, 8], use_3steps_large_twd={ - 'sp': 'false', 'dp': 'false'}), - # 9,9 is good when direct-to-reg, but bad for Navi, so still uses radix-3 - NS(length=81, factors=[3, 3, 3, 3], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'true'}), - NS(length=84, factors=[7, 2, 6], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'true'}, threads_per_transform=14), - NS(length=96, factors=[8, 3, 4], use_3steps_large_twd={ - 'sp': 'false', 'dp': 'false'}, workgroup_size=256), - NS(length=100, factors=[5, 5, 4], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'false'}, workgroup_size=100, half_lds=True), - NS(length=104, factors=[13, 8], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'false'}), - NS(length=108, factors=[6, 6, 3], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'false'}), - NS(length=112, factors=[4, 7, 4], use_3steps_large_twd={ - 'sp': 'false', 'dp': 'false'}), - NS(length=121, factors=[11, 11], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'true'}, workgroup_size=128, runtime_compile=True), - NS(length=125, factors=[5, 5, 5], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'false'}), - NS(length=128, factors=[16, 8], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'true'}, workgroup_size=256, threads_per_transform= 16), - NS(length=160, factors=[4, 10, 4], use_3steps_large_twd={ - 'sp': 'false', 'dp': 'false'}, flavour='wide'), - NS(length=168, factors=[7, 6, 4], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'false'}, workgroup_size=128, half_lds=True), - NS(length=169, factors=[13, 13], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'false'}, workgroup_size=256, runtime_compile=True), - NS(length=192, factors=[8, 6, 4], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'true'}), - NS(length=200, factors=[5, 8, 5], use_3steps_large_twd={ - 'sp': 'false', 'dp': 'false'}), - NS(length=208, factors=[13, 16], use_3steps_large_twd={ - 'sp': 'false', 'dp': 'false'}), - NS(length=216, factors=(6, 6, 6), use_3steps_large_twd={ - 'sp': 'false', 'dp': 'false'}, threads_per_transform=36), - NS(length=224, factors=[8, 7, 4], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'false'}), - NS(length=240, factors=[8, 5, 6], use_3steps_large_twd={ - 'sp': 'false', 'dp': 'false'}), - # 9,9,3 isn't better on all archs, some are much better, some get regressions - NS(length=243, factors=[3, 3, 3, 3, 3], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'false'}, workgroup_size=243), - NS(length=256, factors=[8, 4, 8], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'false'}, flavour='wide'), - NS(length=280, factors=[8, 5, 7], use_3steps_large_twd={ - 'sp': 'false', 'dp': 'false'}, runtime_compile=True), - NS(length=289, factors=[17, 17], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'true'}, runtime_compile=True), - NS(length=336, factors=[6, 7, 8], use_3steps_large_twd={ - 'sp': 'false', 'dp': 'false'}), - NS(length=343, factors=[7, 7, 7], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'true'}), - NS(length=512, factors=[8, 8, 8], use_3steps_large_twd={ - 'sp': 'true', 'dp': 'false'}), - ] - # for SBCC kernel, increase desired workgroup_size so that columns per # thread block is also increased. currently targeting for 16 columns block_width = 16 + sbcc_kernels = config_sbcc.sbcc_kernels for k in sbcc_kernels: k.scheme = 'CS_KERNEL_STOCKHAM_BLOCK_CC' if not hasattr(k, 'workgroup_size'): @@ -973,57 +274,10 @@ def list_large_kernels(): k.workgroup_size = min(1024, k.workgroup_size * 2) if not hasattr(k, 'length'): k.length = functools.reduce(lambda a, b: a * b, k.factors) - - # for SBRC, if direct_to_from_reg is True, we do store-from-reg, but will not do load-to-reg - # And since SBRC is is dir-from-lds but NOT dir-to-reg, the global load part requires full LDS - # So, SBRC is able to use half-lds. - sbrc_kernels = [ - NS(length=17, factors=[17], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=256, threads_per_transform=1, runtime_compile=True), - NS(length=49, factors=[7, 7], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=196, threads_per_transform=7), # block_width=28 - NS(length=50, factors=[10, 5], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=50, threads_per_transform=5, direct_to_from_reg=False), # block_width=10 - # SBRC64: wgs=256 poor in MI50 - NS(length=64, factors=[4, 4, 4], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=128, threads_per_transform=16), # block_width=8 - # 9,9 not good by experiments - NS(length=81, factors=[3, 3, 3, 3], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=243, threads_per_transform=27), # block_width=9 - NS(length=100, factors=[5, 5, 4], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=100, threads_per_transform=25), # block_width=4 - NS(length=112, factors=[4, 7, 4], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=448, threads_per_transform=28), # block_width=16 - NS(length=121, factors=[11, 11], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=128, threads_per_transform=11, runtime_compile=True), - NS(length=125, factors=[5, 5, 5], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=250, threads_per_transform=25), # block_width=10 - NS(length=128, factors=[8, 4, 4], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=128, threads_per_transform=16), # block_width=8 - NS(length=169, factors=[13, 13], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=256, threads_per_transform=13, runtime_compile=True), - NS(length=192, factors=[6, 4, 4, 2], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=256, threads_per_transform=32), # block_width=8 - NS(length=200, factors=[8, 5, 5], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=400, threads_per_transform=40), # block_width=10 - NS(length=243, factors=[3, 3, 3, 3, 3], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=256, threads_per_transform=27, runtime_compile=True), # block_width=10 - NS(length=256, factors=[4, 4, 4, 4], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=256, threads_per_transform=32), # block_width=8 - NS(length=289, factors=[17, 17], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=128, threads_per_transform=17, runtime_compile=True), - NS(length=343, factors=[7, 7, 7], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=256, threads_per_transform=49, runtime_compile=True), - NS(length=512, factors=[8, 8, 8], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=512, threads_per_transform=128), - NS(length=625, factors=[5, 5, 5, 5], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=128, threads_per_transform=125, runtime_compile=True), - NS(length=1331, factors=[11, 11, 11], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=256, threads_per_transform=121, runtime_compile=True), - ] - - for k in sbrc_kernels: - k.half_lds = False - - # NB: - # Technically, we could have SBCR kernels the same amount as SBCC. - # - # sbcr_kernels = copy.deepcopy(sbcc_kernels) - # for k in sbcr_kernels: - # k.scheme = 'CS_KERNEL_STOCKHAM_BLOCK_CR' - # - - # for SBCR, if direct_to_from_reg is True, we do load-to-reg, but will not do store-from-reg - # And since sbcr is dir-to-reg BUT NOT dir-from-reg, the global store part requires full LDS - # So, we can't satifly half_lds in SBCR ! - sbcr_kernels = [ - NS(length=56, factors=[7, 8], direct_to_from_reg=False), - NS(length=100, factors=[10, 10], workgroup_size=100), - NS(length=200, factors=[8, 5, 5]), - NS(length=336, factors=[6, 7, 8]) - ] - + + block_width = 16 + sbcr_kernels = config_sbcr.sbcr_kernels for k in sbcr_kernels: k.scheme = 'CS_KERNEL_STOCKHAM_BLOCK_CR' k.half_lds = False @@ -1033,9 +287,24 @@ def list_large_kernels(): if not hasattr(k, 'length'): k.length = functools.reduce(lambda a, b: a * b, k.factors) - return sbcc_kernels + sbcr_kernels + sbrc_kernels -# yapf: enable + sbrc_kernels = config_sbrc.sbrc_kernels + for k in sbrc_kernels: + k.half_lds = False + + + return config_sbcc.sbcc_kernels + config_sbcr.sbcr_kernels + config_sbrc.sbrc_kernels + +def list_2d_kernels(): + """Return list of fused 2D kernels to generate.""" + + fused_2d_kernels = config_2d_single.fused_2d_kernels + + expanded = [] + expanded.extend(NS(**kernel.__dict__, + scheme='CS_KERNEL_2D_SINGLE', runtime_compile=True) for kernel in fused_2d_kernels) + + return expanded def default_runtime_compile(kernels, default_val): '''Returns a copy of input kernel list with a default value for runtime_compile.''' diff --git a/library/src/device/kernels/configs/config_2d_single.py b/library/src/device/kernels/configs/config_2d_single.py new file mode 100644 index 00000000000..661e0f93912 --- /dev/null +++ b/library/src/device/kernels/configs/config_2d_single.py @@ -0,0 +1,162 @@ +# Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from kernels.configs import config_lds +from types import SimpleNamespace as NS + +# yapf: disable +fused_2d_kernels = [ + NS(length=[4,4], factors=[[2,2],[2,2]], threads_per_transform=[2,2], workgroup_size=8), + NS(length=[4,8], factors=[[2,2],[4,2]], threads_per_transform=[2,2], workgroup_size=16), + NS(length=[4,9], factors=[[2,2],[3,3]], threads_per_transform=[2,3], workgroup_size=18), + NS(length=[4,16], factors=[[2,2],[4,4]], threads_per_transform=[2,4], workgroup_size=32), + NS(length=[4,25], factors=[[2,2],[5,5]], threads_per_transform=[2,5], workgroup_size=50), + NS(length=[4,27], factors=[[2,2],[3,3,3]], threads_per_transform=[2,9], workgroup_size=54), + NS(length=[4,32], factors=[[2,2],[8,4]], threads_per_transform=[2,4], workgroup_size=64), + NS(length=[4,64], factors=[[2,2],[4,4,4]], threads_per_transform=[2,16], workgroup_size=128), + NS(length=[4,81], factors=[[2,2],[3,3,3,3]], threads_per_transform=[2,27], workgroup_size=162), + NS(length=[4,125], factors=[[2,2],[5,5,5]], threads_per_transform=[2,25], workgroup_size=250), + NS(length=[4,128], factors=[[2,2],[8,4,4]], threads_per_transform=[2,16], workgroup_size=256), + NS(length=[4,243], factors=[[2,2],[3,3,3,3,3]], threads_per_transform=[2,81], workgroup_size=486), + NS(length=[4,256], factors=[[2,2],[4,4,4,4]], threads_per_transform=[2,64], workgroup_size=512), + NS(length=[8,4], factors=[[4,2],[2,2]], threads_per_transform=[2,2], workgroup_size=16), + NS(length=[8,8], factors=[[4,2],[4,2]], threads_per_transform=[2,2], workgroup_size=16), + NS(length=[8,9], factors=[[4,2],[3,3]], threads_per_transform=[2,3], workgroup_size=24), + NS(length=[8,16], factors=[[4,2],[4,4]], threads_per_transform=[2,4], workgroup_size=32), + NS(length=[8,25], factors=[[4,2],[5,5]], threads_per_transform=[2,5], workgroup_size=50), + NS(length=[8,27], factors=[[4,2],[3,3,3]], threads_per_transform=[2,9], workgroup_size=72), + NS(length=[8,32], factors=[[4,2],[8,4]], threads_per_transform=[2,4], workgroup_size=64), + NS(length=[8,64], factors=[[4,2],[4,4,4]], threads_per_transform=[2,16], workgroup_size=128), + NS(length=[8,81], factors=[[4,2],[3,3,3,3]], threads_per_transform=[2,27], workgroup_size=216), + NS(length=[8,125], factors=[[4,2],[5,5,5]], threads_per_transform=[2,25], workgroup_size=250), + NS(length=[8,128], factors=[[4,2],[8,4,4]], threads_per_transform=[2,16], workgroup_size=256), + NS(length=[8,243], factors=[[4,2],[3,3,3,3,3]], threads_per_transform=[2,81], workgroup_size=648), + NS(length=[8,256], factors=[[4,2],[4,4,4,4]], threads_per_transform=[2,64], workgroup_size=512), + NS(length=[9,4], factors=[[3,3],[2,2]], threads_per_transform=[3,2], workgroup_size=18), + NS(length=[9,8], factors=[[3,3],[4,2]], threads_per_transform=[3,2], workgroup_size=24), + NS(length=[9,9], factors=[[3,3],[3,3]], threads_per_transform=[3,3], workgroup_size=27), + NS(length=[9,16], factors=[[3,3],[4,4]], threads_per_transform=[3,4], workgroup_size=48), + NS(length=[9,25], factors=[[3,3],[5,5]], threads_per_transform=[3,5], workgroup_size=75), + NS(length=[9,27], factors=[[3,3],[3,3,3]], threads_per_transform=[3,9], workgroup_size=81), + NS(length=[9,32], factors=[[3,3],[8,4]], threads_per_transform=[3,4], workgroup_size=96), + NS(length=[9,64], factors=[[3,3],[4,4,4]], threads_per_transform=[3,16], workgroup_size=192), + NS(length=[9,81], factors=[[3,3],[3,3,3,3]], threads_per_transform=[3,27], workgroup_size=243), + NS(length=[9,125], factors=[[3,3],[5,5,5]], threads_per_transform=[3,25], workgroup_size=375), + NS(length=[9,128], factors=[[3,3],[8,4,4]], threads_per_transform=[3,16], workgroup_size=384), + NS(length=[9,243], factors=[[3,3],[3,3,3,3,3]], threads_per_transform=[3,81], workgroup_size=729), + NS(length=[9,256], factors=[[3,3],[4,4,4,4]], threads_per_transform=[3,64], workgroup_size=768), + NS(length=[16,4], factors=[[4,4],[2,2]], threads_per_transform=[4,2], workgroup_size=32), + NS(length=[16,8], factors=[[4,4],[4,2]], threads_per_transform=[4,2], workgroup_size=32), + NS(length=[16,9], factors=[[4,4],[3,3]], threads_per_transform=[4,3], workgroup_size=48), + NS(length=[16,16], factors=[[4,4],[4,4]], threads_per_transform=[4,4], workgroup_size=64), + NS(length=[16,25], factors=[[4,4],[5,5]], threads_per_transform=[4,5], workgroup_size=100), + NS(length=[16,27], factors=[[4,4],[3,3,3]], threads_per_transform=[4,9], workgroup_size=144), + NS(length=[16,32], factors=[[4,4],[8,4]], threads_per_transform=[4,4], workgroup_size=128), + NS(length=[16,64], factors=[[4,4],[4,4,4]], threads_per_transform=[4,16], workgroup_size=256), + NS(length=[16,81], factors=[[4,4],[3,3,3,3]], threads_per_transform=[4,27], workgroup_size=432), + NS(length=[16,125], factors=[[4,4],[5,5,5]], threads_per_transform=[4,25], workgroup_size=500), + NS(length=[16,128], factors=[[4,4],[8,4,4]], threads_per_transform=[4,16], workgroup_size=512), + NS(length=[25,4], factors=[[5,5],[2,2]], threads_per_transform=[5,2], workgroup_size=50), + NS(length=[25,8], factors=[[5,5],[4,2]], threads_per_transform=[5,2], workgroup_size=50), + NS(length=[25,9], factors=[[5,5],[3,3]], threads_per_transform=[5,3], workgroup_size=75), + NS(length=[25,16], factors=[[5,5],[4,4]], threads_per_transform=[5,4], workgroup_size=100), + NS(length=[25,25], factors=[[5,5],[5,5]], threads_per_transform=[5,5], workgroup_size=125), + NS(length=[25,27], factors=[[5,5],[3,3,3]], threads_per_transform=[5,9], workgroup_size=225), + NS(length=[25,32], factors=[[5,5],[8,4]], threads_per_transform=[5,4], workgroup_size=160), + NS(length=[25,64], factors=[[5,5],[4,4,4]], threads_per_transform=[5,16], workgroup_size=400), + NS(length=[25,81], factors=[[5,5],[3,3,3,3]], threads_per_transform=[5,27], workgroup_size=675), + NS(length=[25,125], factors=[[5,5],[5,5,5]], threads_per_transform=[5,25], workgroup_size=625), + NS(length=[25,128], factors=[[5,5],[8,4,4]], threads_per_transform=[5,16], workgroup_size=640), + NS(length=[27,4], factors=[[3,3,3],[2,2]], threads_per_transform=[9,2], workgroup_size=54), + NS(length=[27,8], factors=[[3,3,3],[4,2]], threads_per_transform=[9,2], workgroup_size=72), + NS(length=[27,9], factors=[[3,3,3],[3,3]], threads_per_transform=[9,3], workgroup_size=81), + NS(length=[27,16], factors=[[3,3,3],[4,4]], threads_per_transform=[9,4], workgroup_size=144), + NS(length=[27,25], factors=[[3,3,3],[5,5]], threads_per_transform=[9,5], workgroup_size=225), + NS(length=[27,27], factors=[[3,3,3],[3,3,3]], threads_per_transform=[9,9], workgroup_size=243), + NS(length=[27,32], factors=[[3,3,3],[8,4]], threads_per_transform=[9,4], workgroup_size=288), + NS(length=[27,64], factors=[[3,3,3],[4,4,4]], threads_per_transform=[9,16], workgroup_size=576), + NS(length=[27,81], factors=[[3,3,3],[3,3,3,3]], threads_per_transform=[9,27], workgroup_size=729), + NS(length=[32,4], factors=[[8,4],[2,2]], threads_per_transform=[4,2], workgroup_size=64), + NS(length=[32,8], factors=[[8,4],[4,2]], threads_per_transform=[4,2], workgroup_size=64), + NS(length=[32,9], factors=[[8,4],[3,3]], threads_per_transform=[4,3], workgroup_size=96), + NS(length=[32,16], factors=[[8,4],[4,4]], threads_per_transform=[4,4], workgroup_size=128), + NS(length=[32,25], factors=[[8,4],[5,5]], threads_per_transform=[4,5], workgroup_size=160), + NS(length=[32,27], factors=[[8,4],[3,3,3]], threads_per_transform=[4,9], workgroup_size=288), + NS(length=[32,32], factors=[[8,4],[8,4]], threads_per_transform=[4,4], workgroup_size=128), + NS(length=[32,64], factors=[[8,4],[4,4,4]], threads_per_transform=[4,16], workgroup_size=512), + NS(length=[32,81], factors=[[8,4],[3,3,3,3]], threads_per_transform=[4,27], workgroup_size=864), + NS(length=[32,125], factors=[[8,4],[5,5,5]], threads_per_transform=[4,25], workgroup_size=800), + NS(length=[32,128], factors=[[8,4],[8,4,4]], threads_per_transform=[4,16], workgroup_size=512), + NS(length=[64,4], factors=[[4,4,4],[2,2]], threads_per_transform=[16,2], workgroup_size=128), + NS(length=[64,8], factors=[[4,4,4],[4,2]], threads_per_transform=[16,2], workgroup_size=128), + NS(length=[64,9], factors=[[4,4,4],[3,3]], threads_per_transform=[16,3], workgroup_size=192), + NS(length=[64,16], factors=[[4,4,4],[4,4]], threads_per_transform=[16,4], workgroup_size=256), + NS(length=[64,25], factors=[[4,4,4],[5,5]], threads_per_transform=[16,5], workgroup_size=400), + NS(length=[64,27], factors=[[4,4,4],[3,3,3]], threads_per_transform=[16,9], workgroup_size=576), + NS(length=[64,32], factors=[[4,4,4],[8,4]], threads_per_transform=[16,4], workgroup_size=512), + NS(length=[81,4], factors=[[3,3,3,3],[2,2]], threads_per_transform=[27,2], workgroup_size=162), + NS(length=[81,8], factors=[[3,3,3,3],[4,2]], threads_per_transform=[27,2], workgroup_size=216), + NS(length=[81,9], factors=[[3,3,3,3],[3,3]], threads_per_transform=[27,3], workgroup_size=243), + NS(length=[81,16], factors=[[3,3,3,3],[4,4]], threads_per_transform=[27,4], workgroup_size=432), + NS(length=[81,25], factors=[[3,3,3,3],[5,5]], threads_per_transform=[27,5], workgroup_size=675), + NS(length=[81,27], factors=[[3,3,3,3],[3,3,3]], threads_per_transform=[27,9], workgroup_size=729), + NS(length=[81,32], factors=[[3,3,3,3],[8,4]], threads_per_transform=[27,4], workgroup_size=864), + NS(length=[125,4], factors=[[5,5,5],[2,2]], threads_per_transform=[25,2], workgroup_size=250), + NS(length=[125,8], factors=[[5,5,5],[4,2]], threads_per_transform=[25,2], workgroup_size=250), + NS(length=[125,9], factors=[[5,5,5],[3,3]], threads_per_transform=[25,3], workgroup_size=375), + NS(length=[125,16], factors=[[5,5,5],[4,4]], threads_per_transform=[25,4], workgroup_size=500), + NS(length=[125,25], factors=[[5,5,5],[5,5]], threads_per_transform=[25,5], workgroup_size=625), + NS(length=[125,32], factors=[[5,5,5],[8,4]], threads_per_transform=[25,4], workgroup_size=800), + NS(length=[128,4], factors=[[8,4,4],[2,2]], threads_per_transform=[16,2], workgroup_size=256), + NS(length=[128,8], factors=[[8,4,4],[4,2]], threads_per_transform=[16,2], workgroup_size=256), + NS(length=[128,9], factors=[[8,4,4],[3,3]], threads_per_transform=[16,3], workgroup_size=384), + NS(length=[128,16], factors=[[8,4,4],[4,4]], threads_per_transform=[16,4], workgroup_size=512), + NS(length=[128,25], factors=[[8,4,4],[5,5]], threads_per_transform=[16,5], workgroup_size=640), + NS(length=[128,32], factors=[[8,4,4],[8,4]], threads_per_transform=[16,4], workgroup_size=512), + NS(length=[243,4], factors=[[3,3,3,3,3],[2,2]], threads_per_transform=[81,2], workgroup_size=486), + NS(length=[243,8], factors=[[3,3,3,3,3],[4,2]], threads_per_transform=[81,2], workgroup_size=648), + NS(length=[243,9], factors=[[3,3,3,3,3],[3,3]], threads_per_transform=[81,3], workgroup_size=729), + NS(length=[256,4], factors=[[4,4,4,4],[2,2]], threads_per_transform=[64,2], workgroup_size=512), + NS(length=[256,8], factors=[[4,4,4,4],[4,2]], threads_per_transform=[64,2], workgroup_size=512), + NS(length=[256,9], factors=[[4,4,4,4],[3,3]], threads_per_transform=[64,3], workgroup_size=768), + # ----- new for r2c/c2r + NS(length=[7,84], factors=[[7],[7,2,6]], threads_per_transform=[1,12], workgroup_size=84), + NS(length=[84,7], factors=[[7,2,6],[7]], threads_per_transform=[12,1], workgroup_size=84), + NS(length=[10,20], factors=[[10],[5,4]], threads_per_transform=[1,5], workgroup_size=50), + NS(length=[20,10], factors=[[5,4],[10]], threads_per_transform=[5,1], workgroup_size=50), + NS(length=[26,64], factors=[[13,2],[4,4,4]], threads_per_transform=[2,16], workgroup_size=416), + NS(length=[64,26], factors=[[4,4,4],[13,2]], threads_per_transform=[16,2], workgroup_size=416), + NS(length=[26,72], factors=[[13,2],[8,3,3]], threads_per_transform=[2,9], workgroup_size=234), + NS(length=[72,26], factors=[[8,3,3],[13,2]], threads_per_transform=[9,2], workgroup_size=234), + NS(length=[30,60], factors=[[10,3],[6,10]], threads_per_transform=[3,10], workgroup_size=300), + NS(length=[60,30], factors=[[6,10],[10,3]], threads_per_transform=[10,3], workgroup_size=300), + NS(length=[36,72], factors=[[6,6],[8,3,3]], threads_per_transform=[6,9], workgroup_size=432), + NS(length=[72,36], factors=[[8,3,3],[6,6]], threads_per_transform=[9,6], workgroup_size=432), + NS(length=[36,80], factors=[[6,6],[5,2,8]], threads_per_transform=[6,10], workgroup_size=480), + NS(length=[80,36], factors=[[5,2,8],[6,6]], threads_per_transform=[10,6], workgroup_size=480), + NS(length=[36,84], factors=[[6,6],[7,2,6]], threads_per_transform=[6,12], workgroup_size=504), + NS(length=[84,36], factors=[[7,2,6],[6,6]], threads_per_transform=[12,6], workgroup_size=504), + NS(length=[40,80], factors=[[10,4],[5,2,8]], threads_per_transform=[4,10], workgroup_size=400), + NS(length=[80,40], factors=[[5,2,8],[10,4]], threads_per_transform=[10,4], workgroup_size=400), + NS(length=[42,84], factors=[[7,6],[7,2,6]], threads_per_transform=[6,12], workgroup_size=504), + NS(length=[84,42], factors=[[7,2,6],[7,6]], threads_per_transform=[12,6], workgroup_size=504), + NS(length=[42,96], factors=[[7,6],[6,16]], threads_per_transform=[6,6], workgroup_size=576), + NS(length=[96,42], factors=[[6,16],[7,6]], threads_per_transform=[6,6], workgroup_size=576), +] \ No newline at end of file diff --git a/library/src/device/kernels/configs/config_lds.py b/library/src/device/kernels/configs/config_lds.py new file mode 100644 index 00000000000..559c918bc24 --- /dev/null +++ b/library/src/device/kernels/configs/config_lds.py @@ -0,0 +1,21 @@ +# Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +LDS_160k = 160 * 1024 \ No newline at end of file diff --git a/library/src/device/kernels/configs/config_pp_3d.py b/library/src/device/kernels/configs/config_pp_3d.py new file mode 100644 index 00000000000..17f50d21077 --- /dev/null +++ b/library/src/device/kernels/configs/config_pp_3d.py @@ -0,0 +1,25 @@ +# Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from kernels.configs import config_lds +from types import SimpleNamespace as NS + +# yapf: disable + diff --git a/library/src/device/kernels/configs/config_sbcc.py b/library/src/device/kernels/configs/config_sbcc.py new file mode 100644 index 00000000000..b205e93906a --- /dev/null +++ b/library/src/device/kernels/configs/config_sbcc.py @@ -0,0 +1,94 @@ +# Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from kernels.configs import config_lds +from types import SimpleNamespace as NS + +# Note: Default direct_to_from_reg is True + +# yapf: disable +sbcc_kernels = [ + NS(length=50, factors=[10, 5], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'true'}, workgroup_size=256), + NS(length=52, factors=[13, 4], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'true'}), + NS(length=60, factors=[6, 10], use_3steps_large_twd={ + 'sp': 'false', 'dp': 'false'}), + NS(length=64, factors=[8, 8], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'false'}, workgroup_size=256), + NS(length=72, factors=[8, 3, 3], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'false'}), + NS(length=80, factors=[10, 8], use_3steps_large_twd={ + 'sp': 'false', 'dp': 'false'}), + # 9,9 is good when direct-to-reg, but bad for Navi, so still uses radix-3 + NS(length=81, factors=[3, 3, 3, 3], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'true'}), + NS(length=84, factors=[7, 2, 6], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'true'}, threads_per_transform=14), + NS(length=96, factors=[8, 3, 4], use_3steps_large_twd={ + 'sp': 'false', 'dp': 'false'}, workgroup_size=256), + NS(length=100, factors=[5, 5, 4], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'false'}, workgroup_size=100, half_lds=True), + NS(length=104, factors=[13, 8], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'false'}), + NS(length=108, factors=[6, 6, 3], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'false'}), + NS(length=112, factors=[4, 7, 4], use_3steps_large_twd={ + 'sp': 'false', 'dp': 'false'}), + NS(length=121, factors=[11, 11], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'true'}, workgroup_size=128, runtime_compile=True), + NS(length=125, factors=[5, 5, 5], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'false'}), + NS(length=128, factors=[16, 8], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'true'}, workgroup_size=256, threads_per_transform= 16), + NS(length=160, factors=[4, 10, 4], use_3steps_large_twd={ + 'sp': 'false', 'dp': 'false'}, flavour='wide'), + NS(length=168, factors=[7, 6, 4], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'false'}, workgroup_size=128, half_lds=True), + NS(length=169, factors=[13, 13], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'false'}, workgroup_size=256, runtime_compile=True), + NS(length=192, factors=[8, 6, 4], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'true'}), + NS(length=200, factors=[5, 8, 5], use_3steps_large_twd={ + 'sp': 'false', 'dp': 'false'}), + NS(length=208, factors=[13, 16], use_3steps_large_twd={ + 'sp': 'false', 'dp': 'false'}), + NS(length=216, factors=(6, 6, 6), use_3steps_large_twd={ + 'sp': 'false', 'dp': 'false'}, threads_per_transform=36), + NS(length=224, factors=[8, 7, 4], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'false'}), + NS(length=240, factors=[8, 5, 6], use_3steps_large_twd={ + 'sp': 'false', 'dp': 'false'}), + # 9,9,3 isn't better on all archs, some are much better, some get regressions + NS(length=243, factors=[3, 3, 3, 3, 3], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'false'}, workgroup_size=243), + NS(length=256, factors=[8, 4, 8], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'false'}, flavour='wide'), + NS(length=280, factors=[8, 5, 7], use_3steps_large_twd={ + 'sp': 'false', 'dp': 'false'}, runtime_compile=True), + NS(length=289, factors=[17, 17], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'true'}, runtime_compile=True), + NS(length=336, factors=[6, 7, 8], use_3steps_large_twd={ + 'sp': 'false', 'dp': 'false'}), + NS(length=343, factors=[7, 7, 7], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'true'}), + NS(length=512, factors=[8, 8, 8], use_3steps_large_twd={ + 'sp': 'true', 'dp': 'false'}), +] \ No newline at end of file diff --git a/library/src/device/kernels/configs/config_sbcr.py b/library/src/device/kernels/configs/config_sbcr.py new file mode 100644 index 00000000000..1198ce44116 --- /dev/null +++ b/library/src/device/kernels/configs/config_sbcr.py @@ -0,0 +1,42 @@ +# Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from kernels.configs import config_lds +from types import SimpleNamespace as NS + +# NB: +# Technically, we could have SBCR kernels the same amount as SBCC. +# +# sbcr_kernels = copy.deepcopy(sbcc_kernels) +# for k in sbcr_kernels: +# k.scheme = 'CS_KERNEL_STOCKHAM_BLOCK_CR' +# + +# for SBCR, if direct_to_from_reg is True, we do load-to-reg, but will not do store-from-reg +# And since sbcr is dir-to-reg BUT NOT dir-from-reg, the global store part requires full LDS +# So, we can't satifly half_lds in SBCR ! + +# yapf: disable +sbcr_kernels = [ + NS(length=56, factors=[7, 8], direct_to_from_reg=False), + NS(length=100, factors=[10, 10], workgroup_size=100), + NS(length=200, factors=[8, 5, 5]), + NS(length=336, factors=[6, 7, 8]) +] \ No newline at end of file diff --git a/library/src/device/kernels/configs/config_sbrc.py b/library/src/device/kernels/configs/config_sbrc.py new file mode 100644 index 00000000000..4a382c32224 --- /dev/null +++ b/library/src/device/kernels/configs/config_sbrc.py @@ -0,0 +1,53 @@ +# Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from kernels.configs import config_lds +from types import SimpleNamespace as NS + +# yapf: disable +# for SBRC, if direct_to_from_reg is True, we do store-from-reg, but will not do load-to-reg +# And since SBRC is is dir-from-lds but NOT dir-to-reg, the global load part requires full LDS +# So, SBRC is able to use half-lds. + +# yapf: disable +sbrc_kernels = [ + NS(length=17, factors=[17], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=256, threads_per_transform=1, runtime_compile=True), + NS(length=49, factors=[7, 7], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=196, threads_per_transform=7), # block_width=28 + NS(length=50, factors=[10, 5], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=50, threads_per_transform=5, direct_to_from_reg=False), # block_width=10 + # SBRC64: wgs=256 poor in MI50 + NS(length=64, factors=[4, 4, 4], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=128, threads_per_transform=16), # block_width=8 + # 9,9 not good by experiments + NS(length=81, factors=[3, 3, 3, 3], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=243, threads_per_transform=27), # block_width=9 + NS(length=100, factors=[5, 5, 4], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=100, threads_per_transform=25), # block_width=4 + NS(length=112, factors=[4, 7, 4], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=448, threads_per_transform=28), # block_width=16 + NS(length=121, factors=[11, 11], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=128, threads_per_transform=11, runtime_compile=True), + NS(length=125, factors=[5, 5, 5], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=250, threads_per_transform=25), # block_width=10 + NS(length=128, factors=[8, 4, 4], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=128, threads_per_transform=16), # block_width=8 + NS(length=169, factors=[13, 13], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=256, threads_per_transform=13, runtime_compile=True), + NS(length=192, factors=[6, 4, 4, 2], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=256, threads_per_transform=32), # block_width=8 + NS(length=200, factors=[8, 5, 5], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=400, threads_per_transform=40), # block_width=10 + NS(length=243, factors=[3, 3, 3, 3, 3], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=256, threads_per_transform=27, runtime_compile=True), # block_width=10 + NS(length=256, factors=[4, 4, 4, 4], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=256, threads_per_transform=32), # block_width=8 + NS(length=289, factors=[17, 17], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=128, threads_per_transform=17, runtime_compile=True), + NS(length=343, factors=[7, 7, 7], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=256, threads_per_transform=49, runtime_compile=True), + NS(length=512, factors=[8, 8, 8], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=512, threads_per_transform=128), + NS(length=625, factors=[5, 5, 5, 5], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=128, threads_per_transform=125, runtime_compile=True), + NS(length=1331, factors=[11, 11, 11], scheme='CS_KERNEL_STOCKHAM_BLOCK_RC', workgroup_size=256, threads_per_transform=121, runtime_compile=True), +] diff --git a/library/src/device/kernels/configs/config_sbrr.py b/library/src/device/kernels/configs/config_sbrr.py new file mode 100644 index 00000000000..9f5d83f4ed6 --- /dev/null +++ b/library/src/device/kernels/configs/config_sbrr.py @@ -0,0 +1,509 @@ +# Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from kernels.configs import config_lds +from types import SimpleNamespace as NS + +# yapf: disable +# Note: Default half_lds is True and default direct_to_from_reg is True as well. +# TODO- Currently, if half_lds is True, then direct_to_from_reg must be True +# but if half_lds is False, direct_to_from_reg can be either (still can be True). + +# yapf: disable +sbrr_kernels = [ + NS(length= 1, workgroup_size= 64, threads_per_transform= 1, factors=(1,), runtime_compile=True), + NS(length= 2, workgroup_size= 64, threads_per_transform= 1, factors=(2,), runtime_compile=True), + NS(length= 3, workgroup_size= 64, threads_per_transform= 1, factors=(3,), runtime_compile=True), + NS(length= 4, workgroup_size=128, threads_per_transform= 1, factors=(4,), runtime_compile=True), + NS(length= 5, workgroup_size=128, threads_per_transform= 1, factors=(5,), runtime_compile=True), + NS(length= 6, workgroup_size=128, threads_per_transform= 1, factors=(6,), runtime_compile=True), + NS(length= 7, workgroup_size= 64, threads_per_transform= 1, factors=(7,), runtime_compile=True), + NS(length= 8, workgroup_size= 64, threads_per_transform= 4, factors=(4, 2), runtime_compile=True), + NS(length= 9, workgroup_size= 64, threads_per_transform= 3, factors=(3, 3), runtime_compile=True), + NS(length= 10, workgroup_size= 64, threads_per_transform= 1, factors=(10,), runtime_compile=True), + NS(length= 11, workgroup_size=128, threads_per_transform= 1, factors=(11,), runtime_compile=True), + NS(length= 12, workgroup_size=128, threads_per_transform= 6, factors=(6, 2), runtime_compile=True), + NS(length= 13, workgroup_size= 64, threads_per_transform= 1, factors=(13,), runtime_compile=True), + NS(length= 14, workgroup_size=128, threads_per_transform= 7, factors=(7, 2), runtime_compile=True), + NS(length= 15, workgroup_size=128, threads_per_transform= 5, factors=(3, 5), runtime_compile=True), + NS(length= 16, workgroup_size= 64, threads_per_transform= 4, factors=(4, 4), runtime_compile=True), + NS(length= 17, workgroup_size=256, threads_per_transform= 1, factors=(17,), runtime_compile=True), + NS(length= 18, workgroup_size= 64, threads_per_transform= 6, factors=(3, 6), runtime_compile=True), + NS(length= 20, workgroup_size=256, threads_per_transform= 10, factors=(5, 4), runtime_compile=True), + NS(length= 21, workgroup_size=128, threads_per_transform= 7, factors=(3, 7), runtime_compile=True), + NS(length= 22, workgroup_size= 64, threads_per_transform= 2, factors=(11, 2), runtime_compile=True), + NS(length= 24, workgroup_size=256, threads_per_transform= 8, factors=(8, 3), runtime_compile=True), + NS(length= 25, workgroup_size=256, threads_per_transform= 5, factors=(5, 5), runtime_compile=True), + NS(length= 26, workgroup_size= 64, threads_per_transform= 2, factors=(13, 2), runtime_compile=True), + NS(length= 27, workgroup_size=256, threads_per_transform= 9, factors=(3, 3, 3), runtime_compile=True), + NS(length= 28, workgroup_size= 64, threads_per_transform= 4, factors=(7, 4), runtime_compile=True), + NS(length= 30, workgroup_size=128, threads_per_transform= 10, factors=(10, 3), runtime_compile=True), + NS(length= 32, workgroup_size=128, threads_per_transform= 16, factors=(8, 4)), + NS(length= 33, workgroup_size=256, threads_per_transform= 11, factors=(11, 3), runtime_compile=True), + NS(length= 34, workgroup_size=256, threads_per_transform= 17, factors=(17, 2), runtime_compile=True), + NS(length= 35, workgroup_size=256, threads_per_transform= 7, factors=(5, 7), half_lds=False, runtime_compile=True), + NS(length= 36, workgroup_size= 64, threads_per_transform= 6, factors=(6, 6)), + NS(length= 39, workgroup_size=256, threads_per_transform= 13, factors=(13, 3), runtime_compile=True), + NS(length= 40, workgroup_size=128, threads_per_transform= 10, factors=(10, 4)), + NS(length= 42, workgroup_size=256, threads_per_transform= 7, factors=(7, 6)), + NS(length= 44, workgroup_size= 64, threads_per_transform= 4, factors=(11, 4)), + NS(length= 45, workgroup_size=128, threads_per_transform= 15, factors=(5, 3, 3)), + NS(length= 48, workgroup_size= 64, threads_per_transform= 16, factors=(4, 3, 4)), + NS(length= 49, workgroup_size= 64, threads_per_transform= 7, factors=(7, 7)), + NS(length= 50, workgroup_size=256, threads_per_transform= 10, factors=(10, 5)), + NS(length= 51, workgroup_size=256, threads_per_transform= 17, factors=(17, 3), runtime_compile=True), + NS(length= 52, workgroup_size= 64, threads_per_transform= 4, factors=(13, 4)), + NS(length= 54, workgroup_size=256, threads_per_transform= 18, factors=(6, 3, 3)), + NS(length= 55, workgroup_size=256, threads_per_transform= 11, factors=(5, 11), half_lds=False, runtime_compile=True), + NS(length= 56, workgroup_size=128, threads_per_transform= 8, factors=(7, 8)), + NS(length= 60, workgroup_size= 64, threads_per_transform= 10, factors=(6, 10)), + NS(length= 63, workgroup_size=256, threads_per_transform= 21, factors=(3, 3, 7), half_lds=False, runtime_compile=True), + NS(length= 64, workgroup_size= 64, threads_per_transform= 16, factors=(4, 4, 4), half_lds=False, direct_to_from_reg=True), + NS(length= 65, workgroup_size=256, threads_per_transform= 13, factors=(13, 5), runtime_compile=True), + NS(length= 66, workgroup_size=256, threads_per_transform= 11, factors=(6, 11), half_lds=False, runtime_compile=True), + NS(length= 68, workgroup_size=256, threads_per_transform= 17, factors=(17, 4), runtime_compile=True), + NS(length= 70, workgroup_size=256, threads_per_transform= 14, factors=(2, 5, 7), runtime_compile=True), + NS(length= 72, workgroup_size= 64, threads_per_transform= 9, factors=(8, 3, 3)), + NS(length= 75, workgroup_size=256, threads_per_transform= 25, factors=(5, 5, 3)), + NS(length= 77, workgroup_size=256, threads_per_transform= 11, factors=(7, 11), runtime_compile=True), + NS(length= 78, workgroup_size=256, threads_per_transform= 13, factors=(6, 13), half_lds=False, runtime_compile=True), + NS(length= 80, workgroup_size= 64, threads_per_transform= 10, factors=(5, 2, 8)), + NS(length= 81, workgroup_size=128, threads_per_transform= 27, factors=(3, 3, 3, 3)), + NS(length= 84, workgroup_size=128, threads_per_transform= 12, factors=(7, 2, 6)), + NS(length= 85, workgroup_size=256, threads_per_transform= 17, factors=(17, 5), runtime_compile=True), + NS(length= 88, workgroup_size=128, threads_per_transform= 11, factors=(11, 8)), + NS(length= 90, workgroup_size= 64, threads_per_transform= 9, factors=(3, 3, 10)), + NS(length= 91, workgroup_size=256, threads_per_transform= 13, factors=(7, 13), half_lds=False, runtime_compile=True), + NS(length= 96, workgroup_size=128, threads_per_transform= 16, factors=(6, 16), half_lds=False, direct_to_from_reg=False), + NS(length= 98, workgroup_size= 256, threads_per_transform= 14, factors=(2, 7, 7), half_lds=False, runtime_compile=True), + NS(length= 99, workgroup_size= 256, threads_per_transform= 11, factors=(3, 3, 11), half_lds=False, runtime_compile=True), + NS(length= 100, workgroup_size= 64, threads_per_transform= 10, factors=(10, 10)), + NS(length= 102, workgroup_size=128, threads_per_transform= 17, factors=(17, 6), runtime_compile=True), + NS(length= 104, workgroup_size= 64, threads_per_transform= 8, factors=(13, 8)), + NS(length= 105, workgroup_size=256, threads_per_transform= 21, factors=(7, 3, 5), half_lds=False, runtime_compile=True), + NS(length= 108, workgroup_size=256, threads_per_transform= 36, factors=(6, 6, 3)), + NS(length= 110, workgroup_size=256, threads_per_transform= 11, factors=(2, 5, 11), half_lds=False, runtime_compile=True), + NS(length= 112, workgroup_size=256, threads_per_transform= 16, factors=(16, 7), half_lds=False, direct_to_from_reg=False), + NS(length= 117, workgroup_size= 64, threads_per_transform= 13, factors=(13, 9), runtime_compile=True), + NS(length= 119, workgroup_size=256, threads_per_transform= 17, factors=(17, 7), runtime_compile=True), + NS(length= 120, workgroup_size= 64, threads_per_transform= 12, factors=(6, 10, 2), runtime_compile=True), + NS(length= 121, workgroup_size=128, threads_per_transform= 11, factors=(11, 11), runtime_compile=True), + NS(length= 125, workgroup_size=256, threads_per_transform= 25, factors=(5, 5, 5), half_lds=False, direct_to_from_reg=False), + NS(length= 126, workgroup_size= 256, threads_per_transform= 42, factors=(6, 7, 3), half_lds=False, runtime_compile=True), + NS(length= 128, workgroup_size=256, threads_per_transform= 16, factors=(16, 8)), + NS(length= 130, workgroup_size= 64, threads_per_transform= 13, factors=(13, 10), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 132, workgroup_size=128, threads_per_transform= 22, factors=(11, 6, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 135, workgroup_size=128, threads_per_transform= 9, factors=(5, 3, 3, 3), runtime_compile=True), + NS(length= 136, workgroup_size=128, threads_per_transform=17, factors=(17, 8), runtime_compile=True), + NS(length= 140, workgroup_size= 64, threads_per_transform= 28, factors=(7, 5, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 143, workgroup_size=256, threads_per_transform= 13, factors=(13, 11), half_lds=False, runtime_compile=True), + NS(length= 144, workgroup_size=128, threads_per_transform= 12, factors=(6, 6, 4)), + NS(length= 147, workgroup_size= 64, threads_per_transform= 21, factors=(7, 7, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 150, workgroup_size= 64, threads_per_transform= 5, factors=(10, 5, 3), runtime_compile=True), + NS(length= 153, workgroup_size=128, threads_per_transform= 17, factors=(17, 9), runtime_compile=True), + NS(length= 154, workgroup_size=128, threads_per_transform= 22, factors=(11, 7, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 156, workgroup_size= 128, threads_per_transform=13, factors=(3, 4, 13), half_lds=False, runtime_compile=True), + NS(length= 160, workgroup_size=256, threads_per_transform= 16, factors=(16, 10)), + NS(length= 162, workgroup_size=256, threads_per_transform= 27, factors=(6, 3, 3, 3), runtime_compile=True), + NS(length= 165, workgroup_size= 64, threads_per_transform= 11, factors=(11, 5, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 168, workgroup_size=256, threads_per_transform= 56, factors=(8, 7, 3), half_lds=False, direct_to_from_reg=False), + NS(length= 169, workgroup_size=256, threads_per_transform= 13, factors=(13, 13), runtime_compile=True), + NS(length= 170, workgroup_size=128, threads_per_transform= 17, factors=(17, 10), runtime_compile=True), + NS(length= 175, workgroup_size=256, threads_per_transform= 35, factors=(5, 5, 7), half_lds=False, runtime_compile=True), + NS(length= 176, workgroup_size= 64, threads_per_transform= 16, factors=(11, 16), runtime_compile=True), + NS(length= 180, workgroup_size=256, threads_per_transform= 60, factors=(10, 6, 3), half_lds=False, direct_to_from_reg=False), + NS(length= 182, workgroup_size= 64, threads_per_transform= 13, factors=(13, 2, 7), half_lds=False, runtime_compile=True), + NS(length= 187, workgroup_size=128, threads_per_transform= 17, factors=(17, 11), runtime_compile=True), + NS(length= 189, workgroup_size= 64, threads_per_transform= 21, factors=(7, 3, 3, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 192, workgroup_size=128, threads_per_transform= 16, factors=(6, 4, 4, 2)), + NS(length= 195, workgroup_size= 64, threads_per_transform= 13, factors=(13, 5, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 196, workgroup_size= 64, threads_per_transform= 28, factors=(4, 7, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 198, workgroup_size=128, threads_per_transform= 22, factors=(11, 2, 9), half_lds=False, runtime_compile=True), + NS(length= 200, workgroup_size= 64, threads_per_transform= 20, factors=(10, 10, 2)), + NS(length= 204, workgroup_size=128, threads_per_transform= 17, factors=(17, 4, 3), runtime_compile=True), + NS(length= 208, workgroup_size= 64, threads_per_transform= 16, factors=(13, 16)), + NS(length= 210, workgroup_size= 64, threads_per_transform= 30, factors=(10, 7, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 216, workgroup_size=256, threads_per_transform= 36, factors=(6, 6, 6)), + NS(length= 220, workgroup_size=128, threads_per_transform= 22, factors=(10, 2, 11), half_lds=False, runtime_compile=True), + NS(length= 221, workgroup_size=128, threads_per_transform= 17, factors=(17, 13), runtime_compile=True), + NS(length= 224, workgroup_size= 64, threads_per_transform= 16, factors=(7, 2, 2, 2, 2, 2)), + NS(length= 225, workgroup_size=256, threads_per_transform= 75, factors=(5, 5, 3, 3), runtime_compile=True), + NS(length= 231, workgroup_size=256, threads_per_transform= 33, factors=(11, 7, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 234, workgroup_size= 64, threads_per_transform= 26, factors=(13, 9, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 238, workgroup_size= 64, threads_per_transform= 17, factors=(17, 7, 2), runtime_compile=True), + NS(length= 240, workgroup_size=128, threads_per_transform= 48, factors=(8, 5, 6)), + NS(length= 242, workgroup_size=128, threads_per_transform= 22, factors=(11, 2, 11), half_lds=False, runtime_compile=True), + NS(length= 243, workgroup_size=256, threads_per_transform= 81, factors=(3, 3, 3, 3, 3)), + NS(length= 245, workgroup_size=256, threads_per_transform= 35, factors=(7, 5, 7), half_lds=False, runtime_compile=True), + NS(length= 250, workgroup_size=128, threads_per_transform= 25, factors=(10, 5, 5), runtime_compile=True), + NS(length= 252, workgroup_size= 64, threads_per_transform= 63, factors=(7, 3, 3, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 255, workgroup_size= 64, threads_per_transform= 17, factors=(17, 5, 3), runtime_compile=True), + NS(length= 256, workgroup_size= 64, threads_per_transform= 64, factors=(4, 4, 4, 4)), + NS(length= 260, workgroup_size= 64, threads_per_transform= 26, factors=(13, 10, 2), half_lds=False, runtime_compile=True), + NS(length= 264, workgroup_size=256, threads_per_transform= 33, factors=(8, 3, 11), half_lds=False, runtime_compile=True), + NS(length= 270, workgroup_size=128, threads_per_transform= 27, factors=(10, 3, 3, 3)), + NS(length= 272, workgroup_size=128, threads_per_transform= 17, factors=(16, 17), runtime_compile=True), + NS(length= 273, workgroup_size= 64, threads_per_transform= 13, factors=(13, 3, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 275, workgroup_size= 64, threads_per_transform= 55, factors=(11, 5, 5), half_lds=False, runtime_compile=True), + NS(length= 280, workgroup_size= 64, threads_per_transform= 56, factors=(8, 7, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 286, workgroup_size= 64, threads_per_transform= 26, factors=(13, 11, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 288, workgroup_size=128, threads_per_transform= 24, factors=(6, 6, 4, 2), runtime_compile=True), + NS(length= 289, workgroup_size=128, threads_per_transform= 17, factors=(17, 17), runtime_compile=True), + NS(length= 294, workgroup_size=128, threads_per_transform= 42, factors=(6, 7, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 297, workgroup_size=256, threads_per_transform= 33, factors=(9, 3, 11), runtime_compile=True), + NS(length= 300, workgroup_size= 64, threads_per_transform= 30, factors=(10, 10, 3), runtime_compile=True), + NS(length= 306, workgroup_size=256, threads_per_transform= 34, factors=(17, 2, 9), runtime_compile=True), + NS(length= 308, workgroup_size= 64, threads_per_transform= 44, factors=(11, 7, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 312, workgroup_size= 64, threads_per_transform= 26, factors=(13, 4, 3, 2), half_lds=False, runtime_compile=True), + NS(length= 315, workgroup_size= 64, threads_per_transform= 63, factors=(7, 3, 3, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 320, workgroup_size= 64, threads_per_transform= 16, factors=(10, 4, 4, 2), runtime_compile=True), + NS(length= 324, workgroup_size= 64, threads_per_transform= 54, factors=(3, 6, 6, 3), runtime_compile=True), + NS(length= 325, workgroup_size= 64, threads_per_transform= 13, factors=(13, 5, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 330, workgroup_size=128, threads_per_transform= 33, factors=(11, 10, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 336, workgroup_size=128, threads_per_transform= 56, factors=(8, 7, 6)), + NS(length= 338, workgroup_size= 64, threads_per_transform= 26, factors=(13, 2, 13), runtime_compile=True), + NS(length= 340, workgroup_size=128, threads_per_transform= 34, factors=(17, 2, 10), runtime_compile=True), + NS(length= 343, workgroup_size=256, threads_per_transform= 49, factors=(7, 7, 7), runtime_compile=True), + NS(length= 350, workgroup_size= 64, threads_per_transform= 50, factors=(5, 7, 10), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 351, workgroup_size=128, threads_per_transform= 39, factors=(13, 3, 9), half_lds=False, runtime_compile=True), + NS(length= 352, workgroup_size= 64, threads_per_transform= 32, factors=(11, 2, 16), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 357, workgroup_size=256, threads_per_transform= 17, factors=(17, 3, 7), runtime_compile=True), + NS(length= 360, workgroup_size=256, threads_per_transform= 60, factors=(10, 6, 6), runtime_compile=True), + NS(length= 363, workgroup_size=128, threads_per_transform= 33, factors=(11, 3, 11), runtime_compile=True), + NS(length= 364, workgroup_size= 64, threads_per_transform= 52, factors=(13, 7, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 374, workgroup_size=256, threads_per_transform= 34, factors=(17, 2, 11), runtime_compile=True), + NS(length= 375, workgroup_size=128, threads_per_transform= 25, factors=(5, 5, 5, 3), runtime_compile=True), + NS(length= 378, workgroup_size=128, threads_per_transform=126, factors=(6, 3, 3, 7), half_lds=False, runtime_compile=True), + NS(length= 384, workgroup_size=128, threads_per_transform= 32, factors=(6, 4, 4, 4), runtime_compile=True), + NS(length= 385, workgroup_size= 64, threads_per_transform= 55, factors=(11, 7, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 390, workgroup_size=128, threads_per_transform= 39, factors=(13, 3, 10), half_lds=False, runtime_compile=True), + NS(length= 392, workgroup_size= 64, threads_per_transform= 56, factors=(8, 7, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 396, workgroup_size= 64, threads_per_transform= 44, factors=(11, 9, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 400, workgroup_size=128, threads_per_transform= 40, factors=(4, 10, 10), runtime_compile=True), + NS(length= 405, workgroup_size=128, threads_per_transform= 27, factors=(5, 3, 3, 3, 3), runtime_compile=True), + NS(length= 408, workgroup_size= 64, threads_per_transform= 17, factors=(17, 3, 8), runtime_compile=True), + NS(length= 416, workgroup_size= 64, threads_per_transform= 32, factors=(13, 2, 16), half_lds=False, runtime_compile=True), + NS(length= 420, workgroup_size= 64, threads_per_transform= 60, factors=(10, 7, 6), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 425, workgroup_size= 64, threads_per_transform= 17, factors=(17, 5, 5), runtime_compile=True), + NS(length= 429, workgroup_size=128, threads_per_transform= 39, factors=(13, 3, 11), half_lds=False, runtime_compile=True), + NS(length= 432, workgroup_size= 64, threads_per_transform= 27, factors=(3, 16, 3, 3), runtime_compile=True), + NS(length= 440, workgroup_size= 64, threads_per_transform= 55, factors=(11, 8, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 441, workgroup_size= 64, threads_per_transform= 63, factors=(9, 7, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 442, workgroup_size=256, threads_per_transform= 34, factors=(17, 2, 13), runtime_compile=True), + NS(length= 448, workgroup_size=128, threads_per_transform= 64, factors=(8, 7, 8), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 450, workgroup_size=128, threads_per_transform= 30, factors=(10, 5, 3, 3), runtime_compile=True), + NS(length= 455, workgroup_size=256, threads_per_transform= 65, factors=(13, 5, 7), half_lds=False, runtime_compile=True), + NS(length= 459, workgroup_size=256, threads_per_transform= 51, factors=(17, 3, 9), runtime_compile=True), + NS(length= 462, workgroup_size=256, threads_per_transform= 77, factors=(11, 6, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 468, workgroup_size= 64, threads_per_transform= 52, factors=(13, 9, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 476, workgroup_size=128, threads_per_transform= 34, factors=(17, 2, 7, 2), runtime_compile=True), + NS(length= 480, workgroup_size= 64, threads_per_transform= 16, factors=(10, 8, 6), runtime_compile=True), + NS(length= 484, workgroup_size= 64, threads_per_transform= 44, factors=(4, 11, 11), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 486, workgroup_size=256, threads_per_transform=162, factors=(6, 3, 3, 3, 3), runtime_compile=True), + NS(length= 490, workgroup_size=256, threads_per_transform= 70, factors=(10, 7, 7), half_lds=False, runtime_compile=True), + NS(length= 495, workgroup_size= 64, threads_per_transform= 55, factors=(11, 9, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 500, workgroup_size=128, threads_per_transform=100, factors=(10, 5, 10), runtime_compile=True), + NS(length= 504, workgroup_size= 64, threads_per_transform= 63, factors=(7, 9, 4, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 507, workgroup_size=128, threads_per_transform= 39, factors=(13, 3, 13), runtime_compile=True), + NS(length= 510, workgroup_size=256, threads_per_transform= 34, factors=(17, 2, 3, 5), runtime_compile=True), + NS(length= 512, workgroup_size= 64, threads_per_transform= 64, factors=(8, 8, 8)), + NS(length= 520, workgroup_size= 64, threads_per_transform= 52, factors=(13, 10, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 525, workgroup_size= 128, threads_per_transform=105, factors=(7, 3, 5, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 528, workgroup_size= 64, threads_per_transform= 48, factors=(4, 4, 3, 11), runtime_compile=True), + NS(length= 539, workgroup_size=256, threads_per_transform= 77, factors=(11, 7, 7), runtime_compile=True), + NS(length= 540, workgroup_size=256, threads_per_transform= 54, factors=(3, 10, 6, 3), runtime_compile=True), + NS(length= 544, workgroup_size=128, threads_per_transform= 34, factors=(17, 2, 16), runtime_compile=True), + NS(length= 546, workgroup_size=128, threads_per_transform= 39, factors=(13, 3, 7, 2), runtime_compile=True), + NS(length= 550, workgroup_size= 64, threads_per_transform= 55, factors=(11, 10, 5), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 560, workgroup_size= 64, threads_per_transform= 56, factors=(8, 7, 5, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 561, workgroup_size=256, threads_per_transform= 51, factors=(17, 3, 11), runtime_compile=True), + NS(length= 567, workgroup_size= 64, threads_per_transform= 63, factors=(7, 9, 3, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 572, workgroup_size= 64, threads_per_transform= 52, factors=(13, 11, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 576, workgroup_size=128, threads_per_transform= 96, factors=(16, 6, 6), runtime_compile=True), + NS(length= 578, workgroup_size= 256, threads_per_transform=34, factors=(17, 17, 2), runtime_compile=True), + NS(length= 585, workgroup_size= 256, threads_per_transform=65, factors=(13, 5, 9), half_lds=False, runtime_compile=True), + NS(length= 588, workgroup_size= 256, threads_per_transform=84, factors=(7, 3, 4, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 594, workgroup_size=128, threads_per_transform= 99, factors=(11, 3, 6, 3), half_lds=False, runtime_compile=True), + NS(length= 595, workgroup_size= 64, threads_per_transform= 17, factors=(7, 17, 5), runtime_compile=True), + NS(length= 600, workgroup_size= 64, threads_per_transform= 60, factors=(10, 6, 10), runtime_compile=True), + NS(length= 605, workgroup_size= 64, threads_per_transform= 55, factors=(11, 5, 11), half_lds=False, runtime_compile=True), + NS(length= 612, workgroup_size= 64, threads_per_transform= 51, factors=(17, 3, 6, 2), runtime_compile=True), + NS(length= 616, workgroup_size=128, threads_per_transform= 88, factors=(11, 7, 8), half_lds=False, runtime_compile=True), + NS(length= 624, workgroup_size= 64, threads_per_transform= 52, factors=(13, 4, 6, 2), half_lds=False, runtime_compile=True), + NS(length= 625, workgroup_size=128, threads_per_transform=125, factors=(5, 5, 5, 5), runtime_compile=True), + NS(length= 630, workgroup_size= 64, threads_per_transform= 63, factors=(3, 3, 5, 7, 2), runtime_compile=True), + NS(length= 637, workgroup_size=128, threads_per_transform= 91, factors=(13, 7, 7), runtime_compile=True), + NS(length= 640, workgroup_size=128, threads_per_transform= 64, factors=(8, 10, 8), runtime_compile=True), + NS(length= 648, workgroup_size=256, threads_per_transform=216, factors=(8, 3, 3, 3, 3), runtime_compile=True), + NS(length= 650, workgroup_size= 256, threads_per_transform=65, factors=(10, 5, 13), half_lds=False, runtime_compile=True), + NS(length= 660, workgroup_size=128, threads_per_transform=110, factors=(11, 6, 10), runtime_compile=True), + NS(length= 663, workgroup_size= 64, threads_per_transform= 51, factors=(17, 13, 3), half_lds=False, runtime_compile=True), + NS(length= 672, workgroup_size= 64, threads_per_transform= 56, factors=(2, 2, 2, 2, 2, 3, 7), runtime_compile=True), + NS(length= 675, workgroup_size=256, threads_per_transform=225, factors=(5, 5, 3, 3, 3), runtime_compile=True), + NS(length= 676, workgroup_size= 64, threads_per_transform= 52, factors=(13, 13, 4), half_lds=False, runtime_compile=True), + NS(length= 680, workgroup_size=256, threads_per_transform= 68, factors=(17, 4, 10), runtime_compile=True), + NS(length= 686, workgroup_size= 64, threads_per_transform= 49, factors=(7, 7, 7, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 693, workgroup_size=128, threads_per_transform= 99, factors=(11, 7, 9), runtime_compile=True), + NS(length= 700, workgroup_size= 128, threads_per_transform=100, factors=(10, 7, 10), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 702, workgroup_size= 128, threads_per_transform=117, factors=(13, 3, 6, 3), runtime_compile=True), + NS(length= 704, workgroup_size=256, threads_per_transform=88, factors=(2, 2, 2, 2, 11, 2, 2), runtime_compile=True), + NS(length= 714, workgroup_size=64, threads_per_transform=51, factors=(3, 17, 7, 2), runtime_compile=True), + NS(length= 715, workgroup_size=256, threads_per_transform= 65, factors=(13, 5, 11), runtime_compile=True), + NS(length= 720, workgroup_size=256, threads_per_transform=120, factors=(10, 3, 8, 3), runtime_compile=True), + NS(length= 726, workgroup_size=256, threads_per_transform= 66, factors=(11, 6, 11), half_lds=False, runtime_compile=True), + NS(length= 728, workgroup_size=128, threads_per_transform=104, factors=(13, 7, 8), runtime_compile=True), + NS(length= 729, workgroup_size=256, threads_per_transform=243, factors=(3, 3, 3, 3, 3, 3), runtime_compile=True), + NS(length= 735, workgroup_size= 256, threads_per_transform=147, factors=(7, 3, 5, 7), half_lds=False, runtime_compile=True), + NS(length= 748, workgroup_size= 256, threads_per_transform=68, factors=(17, 4, 11), runtime_compile=True), + NS(length= 750, workgroup_size=256, threads_per_transform=250, factors=(10, 5, 3, 5), runtime_compile=True), + NS(length= 756, workgroup_size= 64, threads_per_transform= 63, factors=(2, 2, 3, 3, 3, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 765, workgroup_size=256, threads_per_transform= 51, factors=(17, 3, 5, 3), runtime_compile=True), + NS(length= 768, workgroup_size= 64, threads_per_transform= 48, factors=(16, 3, 16), runtime_compile=True), + NS(length= 770, workgroup_size=256, threads_per_transform=110, factors=(11, 10, 7), half_lds=False, runtime_compile=True), + NS(length= 780, workgroup_size=256, threads_per_transform= 78, factors=(2, 3, 13, 5, 2), runtime_compile=True), + NS(length= 784, workgroup_size= 64, threads_per_transform= 56, factors=(2, 2, 2, 2, 7, 7), runtime_compile=True), + NS(length= 792, workgroup_size=256, threads_per_transform= 88, factors=(2, 2, 2, 3, 3, 11), half_lds=False, runtime_compile=True), + NS(length= 800, workgroup_size=256, threads_per_transform=160, factors=(16, 5, 10), runtime_compile=True), + NS(length= 810, workgroup_size=128, threads_per_transform= 81, factors=(3, 10, 3, 3, 3), runtime_compile=True), + NS(length= 816, workgroup_size= 64, threads_per_transform= 51, factors=(17, 2, 3, 2, 2, 2), runtime_compile=True), + NS(length= 819, workgroup_size=128, threads_per_transform=117, factors=(9, 7, 13), half_lds=False, runtime_compile=True), + NS(length= 825, workgroup_size= 64, threads_per_transform= 55, factors=(11, 5, 5, 3), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 832, workgroup_size=128, threads_per_transform=104, factors=(13, 2, 2, 2, 2, 2, 2), runtime_compile=True), + NS(length= 833, workgroup_size=128, threads_per_transform=119, factors=(17, 7, 7), runtime_compile=True), + NS(length= 840, workgroup_size= 64, threads_per_transform= 56, factors=(2, 2, 2, 3, 5, 7), runtime_compile=True), + NS(length= 845, workgroup_size= 256, threads_per_transform=65, factors=(13, 5, 13), runtime_compile=True), + NS(length= 847, workgroup_size= 256, threads_per_transform=77, factors=(11, 7, 11), runtime_compile=True), + NS(length= 850, workgroup_size= 128, threads_per_transform=85, factors=(10, 5, 17), half_lds=False, runtime_compile=True), + NS(length= 858, workgroup_size= 256, threads_per_transform=78, factors=(13, 11, 6), runtime_compile=True), + NS(length= 864, workgroup_size= 64, threads_per_transform= 54, factors=(3, 6, 16, 3), runtime_compile=True), + NS(length= 867, workgroup_size= 64, threads_per_transform=51, factors=(17, 17, 3), runtime_compile=True), + NS(length= 875, workgroup_size= 256, threads_per_transform=175, factors=(7, 5, 5, 5), half_lds=False, runtime_compile=True), + NS(length= 880, workgroup_size=256, threads_per_transform= 88, factors=(2, 2, 2, 2, 11, 5), runtime_compile=True), + NS(length= 882, workgroup_size= 64, threads_per_transform=63, factors=(9, 7, 7, 2), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 884, workgroup_size= 256, threads_per_transform=68, factors=(13, 4, 17), runtime_compile=True), + NS(length= 891, workgroup_size= 256, threads_per_transform=99, factors=(9, 11, 3, 3), runtime_compile=True), + NS(length= 896, workgroup_size=128, threads_per_transform=112, factors=(2, 2, 2, 2, 2, 2, 2, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 900, workgroup_size=256, threads_per_transform= 90, factors=(10, 10, 3, 3), runtime_compile=True), + NS(length= 910, workgroup_size=256, threads_per_transform= 91, factors=(13, 2, 7, 5), half_lds=False, runtime_compile=True), + NS(length= 918, workgroup_size=128, threads_per_transform=102, factors=(17, 9, 2, 3), runtime_compile=True), + NS(length= 924, workgroup_size= 64, threads_per_transform= 44, factors=(2, 2, 3, 7, 11), runtime_compile=True), + NS(length= 935, workgroup_size= 256, threads_per_transform= 85, factors=(17, 11, 5), runtime_compile=True), + NS(length= 936, workgroup_size=256, threads_per_transform= 78, factors=(2, 2, 13, 2, 3, 3), runtime_compile=True), + NS(length= 945, workgroup_size= 64, threads_per_transform= 63, factors=(3, 3, 3, 5, 7), runtime_compile=True), + NS(length= 952, workgroup_size=256, threads_per_transform= 68, factors=(17, 4, 2, 7), runtime_compile=True), + NS(length= 960, workgroup_size=256, threads_per_transform=160, factors=(16, 10, 6), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 968, workgroup_size=256, threads_per_transform= 88, factors=(2, 2, 2, 11, 11), half_lds=False, runtime_compile=True), + NS(length= 972, workgroup_size=256, threads_per_transform=162, factors=(3, 6, 3, 6, 3), runtime_compile=True), + NS(length= 975, workgroup_size=128, threads_per_transform= 39, factors=(13, 5, 3, 5), runtime_compile=True), + NS(length= 980, workgroup_size= 256, threads_per_transform=196, factors=(7, 5, 7, 4), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length= 990, workgroup_size=128, threads_per_transform=110, factors=(2, 3, 3, 5, 11), half_lds=False, runtime_compile=True), + NS(length=1000, workgroup_size=128, threads_per_transform=100, factors=(10, 10, 10), runtime_compile=True), + NS(length=1001, workgroup_size=256, threads_per_transform= 91, factors=(13, 7, 11), runtime_compile=True), + NS(length=1008, workgroup_size= 64, threads_per_transform= 56, factors=(2, 2, 2, 2, 3, 3, 7), runtime_compile=True), + NS(length=1014, workgroup_size=256, threads_per_transform= 78, factors=(13, 6, 13), half_lds=False, runtime_compile=True), + NS(length=1020, workgroup_size=256, threads_per_transform= 68, factors=(2, 17, 2, 3, 5), runtime_compile=True), + NS(length=1024, workgroup_size=128, threads_per_transform=128, factors=(8, 8, 4, 4)), + NS(length=1040, workgroup_size=256, threads_per_transform=208, factors=(13, 16, 5), runtime_compile=True), + NS(length=1050, workgroup_size=256, threads_per_transform=210, factors=(2, 3, 5, 5, 7), half_lds=False, runtime_compile=True), + NS(length=1053, workgroup_size=128, threads_per_transform=117, factors=(3, 3, 13, 3, 3), runtime_compile=True), + NS(length=1056, workgroup_size=256, threads_per_transform=176, factors=(2, 2, 2, 2, 11, 6), runtime_compile=True), + NS(length=1071, workgroup_size=128, threads_per_transform=119, factors=(17, 7, 9), runtime_compile=True), + NS(length=1078, workgroup_size=256, threads_per_transform= 77, factors=(2, 11, 7, 7), runtime_compile=True), + NS(length=1080, workgroup_size=256, threads_per_transform=108, factors=(6, 10, 6, 3), runtime_compile=True), + NS(length=1088, workgroup_size=256, threads_per_transform= 68, factors=(17, 4, 4, 2, 2), runtime_compile=True), + NS(length=1089, workgroup_size=128, threads_per_transform=121, factors=(3, 11, 3, 11), half_lds=False, runtime_compile=True), + NS(length=1092, workgroup_size= 64, threads_per_transform= 52, factors=(2, 2, 13, 7, 3), runtime_compile=True), + NS(length=1100, workgroup_size=128, threads_per_transform=110, factors=(2, 2, 11, 5, 5), half_lds=False, runtime_compile=True), + NS(length=1105, workgroup_size=256, threads_per_transform= 85, factors=(17, 13, 5), runtime_compile=True), + NS(length=1120, workgroup_size=256, threads_per_transform=224, factors=(2, 2, 2, 2, 2, 5, 7), runtime_compile=True), + NS(length=1122, workgroup_size=256, threads_per_transform=102, factors=(17, 11, 6), runtime_compile=True), + NS(length=1125, workgroup_size=256, threads_per_transform=225, factors=(5, 5, 3, 3, 5), runtime_compile=True), + NS(length=1134, workgroup_size=128, threads_per_transform=126, factors=(2, 3, 3, 3, 3, 7), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length=1144, workgroup_size=128, threads_per_transform=104, factors=(13, 11, 8), half_lds=False, direct_to_from_reg=False, runtime_compile=True), + NS(length=1152, workgroup_size=256, threads_per_transform=144, factors=(4, 3, 8, 3, 4), runtime_compile=True), + NS(length=1155, workgroup_size= 64, threads_per_transform= 55, factors=(11, 5, 7, 3), runtime_compile=True), + NS(length=1156, workgroup_size=256, threads_per_transform= 68, factors=(17, 2, 17, 2), runtime_compile=True), + NS(length=1170, workgroup_size=256, threads_per_transform=117, factors=(2, 13, 3, 5, 3), half_lds=False, runtime_compile=True), + NS(length=1176, workgroup_size= 64, threads_per_transform= 56, factors=(2, 2, 2, 3, 7, 7), runtime_compile=True), + NS(length=1183, workgroup_size=256, threads_per_transform= 91, factors=(7, 13, 13), runtime_compile=True), + NS(length=1188, workgroup_size=256, threads_per_transform= 66, factors=(6, 11, 2, 3, 3), runtime_compile=True), + NS(length=1190, workgroup_size=256, threads_per_transform= 85, factors=(17, 2, 5, 7), runtime_compile=True), + NS(length=1200, workgroup_size=256, threads_per_transform= 75, factors=(5, 5, 16, 3), runtime_compile=True), + NS(length=1210, workgroup_size=128, threads_per_transform=110, factors=(2, 5, 11, 11), runtime_compile=True), + NS(length=1215, workgroup_size=256, threads_per_transform=243, factors=(5, 3, 3, 3, 3, 3), runtime_compile=True), + NS(length=1224, workgroup_size=256, threads_per_transform=102, factors=(17, 3, 4, 6), runtime_compile=True), + NS(length=1225, workgroup_size=256, threads_per_transform=175, factors=(5, 5, 7, 7), runtime_compile=True), + NS(length=1232, workgroup_size=256, threads_per_transform=176, factors=(2, 2, 2, 2, 11, 7), runtime_compile=True), + NS(length=1248, workgroup_size= 64, threads_per_transform= 52, factors=(2, 2, 13, 2, 3, 2, 2), runtime_compile=True), + NS(length=1250, workgroup_size=256, threads_per_transform=250, factors=(5, 10, 5, 5), runtime_compile=True), + NS(length=1260, workgroup_size= 64, threads_per_transform= 63, factors=(2, 2, 3, 3, 5, 7), runtime_compile=True), + NS(length=1274, workgroup_size=256, threads_per_transform=182, factors=(2, 13, 7, 7), runtime_compile=True), + NS(length=1275, workgroup_size=256, threads_per_transform= 85, factors=(17, 3, 5, 5), runtime_compile=True), + NS(length=1280, workgroup_size=128, threads_per_transform= 80, factors=(16, 5, 16), runtime_compile=True), + NS(length=1287, workgroup_size=128, threads_per_transform=117, factors=(3, 13, 3, 11), half_lds=False, runtime_compile=True), + NS(length=1296, workgroup_size=128, threads_per_transform=108, factors=(6, 6, 6, 6), runtime_compile=True), + NS(length=1300, workgroup_size=256, threads_per_transform=130, factors=(10, 10, 13), half_lds=False, runtime_compile=True), + NS(length=1309, workgroup_size=128, threads_per_transform=119, factors=(17, 7, 11), runtime_compile=True), + NS(length=1320, workgroup_size=256, threads_per_transform=165, factors=(11, 2, 3, 5, 4), half_lds=False, runtime_compile=True), + NS(length=1323, workgroup_size=256, threads_per_transform=189, factors=(3, 3, 3, 7, 7), half_lds=False, runtime_compile=True), + NS(length=1326, workgroup_size=256, threads_per_transform=102, factors=(17, 6, 13), runtime_compile=True), + NS(length=1331, workgroup_size=256, threads_per_transform=121, factors=(11, 11, 11), runtime_compile=True), + NS(length=1344, workgroup_size=256, threads_per_transform=224, factors=(2, 2, 2, 2, 2, 2, 3, 7), runtime_compile=True), + NS(length=1350, workgroup_size=256, threads_per_transform=135, factors=(5, 10, 3, 3, 3), runtime_compile=True), + NS(length=1352, workgroup_size= 64, threads_per_transform= 52, factors=(2, 13, 13, 4), runtime_compile=True), + NS(length=1360, workgroup_size=256, threads_per_transform= 85, factors=(17, 5, 16), runtime_compile=True), + NS(length=1365, workgroup_size=256, threads_per_transform= 91, factors=(13, 7, 5, 3), runtime_compile=True), + NS(length=1372, workgroup_size=256, threads_per_transform= 98, factors=(2, 2, 7, 7, 7), runtime_compile=True), + NS(length=1375, workgroup_size= 64, threads_per_transform= 55, factors=(11, 5, 5, 5), runtime_compile=True), + NS(length=1377, workgroup_size= 64, threads_per_transform= 51, factors=(17, 3, 9, 3), runtime_compile=True), + NS(length=1386, workgroup_size=256, threads_per_transform=231, factors=(2, 7, 3, 11, 3), runtime_compile=True), + NS(length=1400, workgroup_size= 64, threads_per_transform= 56, factors=(2, 2, 2, 5, 7, 5), runtime_compile=True), + NS(length=1404, workgroup_size=128, threads_per_transform=117, factors=(2, 2, 3, 13, 3, 3), runtime_compile=True), + NS(length=1408, workgroup_size=256, threads_per_transform=176, factors=(2, 2, 2, 2, 2, 2, 11, 2), runtime_compile=True), + NS(length=1428, workgroup_size=128, threads_per_transform=119, factors=(17, 2, 7, 6), runtime_compile=True), + NS(length=1430, workgroup_size=256, threads_per_transform=143, factors=(13, 11, 10), half_lds=False, runtime_compile=True), + NS(length=1440, workgroup_size=128, threads_per_transform= 90, factors=(10, 16, 3, 3), runtime_compile=True), + NS(length=1445, workgroup_size=128, threads_per_transform= 85, factors=(17, 5, 17), runtime_compile=True), + NS(length=1452, workgroup_size=256, threads_per_transform=132, factors=(11, 3, 11, 4), runtime_compile=True), + NS(length=1456, workgroup_size=256, threads_per_transform=182, factors=(13, 4, 7, 2, 2), runtime_compile=True), + NS(length=1458, workgroup_size=256, threads_per_transform=243, factors=(6, 3, 3, 3, 3, 3), runtime_compile=True), + NS(length=1470, workgroup_size=256, threads_per_transform=210, factors=(2, 3, 5, 7, 7), runtime_compile=True), + NS(length=1485, workgroup_size=256, threads_per_transform=165, factors=(3, 5, 11, 3, 3), half_lds=False, runtime_compile=True), + NS(length=1496, workgroup_size=256, threads_per_transform=187, factors=(17, 8, 11), runtime_compile=True), + NS(length=1500, workgroup_size=256, threads_per_transform=150, factors=(5, 10, 10, 3), runtime_compile=True), + NS(length=1512, workgroup_size= 64, threads_per_transform= 63, factors=(2, 2, 2, 3, 3, 3, 7), runtime_compile=True), + NS(length=1521, workgroup_size=128, threads_per_transform=117, factors=(13, 3, 3, 13), runtime_compile=True), + NS(length=1530, workgroup_size=128, threads_per_transform=102, factors=(17, 3, 6, 5), runtime_compile=True), + NS(length=1536, workgroup_size=256, threads_per_transform=256, factors=(16, 16, 6), runtime_compile=True), + NS(length=1540, workgroup_size=256, threads_per_transform=154, factors=(11, 2, 7, 5, 2), runtime_compile=True), + NS(length=1547, workgroup_size=128, threads_per_transform=119, factors=(17, 7, 13), runtime_compile=True), + NS(length=1560, workgroup_size=256, threads_per_transform=156, factors=(13, 2, 2, 10, 3), half_lds=False, runtime_compile=True), + NS(length=1568, workgroup_size=256, threads_per_transform=224, factors=(2, 2, 2, 2, 2, 7, 7), runtime_compile=True), + NS(length=1573, workgroup_size=256, threads_per_transform=143, factors=(13, 11, 11), half_lds=False, runtime_compile=True), + NS(length=1575, workgroup_size= 64, threads_per_transform= 63, factors=(3, 3, 5, 7, 5), runtime_compile=True), + NS(length=1584, workgroup_size=256, threads_per_transform=176, factors=(4, 2, 2, 11, 3, 3), runtime_compile=True), + NS(length=1600, workgroup_size=256, threads_per_transform=100, factors=(10, 16, 10), runtime_compile=True), + NS(length=1617, workgroup_size=256, threads_per_transform=231, factors=(3, 7, 7, 11), half_lds=False, runtime_compile=True), + NS(length=1620, workgroup_size=256, threads_per_transform=162, factors=(10, 3, 3, 6, 3), runtime_compile=True), + NS(length=1625, workgroup_size=256, threads_per_transform= 65, factors=(13, 5, 5, 5), runtime_compile=True), + NS(length=1632, workgroup_size=128, threads_per_transform=102, factors=(17, 2, 2, 3, 8), runtime_compile=True), + NS(length=1638, workgroup_size=256, threads_per_transform=182, factors=(13, 2, 3, 7, 3), runtime_compile=True), + NS(length=1650, workgroup_size=128, threads_per_transform=110, factors=(11, 2, 3, 5, 5), runtime_compile=True), + NS(length=1664, workgroup_size=256, threads_per_transform=208, factors=(13, 2, 2, 4, 2, 2, 2), runtime_compile=True), + NS(length=1666, workgroup_size=128, threads_per_transform=119, factors=(17, 2, 7, 7), runtime_compile=True), + NS(length=1680, workgroup_size=128, threads_per_transform=112, factors=(2, 2, 2, 2, 3, 7, 5), runtime_compile=True), + NS(length=1683, workgroup_size= 64, threads_per_transform= 51, factors=(17, 3, 11, 3), runtime_compile=True), + NS(length=1690, workgroup_size=256, threads_per_transform=169, factors=(13, 10, 13), half_lds=False, runtime_compile=True), + NS(length=1694, workgroup_size=256, threads_per_transform=154, factors=(11, 2, 11, 7), runtime_compile=True), + NS(length=1700, workgroup_size=256, threads_per_transform=170, factors=(17, 10, 10), runtime_compile=True), + NS(length=1701, workgroup_size= 64, threads_per_transform= 63, factors=(3, 3, 3, 3, 3, 7), runtime_compile=True), + NS(length=1715, workgroup_size=256, threads_per_transform=245, factors=(5, 7, 7, 7), runtime_compile=True), + NS(length=1716, workgroup_size=256, threads_per_transform=156, factors=(13, 2, 6, 11), half_lds=False, runtime_compile=True), + NS(length=1728, workgroup_size=128, threads_per_transform=108, factors=(3, 6, 6, 16), runtime_compile=True), + NS(length=1734, workgroup_size=128, threads_per_transform=102, factors=(17, 17, 6), runtime_compile=True), + NS(length=1750, workgroup_size=256, threads_per_transform=175, factors=(2, 5, 5, 7, 5), runtime_compile=True), + NS(length=1755, workgroup_size=128, threads_per_transform=117, factors=(13, 3, 3, 3, 5), runtime_compile=True), + NS(length=1760, workgroup_size=256, threads_per_transform=176, factors=(2, 2, 2, 2, 2, 11, 5), runtime_compile=True), + NS(length=1764, workgroup_size=128, threads_per_transform=126, factors=(2, 2, 3, 3, 7, 7), runtime_compile=True), + NS(length=1768, workgroup_size=256, threads_per_transform=136, factors=(17, 13, 8), runtime_compile=True), + NS(length=1782, workgroup_size=128, threads_per_transform= 99, factors=(11, 3, 3, 3, 3, 2), runtime_compile=True), + NS(length=1785, workgroup_size=128, threads_per_transform=119, factors=(17, 3, 5, 7), runtime_compile=True), + NS(length=1792, workgroup_size=256, threads_per_transform=224, factors=(4, 4, 4, 4, 7), runtime_compile=True), + NS(length=1800, workgroup_size=256, threads_per_transform=180, factors=(10, 6, 10, 3), runtime_compile=True), + NS(length=1815, workgroup_size=256, threads_per_transform=165, factors=(11, 3, 5, 11), half_lds=False, runtime_compile=True), + NS(length=1820, workgroup_size=256, threads_per_transform=182, factors=(10, 13, 7, 2), runtime_compile=True), + NS(length=1836, workgroup_size=256, threads_per_transform=153, factors=(17, 3, 3, 2, 6), runtime_compile=True), + NS(length=1848, workgroup_size=256, threads_per_transform=231, factors=(3, 11, 7, 4, 2), runtime_compile=True), + NS(length=1859, workgroup_size=256, threads_per_transform=169, factors=(13, 11, 13), runtime_compile=True), + NS(length=1870, workgroup_size=256, threads_per_transform=187, factors=(17, 10, 11), runtime_compile=True), + NS(length=1872, workgroup_size=256, threads_per_transform=156, factors=(13, 3, 4, 6, 2), runtime_compile=True), + NS(length=1875, workgroup_size=256, threads_per_transform=125, factors=(5, 5, 5, 5, 3), runtime_compile=True), + NS(length=1890, workgroup_size=128, threads_per_transform=126, factors=(2, 3, 3, 3, 7, 5), runtime_compile=True), + NS(length=1904, workgroup_size=128, threads_per_transform=119, factors=(17, 2, 2, 7, 4), runtime_compile=True), + NS(length=1911, workgroup_size=128, threads_per_transform= 91, factors=(13, 7, 7, 3), runtime_compile=True), + NS(length=1920, workgroup_size=256, threads_per_transform=120, factors=(10, 6, 16, 2), runtime_compile=True), + NS(length=1925, workgroup_size= 64, threads_per_transform= 55, factors=(7, 11, 5, 5), runtime_compile=True), + NS(length=1936, workgroup_size=256, threads_per_transform=176, factors=(2, 2, 4, 11, 11), half_lds=False, runtime_compile=True), + NS(length=1944, workgroup_size=256, threads_per_transform=243, factors=(3, 3, 3, 3, 8, 3), runtime_compile=True), + NS(length=1950, workgroup_size=256, threads_per_transform=195, factors=(13, 5, 10, 3), half_lds=False, runtime_compile=True), + NS(length=1960, workgroup_size= 64, threads_per_transform= 56, factors=(4, 7, 2, 7, 5), runtime_compile=True), + NS(length=1980, workgroup_size=256, threads_per_transform=198, factors=(11, 2, 3, 3, 5, 2), runtime_compile=True), + NS(length=1989, workgroup_size=256, threads_per_transform=153, factors=(17, 13, 9), runtime_compile=True), + NS(length=2000, workgroup_size=128, threads_per_transform=125, factors=(5, 5, 5, 16), runtime_compile=True), + NS(length=2002, workgroup_size=256, threads_per_transform=182, factors=(2, 13, 7, 11), runtime_compile=True), + NS(length=2016, workgroup_size=256, threads_per_transform=112, factors=(2, 2, 2, 2, 2, 3, 3, 7), runtime_compile=True), + NS(length=2023, workgroup_size=128, threads_per_transform=119, factors=(17, 7, 17), runtime_compile=True), + NS(length=2025, workgroup_size=256, threads_per_transform=135, factors=(3, 3, 5, 5, 3, 3), runtime_compile=True), + NS(length=2028, workgroup_size=256, threads_per_transform=156, factors=(13, 4, 3, 13), half_lds=False, runtime_compile=True), + NS(length=2040, workgroup_size=256, threads_per_transform=170, factors=(17, 4, 3, 10), runtime_compile=True), + NS(length=2048, workgroup_size=256, threads_per_transform=256, factors=(16, 16, 8), runtime_compile=True), + NS(length=2160, workgroup_size=256, threads_per_transform= 60, factors=(10, 6, 6, 6), runtime_compile=True), + NS(length=2187, workgroup_size=256, threads_per_transform=243, factors=(3, 3, 3, 3, 3, 3, 3), runtime_compile=True), + NS(length=2197, workgroup_size=256, threads_per_transform=169, factors=(13, 13, 13), runtime_compile=True), + NS(length=2250, workgroup_size=256, threads_per_transform= 90, factors=(10, 3, 5, 3, 5), runtime_compile=True), + NS(length=2304, workgroup_size=256, threads_per_transform=192, factors=(6, 6, 4, 4, 4), runtime_compile=True), + NS(length=2400, workgroup_size=256, threads_per_transform=240, factors=(4, 10, 10, 6), runtime_compile=True), + NS(length=2401, workgroup_size=256, threads_per_transform= 49, factors=(7, 7, 7, 7), runtime_compile=True), + NS(length=2430, workgroup_size=256, threads_per_transform= 81, factors=(10, 3, 3, 3, 3, 3), runtime_compile=True), + NS(length=2500, workgroup_size=256, threads_per_transform=250, factors=(10, 5, 10, 5), runtime_compile=True), + NS(length=2560, workgroup_size=128, threads_per_transform=128, factors=(4, 4, 4, 10, 4), runtime_compile=True), + NS(length=2592, workgroup_size=256, threads_per_transform=216, factors=(6, 6, 6, 6, 2), runtime_compile=True), + NS(length=2700, workgroup_size=128, threads_per_transform= 90, factors=(3, 10, 10, 3, 3), runtime_compile=True), + NS(length=2880, workgroup_size=256, threads_per_transform= 96, factors=(10, 6, 6, 2, 2, 2), runtime_compile=True), + NS(length=2916, workgroup_size=256, threads_per_transform=243, factors=(6, 6, 3, 3, 3, 3), runtime_compile=True), + NS(length=3000, workgroup_size=128, threads_per_transform=100, factors=(10, 3, 10, 10), runtime_compile=True), + NS(length=3072, workgroup_size=256, threads_per_transform=256, factors=(6, 4, 4, 4, 4, 2), runtime_compile=True), + NS(length=3125, workgroup_size=128, threads_per_transform=125, factors=(5, 5, 5, 5, 5), runtime_compile=True), + NS(length=3200, workgroup_size=256, threads_per_transform=160, factors=(10, 10, 4, 4, 2), runtime_compile=True), + NS(length=3240, workgroup_size=128, threads_per_transform=108, factors=(3, 3, 10, 6, 6), runtime_compile=True), + NS(length=3375, workgroup_size=256, threads_per_transform=225, factors=(5, 5, 5, 3, 3, 3), runtime_compile=True), + NS(length=3456, workgroup_size=256, threads_per_transform=144, factors=(6, 6, 6, 4, 4), runtime_compile=True), + NS(length=3600, workgroup_size=256, threads_per_transform=120, factors=(10, 10, 6, 6), runtime_compile=True), + NS(length=3645, workgroup_size=256, threads_per_transform=243, factors=(5, 3, 3, 3, 3, 3, 3), runtime_compile=True), + NS(length=3750, workgroup_size=256, threads_per_transform=125, factors=(3, 5, 5, 10, 5), runtime_compile=True), + NS(length=3840, workgroup_size=256, threads_per_transform=128, factors=(10, 6, 2, 2, 2, 2, 2, 2), runtime_compile=True), + NS(length=3888, workgroup_size=512, threads_per_transform=324, factors=(16, 3, 3, 3, 3, 3), runtime_compile=True), + NS(length=4000, workgroup_size=256, threads_per_transform=200, factors=(10, 10, 10, 4), runtime_compile=True), + NS(length=4050, workgroup_size=256, threads_per_transform=135, factors=(10, 5, 3, 3, 3, 3), runtime_compile=True), + NS(length=4096, workgroup_size=256, threads_per_transform=256, factors=(16, 16, 16), runtime_compile=True), + NS(length=4704, workgroup_size=256, threads_per_transform=224, factors=(8, 4, 7, 7, 3), double_precision=False, runtime_compile=True), + NS(length=5488, workgroup_size=256, threads_per_transform=196, factors=(7, 4, 7, 4, 7), double_precision=False, runtime_compile=True), + NS(length=6144, workgroup_size=512, threads_per_transform=512, factors=(16, 4, 8, 3, 4), double_precision=False, runtime_compile=True), + NS(length=6561, workgroup_size=256, threads_per_transform=243, factors=(3, 3, 3, 3, 3, 3, 3, 3), double_precision=False, runtime_compile=True), + NS(length=8192, workgroup_size=512, threads_per_transform=512, factors=(16, 4, 4, 4, 8), double_precision=False, runtime_compile=True), + + # configs for 160kiB LDS + NS(length=4704, workgroup_size=256, threads_per_transform=224, factors=(8, 4, 7, 7, 3), lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), + NS(length=5488, workgroup_size=256, threads_per_transform=196, factors=(7, 4, 7, 4, 7), lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), + NS(length=6144, workgroup_size=384, threads_per_transform=256, factors=(4, 8, 8, 8, 3), lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), + NS(length=6561, workgroup_size=256, threads_per_transform=243, factors=(3, 3, 3, 3, 3, 3, 3, 3), lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), + NS(length=8192, workgroup_size=512, threads_per_transform=512, factors=(16, 4, 16, 8), lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), + NS(length=9216, workgroup_size=512, threads_per_transform=512, factors=(4, 8, 4, 4, 3, 6), lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), + NS(length=10000, workgroup_size=512, threads_per_transform=500, factors=(4, 5, 5, 10, 10), lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), + NS(length=10240, workgroup_size=512, threads_per_transform=512, factors=(8, 4, 4, 4, 5, 4), lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), + NS(length=10752, workgroup_size=512, threads_per_transform=512, factors=(4, 16, 8, 7, 3), double_precision=False, lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), + NS(length=11200, workgroup_size=512, threads_per_transform=448, factors=(4, 7, 5, 16, 5), double_precision=False, lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), + NS(length=12288, workgroup_size=512, threads_per_transform=512, factors=(8, 8, 4, 6, 8), double_precision=False, lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), + NS(length=16384, workgroup_size=512, threads_per_transform=512, factors=(8, 16, 4, 8, 4), double_precision=False, lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), + NS(length=16807, workgroup_size=384, threads_per_transform=343, factors=(7, 7, 7, 7, 7), double_precision=False, lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), + NS(length=18816, workgroup_size=512, threads_per_transform=448, factors=(8, 8, 7, 7, 6), double_precision=False, lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), + NS(length=19200, workgroup_size=512, threads_per_transform=480, factors=(8, 10, 8, 5, 6), double_precision=False, lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), + NS(length=20480, workgroup_size=512, threads_per_transform=512, factors=(4, 4, 16, 10, 8), double_precision=False, lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), +] \ No newline at end of file From d1f241de9fbe5b25d63c723dd85a57f89839aeb5 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 6 Jun 2025 16:37:35 -0600 Subject: [PATCH 44/69] - Fix issue with wgs / tpt / tpb in partial-pass kernel configuration. --- library/src/device/generator/stockham_gen.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/library/src/device/generator/stockham_gen.cpp b/library/src/device/generator/stockham_gen.cpp index 65b9c5abc9b..f9d352117ea 100644 --- a/library/src/device/generator/stockham_gen.cpp +++ b/library/src/device/generator/stockham_gen.cpp @@ -547,17 +547,21 @@ void validate_pp_grid_params(const StockhamPartialPassParams& params_1, if((params_1.current_dim == 0 && params_2.current_dim == 2) || (params_1.current_dim == 2 && params_2.current_dim == 0)) { - // SBRR needs tpb to be at least max(pp_factors), + // SBRR needs tpb to be prod(pp_factors), // so that it has the required off-dim data in LDS - // to perform partial pass + // to perform partial passes auto tpb_sbrr = (params_1.current_dim == 0 && params_2.current_dim == 2) ? specs_1.workgroup_size / specs_1.threads_per_transform : specs_2.workgroup_size / specs_2.threads_per_transform; - if(tpb_sbrr < *std::max_element(params_1.factors_off_dim.begin(), - params_1.factors_off_dim.end())) + + unsigned int prod_factors_off_dim = std::accumulate(params_1.factors_off_dim.begin(), + params_1.factors_off_dim.end(), + 1, + std::multiplies()); + if(tpb_sbrr != prod_factors_off_dim) { throw std::runtime_error("CS_KERNEL_STOCKHAM_PP requires transform-per-block " - "to be at least max(pp_factors)"); + "to be prod(pp_factors)"); } } // SBCC_PP + SBCC_PP @@ -704,10 +708,12 @@ int main() StockhamGeneratorSpecs specs1(factors1, {}, precisions, workgroup_size[0], scheme); specs1.direct_to_from_reg = direct_to_from_reg[0]; specs1.threads_per_transform = threads_per_transform[0]; + specs1.wgs_is_derived = true; StockhamGeneratorSpecs specs2(factors2, {}, precisions, workgroup_size[1], scheme); specs2.direct_to_from_reg = direct_to_from_reg[1]; specs2.threads_per_transform = threads_per_transform[1]; + specs2.wgs_is_derived = true; StockhamPartialPassParams pp_params_1(parent_length, dims[0], off_dim, pp_factors1); StockhamPartialPassParams pp_params_2(parent_length, dims[1], off_dim, pp_factors2); From 78dba8f81a829f23dabca02165b6961008864ed6 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Mon, 9 Jun 2025 16:22:32 -0600 Subject: [PATCH 45/69] - Add validation and fix hardcoded offset calculation in partial pass kernel. --- .../src/device/generator/stockham_pp_gen_rr.h | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index ba740192ac3..9be0fbe511a 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -25,7 +25,6 @@ // - Revisit all usages of transform_per_block and max_factor_pp. // - Test with factors_pp.size() > 1 // - Revisit lstride usage and input/output strides -// - Revisit factor 64 logic in calculate_offsets() with different input lengths // Variation of StockhamKernelRR that implements the partial pass // method. Similarities of StockhamPartialPassKernelRR with @@ -53,11 +52,25 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR max_factor_pp = *std::max_element(factors_pp.begin(), factors_pp.end()); + length_off_dim = params.parent_length[params.off_dim]; + R.size = Expression{std::max(nregisters, max_factor_pp)}; + + // nregister must not be larger than max_factor_pp. + // If that were to be true, work in the off-dimension + // in perform_partial_pass_step_1_2() would require to + // to be applied to (nregisters-max_factor_pp) elements + // in the off-dimension, but this data is not available + // in the LDS, and the number of additional elements + // would need to be at least a multiple of max_factor_pp. + if(nregisters > max_factor_pp) + throw std::runtime_error( + "StockhamPartialPassKernelRR: nregisters cannot be larger than max_factor_pp"); } StockhamPartialPassParams params; + unsigned int length_off_dim; unsigned int max_factor_pp; std::vector factors_pp; unsigned int length_pp; @@ -84,8 +97,10 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR block_id * transforms_per_block + thread_id / threads_per_transform}; stmts += Assign{remaining, transform}; stmts += Assign{remaining_pp, - 64 * Parens(transform / 64) + Parens(transform % 64) / transforms_per_block - + Parens(transform * (64 / transforms_per_block)) % 64}; + length_off_dim * Parens(transform / length_off_dim) + + Parens(transform % length_off_dim) / transforms_per_block + + Parens(transform * (length_off_dim / transforms_per_block)) + % length_off_dim}; stmts += For{d, 1, From e232ad001a77fe71312789fbaf5150e1ed7b82cf Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Mon, 23 Jun 2025 15:35:28 -0600 Subject: [PATCH 46/69] - Fixes for local transposition in SBCC partial-pass kernel generator. - Add separate python files for kernel configurations. --- library/src/CMakeLists.txt | 10 + library/src/device/CMakeLists.txt | 6 + library/src/device/generator/stockham_gen.cpp | 51 ++++-- library/src/device/generator/stockham_gen.h | 9 +- .../src/device/generator/stockham_pp_gen_cc.h | 77 +++++--- .../src/device/generator/stockham_pp_gen_rr.h | 2 +- library/src/device/kernel-generator.py | 173 ++++++++++-------- .../device/kernels/configs/config_pp_3d.py | 4 +- library/src/include/function_pool.h | 22 ++- library/src/rtc_stockham_gen.cpp | 4 +- library/src/rtc_stockham_kernel.cpp | 7 +- library/src/tree_node.cpp | 4 +- 12 files changed, 226 insertions(+), 143 deletions(-) diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index cdecfaae93a..6a17f8bfe97 100644 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -150,6 +150,16 @@ set( kgen_logic_files # python code that decides kernel parameters ${CMAKE_SOURCE_DIR}/library/src/device/kernel-generator.py ${CMAKE_SOURCE_DIR}/library/src/device/generator.py + ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_lds.py + ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_sbrr.py + ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_sbcc.py + ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_sbcr.py + ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_sbrc.py + ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_2d_single.py + ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_pp_3d.py + + + # stockham generator code ${CMAKE_SOURCE_DIR}/library/src/device/generator/generator.h diff --git a/library/src/device/CMakeLists.txt b/library/src/device/CMakeLists.txt index f1df908cb6a..d4a2f53be12 100644 --- a/library/src/device/CMakeLists.txt +++ b/library/src/device/CMakeLists.txt @@ -58,6 +58,12 @@ endif() # Make it possible to let install.sh control this ? set( kgen ${CMAKE_SOURCE_DIR}/library/src/device/kernel-generator.py ) set( kgendeps ${CMAKE_SOURCE_DIR}/library/src/device/kernel-generator.py + ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_sbrr.py + ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_sbcc.py + ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_sbcr.py + ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_sbrc.py + ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_2d_single.py + ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_pp_3d.py ${CMAKE_SOURCE_DIR}/library/src/device/generator.py ) # create list of all N files that will initialize function pool diff --git a/library/src/device/generator/stockham_gen.cpp b/library/src/device/generator/stockham_gen.cpp index f9d352117ea..32657f58987 100644 --- a/library/src/device/generator/stockham_gen.cpp +++ b/library/src/device/generator/stockham_gen.cpp @@ -47,7 +47,8 @@ struct GeneratedLauncher GeneratedLauncher(StockhamKernel& kernel, const std::string& scheme, const std::string& pp_child_scheme, - const std::vector& pp_factors, + const std::vector& pp_factors_curr, + const std::vector& pp_factors_other, const unsigned int& pp_current_dim, const unsigned int& pp_off_dim, bool double_precision, @@ -55,7 +56,8 @@ struct GeneratedLauncher const std::string& sbrc_transpose_type) : scheme(scheme) , pp_child_scheme(pp_child_scheme) - , pp_factors(pp_factors) + , pp_factors_curr(pp_factors_curr) + , pp_factors_other(pp_factors_other) , pp_current_dim(pp_current_dim) , pp_off_dim(pp_off_dim) , lengths(kernel.launcher_lengths()) @@ -72,7 +74,8 @@ struct GeneratedLauncher std::string scheme; std::string pp_child_scheme; - std::vector pp_factors; + std::vector pp_factors_curr; + std::vector pp_factors_other; unsigned int pp_current_dim; unsigned int pp_off_dim; std::vector lengths; @@ -127,7 +130,8 @@ struct GeneratedLauncher add_member("sbrc_transpose_type", quote_str(sbrc_transpose_type)); add_member("double_precision", double_precision ? "true" : "false"); add_member("pp_child_scheme", quote_str(pp_child_scheme)); - add_member("pp_factors", vec_to_list(pp_factors)); + add_member("pp_factors_curr", vec_to_list(pp_factors_curr)); + add_member("pp_factors_other", vec_to_list(pp_factors_other)); add_member("pp_current_dim", std::to_string(pp_current_dim)); add_member("pp_off_dim", std::to_string(pp_off_dim)); @@ -148,7 +152,8 @@ void make_launcher(const std::vector& precision_types, const std::vector& launcher_suffixes, StockhamKernel& kernel, const std::string& pp_child_scheme, - const std::vector& pp_factors, + const std::vector& pp_factors_curr, + const std::vector& pp_factors_other, const unsigned int& pp_current_dim, const unsigned int& pp_off_dim, std::vector& generated_launchers) @@ -160,7 +165,8 @@ void make_launcher(const std::vector& precision_types, generated_launchers.emplace_back(kernel, launcher.scheme, pp_child_scheme, - pp_factors, + pp_factors_curr, + pp_factors_other, pp_current_dim, pp_off_dim, precision_type == rocfft_precision_double, @@ -252,7 +258,8 @@ void stockham_partial_pass_variants(const std::string& kernel_name {{"pp_stoc", specs1.scheme, "", ""}}, kernelRR, "CS_KERNEL_STOCKHAM_PP", - params_1.factors_off_dim, + params_1.pp_factors_curr, + params_1.pp_factors_other, params_1.current_dim, params_1.off_dim, launchers); @@ -262,7 +269,8 @@ void stockham_partial_pass_variants(const std::string& kernel_name {{"pp_sbcc", specs2.scheme, "", ""}}, kernelCC, "CS_KERNEL_STOCKHAM_PP_BLOCK_CC", - params_2.factors_off_dim, + params_2.pp_factors_curr, + params_2.pp_factors_other, params_2.current_dim, params_2.off_dim, launchers); @@ -274,7 +282,8 @@ void stockham_partial_pass_variants(const std::string& kernel_name {{"pp_sbcc", specs1.scheme, "", ""}}, kernelCC, "CS_KERNEL_STOCKHAM_PP_BLOCK_CC", - params_1.factors_off_dim, + params_1.pp_factors_curr, + params_1.pp_factors_other, params_1.current_dim, params_1.off_dim, launchers); @@ -284,7 +293,8 @@ void stockham_partial_pass_variants(const std::string& kernel_name {{"pp_stoc", specs2.scheme, "", ""}}, kernelRR, "CS_KERNEL_STOCKHAM_PP", - params_2.factors_off_dim, + params_2.pp_factors_curr, + params_2.pp_factors_other, params_2.current_dim, params_2.off_dim, launchers); @@ -334,6 +344,7 @@ void stockham_variants(const std::string& kernel_name, kernel, "CS_NONE", std::vector(), + std::vector(), 0, 0, launchers); @@ -346,6 +357,7 @@ void stockham_variants(const std::string& kernel_name, kernel, "CS_NONE", std::vector(), + std::vector(), 0, 0, launchers); @@ -396,6 +408,7 @@ void stockham_variants(const std::string& kernel_name, kernel, "CS_NONE", std::vector(), + std::vector(), 0, 0, launchers); @@ -409,6 +422,7 @@ void stockham_variants(const std::string& kernel_name, kernel, "CS_NONE", std::vector(), + std::vector(), 0, 0, launchers); @@ -424,6 +438,7 @@ void stockham_variants(const std::string& kernel_name, specs.scheme, "CS_NONE", std::vector(), + std::vector(), 0, 0, (prec_type == rocfft_precision_double), @@ -520,10 +535,10 @@ void validate_pp_length(const StockhamPartialPassParams& pp_params, void validate_pp_off_dim_length(const StockhamPartialPassParams& pp_params_1, const StockhamPartialPassParams& pp_params_2) { - auto off_factors_all = pp_params_1.factors_off_dim; + auto off_factors_all = pp_params_1.pp_factors_curr; off_factors_all.insert(off_factors_all.end(), - pp_params_2.factors_off_dim.begin(), - pp_params_2.factors_off_dim.end()); + pp_params_2.pp_factors_curr.begin(), + pp_params_2.pp_factors_curr.end()); unsigned int length_off_dim = std::accumulate( off_factors_all.begin(), off_factors_all.end(), 1, std::multiplies()); @@ -554,8 +569,8 @@ void validate_pp_grid_params(const StockhamPartialPassParams& params_1, ? specs_1.workgroup_size / specs_1.threads_per_transform : specs_2.workgroup_size / specs_2.threads_per_transform; - unsigned int prod_factors_off_dim = std::accumulate(params_1.factors_off_dim.begin(), - params_1.factors_off_dim.end(), + unsigned int prod_factors_off_dim = std::accumulate(params_1.pp_factors_curr.begin(), + params_1.pp_factors_curr.end(), 1, std::multiplies()); if(tpb_sbrr != prod_factors_off_dim) @@ -715,8 +730,10 @@ int main() specs2.threads_per_transform = threads_per_transform[1]; specs2.wgs_is_derived = true; - StockhamPartialPassParams pp_params_1(parent_length, dims[0], off_dim, pp_factors1); - StockhamPartialPassParams pp_params_2(parent_length, dims[1], off_dim, pp_factors2); + StockhamPartialPassParams pp_params_1( + parent_length, dims[0], off_dim, pp_factors1, pp_factors2); + StockhamPartialPassParams pp_params_2( + parent_length, dims[1], off_dim, pp_factors2, pp_factors1); validate_pp_length(pp_params_1, factors1); validate_pp_length(pp_params_2, factors2); diff --git a/library/src/device/generator/stockham_gen.h b/library/src/device/generator/stockham_gen.h index feb4b1acf4f..0c9ab2d9431 100644 --- a/library/src/device/generator/stockham_gen.h +++ b/library/src/device/generator/stockham_gen.h @@ -87,18 +87,21 @@ struct StockhamPartialPassParams StockhamPartialPassParams(const std::vector& parent_length, const unsigned int current_dim, const unsigned int off_dim, - const std::vector& factors_off_dim) + const std::vector& pp_factors_curr, + const std::vector& pp_factors_other) : parent_length(parent_length) , current_dim(current_dim) , off_dim(off_dim) - , factors_off_dim(factors_off_dim) + , pp_factors_curr(pp_factors_curr) + , pp_factors_other(pp_factors_other) { } std::vector parent_length; unsigned int current_dim = 0; unsigned int off_dim = 0; - std::vector factors_off_dim; + std::vector pp_factors_curr; + std::vector pp_factors_other; }; void stockham_partial_pass_variants(const std::string& kernel_name, diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index 1d765dd075f..562f4833c68 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -51,16 +51,21 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC bool largeTwdBatchIsTransformCount) : StockhamKernelCC(specs, largeTwdBatchIsTransformCount, false) , params(params) + , factors_pp_curr(params.pp_factors_curr) + , factors_pp_other(params.pp_factors_other) { - factors_pp = params.factors_off_dim; - - max_factor_pp = *std::max_element(factors_pp.begin(), factors_pp.end()); + max_factor_pp = *std::max_element(factors_pp_curr.begin(), factors_pp_curr.end()); transforms_per_block_pp = transforms_per_block; transforms_per_block *= max_factor_pp; workgroup_size *= max_factor_pp; + + pp_factors_curr_prod = std::accumulate( + factors_pp_curr.begin(), factors_pp_curr.end(), 1, std::multiplies()); + pp_factors_other_prod = std::accumulate( + factors_pp_other.begin(), factors_pp_other.end(), 1, std::multiplies()); } StockhamPartialPassParams params; @@ -68,7 +73,11 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC unsigned int transforms_per_block_pp; unsigned int max_factor_pp; - std::vector factors_pp; + std::vector factors_pp_curr; + unsigned int pp_factors_curr_prod; + + std::vector factors_pp_other; + unsigned int pp_factors_other_prod; Variable thread_lds{"thread_lds", "unsigned int"}; Variable stride_lds_pp{"stride_lds_pp", "unsigned int"}; @@ -255,6 +264,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += Assign{plength, num_of_tiles}; stmts += Assign{tile_index, block_id % num_of_tiles}; + // TODO: factor 128 stmts += Assign{remaining, (block_id % 128) / num_of_tiles}; stmts += Assign{offset, tile_index * transforms_per_block_pp * stride[1]}; stmts += For{d, @@ -287,6 +297,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += Declaration{thread_lds, thread_id / transforms_per_block_pp}; stmts += Declaration{tid_hor_lds, thread_id % transforms_per_block_pp}; + // TODO: length stmts += Declaration( tid_hor_pp, thread_id % transforms_per_block_pp + length * (thread % max_factor_pp)); stmts += Declaration(thread_new, thread_id / (transforms_per_block_pp * max_factor_pp)); @@ -295,6 +306,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += Declaration(thread_idx, thread_id); stmts += Declaration(block_idx, block_id); + // TODO: length and TODO: Literal{192} stmts += Declaration( offset_pp, offset + Parens(offset / length) * Literal{192} + batch_new * stride[dim]); stmts += Declaration(offset_tid_hor, offset_pp + tid_hor_pp * stride[1]); @@ -439,7 +451,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC tmp_stmts += StoreGlobal{ buf, CallExpr{"local_transpose_pp_length" + std::to_string(length) + "_device", - {offset_tile_wbuf(i)}}, + {offset_tile_wbuf(i), lengths}}, lds_complex[offset_tile_rlds(i)]}; stmts += CommentLines{ @@ -659,30 +671,43 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC = "local_transpose_pp_length" + std::to_string(length) + "_device"; Function f{function_name}; - f.arguments = ArgumentList{global_idx}; + f.arguments = ArgumentList{global_idx, lengths}; f.return_type = "unsigned int"; f.qualifier = "__device__"; StatementList& body = f.body; - auto factor_transpose_1 = (length * length) / max_factor_pp; - auto factor_transpose_2 = length * max_factor_pp; - auto factor_transpose_3 = length * length; - auto factor_transpose_4 = length * length - length; - auto factor_transpose_5 = length * length * length; - - body += Declaration{transpose_idx, global_idx % factor_transpose_5}; - body += Assign{transpose_idx, - Parens((transpose_idx % length) - + Parens(Parens(transpose_idx % factor_transpose_1) / length) - * factor_transpose_1) - % factor_transpose_4 - + (Parens(transpose_idx / factor_transpose_1) * factor_transpose_2) - + Parens(transpose_idx / factor_transpose_3) - * (factor_transpose_3 - factor_transpose_1)}; - - body += Assign{transpose_idx, - transpose_idx + global_idx / factor_transpose_5 * factor_transpose_5}; + auto len_1 = lengths[2]; + auto len_2 = lengths[1]; + auto len_3 = lengths[0]; + auto len_1_2 = len_1 * len_2; + auto len_1_2_3 = len_1 * len_2 * len_3; + + auto pp_factor_1 = pp_factors_curr_prod; + auto pp_factor_2 = pp_factors_other_prod; + + auto pp_factor_len = pp_factor_1 * len_2; + + body += Declaration{transpose_idx, global_idx % len_1_2_3}; + + Variable idx_1{"idx_1", "unsigned int"}; + body += Declaration{idx_1, transpose_idx % len_2}; + + Variable idx_2{"idx_2", "unsigned int"}; + body += Declaration{idx_2, transpose_idx % pp_factor_len}; + body += Assign{idx_2, idx_2 / len_2}; + body += Assign{idx_2, idx_2 * pp_factor_2 * len_2}; + + Variable idx_3{"idx_3", "unsigned int"}; + body += Declaration{idx_3, transpose_idx % len_1_2}; + body += Assign{idx_3, Parens{idx_3 / pp_factor_len} * len_2}; + + Variable idx_4{"idx_4", "unsigned int"}; + body += Declaration{idx_4, Parens{transpose_idx / len_1_2} * len_1_2}; + + body += Assign{transpose_idx, idx_1 + idx_2 + idx_3 + idx_4}; + + body += Assign{transpose_idx, transpose_idx + Parens{global_idx / len_1_2_3} * len_1_2_3}; body += ReturnExpr(transpose_idx); @@ -706,9 +731,9 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC pre_post_lds_args}; stmts += preLoad; - for(unsigned int npass = 0; npass < factors_pp.size(); ++npass) + for(unsigned int npass = 0; npass < factors_pp_curr.size(); ++npass) { - unsigned int width = factors_pp[npass]; + unsigned int width = factors_pp_curr[npass]; unsigned int height = threads_per_transform / max_factor_pp; auto butterfly = std::mem_fn(&StockhamKernel::butterfly_generator); diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index 9be0fbe511a..47acd47cc6c 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -48,7 +48,7 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR , params(params) { length_pp = params.parent_length[params.off_dim]; - factors_pp = params.factors_off_dim; + factors_pp = params.pp_factors_curr; max_factor_pp = *std::max_element(factors_pp.begin(), factors_pp.end()); diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index 23ebfdec582..d824d81a815 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -147,12 +147,16 @@ def __str__(self): f += ', ' f += 'true' if aot_rtc else 'false' - f += ', ' + str(self.function.meta.pp_child_scheme) + f += ', ' + str(self.function.meta.pp_child_scheme) f += ', ' + str(self.function.meta.pp_current_dim) f += ', ' + str(self.function.meta.pp_off_dim) - pp_factors = getattr(self.function.meta, 'pp_factors', None) - if pp_factors is not None: - f += ', {' + cjoin(pp_factors) + '}' + pp_factors_curr = getattr(self.function.meta, 'pp_factors_curr', None) + if pp_factors_curr is not None: + f += ', {' + cjoin(pp_factors_curr) + '}' + pp_factors_other = getattr(self.function.meta, 'pp_factors_other', + None) + if pp_factors_other is not None: + f += ', {' + cjoin(pp_factors_other) + '}' f += ')' return f @@ -197,8 +201,9 @@ def generate_cpu_function_pool_pieces(functions, pp_functions, num_files): # Init list to store contents of function_pool_init function per file being generated piece_contents = [ - StatementList() + var_kernel.declaration() + var_pp_kernel_1.declaration() + - var_pp_kernel_2.declaration() for _ in range(num_files) + StatementList() + var_kernel.declaration() + + var_pp_kernel_1.declaration() + var_pp_kernel_2.declaration() + for _ in range(num_files) ] # Cycles through each file per loop execution to distribute work amongst N files @@ -210,46 +215,45 @@ def generate_cpu_function_pool_pieces(functions, pp_functions, num_files): length, precision, scheme, transpose = f.meta.length, f.meta.precision, f.meta.scheme, f.meta.transpose if scheme == 'CS_3D_PP': - piece_contents[curr_file] += Assign(var_pp_kernel_1, FFTKernel(f)) + piece_contents[curr_file] += Assign(var_pp_kernel_1, FFTKernel(f)) f = all_functions[curr_func + curr_func_offset] piece_contents[curr_file] += Assign(var_pp_kernel_2, FFTKernel(f)) - - key = Call( - name='PPFMKey', - arguments=ArgumentList(length[0], length[1], length[2], - precisions[precision], - scheme, 'pp_kernel_1.get_kernel_config()', - 'pp_kernel_2.get_kernel_config()')).inline() + + key = Call(name='PPFMKey', + arguments=ArgumentList( + length[0], length[1], length[2], + precisions[precision], scheme, + 'pp_kernel_1.get_kernel_config()', + 'pp_kernel_2.get_kernel_config()')).inline() piece_contents[curr_file] += function_map.insert_pp( - key, var_pp_kernel_1, var_pp_kernel_2, 'std::get<1>(def_keys)', + key, var_pp_kernel_1, var_pp_kernel_2, 'std::get<1>(def_keys)', 'std::get<1>(function_maps)', f.meta.lds_size_bytes) - + curr_pp_func = curr_pp_func + 1 else: if isinstance(length, (int, str)): length = [length, 0] piece_contents[curr_file] += Assign(var_kernel, FFTKernel(f)) - key = Call( - name='FMKey', - arguments=ArgumentList(length[0], length[1], precisions[precision], - scheme, transpose or 'NONE', - 'kernel.get_kernel_config()')).inline() + key = Call(name='FMKey', + arguments=ArgumentList( + length[0], length[1], precisions[precision], scheme, + transpose or 'NONE', + 'kernel.get_kernel_config()')).inline() piece_contents[curr_file] += function_map.insert( - key, var_kernel, 'std::get<0>(def_keys)', 'std::get<0>(function_maps)', - f.meta.lds_size_bytes) - + key, var_kernel, 'std::get<0>(def_keys)', + 'std::get<0>(function_maps)', f.meta.lds_size_bytes) + if curr_pp_func == len(precisions): curr_func, curr_pp_func = curr_func + len(precisions) + 1, 0 else: curr_func = curr_func + 1 - + curr_file = (curr_file + 1) % num_files # Assemble contents of each file to return in a list pieces = [None] * num_files - piece_args = ArgumentList( - 'std::tuple& def_keys', - 'std::tuple& function_maps') + piece_args = ArgumentList('std::tuple& def_keys', + 'std::tuple& function_maps') for k in range(num_files): pieces[k] = StatementList( Include('"../include/function_pool.h"'), @@ -288,19 +292,23 @@ def list_small_kernels(): """Return list of small kernels to generate.""" kernels1d = config_sbrr.sbrr_kernels - kernels = [NS(**kernel.__dict__, - scheme='CS_KERNEL_STOCKHAM', - precision=['sp','dp'] if not hasattr(kernel, 'double_precision') or kernel.double_precision else ['sp']) for kernel in kernels1d] - + kernels = [ + NS(**kernel.__dict__, + scheme='CS_KERNEL_STOCKHAM', + precision=['sp', 'dp'] if not hasattr(kernel, 'double_precision') + or kernel.double_precision else ['sp']) for kernel in kernels1d + ] + return kernels + def list_large_kernels(): """Return list of large kernels to generate.""" # for SBCC kernel, increase desired workgroup_size so that columns per # thread block is also increased. currently targeting for 16 columns block_width = 16 - sbcc_kernels = config_sbcc.sbcc_kernels + sbcc_kernels = config_sbcc.sbcc_kernels for k in sbcc_kernels: k.scheme = 'CS_KERNEL_STOCKHAM_BLOCK_CC' if not hasattr(k, 'workgroup_size'): @@ -310,10 +318,9 @@ def list_large_kernels(): k.workgroup_size = min(1024, k.workgroup_size * 2) if not hasattr(k, 'length'): k.length = functools.reduce(lambda a, b: a * b, k.factors) - - + block_width = 16 - sbcr_kernels = config_sbcr.sbcr_kernels + sbcr_kernels = config_sbcr.sbcr_kernels for k in sbcr_kernels: k.scheme = 'CS_KERNEL_STOCKHAM_BLOCK_CR' k.half_lds = False @@ -323,38 +330,40 @@ def list_large_kernels(): if not hasattr(k, 'length'): k.length = functools.reduce(lambda a, b: a * b, k.factors) - sbrc_kernels = config_sbrc.sbrc_kernels for k in sbrc_kernels: k.half_lds = False - return config_sbcc.sbcc_kernels + config_sbcr.sbcr_kernels + config_sbrc.sbrc_kernels + def list_2d_kernels(): """Return list of fused 2D kernels to generate.""" fused_2d_kernels = config_2d_single.fused_2d_kernels expanded = [] - expanded.extend(NS(**kernel.__dict__, - scheme='CS_KERNEL_2D_SINGLE', runtime_compile=True) for kernel in fused_2d_kernels) + expanded.extend( + NS(**kernel.__dict__, + scheme='CS_KERNEL_2D_SINGLE', + runtime_compile=True) for kernel in fused_2d_kernels) return expanded + def list_3d_partial_pass_kernels(): """Return list of partial-pass 3D kernels to generate.""" - - pp_3d_kernels = [ - NS(length=[64,64,64], dims=[0, 2], factors=[[8, 8],[4, 4, 4]], factors_pp=[[4],[16]], threads_per_transform=[8, 8], workgroup_size=[64,128], direct_to_from_reg=[False, False]), - ] + + pp_3d_kernels = config_pp_3d.pp_3d_kernels expanded = [] - expanded.extend(NS(**kernel.__dict__, - scheme='CS_3D_PP', runtime_compile=True) for kernel in pp_3d_kernels) + expanded.extend( + NS(**kernel.__dict__, scheme='CS_3D_PP', runtime_compile=True) + for kernel in pp_3d_kernels) return expanded + def default_runtime_compile(kernels, default_val): '''Returns a copy of input kernel list with a default value for runtime_compile.''' @@ -381,13 +390,14 @@ def generate_kernel_functions(kernels, precisions, launchers_json): launcher = NS(**launcher_dict) factors = launcher.factors - + if len(launcher.lengths) == 1: length = launcher.lengths[0] elif len(launcher.lengths) == 2: length = (launcher.lengths[0], launcher.lengths[1]) elif len(launcher.lengths) == 3: - length = (launcher.lengths[0], launcher.lengths[1], launcher.lengths[2]) + length = (launcher.lengths[0], launcher.lengths[1], + launcher.lengths[2]) transforms_per_block = launcher.transforms_per_block workgroup_size = launcher.workgroup_size @@ -396,7 +406,8 @@ def generate_kernel_functions(kernels, precisions, launchers_json): direct_to_from_reg = launcher.direct_to_from_reg scheme = launcher.scheme pp_child_scheme = launcher.pp_child_scheme - pp_factors = launcher.pp_factors + pp_factors_curr = launcher.pp_factors_curr + pp_factors_other = launcher.pp_factors_other pp_current_dim = launcher.pp_current_dim pp_off_dim = launcher.pp_off_dim sbrc_transpose_type = launcher.sbrc_transpose_type @@ -419,24 +430,23 @@ def generate_kernel_functions(kernels, precisions, launchers_json): precisions.append('half') for p in precisions: f = Function(arguments=ArgumentList(data, back), - meta=NS( - factors=factors, - length=length, - params=params, - precision=p, - runtime_compile=runtime_compile, - scheme=scheme, - workgroup_size=workgroup_size, - transforms_per_block=transforms_per_block, - threads_per_transform=tpt_list, - transpose=sbrc_transpose_type, - use_3steps_large_twd=use_3steps_large_twd, - lds_size_bytes=kernel.lds_size_bytes, - pp_child_scheme=pp_child_scheme, - pp_factors = pp_factors, - pp_current_dim = pp_current_dim, - pp_off_dim = pp_off_dim - )) + meta=NS(factors=factors, + length=length, + params=params, + precision=p, + runtime_compile=runtime_compile, + scheme=scheme, + workgroup_size=workgroup_size, + transforms_per_block=transforms_per_block, + threads_per_transform=tpt_list, + transpose=sbrc_transpose_type, + use_3steps_large_twd=use_3steps_large_twd, + lds_size_bytes=kernel.lds_size_bytes, + pp_child_scheme=pp_child_scheme, + pp_factors_curr=pp_factors_curr, + pp_factors_other=pp_factors_other, + pp_current_dim=pp_current_dim, + pp_off_dim=pp_off_dim)) if (scheme == 'CS_3D_PP'): pp_kernel_functions.append(f) @@ -521,32 +531,36 @@ def generate_kernels(kernels, precisions, stockham_gen): half_lds = False # Send data over to subprocess - + if isinstance(k.workgroup_size, list): - proc.stdin.write(" " + ','.join([str(f) for f in k.workgroup_size])) - else: + proc.stdin.write(" " + + ','.join([str(f) for f in k.workgroup_size])) + else: proc.stdin.write(f' {str(k.workgroup_size)}') - + proc.stdin.write(' 1' if half_lds else ' 0') - + direct_to_from_reg = getattr(k, 'direct_to_from_reg', True) - + if isinstance(direct_to_from_reg, list): - proc.stdin.write(" " + ','.join(['1' if f else '0' for f in direct_to_from_reg])) + proc.stdin.write( + " " + + ','.join(['1' if f else '0' for f in direct_to_from_reg])) else: # for unspecified direct_to_from_reg, default is True only for CS_KERNEL_STOCKHAM and SBCC direct_to_from_reg = getattr(k, 'direct_to_from_reg', True) proc.stdin.write(' 1' if direct_to_from_reg else ' 0') - + # check for data specific to partial-pass 3D kernels if hasattr(k, 'dims'): proc.stdin.write(" " + ','.join([str(f) for f in k.dims])) - proc.stdin.write(" " + ','.join([str(f) + proc.stdin.write(" " + + ','.join([str(f) for f in k.factors_pp[0]]) + " ") proc.stdin.write(','.join([str(f) for f in k.factors_pp[1]]) + " ") proc.stdin.write(','.join([str(f) for f in k.length])) - + proc.stdin.write(f' {k.scheme}') proc.stdin.write(f' {kernel_name(k)}') proc.stdin.write(f' {k.lds_size_bytes}') @@ -631,9 +645,8 @@ def cli(): if args.command == 'generate': functions, pp_functions = generate_kernels(kernels, precisions, - args.stockham_gen) - func_files = generate_cpu_function_pool_pieces(functions, - pp_functions, + args.stockham_gen) + func_files = generate_cpu_function_pool_pieces(functions, pp_functions, args.num_files) for i in range(args.num_files): write(f'function_pool_init_{i}.cpp', func_files[i], format=False) diff --git a/library/src/device/kernels/configs/config_pp_3d.py b/library/src/device/kernels/configs/config_pp_3d.py index 17f50d21077..f18d2316d1d 100644 --- a/library/src/device/kernels/configs/config_pp_3d.py +++ b/library/src/device/kernels/configs/config_pp_3d.py @@ -22,4 +22,6 @@ from types import SimpleNamespace as NS # yapf: disable - +pp_3d_kernels = [ + NS(length=[64,64,64], dims=[0,2], factors=[[8,8],[4,4,4]], factors_pp=[[4],[16]], threads_per_transform=[8,16], workgroup_size=[64,256], direct_to_from_reg=[False,False]), +] diff --git a/library/src/include/function_pool.h b/library/src/include/function_pool.h index 5cc53c9335e..b3eb120bed0 100644 --- a/library/src/include/function_pool.h +++ b/library/src/include/function_pool.h @@ -61,21 +61,24 @@ struct PartialPassParams { PartialPassParams() = default; - PartialPassParams(ComputeScheme scheme, - unsigned int current_dim, - unsigned int off_dim, - std::vector factors_off_dim) + PartialPassParams(const ComputeScheme& scheme, + const unsigned int& current_dim, + const unsigned int& off_dim, + const std::vector& pp_factors_curr, + const std::vector& pp_factors_other) : scheme(scheme) , current_dim(current_dim) , off_dim(off_dim) - , factors_off_dim(factors_off_dim) + , pp_factors_curr(pp_factors_curr) + , pp_factors_other(pp_factors_other) { } ComputeScheme scheme = CS_NONE; unsigned int current_dim = 0; unsigned int off_dim = 0; - std::vector factors_off_dim; + std::vector pp_factors_curr; + std::vector pp_factors_other; }; struct FFTKernel @@ -116,7 +119,8 @@ struct FFTKernel ComputeScheme scheme = CS_NONE, unsigned int current_dim = 0, unsigned int off_dim = 0, - std::vector&& factors_off_dim = std::vector()) + std::vector&& pp_factors_curr = std::vector(), + std::vector&& pp_factors_other = std::vector()) : factors(factors) , transforms_per_block(tpb) , workgroup_size(wgs) @@ -125,7 +129,7 @@ struct FFTKernel , half_lds(half_lds) , direct_to_from_reg(direct_to_from_reg) , aot_rtc(aot_rtc) - , pp_params(scheme, current_dim, off_dim, factors_off_dim) + , pp_params(scheme, current_dim, off_dim, pp_factors_curr, pp_factors_other) { } @@ -259,7 +263,7 @@ class function_pool throw std::runtime_error("function_pool: max_lds_bytes not initialized"); } - function_pool(function_pool& p) = delete; + function_pool(function_pool& p) = delete; function_pool& operator=(const function_pool&) = delete; ~function_pool() = default; diff --git a/library/src/rtc_stockham_gen.cpp b/library/src/rtc_stockham_gen.cpp index 947ffd8ce76..1bdc5aaa726 100644 --- a/library/src/rtc_stockham_gen.cpp +++ b/library/src/rtc_stockham_gen.cpp @@ -409,8 +409,8 @@ std::string stockham_rtc(const StockhamGeneratorSpecs& specs, if(ppType != PPT_NONE) all_factors.insert(all_factors.end(), - params_pp.factors_off_dim.begin(), - params_pp.factors_off_dim.end()); + params_pp.pp_factors_curr.begin(), + params_pp.pp_factors_curr.end()); } // generated functions default to forward in-place interleaved. diff --git a/library/src/rtc_stockham_kernel.cpp b/library/src/rtc_stockham_kernel.cpp index 2f54ecb2a88..1439fdbc043 100644 --- a/library/src/rtc_stockham_kernel.cpp +++ b/library/src/rtc_stockham_kernel.cpp @@ -88,8 +88,11 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& { pp_params.off_dim = node.ppOffDim; pp_params.current_dim = node.ppCurrDim; - pp_params.factors_off_dim = std::vector( - kernel->pp_params.factors_off_dim.begin(), kernel->pp_params.factors_off_dim.end()); + pp_params.pp_factors_curr = std::vector( + kernel->pp_params.pp_factors_curr.begin(), kernel->pp_params.pp_factors_curr.end()); + pp_params.pp_factors_other + = std::vector(kernel->pp_params.pp_factors_other.begin(), + kernel->pp_params.pp_factors_other.end()); pp_params.parent_length = std::vector(node.length.begin(), node.length.end()); } diff --git a/library/src/tree_node.cpp b/library/src/tree_node.cpp index ab9a7f319b1..d51d30c5d8e 100644 --- a/library/src/tree_node.cpp +++ b/library/src/tree_node.cpp @@ -111,8 +111,8 @@ void LeafNode::GetKernelFactors() void LeafNode::GetKernelPartialPassFactors() { auto kernel = GetKernel(); - kernelFactorsPP = std::vector(kernel.pp_params.factors_off_dim.begin(), - kernel.pp_params.factors_off_dim.end()); + kernelFactorsPP = std::vector(kernel.pp_params.pp_factors_curr.begin(), + kernel.pp_params.pp_factors_curr.end()); switch(ppOffDim) { From 348e0a43057cf18d83367fafd228ced6df7a39ac Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Tue, 24 Jun 2025 15:07:58 -0600 Subject: [PATCH 47/69] - Fix for hardcoded value in offset computation. --- .../src/device/generator/stockham_pp_gen_cc.h | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index 562f4833c68..08cf783777d 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -22,13 +22,10 @@ #include "stockham_gen_cc.h" // TODO: Once partial pass is fully configurable in kernel-generator.py: -// - Test with factors_pp.size() > 1. +// - Support transforms with factors_pp.size() > 1. // - Revisit all usages of transforms_per_block_pp and threads_per_transform. // - Different input/output strides. -// - Revisit mod 128 usage in calculate_offsets() with different input lengths, -// (logic is required to work with nbatch > 1) // - Revisit factor 192 logic in calculate_offsets() with different input lengths -// - Revisit and test local transpose logic for different input lengths // Variation of StockhamKernelCC that implements the partial pass // method. Similarities of StockhamPartialPassKernelCC with @@ -66,10 +63,30 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC factors_pp_curr.begin(), factors_pp_curr.end(), 1, std::multiplies()); pp_factors_other_prod = std::accumulate( factors_pp_other.begin(), factors_pp_other.end(), 1, std::multiplies()); + + switch(params.off_dim) + { + case 0: + throw std::runtime_error( + "StockhamPartialPassKernelCC:: partial-passes along x not currently supported"); + break; + case 1: + num_blocks_per_batch = ((params.parent_length[1]) - 1) / transforms_per_block + 1; + num_blocks_per_batch *= params.parent_length[2]; + break; + case 2: + throw std::runtime_error( + "StockhamPartialPassKernelCC:: partial-passes along z not currently supported"); + break; + default: + throw std::runtime_error("StockhamPartialPassKernelCC:: Unexpected off_dim value"); + } } StockhamPartialPassParams params; + unsigned int num_blocks_per_batch; + unsigned int transforms_per_block_pp; unsigned int max_factor_pp; @@ -264,8 +281,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += Assign{plength, num_of_tiles}; stmts += Assign{tile_index, block_id % num_of_tiles}; - // TODO: factor 128 - stmts += Assign{remaining, (block_id % 128) / num_of_tiles}; + stmts += Assign{remaining, (block_id % num_blocks_per_batch) / num_of_tiles}; stmts += Assign{offset, tile_index * transforms_per_block_pp * stride[1]}; stmts += For{d, 2, From a3be0ff0c305d713ca44371b401a1212000098e9 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 26 Jun 2025 15:11:21 -0600 Subject: [PATCH 48/69] - Further fixes to calculate_offsets in partial-pass SBCC kernel. --- .../src/device/generator/stockham_pp_gen_cc.h | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index 08cf783777d..3cc29dc7e70 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -25,7 +25,6 @@ // - Support transforms with factors_pp.size() > 1. // - Revisit all usages of transforms_per_block_pp and threads_per_transform. // - Different input/output strides. -// - Revisit factor 192 logic in calculate_offsets() with different input lengths // Variation of StockhamKernelCC that implements the partial pass // method. Similarities of StockhamPartialPassKernelCC with @@ -300,7 +299,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC tile_index * transforms_per_block_pp + thread_id / threads_per_transform}; stmts += Assign{stride_lds, (length + get_lds_padding())}; - stmts += MultiplyAssign(stride_lds, Literal{max_factor_pp}); + stmts += MultiplyAssign(stride_lds, Literal{pp_factors_curr_prod}); stmts += Declaration{ in_bound, @@ -313,23 +312,25 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += Declaration{thread_lds, thread_id / transforms_per_block_pp}; stmts += Declaration{tid_hor_lds, thread_id % transforms_per_block_pp}; - // TODO: length - stmts += Declaration( - tid_hor_pp, thread_id % transforms_per_block_pp + length * (thread % max_factor_pp)); - stmts += Declaration(thread_new, thread_id / (transforms_per_block_pp * max_factor_pp)); - stmts += Declaration(batch_new, block_id / (plength / max_factor_pp)); + stmts += Declaration(tid_hor_pp, + thread_id % transforms_per_block_pp + + lengths[1] * (thread % pp_factors_curr_prod)); + stmts += Declaration(thread_new, + thread_id / (transforms_per_block_pp * pp_factors_curr_prod)); + stmts += Declaration(batch_new, block_id / (plength / pp_factors_curr_prod)); stmts += Declaration(thread_idx, thread_id); stmts += Declaration(block_idx, block_id); - // TODO: length and TODO: Literal{192} stmts += Declaration( - offset_pp, offset + Parens(offset / length) * Literal{192} + batch_new * stride[dim]); + offset_pp, + offset + Parens(offset / lengths[1]) * (lengths[1] * pp_factors_curr_prod - lengths[1]) + + batch_new * stride[dim]); stmts += Declaration(offset_tid_hor, offset_pp + tid_hor_pp * stride[1]); stmts += Assign{transform, tile_index * transforms_per_block_pp - + thread_id / (threads_per_transform * max_factor_pp)}; + + thread_id / (threads_per_transform * pp_factors_curr_prod)}; stmts += Assign{offset_lds, Ternary{lds_linear, From f52aefaa3c344c286bf547ea151fe79557df17d0 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 26 Jun 2025 16:14:37 -0600 Subject: [PATCH 49/69] - clang format. --- library/src/include/function_pool.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/src/include/function_pool.h b/library/src/include/function_pool.h index b3eb120bed0..14ed7d063e9 100644 --- a/library/src/include/function_pool.h +++ b/library/src/include/function_pool.h @@ -263,7 +263,7 @@ class function_pool throw std::runtime_error("function_pool: max_lds_bytes not initialized"); } - function_pool(function_pool& p) = delete; + function_pool(function_pool& p) = delete; function_pool& operator=(const function_pool&) = delete; ~function_pool() = default; From 3657b31e74269399d4553d0f396b94a98922fb32 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 11 Jul 2025 14:04:09 -0600 Subject: [PATCH 50/69] - Fixes for partial-pass SBCC kernel when using certain wgs and tpt combinations. --- .../src/device/generator/stockham_gen_base.h | 17 +++++++++------- .../src/device/generator/stockham_pp_gen_cc.h | 20 ++++++++++++++----- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/library/src/device/generator/stockham_gen_base.h b/library/src/device/generator/stockham_gen_base.h index afe26ec4131..7a508f304d5 100644 --- a/library/src/device/generator/stockham_gen_base.h +++ b/library/src/device/generator/stockham_gen_base.h @@ -503,7 +503,8 @@ struct StockhamKernel : public StockhamGeneratorSpecs unsigned int width, double height, ThreadGuardMode guard, - bool trans_dir = false) const + bool trans_dir = false, + const std::optional& guard_factor = std::nullopt) const { StatementList stmts; unsigned int iheight = std::floor(height); @@ -512,17 +513,19 @@ struct StockhamKernel : public StockhamGeneratorSpecs Expression guard_expr = Expression{Literal{"true"}}; + auto thread_guard_cond = (length / width) * (guard_factor ? *guard_factor : 1); + // do thread gurad when guard_by_if or guard_by_arg if(guard != ThreadGuardMode::NO_GUARD) { // using ">" : no need to test "if(thread < XXX)"" if it is always true - if((!trans_dir && threads_per_transform > length / width) - || (trans_dir && workgroup_size / transforms_per_block > length / width)) + if((!trans_dir && threads_per_transform > (length / width)) + || (trans_dir && workgroup_size / transforms_per_block > (length / width))) { if(writeGuard) - guard_expr = Expression{write && (thread < length / width)}; + guard_expr = Expression{write && (thread < thread_guard_cond)}; else - guard_expr = Expression{thread < length / width}; + guard_expr = Expression{thread < thread_guard_cond}; } else { @@ -553,9 +556,9 @@ struct StockhamKernel : public StockhamGeneratorSpecs // always do thread gurad if(writeGuard) - guard_expr = Expression{write && (thread + dt < length / width)}; + guard_expr = Expression{write && (thread + dt < thread_guard_cond)}; else - guard_expr = Expression{thread + dt < length / width}; + guard_expr = Expression{thread + dt < thread_guard_cond}; work = generator(0, iheight, width, dt, guard_expr); diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index 3cc29dc7e70..22d39fb0015 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -678,7 +678,9 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC body += add_work(std::bind(store_lds, this, _1, _2, _3, _4, _5, Component::BOTH, cumheight), width, height, - ThreadGuardMode::GUARD_BY_IF); + ThreadGuardMode::GUARD_BY_IF, + false, + max_factor_pp); return f; } @@ -884,7 +886,9 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC += add_work(std::bind(load_lds, this, _1, _2, _3, _4, _5, Component::BOTH), width, height, - ThreadGuardMode::GUARD_BY_IF); + ThreadGuardMode::GUARD_BY_IF, + false, + max_factor_pp); body += If{Not{lds_is_real}, lds2reg_full}; auto apply_twiddle @@ -924,7 +928,9 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC std::bind(store_lds, this, _1, _2, _3, _4, _5, component, cumheight), half_width, half_height, - ThreadGuardMode::GUARD_BY_IF); + ThreadGuardMode::GUARD_BY_IF, + false, + max_factor_pp); half_width = factors[npass + 1]; half_height = static_cast(length) / half_width / threads_per_transform; @@ -933,7 +939,9 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC += add_work(std::bind(load_lds, this, _1, _2, _3, _4, _5, component), half_width, half_height, - ThreadGuardMode::GUARD_BY_IF); + ThreadGuardMode::GUARD_BY_IF, + false, + max_factor_pp); } // internal full lds store (both linear/nonlinear variants) @@ -945,7 +953,9 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC std::bind(store_lds, this, _1, _2, _3, _4, _5, Component::BOTH, cumheight), width, height, - ThreadGuardMode::GUARD_BY_IF); + ThreadGuardMode::GUARD_BY_IF, + false, + max_factor_pp); body += If{Not{lds_is_real}, reg2lds_full}; body += Else{reg2lds_half}; From eee9bbdecd6294ef8cec995260550f9490694c83 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Tue, 15 Jul 2025 13:59:29 -0600 Subject: [PATCH 51/69] - Fixes for partial-pass step_3_4 lds-to-reg and reg-to-lds generators. - Fix step_3_4 lds offset calculation. --- .../src/device/generator/stockham_pp_gen_cc.h | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index 22d39fb0015..e7a4febf219 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -72,6 +72,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC case 1: num_blocks_per_batch = ((params.parent_length[1]) - 1) / transforms_per_block + 1; num_blocks_per_batch *= params.parent_length[2]; + off_dim_length = params.parent_length[2]; break; case 2: throw std::runtime_error( @@ -89,6 +90,8 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC unsigned int transforms_per_block_pp; unsigned int max_factor_pp; + unsigned int off_dim_length; + std::vector factors_pp_curr; unsigned int pp_factors_curr_prod; @@ -176,7 +179,8 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC StatementList work; for(unsigned int w = 0; w < width; ++w) - work += Assign(lds_complex[offset_lds + (w * stride_lds)], R[w]); + work += Assign(lds_complex[offset_lds + ((hr * width + w) * stride_lds)], + R[hr * width + w]); return work; } @@ -196,8 +200,8 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC lstride, Ternary{Parens{stride_type == "SB_UNIT"}, Parens{1}, Parens{stride_lds}}}; auto store_lds = std::mem_fn(&StockhamPartialPassKernelCC::store_pp_step_3_4_lds_generator); - // last pass of store (full) - unsigned int width = factors.back(); + // last pass of store (partial-pass) + unsigned int width = factors_pp_curr.back(); float height = static_cast(length) / width / threads_per_transform; body += SyncThreads(); body += add_work(std::bind(store_lds, this, _1, _2, _3, _4, _5), @@ -215,7 +219,8 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC StatementList work; for(unsigned int w = 0; w < width; ++w) - work += Assign(R[w], lds_complex[offset_lds + (w * stride_lds)]); + work += Assign(R[hr * width + w], + lds_complex[offset_lds + ((hr * width + w) * stride_lds)]); return work; } @@ -244,8 +249,8 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC StatementList& body = f.body; auto load_lds = std::mem_fn(&StockhamPartialPassKernelCC::load_lds_step_3_4_generator); - // first pass of load (full) - unsigned int width = factors[0]; + // first pass of load (partial-pass) + unsigned int width = factors_pp_curr[0]; float height = static_cast(length) / width / threads_per_transform; body += SyncThreads(); body += add_work(std::bind(load_lds, this, _1, _2, _3, _4, _5), @@ -738,7 +743,11 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC StatementList stmts; stmts += Declaration{stride_lds_pp, Literal{1}}; - stmts += Declaration{offset_lds_pp, thread_id * transforms_per_block_pp}; + + unsigned int width = factors_pp_curr[0]; + float height = static_cast(length) / width / threads_per_transform; + stmts += Declaration{offset_lds_pp, + thread_id * Literal{width * static_cast(height)}}; auto pre_post_lds_tmpl = device_lds_reg_inout_device_call_templates(); auto pre_post_lds_args = device_lds_reg_inout_pp_device_call_arguments(); @@ -753,7 +762,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC for(unsigned int npass = 0; npass < factors_pp_curr.size(); ++npass) { unsigned int width = factors_pp_curr[npass]; - unsigned int height = threads_per_transform / max_factor_pp; + unsigned int height = static_cast(length) / width / threads_per_transform; auto butterfly = std::mem_fn(&StockhamKernel::butterfly_generator); stmts += add_work(std::bind(butterfly, this, _1, _2, _3, _4, _5), @@ -762,6 +771,11 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC ThreadGuardMode::NO_GUARD); } + width = factors_pp_curr.back(); + height = static_cast(length) / width / threads_per_transform; + stmts += Assign{offset_lds_pp, + thread_id * Literal{width * static_cast(height)}}; + StatementList postStore; postStore += Call{"lds_from_reg_output_pp_step_3_4_length" + std::to_string(length) + "_device", From 25b4fa2d00a114371e8c114ed808d570f38377b1 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Tue, 15 Jul 2025 16:31:36 -0600 Subject: [PATCH 52/69] - Further resolved merge conflicts. --- library/src/CMakeLists.txt | 5 +---- library/src/device/generator.py | 6 ------ library/src/device/generator/stockham_gen.cpp | 10 +++++++++- library/src/rtc_stockham_kernel.cpp | 1 + 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index d39b7460693..943facac618 100644 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -151,16 +151,13 @@ set( kgen_logic_files # python code that decides kernel parameters ${CMAKE_SOURCE_DIR}/library/src/device/kernel-generator.py ${CMAKE_SOURCE_DIR}/library/src/device/generator.py - ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_lds.py + ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_lds.py ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_sbrr.py ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_sbcc.py ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_sbcr.py ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_sbrc.py ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_2d_single.py ${CMAKE_SOURCE_DIR}/library/src/device/kernels/configs/config_pp_3d.py - - - # stockham generator code ${CMAKE_SOURCE_DIR}/library/src/device/generator/generator.h diff --git a/library/src/device/generator.py b/library/src/device/generator.py index 478b3fe3eec..8e8ebfe2f86 100644 --- a/library/src/device/generator.py +++ b/library/src/device/generator.py @@ -822,12 +822,6 @@ def insert_pp(self, key, value_1, value_2, def_key_pool, function_map, arguments=ArgumentList(key, value_1, value_2, def_key_pool, function_map, lds_size_bytes)) - def insert_pp(self, key, value_1, value_2, def_key_pool, function_map, - lds_size_bytes): - return Call('insert_default_entry', - arguments=ArgumentList(key, value_1, value_2, def_key_pool, - function_map, lds_size_bytes)) - # def __getitem__(self, idx): # return ArrayElement(self.name, idx) diff --git a/library/src/device/generator/stockham_gen.cpp b/library/src/device/generator/stockham_gen.cpp index 1ef8139f314..5201121dcdd 100644 --- a/library/src/device/generator/stockham_gen.cpp +++ b/library/src/device/generator/stockham_gen.cpp @@ -238,6 +238,7 @@ void output_json(const std::vector& launchers, output << "]"; } +// Render stockham partial-pass kernel generated launchers in JSON format. void stockham_partial_pass_variants(const std::string& kernel_name, const StockhamGeneratorSpecs& specs1, const StockhamGeneratorSpecs& specs2, @@ -577,7 +578,6 @@ void validate_pp_off_dim_length(const StockhamPartialPassParams& pp_params_1, if(length_off_dim != pp_params_1.parent_length[pp_params_1.off_dim]) throw std::runtime_error("Invalid partial-pass kernel off-dimension length"); } - // Validate grid parameters for partial pass kernels. void validate_pp_grid_params(const StockhamPartialPassParams& params_1, const StockhamPartialPassParams& params_2, @@ -787,6 +787,14 @@ int main() if(!threads_per_transform.empty()) specs2d.threads_per_transform = threads_per_transform.back(); + // 2D_SINGLE kernels use the specified workgroup size + // directly + if(scheme == "CS_KERNEL_2D_SINGLE") + { + specs.wgs_is_derived = true; + specs2d.wgs_is_derived = true; + } + // aim for occupancy-2 by default specs.lds_byte_limit = lds_size_bytes / 2; specs2d.lds_byte_limit = lds_size_bytes / 2; diff --git a/library/src/rtc_stockham_kernel.cpp b/library/src/rtc_stockham_kernel.cpp index 4f76610dc8c..eb78e0407be 100644 --- a/library/src/rtc_stockham_kernel.cpp +++ b/library/src/rtc_stockham_kernel.cpp @@ -83,6 +83,7 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& specs->threads_per_transform = kernel->threads_per_transform[0]; specs->half_lds = kernel->half_lds; specs->direct_to_from_reg = kernel->direct_to_from_reg; + specs->ebtype = node.ebtype; if(node.isPartialPassEnabled()) { From a08e2d10bcb1592ca6c422c117999b956112fcb6 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Tue, 15 Jul 2025 16:36:15 -0600 Subject: [PATCH 53/69] - Remove no longer needed lines from python config files. --- library/src/device/kernels/configs/config_sbrc.py | 1 - library/src/device/kernels/configs/config_sbrr.py | 1 - 2 files changed, 2 deletions(-) diff --git a/library/src/device/kernels/configs/config_sbrc.py b/library/src/device/kernels/configs/config_sbrc.py index 4a382c32224..38ce5f0f37c 100644 --- a/library/src/device/kernels/configs/config_sbrc.py +++ b/library/src/device/kernels/configs/config_sbrc.py @@ -21,7 +21,6 @@ from kernels.configs import config_lds from types import SimpleNamespace as NS -# yapf: disable # for SBRC, if direct_to_from_reg is True, we do store-from-reg, but will not do load-to-reg # And since SBRC is is dir-from-lds but NOT dir-to-reg, the global load part requires full LDS # So, SBRC is able to use half-lds. diff --git a/library/src/device/kernels/configs/config_sbrr.py b/library/src/device/kernels/configs/config_sbrr.py index 9f5d83f4ed6..6a0757d4f4e 100644 --- a/library/src/device/kernels/configs/config_sbrr.py +++ b/library/src/device/kernels/configs/config_sbrr.py @@ -21,7 +21,6 @@ from kernels.configs import config_lds from types import SimpleNamespace as NS -# yapf: disable # Note: Default half_lds is True and default direct_to_from_reg is True as well. # TODO- Currently, if half_lds is True, then direct_to_from_reg must be True # but if half_lds is False, direct_to_from_reg can be either (still can be True). From 36608cd30f439067fcb2f3fa4bd7ee08b47e04bc Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 17 Jul 2025 13:42:52 -0600 Subject: [PATCH 54/69] - Further fixes to partial-pass SBCC calculate_offsets(). - Clean-up. - Add new md test suite length to partial-pass 3D kernel list. --- .../src/device/generator/stockham_gen_base.h | 3 ++ .../src/device/generator/stockham_pp_gen_cc.h | 28 ++++++++----------- .../device/kernels/configs/config_pp_3d.py | 1 + 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/library/src/device/generator/stockham_gen_base.h b/library/src/device/generator/stockham_gen_base.h index 4753d4ba8f5..9bc44aeb20c 100644 --- a/library/src/device/generator/stockham_gen_base.h +++ b/library/src/device/generator/stockham_gen_base.h @@ -149,6 +149,9 @@ struct StockhamKernel : public StockhamGeneratorSpecs Variable lds_complex{"lds_complex", "scalar_type", true, true}; Variable lds_row_padding{"lds_row_padding", "unsigned int"}; + // hip thread grid dim + Variable grid_dim{"gridDim.x", "unsigned int"}; + // hip thread block id Variable block_id{"blockIdx.x", "unsigned int"}; diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index e7a4febf219..a7e43d58299 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -106,11 +106,9 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC Variable tid_hor_pp{"tid_hor_pp", "unsigned int"}; Variable offset_tid_hor{"offset_tid_hor", "unsigned int"}; Variable offset_pp{"offset_pp", "unsigned int"}; - Variable thread_new{"thread_new", "unsigned int"}; - Variable batch_new{"batch_new", "unsigned int"}; + Variable thread_pp{"thread_pp", "unsigned int"}; - Variable thread_idx{"thread_idx", "unsigned int"}; - Variable block_idx{"block_idx", "unsigned int"}; + Variable block_idx_pp{"block_idx_pp", "unsigned int"}; Variable thread_in_device_twd{"thread_in_device_twd", "unsigned int"}; @@ -274,6 +272,8 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += Declaration{tile_index}; stmts += Declaration{num_of_tiles}; + stmts += Declaration(block_idx_pp, block_id % Parens{grid_dim / nbatch}); + stmts += LineBreak{}; stmts += CommentLines{"calculate offset for each tile:", " tile_index now means index of the tile along dim1", @@ -283,9 +283,9 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += Declaration{index_along_d}; stmts += Assign{num_of_tiles, (lengths[1] - 1) / transforms_per_block_pp + 1}; stmts += Assign{plength, num_of_tiles}; - stmts += Assign{tile_index, block_id % num_of_tiles}; + stmts += Assign{tile_index, block_idx_pp % num_of_tiles}; - stmts += Assign{remaining, (block_id % num_blocks_per_batch) / num_of_tiles}; + stmts += Assign{remaining, (block_idx_pp % num_blocks_per_batch) / num_of_tiles}; stmts += Assign{offset, tile_index * transforms_per_block_pp * stride[1]}; stmts += For{d, 2, @@ -298,7 +298,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += LineBreak{}; - stmts += Assign{batch, block_id / plength}; + stmts += Assign{batch, block_id / (plength / pp_factors_curr_prod)}; stmts += Assign{transform, tile_index * transforms_per_block_pp + thread_id / threads_per_transform}; @@ -320,17 +320,13 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += Declaration(tid_hor_pp, thread_id % transforms_per_block_pp + lengths[1] * (thread % pp_factors_curr_prod)); - stmts += Declaration(thread_new, - thread_id / (transforms_per_block_pp * pp_factors_curr_prod)); - stmts += Declaration(batch_new, block_id / (plength / pp_factors_curr_prod)); - - stmts += Declaration(thread_idx, thread_id); - stmts += Declaration(block_idx, block_id); + stmts + += Declaration(thread_pp, thread_id / (transforms_per_block_pp * pp_factors_curr_prod)); stmts += Declaration( offset_pp, offset + Parens(offset / lengths[1]) * (lengths[1] * pp_factors_curr_prod - lengths[1]) - + batch_new * stride[dim]); + + batch * stride[dim]); stmts += Declaration(offset_tid_hor, offset_pp + tid_hor_pp * stride[1]); stmts += Assign{transform, @@ -357,7 +353,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC auto stripmine_h = workgroup_size / stripmine_w; auto offset_tile_rbuf - = [&](unsigned int i) { return (thread_new + i * stripmine_h) * stride0; }; + = [&](unsigned int i) { return (thread_pp + i * stripmine_h) * stride0; }; auto offset_tile_wlds = [&](unsigned int i) { return tid_hor_lds * stride_lds + (thread_lds + i * stripmine_h * max_factor_pp) * 1; @@ -463,7 +459,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC auto stripmine_h = workgroup_size / stripmine_w; auto offset_tile_wbuf = [&](unsigned int i) { - return offset_tid_hor + (thread_new + i * stripmine_h) * stride0; + return offset_tid_hor + (thread_pp + i * stripmine_h) * stride0; }; auto offset_tile_rlds = [&](unsigned int i) { return tid_hor_lds * stride_lds + (thread_lds + i * stripmine_h * max_factor_pp) * 1; diff --git a/library/src/device/kernels/configs/config_pp_3d.py b/library/src/device/kernels/configs/config_pp_3d.py index f18d2316d1d..fc3b85e5352 100644 --- a/library/src/device/kernels/configs/config_pp_3d.py +++ b/library/src/device/kernels/configs/config_pp_3d.py @@ -24,4 +24,5 @@ # yapf: disable pp_3d_kernels = [ NS(length=[64,64,64], dims=[0,2], factors=[[8,8],[4,4,4]], factors_pp=[[4],[16]], threads_per_transform=[8,16], workgroup_size=[64,256], direct_to_from_reg=[False,False]), + NS(length=[64,64,52], dims=[0,2], factors=[[8,8],[13,4]], factors_pp=[[4],[16]], threads_per_transform=[8,4], workgroup_size=[64,64], direct_to_from_reg=[False,False]), ] From 3050aaad39286860ba5fdd8bdbd7854a3fdad9c1 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Tue, 22 Jul 2025 13:19:39 -0600 Subject: [PATCH 55/69] - Fixes for partial-pass twiddle table generation. --- library/src/rtc_twiddle_gen.cpp | 2 +- library/src/twiddles.cpp | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/library/src/rtc_twiddle_gen.cpp b/library/src/rtc_twiddle_gen.cpp index 924c606e9a5..0a25ae772dc 100644 --- a/library/src/rtc_twiddle_gen.cpp +++ b/library/src/rtc_twiddle_gen.cpp @@ -195,7 +195,7 @@ static std::string twiddle_rtc_body(TwiddleTableType type) auto i_row = threadIdx.x + blockIdx.x * blockDim.x; auto i_col = threadIdx.y + blockIdx.y * blockDim.y; - if(i_row * i_col < N * N) + if(i_row < N && i_col < N) { auto n = i_row * i_col; double arg = -TWO_PI * n / N; diff --git a/library/src/twiddles.cpp b/library/src/twiddles.cpp index 17dc9ef46d8..d8ee5064bac 100644 --- a/library/src/twiddles.cpp +++ b/library/src/twiddles.cpp @@ -217,7 +217,12 @@ class TwiddleTable auto numBlocks_N = DivRoundingUp(N, blockSize); - kernel.launch(kargs, dim3(numBlocks_N), dim3(blockSize), 0, deviceProp, stream); + kernel.launch(kargs, + dim3(numBlocks_N, numBlocks_N), + dim3(blockSize, blockSize, 1), + 0, + deviceProp, + stream); } void launch_half_N_kernel(hipStream_t& stream, T* output, size_t half_N, size_t N) From ab395afbd7d16970af696ab2bee995f400c3bf24 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Wed, 23 Jul 2025 10:18:01 -0600 Subject: [PATCH 56/69] - Add more partial-pass 3D lengths. - Further fixes to first partial-pass kernel generator. --- .../src/device/generator/stockham_gen_base.h | 12 ++++--- .../src/device/generator/stockham_pp_gen_rr.h | 32 +++++++++++++------ .../device/kernels/configs/config_pp_3d.py | 5 +++ 3 files changed, 34 insertions(+), 15 deletions(-) diff --git a/library/src/device/generator/stockham_gen_base.h b/library/src/device/generator/stockham_gen_base.h index 9bc44aeb20c..48bb1badda1 100644 --- a/library/src/device/generator/stockham_gen_base.h +++ b/library/src/device/generator/stockham_gen_base.h @@ -508,7 +508,8 @@ struct StockhamKernel : public StockhamGeneratorSpecs double height, ThreadGuardMode guard, bool trans_dir = false, - const std::optional& guard_factor = std::nullopt) const + const std::optional& guard_factor = std::nullopt, + const std::optional& work_length = std::nullopt) const { StatementList stmts; unsigned int iheight = std::floor(height); @@ -517,14 +518,15 @@ struct StockhamKernel : public StockhamGeneratorSpecs Expression guard_expr = Expression{Literal{"true"}}; - auto thread_guard_cond = (length / width) * (guard_factor ? *guard_factor : 1); + auto effective_length = work_length ? *work_length : length; + auto thread_guard_cond = (effective_length / width) * (guard_factor ? *guard_factor : 1); // do thread gurad when guard_by_if or guard_by_arg if(guard != ThreadGuardMode::NO_GUARD) { // using ">" : no need to test "if(thread < XXX)"" if it is always true - if((!trans_dir && threads_per_transform > (length / width)) - || (trans_dir && workgroup_size / transforms_per_block > (length / width))) + if((!trans_dir && threads_per_transform > (effective_length / width)) + || (trans_dir && workgroup_size / transforms_per_block > (effective_length / width))) { if(writeGuard) guard_expr = Expression{write && (thread < thread_guard_cond)}; @@ -553,7 +555,7 @@ struct StockhamKernel : public StockhamGeneratorSpecs stmts += work; } - if(height > iheight && threads_per_transform < length / width) + if(height > iheight && threads_per_transform < effective_length / width) { stmts += CommentLines{"not enough threads, some threads do extra work"}; unsigned int dt = iheight * threads_per_transform; diff --git a/library/src/device/generator/stockham_pp_gen_rr.h b/library/src/device/generator/stockham_pp_gen_rr.h index 7cb451a4cfa..6d083fed48d 100644 --- a/library/src/device/generator/stockham_pp_gen_rr.h +++ b/library/src/device/generator/stockham_pp_gen_rr.h @@ -233,20 +233,21 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR StatementList work; for(unsigned int w = 0; w < width; ++w) - work += Assign(R[w], lds_complex[offset_lds + (w * stride_lds)]); + work += Assign(R[hr * width + w], + lds_complex[offset_lds + ((hr * width + w) * stride_lds)]); return work; } ArgumentList device_lds_reg_inout_pp_arguments() { - ArgumentList args{R, lds_complex, stride_lds, offset_lds}; + ArgumentList args{R, lds_complex, stride_lds, offset_lds, thread}; return args; } std::vector device_lds_reg_inout_pp_device_call_arguments() { - return {R, lds_complex, stride_lds_pp, offset_lds_pp}; + return {R, lds_complex, stride_lds_pp, offset_lds_pp, thread_id / max_factor_pp}; } TemplateList device_lds_reg_inout_pp_templates() @@ -266,17 +267,22 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR f.arguments = device_lds_reg_inout_pp_arguments(); f.qualifier = "__device__"; + auto effective_length = std::max(length, length_pp); + StatementList& body = f.body; auto load_lds = std::mem_fn(&StockhamPartialPassKernelRR::load_lds_step_1_2_generator); // first pass of load (full) - unsigned int width = max_factor_pp; - float height = static_cast(length) / width / threads_per_transform; + unsigned int width = factors_pp[0]; + float height = static_cast(effective_length) / width / threads_per_transform; body += SyncThreads(); body += add_work(std::bind(load_lds, this, _1, _2, _3, _4, _5), width, height, - ThreadGuardMode::NO_GUARD); + ThreadGuardMode::GUARD_BY_IF, + false, + std::nullopt, + effective_length); return f; } @@ -289,7 +295,8 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR StatementList work; for(unsigned int w = 0; w < width; ++w) - work += Assign(lds_complex[offset_lds + (w * stride_lds)], R[w]); + work += Assign(lds_complex[offset_lds + ((hr * width + w) * stride_lds)], + R[hr * width + w]); return work; } @@ -304,17 +311,22 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR f.arguments = device_lds_reg_inout_pp_arguments(); f.qualifier = "__device__"; + auto effective_length = std::max(length_pp, length); + StatementList& body = f.body; auto store_lds = std::mem_fn(&StockhamPartialPassKernelRR::store_pp_step_1_2_lds_generator); // last pass of store (full) - unsigned int width = max_factor_pp; - float height = static_cast(length) / width / threads_per_transform; + unsigned int width = factors_pp.back(); + float height = static_cast(effective_length) / width / threads_per_transform; body += SyncThreads(); body += add_work(std::bind(store_lds, this, _1, _2, _3, _4, _5), width, height, - ThreadGuardMode::NO_GUARD); + ThreadGuardMode::GUARD_BY_IF, + false, + std::nullopt, + effective_length); return f; } diff --git a/library/src/device/kernels/configs/config_pp_3d.py b/library/src/device/kernels/configs/config_pp_3d.py index fc3b85e5352..e432ab69773 100644 --- a/library/src/device/kernels/configs/config_pp_3d.py +++ b/library/src/device/kernels/configs/config_pp_3d.py @@ -23,6 +23,11 @@ # yapf: disable pp_3d_kernels = [ + NS(length=[32,32,128], dims=[0,2], factors=[[8,4],[2,4,4,4]], factors_pp=[[4],[8]], threads_per_transform=[8,16], workgroup_size=[64,128], direct_to_from_reg=[False,False]), + NS(length=[32,32,64], dims=[0,2], factors=[[4,8],[4,4,4]], factors_pp=[[4],[8]], threads_per_transform=[8,16], workgroup_size=[64,128], direct_to_from_reg=[False,False]), + NS(length=[64,32,128], dims=[0,2], factors=[[8,8],[2,4,4,4]], factors_pp=[[4],[8]], threads_per_transform=[8,16], workgroup_size=[64,128], direct_to_from_reg=[False,False]), + NS(length=[64,64,128], dims=[0,2], factors=[[8,8],[2,4,4,4]], factors_pp=[[4],[16]], threads_per_transform=[8,16], workgroup_size=[64,256], direct_to_from_reg=[False,False]), NS(length=[64,64,64], dims=[0,2], factors=[[8,8],[4,4,4]], factors_pp=[[4],[16]], threads_per_transform=[8,16], workgroup_size=[64,256], direct_to_from_reg=[False,False]), NS(length=[64,64,52], dims=[0,2], factors=[[8,8],[13,4]], factors_pp=[[4],[16]], threads_per_transform=[8,4], workgroup_size=[64,64], direct_to_from_reg=[False,False]), + NS(length=[60,60,60], dims=[0,2], factors=[[6,10],[6,10]], factors_pp=[[6],[10]], threads_per_transform=[10,10], workgroup_size=[20,100], direct_to_from_reg=[False,False]), ] From 3e40ba8ddb21541450cfc73e56b0b06bb818e9a2 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Wed, 23 Jul 2025 10:33:55 -0600 Subject: [PATCH 57/69] - Clean up. --- library/src/device/generator/stockham_pp_gen_cc.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index a7e43d58299..9846ca583bb 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -72,7 +72,6 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC case 1: num_blocks_per_batch = ((params.parent_length[1]) - 1) / transforms_per_block + 1; num_blocks_per_batch *= params.parent_length[2]; - off_dim_length = params.parent_length[2]; break; case 2: throw std::runtime_error( @@ -90,8 +89,6 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC unsigned int transforms_per_block_pp; unsigned int max_factor_pp; - unsigned int off_dim_length; - std::vector factors_pp_curr; unsigned int pp_factors_curr_prod; From 1061964b4070370796204aec00d4ced275651abc Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Wed, 23 Jul 2025 16:26:25 -0600 Subject: [PATCH 58/69] - Changes to reduce cost of local transpose in partial-pass steps 3-4. --- .../src/device/generator/stockham_pp_gen_cc.h | 67 +++++++++++-------- library/src/rtc_stockham_kernel.cpp | 52 +++++++++++--- 2 files changed, 80 insertions(+), 39 deletions(-) diff --git a/library/src/device/generator/stockham_pp_gen_cc.h b/library/src/device/generator/stockham_pp_gen_cc.h index 9846ca583bb..b0be9408b68 100644 --- a/library/src/device/generator/stockham_pp_gen_cc.h +++ b/library/src/device/generator/stockham_pp_gen_cc.h @@ -112,6 +112,11 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC Variable global_idx{"global_idx", "unsigned int"}; Variable transpose_idx{"transpose_idx", "unsigned int"}; + Variable len_1_2{"len_1_2", "unsigned int"}; + Variable len_1_2_3{"len_1_2_3", "unsigned int"}; + Variable len_pp_factors_curr_prod{"len_pp_factors_curr_prod", "unsigned int"}; + Variable len_pp_factors_other_prod{"len_pp_factors_other_prod", "unsigned int"}; + std::vector launcher_lengths() override { return params.parent_length; @@ -466,7 +471,12 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC tmp_stmts += StoreGlobal{ buf, CallExpr{"local_transpose_pp_length" + std::to_string(length) + "_device", - {offset_tile_wbuf(i), lengths}}, + {offset_tile_wbuf(i), + lengths, + len_1_2, + len_1_2_3, + len_pp_factors_curr_prod, + len_pp_factors_other_prod}}, lds_complex[offset_tile_rlds(i)]}; stmts += CommentLines{ @@ -688,41 +698,30 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC = "local_transpose_pp_length" + std::to_string(length) + "_device"; Function f{function_name}; - f.arguments = ArgumentList{global_idx, lengths}; + f.arguments = ArgumentList{global_idx, + lengths, + len_1_2, + len_1_2_3, + len_pp_factors_curr_prod, + len_pp_factors_other_prod}; f.return_type = "unsigned int"; f.qualifier = "__device__"; StatementList& body = f.body; - auto len_1 = lengths[2]; - auto len_2 = lengths[1]; - auto len_3 = lengths[0]; - auto len_1_2 = len_1 * len_2; - auto len_1_2_3 = len_1 * len_2 * len_3; - - auto pp_factor_1 = pp_factors_curr_prod; - auto pp_factor_2 = pp_factors_other_prod; - - auto pp_factor_len = pp_factor_1 * len_2; + auto len_1 = lengths[2]; + auto len_2 = lengths[1]; + auto len_3 = lengths[0]; body += Declaration{transpose_idx, global_idx % len_1_2_3}; - Variable idx_1{"idx_1", "unsigned int"}; - body += Declaration{idx_1, transpose_idx % len_2}; - - Variable idx_2{"idx_2", "unsigned int"}; - body += Declaration{idx_2, transpose_idx % pp_factor_len}; - body += Assign{idx_2, idx_2 / len_2}; - body += Assign{idx_2, idx_2 * pp_factor_2 * len_2}; - - Variable idx_3{"idx_3", "unsigned int"}; - body += Declaration{idx_3, transpose_idx % len_1_2}; - body += Assign{idx_3, Parens{idx_3 / pp_factor_len} * len_2}; - - Variable idx_4{"idx_4", "unsigned int"}; - body += Declaration{idx_4, Parens{transpose_idx / len_1_2} * len_1_2}; - - body += Assign{transpose_idx, idx_1 + idx_2 + idx_3 + idx_4}; + body += Assign{ + transpose_idx, + Parens{transpose_idx % len_2} + + Parens{Parens{Parens{transpose_idx % len_pp_factors_curr_prod} / len_2} + * len_pp_factors_other_prod} + + Parens{Parens{Parens{transpose_idx % len_1_2} / len_pp_factors_curr_prod} * len_2} + + Parens{Parens{transpose_idx / len_1_2} * len_1_2}}; body += Assign{transpose_idx, transpose_idx + Parens{global_idx / len_1_2_3} * len_1_2_3}; @@ -971,6 +970,18 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC return f; } + ArgumentList global_arguments() override + { + // insert large twiddles + ArgumentList arglist = StockhamKernel::global_arguments(); + arglist.arguments.insert(arglist.arguments.begin() + 1, large_twiddles); + arglist.arguments.insert(arglist.arguments.begin() + 2, len_1_2); + arglist.arguments.insert(arglist.arguments.begin() + 3, len_1_2_3); + arglist.arguments.insert(arglist.arguments.begin() + 4, len_pp_factors_curr_prod); + arglist.arguments.insert(arglist.arguments.begin() + 5, len_pp_factors_other_prod); + return arglist; + } + Function generate_global_function() override { Function f("forward_pp_length" + std::to_string(length) + "_" + tiling_name()); diff --git a/library/src/rtc_stockham_kernel.cpp b/library/src/rtc_stockham_kernel.cpp index eb78e0407be..d07ab6108e3 100644 --- a/library/src/rtc_stockham_kernel.cpp +++ b/library/src/rtc_stockham_kernel.cpp @@ -29,6 +29,27 @@ #include "device/kernel-generator-embed.h" +static StockhamPartialPassParams get_partial_pass_params(const TreeNode& node, + const FFTKernel& kernel) +{ + if(node.isPartialPassEnabled()) + { + StockhamPartialPassParams pp_params; + + pp_params.off_dim = node.ppOffDim; + pp_params.current_dim = node.ppCurrDim; + pp_params.pp_factors_curr = std::vector( + kernel.pp_params.pp_factors_curr.begin(), kernel.pp_params.pp_factors_curr.end()); + pp_params.pp_factors_other = std::vector( + kernel.pp_params.pp_factors_other.begin(), kernel.pp_params.pp_factors_other.end()); + pp_params.parent_length = std::vector(node.length.begin(), node.length.end()); + + return pp_params; + } + + throw std::runtime_error("Invalid scheme for partial pass"); +} + RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& node, const std::string& gpu_arch, bool enable_callbacks) @@ -86,17 +107,7 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& specs->ebtype = node.ebtype; if(node.isPartialPassEnabled()) - { - pp_params.off_dim = node.ppOffDim; - pp_params.current_dim = node.ppCurrDim; - pp_params.pp_factors_curr = std::vector( - kernel->pp_params.pp_factors_curr.begin(), kernel->pp_params.pp_factors_curr.end()); - pp_params.pp_factors_other - = std::vector(kernel->pp_params.pp_factors_other.begin(), - kernel->pp_params.pp_factors_other.end()); - pp_params.parent_length - = std::vector(node.length.begin(), node.length.end()); - } + pp_params = get_partial_pass_params(node, *kernel); break; } @@ -244,6 +255,25 @@ RTCKernelArgs RTCKernelStockham::get_launch_args(DeviceCallIn& data) if(data.node->scheme == CS_KERNEL_STOCKHAM_BLOCK_CC || data.node->scheme == CS_KERNEL_STOCKHAM_PP_BLOCK_CC) kargs.append_ptr(data.node->twiddles_large); + if(data.node->scheme == CS_KERNEL_STOCKHAM_PP_BLOCK_CC) + { + auto kernel = data.node->GetKernel(); + auto pp_params = get_partial_pass_params(*data.node, kernel); + auto pp_factors_curr_prod = std::accumulate(pp_params.pp_factors_curr.begin(), + pp_params.pp_factors_curr.end(), + 1, + std::multiplies()); + auto pp_factors_other_prod = std::accumulate(pp_params.pp_factors_other.begin(), + pp_params.pp_factors_other.end(), + 1, + std::multiplies()); + + kargs.append_unsigned_int(data.node->length[1] * data.node->length[2]); + kargs.append_unsigned_int(data.node->length[0] * data.node->length[1] + * data.node->length[2]); + kargs.append_unsigned_int(data.node->length[1] * pp_factors_curr_prod); + kargs.append_unsigned_int(data.node->length[1] * pp_factors_other_prod); + } if(!hardcoded_dim) kargs.append_size_t(data.node->length.size()); // lengths From 73d6b8c15ebdba46657c4cfe6b2672ddef3f8637 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Wed, 23 Jul 2025 17:04:24 -0600 Subject: [PATCH 59/69] - Add accuracy test coverage for new 3D lengths. --- clients/tests/accuracy_tests_range.h | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/clients/tests/accuracy_tests_range.h b/clients/tests/accuracy_tests_range.h index 103a97a29eb..5972cfd0427 100644 --- a/clients/tests/accuracy_tests_range.h +++ b/clients/tests/accuracy_tests_range.h @@ -167,7 +167,15 @@ const static std::vector inner_batch_3D_batch_range = {3, 2, 1}; // partial pass test problems //----------------------------------------------------------------------- //----------------------------------------------------------------------- -const static std::vector> partial_pass_adhoc_3D = {{64, 64, 64}}; -const static std::vector partial_pass_batch_range_3D = {1, 5, 10, 20, 50}; +const static std::vector> partial_pass_adhoc_3D = { + {64, 64, 128}, + {64, 64, 64}, + {64, 64, 52}, + {60, 60, 60}, + {32, 32, 128}, + {32, 32, 64}, + {64, 32, 128}, +}; +const static std::vector partial_pass_batch_range_3D = {1, 5, 10, 20, 50}; #endif // ACCURACY_TESTS_RANGE_H \ No newline at end of file From 28a36891c0d7b42209a3a2e0f7e316009a444938 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 24 Jul 2025 09:55:07 -0600 Subject: [PATCH 60/69] - Add more lengths to partial-pass perf suite. --- scripts/perf/suites.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/scripts/perf/suites.py b/scripts/perf/suites.py index 95c49c9ace5..3e855ef0795 100644 --- a/scripts/perf/suites.py +++ b/scripts/perf/suites.py @@ -1046,25 +1046,25 @@ def tuning_suite(): def partial_pass(): - for direction in [-1, 1]: - for precision in ['single', 'double']: - for place in all_inplaces: - for batch in [ - 1, 2, 5, 10, 15, 20, 25, 50, 75, 100, 150, 200, 250, - 500, 750, 1000, 1250, 1500, 1575, 2000, 2500, 3000, - 3500, 4000, 4500, 5000, 7500, 10000 - ]: - - length = (64, 64, 64) - yield Problem(length, - tag=mktag("partial_pass", 1, precision, - direction, place, False), - nbatch=batch, - direction=direction, - inplace=place, - real=False, - meta={'ivariable': 'batch'}, - precision=precision) + for length in [(64, 64, 128), (64, 64, 64), (64, 64, 52), (60, 60, 60), + (32, 32, 128), (32, 32, 64), (64, 32, 128)]: + for direction in [-1, 1]: + for precision in ['single', 'double']: + for place in all_inplaces: + for batch in [ + 1, 5, 20, 50, 100, 200, 500, 1000, 1500, 3000, + 5000, 7500, 10000 + ]: + + yield Problem(length, + tag=mktag("partial_pass", 1, precision, + direction, place, False), + nbatch=batch, + direction=direction, + inplace=place, + real=False, + meta={'ivariable': 'batch'}, + precision=precision) def large_1d_extended(): From 280d78ae61b0fe7106d9a3f8ffd3f20d2421096e Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 24 Jul 2025 18:37:34 -0600 Subject: [PATCH 61/69] - Fix function pool init for partial-pass kernels. - Enable 60x60x60 partial-pass optimization only in double-precision. --- library/src/device/kernel-generator.py | 99 ++++++++++++------- .../device/kernels/configs/config_pp_3d.py | 2 +- 2 files changed, 67 insertions(+), 34 deletions(-) diff --git a/library/src/device/kernel-generator.py b/library/src/device/kernel-generator.py index bb3525d96d8..7f40d4d25ce 100644 --- a/library/src/device/kernel-generator.py +++ b/library/src/device/kernel-generator.py @@ -187,8 +187,6 @@ def generate_cpu_function_pool_main(num_files): def generate_cpu_function_pool_pieces(functions, pp_functions, num_files): """Generate function(s) to populate the kernel function pool.""" - all_functions = functions + pp_functions - function_map = Map('function_map') precisions = { 'sp': 'rocfft_precision_single', @@ -206,19 +204,72 @@ def generate_cpu_function_pool_pieces(functions, pp_functions, num_files): for _ in range(num_files) ] - # Cycles through each file per loop execution to distribute work amongst N files - curr_func, curr_pp_func = 0, 0 - curr_func_offset = 0 if len(pp_functions) == 0 else len(precisions) - curr_file = 0 - while curr_func < len(all_functions) - curr_func_offset: - f = all_functions[curr_func] + # Cycles through each file per loop execution to distribute regular kernels work amongst N files + curr_func, curr_file = 0, 0 + while curr_func < len(functions): + f = functions[curr_func] length, precision, scheme, transpose = f.meta.length, f.meta.precision, f.meta.scheme, f.meta.transpose - if scheme == 'CS_3D_PP': - piece_contents[curr_file] += Assign(var_pp_kernel_1, FFTKernel(f)) - f = all_functions[curr_func + curr_func_offset] - piece_contents[curr_file] += Assign(var_pp_kernel_2, FFTKernel(f)) + if isinstance(length, (int, str)): + length = [length, 0] + piece_contents[curr_file] += Assign(var_kernel, FFTKernel(f)) + key = Call( + name='FMKey', + arguments=ArgumentList(length[0], length[1], precisions[precision], + scheme, transpose or 'NONE', + 'kernel.get_kernel_config()')).inline() + piece_contents[curr_file] += function_map.insert( + key, var_kernel, 'std::get<0>(def_keys)', + 'std::get<0>(function_maps)', f.meta.lds_size_bytes) + + curr_func = curr_func + 1 + curr_file = (curr_file + 1) % num_files + # Partial-pass kernels are handled separately. + if len(pp_functions) > 0: + counter_f_pp_1 = 0 + skip_to_next_iter = False + # Cycles through each file per loop execution to distribute partial-pass kernels work amongst N files + while True: + if counter_f_pp_1 >= len(pp_functions): + break + # get first pp kernel + f_pp_1 = pp_functions[counter_f_pp_1] + + # PPFMKey entry needs two kernels with same length and precision, but different pp_current_dim + counter_f_pp_2 = counter_f_pp_1 + 1 + if counter_f_pp_2 >= len(pp_functions): + break + # loop to get the second pp kernel + while counter_f_pp_2 < len(pp_functions): + f_pp_2 = pp_functions[counter_f_pp_2] + if (f_pp_1.meta.length == f_pp_2.meta.length + and f_pp_1.meta.precision == f_pp_2.meta.precision + and f_pp_1.meta.pp_current_dim != + f_pp_2.meta.pp_current_dim): + break + if (f_pp_1.meta.length != f_pp_2.meta.length): + # we hit a new kernel with different length + # start next iteration looking for the next pair + counter_f_pp_1 = counter_f_pp_2 + skip_to_next_iter = True + break + counter_f_pp_2 = counter_f_pp_2 + 1 + + if skip_to_next_iter: + skip_to_next_iter = False + continue + # get second pp kernel + f_pp_2 = pp_functions[counter_f_pp_2] + + piece_contents[curr_file] += Assign(var_pp_kernel_1, + FFTKernel(f_pp_1)) + piece_contents[curr_file] += Assign(var_pp_kernel_2, + FFTKernel(f_pp_2)) + + length = f_pp_1.meta.length + precision = f_pp_1.meta.precision + scheme = f_pp_1.meta.scheme key = Call(name='PPFMKey', arguments=ArgumentList( length[0], length[1], length[2], @@ -227,28 +278,10 @@ def generate_cpu_function_pool_pieces(functions, pp_functions, num_files): 'pp_kernel_2.get_kernel_config()')).inline() piece_contents[curr_file] += function_map.insert_pp( key, var_pp_kernel_1, var_pp_kernel_2, 'std::get<1>(def_keys)', - 'std::get<1>(function_maps)', f.meta.lds_size_bytes) + 'std::get<1>(function_maps)', f_pp_1.meta.lds_size_bytes) - curr_pp_func = curr_pp_func + 1 - else: - if isinstance(length, (int, str)): - length = [length, 0] - piece_contents[curr_file] += Assign(var_kernel, FFTKernel(f)) - key = Call(name='FMKey', - arguments=ArgumentList( - length[0], length[1], precisions[precision], scheme, - transpose or 'NONE', - 'kernel.get_kernel_config()')).inline() - piece_contents[curr_file] += function_map.insert( - key, var_kernel, 'std::get<0>(def_keys)', - 'std::get<0>(function_maps)', f.meta.lds_size_bytes) - - if curr_pp_func == len(precisions): - curr_func, curr_pp_func = curr_func + len(precisions) + 1, 0 - else: - curr_func = curr_func + 1 - - curr_file = (curr_file + 1) % num_files + counter_f_pp_1 = counter_f_pp_1 + 1 + curr_file = (curr_file + 1) % num_files # Assemble contents of each file to return in a list pieces = [None] * num_files diff --git a/library/src/device/kernels/configs/config_pp_3d.py b/library/src/device/kernels/configs/config_pp_3d.py index 65da97f22b1..b017808b8d2 100644 --- a/library/src/device/kernels/configs/config_pp_3d.py +++ b/library/src/device/kernels/configs/config_pp_3d.py @@ -29,5 +29,5 @@ NS(length=[64,64,128], dims=[0,2], factors=[[8,8],[2,4,4,4]], factors_pp=[[4],[16]], threads_per_transform=[8,16], workgroup_size=[64,256], direct_to_from_reg=[False,False]), NS(length=[64,64,64], dims=[0,2], factors=[[8,8],[4,4,4]], factors_pp=[[4],[16]], threads_per_transform=[8,16], workgroup_size=[64,256], direct_to_from_reg=[False,False]), NS(length=[64,64,52], dims=[0,2], factors=[[8,8],[13,4]], factors_pp=[[4],[16]], threads_per_transform=[8,4], workgroup_size=[64,64], direct_to_from_reg=[False,False]), - NS(length=[60,60,60], dims=[0,2], factors=[[6,10],[6,10]], factors_pp=[[6],[10]], threads_per_transform=[10,10], workgroup_size=[20,100], direct_to_from_reg=[False,False]), + NS(length=[60,60,60], dims=[0,2], factors=[[6,10],[6,10]], factors_pp=[[6],[10]], threads_per_transform=[10,10], workgroup_size=[20,100], direct_to_from_reg=[False,False], precision=['dp']), ] From 7cdcad5edcabcbef57d7bd77251fa85d1c578b93 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Thu, 24 Jul 2025 20:40:50 -0600 Subject: [PATCH 62/69] - Further double-precision restrictions for some of the new partial-pass 3D kernels. --- library/src/device/kernels/configs/config_pp_3d.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/library/src/device/kernels/configs/config_pp_3d.py b/library/src/device/kernels/configs/config_pp_3d.py index b017808b8d2..7ce240a409e 100644 --- a/library/src/device/kernels/configs/config_pp_3d.py +++ b/library/src/device/kernels/configs/config_pp_3d.py @@ -23,9 +23,9 @@ # yapf: disable pp_3d_kernels = [ - NS(length=[32,32,128], dims=[0,2], factors=[[8,4],[2,4,4,4]], factors_pp=[[4],[8]], threads_per_transform=[8,16], workgroup_size=[64,128], direct_to_from_reg=[False,False]), - NS(length=[32,32,64], dims=[0,2], factors=[[4,8],[4,4,4]], factors_pp=[[4],[8]], threads_per_transform=[8,16], workgroup_size=[64,128], direct_to_from_reg=[False,False]), - NS(length=[64,32,128], dims=[0,2], factors=[[8,8],[2,4,4,4]], factors_pp=[[4],[8]], threads_per_transform=[8,16], workgroup_size=[64,128], direct_to_from_reg=[False,False]), + NS(length=[32,32,128], dims=[0,2], factors=[[8,4],[2,4,4,4]], factors_pp=[[4],[8]], threads_per_transform=[8,16], workgroup_size=[64,128], direct_to_from_reg=[False,False], precision=['dp']), + NS(length=[32,32,64], dims=[0,2], factors=[[4,8],[4,4,4]], factors_pp=[[4],[8]], threads_per_transform=[8,16], workgroup_size=[64,128], direct_to_from_reg=[False,False], precision=['dp']), + NS(length=[64,32,128], dims=[0,2], factors=[[8,8],[2,4,4,4]], factors_pp=[[4],[8]], threads_per_transform=[8,16], workgroup_size=[64,128], direct_to_from_reg=[False,False], precision=['dp']), NS(length=[64,64,128], dims=[0,2], factors=[[8,8],[2,4,4,4]], factors_pp=[[4],[16]], threads_per_transform=[8,16], workgroup_size=[64,256], direct_to_from_reg=[False,False]), NS(length=[64,64,64], dims=[0,2], factors=[[8,8],[4,4,4]], factors_pp=[[4],[16]], threads_per_transform=[8,16], workgroup_size=[64,256], direct_to_from_reg=[False,False]), NS(length=[64,64,52], dims=[0,2], factors=[[8,8],[13,4]], factors_pp=[[4],[16]], threads_per_transform=[8,4], workgroup_size=[64,64], direct_to_from_reg=[False,False]), From 5a06adb591202f885b7770f5823ee3dec8347935 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Tue, 29 Jul 2025 09:43:37 -0600 Subject: [PATCH 63/69] - Clang format. --- clients/tests/accuracy_tests_range.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/tests/accuracy_tests_range.h b/clients/tests/accuracy_tests_range.h index 5972cfd0427..6b71b27018f 100644 --- a/clients/tests/accuracy_tests_range.h +++ b/clients/tests/accuracy_tests_range.h @@ -174,7 +174,7 @@ const static std::vector> partial_pass_adhoc_3D = { {60, 60, 60}, {32, 32, 128}, {32, 32, 64}, - {64, 32, 128}, + {64, 32, 128}, }; const static std::vector partial_pass_batch_range_3D = {1, 5, 10, 20, 50}; From a5339195b470dd7138a54ef380e93d776900f32c Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Tue, 29 Jul 2025 10:36:17 -0600 Subject: [PATCH 64/69] - More clang format. --- library/src/device/kernels/configs/config_2d_single.py | 2 +- library/src/device/kernels/configs/config_lds.py | 2 +- library/src/device/kernels/configs/config_sbcc.py | 2 +- library/src/device/kernels/configs/config_sbcr.py | 2 +- library/src/device/kernels/configs/config_sbrr.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/library/src/device/kernels/configs/config_2d_single.py b/library/src/device/kernels/configs/config_2d_single.py index b435be2d1d6..181f63fb9aa 100644 --- a/library/src/device/kernels/configs/config_2d_single.py +++ b/library/src/device/kernels/configs/config_2d_single.py @@ -177,4 +177,4 @@ NS(length=[128,64], factors=[[4,8,4],[4,4,4]], threads_per_transform=[16,8], double_precision=False, workgroup_size=512, lds_size_bytes=config_lds.LDS_160k), NS(length=[96,96], factors=[[4,6,4],[4,6,4]], threads_per_transform=[8,8], workgroup_size=256, lds_size_bytes=config_lds.LDS_160k), NS(length=[100,100], factors=[[10,10],[10,10]], threads_per_transform=[10,10], double_precision=False, workgroup_size=500, lds_size_bytes=config_lds.LDS_160k), -] \ No newline at end of file +] diff --git a/library/src/device/kernels/configs/config_lds.py b/library/src/device/kernels/configs/config_lds.py index 559c918bc24..52dd1a558af 100644 --- a/library/src/device/kernels/configs/config_lds.py +++ b/library/src/device/kernels/configs/config_lds.py @@ -18,4 +18,4 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -LDS_160k = 160 * 1024 \ No newline at end of file +LDS_160k = 160 * 1024 diff --git a/library/src/device/kernels/configs/config_sbcc.py b/library/src/device/kernels/configs/config_sbcc.py index b205e93906a..6eeed173374 100644 --- a/library/src/device/kernels/configs/config_sbcc.py +++ b/library/src/device/kernels/configs/config_sbcc.py @@ -91,4 +91,4 @@ 'sp': 'true', 'dp': 'true'}), NS(length=512, factors=[8, 8, 8], use_3steps_large_twd={ 'sp': 'true', 'dp': 'false'}), -] \ No newline at end of file +] diff --git a/library/src/device/kernels/configs/config_sbcr.py b/library/src/device/kernels/configs/config_sbcr.py index 1198ce44116..625d2a1d217 100644 --- a/library/src/device/kernels/configs/config_sbcr.py +++ b/library/src/device/kernels/configs/config_sbcr.py @@ -39,4 +39,4 @@ NS(length=100, factors=[10, 10], workgroup_size=100), NS(length=200, factors=[8, 5, 5]), NS(length=336, factors=[6, 7, 8]) -] \ No newline at end of file +] diff --git a/library/src/device/kernels/configs/config_sbrr.py b/library/src/device/kernels/configs/config_sbrr.py index 01eb643ad3e..2d495a775f9 100644 --- a/library/src/device/kernels/configs/config_sbrr.py +++ b/library/src/device/kernels/configs/config_sbrr.py @@ -505,4 +505,4 @@ NS(length=18816, workgroup_size=512, threads_per_transform=448, factors=(8, 8, 7, 7, 6), double_precision=False, lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), NS(length=19200, workgroup_size=512, threads_per_transform=480, factors=(8, 10, 8, 5, 6), double_precision=False, lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), NS(length=20480, workgroup_size=512, threads_per_transform=512, factors=(4, 4, 16, 10, 8), double_precision=False, lds_size_bytes=config_lds.LDS_160k, runtime_compile=True), -] \ No newline at end of file +] From ea4e8f79892e2eaaa9a85d2bda7fcf28132c5f4e Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 1 Aug 2025 11:17:35 -0600 Subject: [PATCH 65/69] - Replace std::accumulate instances with arithmetic helper. --- .../src/device/generator/stockham_gen.cpp | 13 +++++-------- .../src/device/generator/stockham_pp_gen_cc.h | 6 ++---- .../rocfft/library/src/rtc_stockham_kernel.cpp | 17 +++++++---------- 3 files changed, 14 insertions(+), 22 deletions(-) diff --git a/projects/rocfft/library/src/device/generator/stockham_gen.cpp b/projects/rocfft/library/src/device/generator/stockham_gen.cpp index 5201121dcdd..e7f921cd42f 100644 --- a/projects/rocfft/library/src/device/generator/stockham_gen.cpp +++ b/projects/rocfft/library/src/device/generator/stockham_gen.cpp @@ -21,6 +21,7 @@ #include using namespace std::placeholders; +#include "../../../../shared/arithmetic.h" #include "../../../../shared/precision_type.h" #include "generator.h" #include "stockham_gen.h" @@ -550,8 +551,7 @@ void validate_pp_length(const StockhamPartialPassParams& pp_params, const std::vector& factors) { - unsigned int length_curr - = std::accumulate(factors.begin(), factors.end(), 1, std::multiplies()); + unsigned int length_curr = product(factors.begin(), factors.end()); auto curr_dim = pp_params.current_dim; if(length_curr != pp_params.parent_length[curr_dim]) @@ -568,8 +568,7 @@ void validate_pp_off_dim_length(const StockhamPartialPassParams& pp_params_1, pp_params_2.pp_factors_curr.begin(), pp_params_2.pp_factors_curr.end()); - unsigned int length_off_dim = std::accumulate( - off_factors_all.begin(), off_factors_all.end(), 1, std::multiplies()); + unsigned int length_off_dim = product(off_factors_all.begin(), off_factors_all.end()); if(pp_params_1.parent_length[pp_params_1.off_dim] != pp_params_2.parent_length[pp_params_2.off_dim]) @@ -597,10 +596,8 @@ void validate_pp_grid_params(const StockhamPartialPassParams& params_1, ? specs_1.workgroup_size / specs_1.threads_per_transform : specs_2.workgroup_size / specs_2.threads_per_transform; - unsigned int prod_factors_off_dim = std::accumulate(params_1.pp_factors_curr.begin(), - params_1.pp_factors_curr.end(), - 1, - std::multiplies()); + auto prod_factors_off_dim + = product(params_1.pp_factors_curr.begin(), params_1.pp_factors_curr.end()); if(tpb_sbrr != prod_factors_off_dim) { throw std::runtime_error("CS_KERNEL_STOCKHAM_PP requires transform-per-block " diff --git a/projects/rocfft/library/src/device/generator/stockham_pp_gen_cc.h b/projects/rocfft/library/src/device/generator/stockham_pp_gen_cc.h index b0be9408b68..e587d1c018f 100644 --- a/projects/rocfft/library/src/device/generator/stockham_pp_gen_cc.h +++ b/projects/rocfft/library/src/device/generator/stockham_pp_gen_cc.h @@ -58,10 +58,8 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC transforms_per_block *= max_factor_pp; workgroup_size *= max_factor_pp; - pp_factors_curr_prod = std::accumulate( - factors_pp_curr.begin(), factors_pp_curr.end(), 1, std::multiplies()); - pp_factors_other_prod = std::accumulate( - factors_pp_other.begin(), factors_pp_other.end(), 1, std::multiplies()); + pp_factors_curr_prod = product(factors_pp_curr.begin(), factors_pp_curr.end()); + pp_factors_other_prod = product(factors_pp_other.begin(), factors_pp_other.end()); switch(params.off_dim) { diff --git a/projects/rocfft/library/src/rtc_stockham_kernel.cpp b/projects/rocfft/library/src/rtc_stockham_kernel.cpp index d07ab6108e3..50006bc89a5 100644 --- a/projects/rocfft/library/src/rtc_stockham_kernel.cpp +++ b/projects/rocfft/library/src/rtc_stockham_kernel.cpp @@ -20,6 +20,7 @@ #include +#include "../../shared/arithmetic.h" #include "../../shared/array_predicate.h" #include "function_pool.h" #include "kernel_launch.h" @@ -257,16 +258,12 @@ RTCKernelArgs RTCKernelStockham::get_launch_args(DeviceCallIn& data) kargs.append_ptr(data.node->twiddles_large); if(data.node->scheme == CS_KERNEL_STOCKHAM_PP_BLOCK_CC) { - auto kernel = data.node->GetKernel(); - auto pp_params = get_partial_pass_params(*data.node, kernel); - auto pp_factors_curr_prod = std::accumulate(pp_params.pp_factors_curr.begin(), - pp_params.pp_factors_curr.end(), - 1, - std::multiplies()); - auto pp_factors_other_prod = std::accumulate(pp_params.pp_factors_other.begin(), - pp_params.pp_factors_other.end(), - 1, - std::multiplies()); + auto kernel = data.node->GetKernel(); + auto pp_params = get_partial_pass_params(*data.node, kernel); + auto pp_factors_curr_prod + = product(pp_params.pp_factors_curr.begin(), pp_params.pp_factors_curr.end()); + auto pp_factors_other_prod + = product(pp_params.pp_factors_other.begin(), pp_params.pp_factors_other.end()); kargs.append_unsigned_int(data.node->length[1] * data.node->length[2]); kargs.append_unsigned_int(data.node->length[0] * data.node->length[1] From e76ba667924ff18a4590898b8a74a4b5509ef000 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 1 Aug 2025 11:54:58 -0600 Subject: [PATCH 66/69] - Address code review suggestions. --- .../src/device/generator/stockham_gen_base.h | 5 +++-- .../src/device/generator/stockham_pp_gen_cc.h | 15 ++++++++------- .../src/device/generator/stockham_pp_gen_rr.h | 4 ++-- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/projects/rocfft/library/src/device/generator/stockham_gen_base.h b/projects/rocfft/library/src/device/generator/stockham_gen_base.h index 48bb1badda1..4fc6aba481b 100644 --- a/projects/rocfft/library/src/device/generator/stockham_gen_base.h +++ b/projects/rocfft/library/src/device/generator/stockham_gen_base.h @@ -518,8 +518,9 @@ struct StockhamKernel : public StockhamGeneratorSpecs Expression guard_expr = Expression{Literal{"true"}}; - auto effective_length = work_length ? *work_length : length; - auto thread_guard_cond = (effective_length / width) * (guard_factor ? *guard_factor : 1); + const auto effective_length = work_length ? *work_length : length; + const auto thread_guard_cond + = (effective_length / width) * (guard_factor ? *guard_factor : 1); // do thread gurad when guard_by_if or guard_by_arg if(guard != ThreadGuardMode::NO_GUARD) diff --git a/projects/rocfft/library/src/device/generator/stockham_pp_gen_cc.h b/projects/rocfft/library/src/device/generator/stockham_pp_gen_cc.h index e587d1c018f..34bf7df4cb1 100644 --- a/projects/rocfft/library/src/device/generator/stockham_pp_gen_cc.h +++ b/projects/rocfft/library/src/device/generator/stockham_pp_gen_cc.h @@ -68,7 +68,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC "StockhamPartialPassKernelCC:: partial-passes along x not currently supported"); break; case 1: - num_blocks_per_batch = ((params.parent_length[1]) - 1) / transforms_per_block + 1; + num_blocks_per_batch = (params.parent_length[1] - 1) / transforms_per_block + 1; num_blocks_per_batch *= params.parent_length[2]; break; case 2: @@ -177,7 +177,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC StatementList work; for(unsigned int w = 0; w < width; ++w) - work += Assign(lds_complex[offset_lds + ((hr * width + w) * stride_lds)], + work += Assign(lds_complex[offset_lds + (hr * width + w) * stride_lds], R[hr * width + w]); return work; @@ -218,7 +218,7 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC for(unsigned int w = 0; w < width; ++w) work += Assign(R[hr * width + w], - lds_complex[offset_lds + ((hr * width + w) * stride_lds)]); + lds_complex[offset_lds + (hr * width + w) * stride_lds]); return work; } @@ -751,13 +751,14 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC for(unsigned int npass = 0; npass < factors_pp_curr.size(); ++npass) { - unsigned int width = factors_pp_curr[npass]; - unsigned int height = static_cast(length) / width / threads_per_transform; + unsigned int pass_width = factors_pp_curr[npass]; + unsigned int pass_height + = static_cast(length) / pass_width / threads_per_transform; auto butterfly = std::mem_fn(&StockhamKernel::butterfly_generator); stmts += add_work(std::bind(butterfly, this, _1, _2, _3, _4, _5), - width, - height, + pass_width, + pass_height, ThreadGuardMode::NO_GUARD); } diff --git a/projects/rocfft/library/src/device/generator/stockham_pp_gen_rr.h b/projects/rocfft/library/src/device/generator/stockham_pp_gen_rr.h index 6d083fed48d..57dce719674 100644 --- a/projects/rocfft/library/src/device/generator/stockham_pp_gen_rr.h +++ b/projects/rocfft/library/src/device/generator/stockham_pp_gen_rr.h @@ -234,7 +234,7 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR for(unsigned int w = 0; w < width; ++w) work += Assign(R[hr * width + w], - lds_complex[offset_lds + ((hr * width + w) * stride_lds)]); + lds_complex[offset_lds + (hr * width + w) * stride_lds]); return work; } @@ -295,7 +295,7 @@ struct StockhamPartialPassKernelRR : public StockhamKernelRR StatementList work; for(unsigned int w = 0; w < width; ++w) - work += Assign(lds_complex[offset_lds + ((hr * width + w) * stride_lds)], + work += Assign(lds_complex[offset_lds + (hr * width + w) * stride_lds], R[hr * width + w]); return work; From 3819868c16058e0d9c29cca79f37a345074582d4 Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 1 Aug 2025 13:53:13 -0600 Subject: [PATCH 67/69] - More code review suggestions. --- .../rocfft/library/src/rtc_stockham_kernel.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/projects/rocfft/library/src/rtc_stockham_kernel.cpp b/projects/rocfft/library/src/rtc_stockham_kernel.cpp index 50006bc89a5..407a5051fba 100644 --- a/projects/rocfft/library/src/rtc_stockham_kernel.cpp +++ b/projects/rocfft/library/src/rtc_stockham_kernel.cpp @@ -37,13 +37,13 @@ static StockhamPartialPassParams get_partial_pass_params(const TreeNode& node, { StockhamPartialPassParams pp_params; - pp_params.off_dim = node.ppOffDim; - pp_params.current_dim = node.ppCurrDim; - pp_params.pp_factors_curr = std::vector( - kernel.pp_params.pp_factors_curr.begin(), kernel.pp_params.pp_factors_curr.end()); - pp_params.pp_factors_other = std::vector( - kernel.pp_params.pp_factors_other.begin(), kernel.pp_params.pp_factors_other.end()); - pp_params.parent_length = std::vector(node.length.begin(), node.length.end()); + pp_params.off_dim = node.ppOffDim; + pp_params.current_dim = node.ppCurrDim; + pp_params.pp_factors_curr.assign(kernel.pp_params.pp_factors_curr.begin(), + kernel.pp_params.pp_factors_curr.end()); + pp_params.pp_factors_other.assign(kernel.pp_params.pp_factors_other.begin(), + kernel.pp_params.pp_factors_other.end()); + pp_params.parent_length.assign(node.length.begin(), node.length.end()); return pp_params; } From 5205082bed068e0f85a925cfe48dd4693694601f Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 1 Aug 2025 14:35:49 -0600 Subject: [PATCH 68/69] - Remove casts from offset calculation. --- .../library/src/device/generator/stockham_pp_gen_cc.h | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/projects/rocfft/library/src/device/generator/stockham_pp_gen_cc.h b/projects/rocfft/library/src/device/generator/stockham_pp_gen_cc.h index 34bf7df4cb1..18ec85fac9f 100644 --- a/projects/rocfft/library/src/device/generator/stockham_pp_gen_cc.h +++ b/projects/rocfft/library/src/device/generator/stockham_pp_gen_cc.h @@ -735,9 +735,8 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC stmts += Declaration{stride_lds_pp, Literal{1}}; unsigned int width = factors_pp_curr[0]; - float height = static_cast(length) / width / threads_per_transform; - stmts += Declaration{offset_lds_pp, - thread_id * Literal{width * static_cast(height)}}; + unsigned int height = length / width / threads_per_transform; + stmts += Declaration{offset_lds_pp, thread_id * Literal{width * height}}; auto pre_post_lds_tmpl = device_lds_reg_inout_device_call_templates(); auto pre_post_lds_args = device_lds_reg_inout_pp_device_call_arguments(); @@ -763,9 +762,8 @@ struct StockhamPartialPassKernelCC : public StockhamKernelCC } width = factors_pp_curr.back(); - height = static_cast(length) / width / threads_per_transform; - stmts += Assign{offset_lds_pp, - thread_id * Literal{width * static_cast(height)}}; + height = length / width / threads_per_transform; + stmts += Assign{offset_lds_pp, thread_id * Literal{width * height}}; StatementList postStore; postStore From 719c7f8dca766d2933e16031b9abed1cf399873a Mon Sep 17 00:00:00 2001 From: Flavio Teixeira Date: Fri, 1 Aug 2025 16:38:44 -0600 Subject: [PATCH 69/69] - Add parent length to partial-pass kernel name. --- .../library/src/include/rtc_stockham_gen.h | 41 +++++++-------- .../rocfft/library/src/rocfft_aot_helper.cpp | 2 + .../rocfft/library/src/rtc_stockham_gen.cpp | 51 ++++++++++--------- .../library/src/rtc_stockham_kernel.cpp | 1 + 4 files changed, 52 insertions(+), 43 deletions(-) diff --git a/projects/rocfft/library/src/include/rtc_stockham_gen.h b/projects/rocfft/library/src/include/rtc_stockham_gen.h index c3b41aa92f6..ab79309f7f3 100644 --- a/projects/rocfft/library/src/include/rtc_stockham_gen.h +++ b/projects/rocfft/library/src/include/rtc_stockham_gen.h @@ -32,26 +32,27 @@ #include "../device/kernels/common.h" // generate name for RTC stockham kernel -std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, - const StockhamGeneratorSpecs& specs2d, - ComputeScheme scheme, - int direction, - rocfft_precision precision, - rocfft_result_placement placement, - rocfft_array_type inArrayType, - rocfft_array_type outArrayType, - bool unitstride, - size_t largeTwdBase, - size_t largeTwdSteps, - bool largeTwdBatchIsTransformCount, - DirectRegType dir2regMode, - IntrinsicAccessType intrinsicMode, - SBRC_TRANSPOSE_TYPE transpose_type, - CallbackType cbtype, - BluesteinFuseType fuseBlue, - PartialPassType ppType, - const LoadOps& loadOps, - const StoreOps& storeOps); +std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, + const StockhamGeneratorSpecs& specs2d, + ComputeScheme scheme, + int direction, + rocfft_precision precision, + rocfft_result_placement placement, + rocfft_array_type inArrayType, + rocfft_array_type outArrayType, + bool unitstride, + size_t largeTwdBase, + size_t largeTwdSteps, + bool largeTwdBatchIsTransformCount, + DirectRegType dir2regMode, + IntrinsicAccessType intrinsicMode, + SBRC_TRANSPOSE_TYPE transpose_type, + CallbackType cbtype, + BluesteinFuseType fuseBlue, + PartialPassType ppType, + const StockhamPartialPassParams& ppParams, + const LoadOps& loadOps, + const StoreOps& storeOps); // generate source for RTC stockham kernel. transforms_per_block may // be nullptr, but if non-null, stockham_rtc stores the number of diff --git a/projects/rocfft/library/src/rocfft_aot_helper.cpp b/projects/rocfft/library/src/rocfft_aot_helper.cpp index eb48f072aed..5468e211f1c 100644 --- a/projects/rocfft/library/src/rocfft_aot_helper.cpp +++ b/projects/rocfft/library/src/rocfft_aot_helper.cpp @@ -301,6 +301,7 @@ void build_stockham_function_pool(CompileQueue& queue) cbtype, fuseBlue, ppType, + ppParams, {}, {}); std::function generate_src @@ -692,6 +693,7 @@ void build_solution_kernels(CompileQueue& queue) cbtype, fuseBlue, ppType, + ppParams, {}, {}); diff --git a/projects/rocfft/library/src/rtc_stockham_gen.cpp b/projects/rocfft/library/src/rtc_stockham_gen.cpp index 4861f3879bf..9d234403555 100644 --- a/projects/rocfft/library/src/rtc_stockham_gen.cpp +++ b/projects/rocfft/library/src/rtc_stockham_gen.cpp @@ -43,26 +43,27 @@ using namespace std::placeholders; #include "device/kernel-generator-embed.h" // generate name for RTC stockham kernel -std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, - const StockhamGeneratorSpecs& specs2d, - ComputeScheme scheme, - int direction, - rocfft_precision precision, - rocfft_result_placement placement, - rocfft_array_type inArrayType, - rocfft_array_type outArrayType, - bool unitstride, - size_t largeTwdBase, - size_t largeTwdSteps, - bool largeTwdBatchIsTransformCount, - DirectRegType dir2regMode, - IntrinsicAccessType intrinsicMode, - SBRC_TRANSPOSE_TYPE transpose_type, - CallbackType cbtype, - BluesteinFuseType fuseBlue, - PartialPassType ppType, - const LoadOps& loadOps, - const StoreOps& storeOps) +std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, + const StockhamGeneratorSpecs& specs2d, + ComputeScheme scheme, + int direction, + rocfft_precision precision, + rocfft_result_placement placement, + rocfft_array_type inArrayType, + rocfft_array_type outArrayType, + bool unitstride, + size_t largeTwdBase, + size_t largeTwdSteps, + bool largeTwdBatchIsTransformCount, + DirectRegType dir2regMode, + IntrinsicAccessType intrinsicMode, + SBRC_TRANSPOSE_TYPE transpose_type, + CallbackType cbtype, + BluesteinFuseType fuseBlue, + PartialPassType ppType, + const StockhamPartialPassParams& ppParams, + const LoadOps& loadOps, + const StoreOps& storeOps) { std::string kernel_name = "fft_rtc"; @@ -77,10 +78,14 @@ std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, break; case PPT_SBCC: case PPT_SBRR: - kernel_name += "_pp"; + kernel_name += "_partial_pass"; + kernel_name += "_parent_len"; + for(auto f : ppParams.parent_length) + kernel_name += "_" + std::to_string(f); + break; } - kernel_name += "_len"; + kernel_name += "_len_"; kernel_name += std::to_string(specs.length); if(scheme == CS_KERNEL_2D_SINGLE) kernel_name += "x" + std::to_string(specs2d.length); @@ -113,7 +118,7 @@ std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, if(specs.static_dim) { - kernel_name += "_dim"; + kernel_name += "_dim_"; kernel_name += std::to_string(specs.static_dim); } diff --git a/projects/rocfft/library/src/rtc_stockham_kernel.cpp b/projects/rocfft/library/src/rtc_stockham_kernel.cpp index d07ab6108e3..1f13188efa1 100644 --- a/projects/rocfft/library/src/rtc_stockham_kernel.cpp +++ b/projects/rocfft/library/src/rtc_stockham_kernel.cpp @@ -205,6 +205,7 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& node.GetCallbackType(enable_callbacks), node.fuseBlue, ppType, + pp_params, node.loadOps, node.storeOps); };