diff --git a/projects/rocfft/library/src/include/rtc_stockham_gen.h b/projects/rocfft/library/src/include/rtc_stockham_gen.h index c3b41aa92f6..ab79309f7f3 100644 --- a/projects/rocfft/library/src/include/rtc_stockham_gen.h +++ b/projects/rocfft/library/src/include/rtc_stockham_gen.h @@ -32,26 +32,27 @@ #include "../device/kernels/common.h" // generate name for RTC stockham kernel -std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, - const StockhamGeneratorSpecs& specs2d, - ComputeScheme scheme, - int direction, - rocfft_precision precision, - rocfft_result_placement placement, - rocfft_array_type inArrayType, - rocfft_array_type outArrayType, - bool unitstride, - size_t largeTwdBase, - size_t largeTwdSteps, - bool largeTwdBatchIsTransformCount, - DirectRegType dir2regMode, - IntrinsicAccessType intrinsicMode, - SBRC_TRANSPOSE_TYPE transpose_type, - CallbackType cbtype, - BluesteinFuseType fuseBlue, - PartialPassType ppType, - const LoadOps& loadOps, - const StoreOps& storeOps); +std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, + const StockhamGeneratorSpecs& specs2d, + ComputeScheme scheme, + int direction, + rocfft_precision precision, + rocfft_result_placement placement, + rocfft_array_type inArrayType, + rocfft_array_type outArrayType, + bool unitstride, + size_t largeTwdBase, + size_t largeTwdSteps, + bool largeTwdBatchIsTransformCount, + DirectRegType dir2regMode, + IntrinsicAccessType intrinsicMode, + SBRC_TRANSPOSE_TYPE transpose_type, + CallbackType cbtype, + BluesteinFuseType fuseBlue, + PartialPassType ppType, + const StockhamPartialPassParams& ppParams, + const LoadOps& loadOps, + const StoreOps& storeOps); // generate source for RTC stockham kernel. transforms_per_block may // be nullptr, but if non-null, stockham_rtc stores the number of diff --git a/projects/rocfft/library/src/rocfft_aot_helper.cpp b/projects/rocfft/library/src/rocfft_aot_helper.cpp index eb48f072aed..5468e211f1c 100644 --- a/projects/rocfft/library/src/rocfft_aot_helper.cpp +++ b/projects/rocfft/library/src/rocfft_aot_helper.cpp @@ -301,6 +301,7 @@ void build_stockham_function_pool(CompileQueue& queue) cbtype, fuseBlue, ppType, + ppParams, {}, {}); std::function generate_src @@ -692,6 +693,7 @@ void build_solution_kernels(CompileQueue& queue) cbtype, fuseBlue, ppType, + ppParams, {}, {}); diff --git a/projects/rocfft/library/src/rtc_stockham_gen.cpp b/projects/rocfft/library/src/rtc_stockham_gen.cpp index 4861f3879bf..9d234403555 100644 --- a/projects/rocfft/library/src/rtc_stockham_gen.cpp +++ b/projects/rocfft/library/src/rtc_stockham_gen.cpp @@ -43,26 +43,27 @@ using namespace std::placeholders; #include "device/kernel-generator-embed.h" // generate name for RTC stockham kernel -std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, - const StockhamGeneratorSpecs& specs2d, - ComputeScheme scheme, - int direction, - rocfft_precision precision, - rocfft_result_placement placement, - rocfft_array_type inArrayType, - rocfft_array_type outArrayType, - bool unitstride, - size_t largeTwdBase, - size_t largeTwdSteps, - bool largeTwdBatchIsTransformCount, - DirectRegType dir2regMode, - IntrinsicAccessType intrinsicMode, - SBRC_TRANSPOSE_TYPE transpose_type, - CallbackType cbtype, - BluesteinFuseType fuseBlue, - PartialPassType ppType, - const LoadOps& loadOps, - const StoreOps& storeOps) +std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, + const StockhamGeneratorSpecs& specs2d, + ComputeScheme scheme, + int direction, + rocfft_precision precision, + rocfft_result_placement placement, + rocfft_array_type inArrayType, + rocfft_array_type outArrayType, + bool unitstride, + size_t largeTwdBase, + size_t largeTwdSteps, + bool largeTwdBatchIsTransformCount, + DirectRegType dir2regMode, + IntrinsicAccessType intrinsicMode, + SBRC_TRANSPOSE_TYPE transpose_type, + CallbackType cbtype, + BluesteinFuseType fuseBlue, + PartialPassType ppType, + const StockhamPartialPassParams& ppParams, + const LoadOps& loadOps, + const StoreOps& storeOps) { std::string kernel_name = "fft_rtc"; @@ -77,10 +78,14 @@ std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, break; case PPT_SBCC: case PPT_SBRR: - kernel_name += "_pp"; + kernel_name += "_partial_pass"; + kernel_name += "_parent_len"; + for(auto f : ppParams.parent_length) + kernel_name += "_" + std::to_string(f); + break; } - kernel_name += "_len"; + kernel_name += "_len_"; kernel_name += std::to_string(specs.length); if(scheme == CS_KERNEL_2D_SINGLE) kernel_name += "x" + std::to_string(specs2d.length); @@ -113,7 +118,7 @@ std::string stockham_rtc_kernel_name(const StockhamGeneratorSpecs& specs, if(specs.static_dim) { - kernel_name += "_dim"; + kernel_name += "_dim_"; kernel_name += std::to_string(specs.static_dim); } diff --git a/projects/rocfft/library/src/rtc_stockham_kernel.cpp b/projects/rocfft/library/src/rtc_stockham_kernel.cpp index 407a5051fba..d2e955be9fd 100644 --- a/projects/rocfft/library/src/rtc_stockham_kernel.cpp +++ b/projects/rocfft/library/src/rtc_stockham_kernel.cpp @@ -206,6 +206,7 @@ RTCKernel::RTCGenerator RTCKernelStockham::generate_from_node(const LeafNode& node.GetCallbackType(enable_callbacks), node.fuseBlue, ppType, + pp_params, node.loadOps, node.storeOps); };