diff --git a/.githooks/pre-commit b/.githooks/pre-commit index e166dadd03..d91036e8fc 100755 --- a/.githooks/pre-commit +++ b/.githooks/pre-commit @@ -40,4 +40,3 @@ do "$format" -i -style=file "$file" fi done - diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 02bcb88622..3e56c89a9d 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -32,3 +32,4 @@ The MIOpen API library is structured as follows: * :doc:`GroupNorm <../doxygen/html/group__groupnorm>` (experimental) * :doc:`Cat <../doxygen/html/group__cat>` (experimental) * :doc:`Argmax<./argmax>` (experimental) + * :doc:`NLLLoss<../doxygen/html/group__nllloss>` (experimental) diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 224e550fed..13566589d8 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -45,6 +45,7 @@ add_executable(MIOpenDriver dm_groupnorm.cpp dm_layernorm.cpp dm_lrn.cpp + dm_nllloss.cpp dm_pool.cpp dm_reduce.cpp dm_rnn.cpp diff --git a/driver/dm_nllloss.cpp b/driver/dm_nllloss.cpp new file mode 100644 index 0000000000..1d4fa4c0dc --- /dev/null +++ b/driver/dm_nllloss.cpp @@ -0,0 +1,40 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "registry_driver_maker.hpp" +#include "nllloss_driver.hpp" + +static Driver* makeDriver(const std::string& base_arg) +{ + if(base_arg == "nllloss") + return new NLLLossDriver(); + if(base_arg == "nlllossfp16") + return new NLLLossDriver(); + if(base_arg == "nlllossbfp16") + return new NLLLossDriver(); + return nullptr; +} + +REGISTER_DRIVER_MAKER(makeDriver); diff --git a/driver/driver.hpp b/driver/driver.hpp index 4cfc2b544e..3e9969f40b 100644 --- a/driver/driver.hpp +++ b/driver/driver.hpp @@ -151,7 +151,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz) "pool[fp16], lrn[fp16], " "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], " "tensorop[fp16], reduce[fp16|fp64], layernorm[bfp16|fp16], sum[bfp16|fp16], " - "argmax[bfp16|fp16], groupnorm[bfp16|fp16], cat[bfp16|fp16]\n"); + "argmax[bfp16|fp16], groupnorm[bfp16|fp16], cat[bfp16|fp16], nllloss[bfp16|fp16]\n"); exit(0); // NOLINT (concurrency-mt-unsafe) } @@ -176,7 +176,8 @@ inline std::string ParseBaseArg(int argc, char* argv[]) arg != "layernormfp16" && arg != "layernormbfp16" && arg != "sum" && arg != "sumfp16" && arg != "sumbfp16" && arg != "argmax" && arg != "argmaxfp16" && arg != "argmaxbfp16" && arg != "groupnorm" && arg != "groupnormfp16" && arg != "groupnormbfp16" && arg != "cat" && - arg != "catfp16" && arg != "catbfp16" && arg != "--version") + arg != "catfp16" && arg != "catbfp16" && arg != "nllloss" && arg != "nlllossfp16" && + arg != "nlllossbfp16" && arg != "--version") { printf("FAILED: Invalid Base Input Argument\n"); Usage(); diff --git a/driver/mloNLLLossHost.hpp b/driver/mloNLLLossHost.hpp new file mode 100644 index 0000000000..df2eb3251d --- /dev/null +++ b/driver/mloNLLLossHost.hpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MLO_NLLLOSSHOST_H_ +#define MLO_NLLLOSSHOST_H_ + +#include + +template +int32_t mloNLLLossForwardRunHost(miopenTensorDescriptor_t inputDesc, + Tgpu* input, + int32_t* target, + Tgpu* weight, + Tcheck* outputhost, + int32_t ignore_index) +{ + auto dims = miopen::deref(inputDesc).GetLengths(); + + size_t N = dims[0]; + size_t C = dims[1]; + size_t D1 = dims[2]; + size_t D2 = dims[3]; + + for(size_t n = 0; n < N; n++) + { + for(size_t d1 = 0; d1 < D1; d1++) + { + for(size_t d2 = 0; d2 < D2; d2++) + { + size_t target_index = n * D1 * D2 + d1 * D2 + d2; + int32_t t = target[target_index]; + size_t input_index = (n * C + t) * D1 * D2 + d1 * D2 + d2; + size_t weight_index = t; + size_t output_index = target_index; + + if(t < 0 || t == ignore_index || t >= C) + { + outputhost[output_index] = static_cast(0); + } + else + { + outputhost[output_index] = static_cast(-1) * + static_cast(weight[weight_index]) * + static_cast(input[input_index]); + } + } + } + } + + return 0; +} +#endif // MLO_NLLLOSSHOST_H_ diff --git a/driver/nllloss_driver.hpp b/driver/nllloss_driver.hpp new file mode 100644 index 0000000000..742d8f995e --- /dev/null +++ b/driver/nllloss_driver.hpp @@ -0,0 +1,327 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_MIOPEN_NLLLOSS_DRIVER_HPP +#define GUARD_MIOPEN_NLLLOSS_DRIVER_HPP + +#include "InputFlags.hpp" +#include "driver.hpp" +#include "mloNLLLossHost.hpp" +#include "random.hpp" +#include "tensor_driver.hpp" +#include "timer.hpp" +#include "util_driver.hpp" + +#include <../test/tensor_holder.hpp> +#include <../test/verify.hpp> + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +template +class NLLLossDriver : public Driver +{ +public: + NLLLossDriver() : Driver() + { + miopenCreateTensorDescriptor(&inputDesc); + miopenCreateTensorDescriptor(&targetDesc); + miopenCreateTensorDescriptor(&weightDesc); + miopenCreateTensorDescriptor(&outputDesc); + + data_type = miopen_type{}; + } + + int AddCmdLineArgs() override; + int ParseCmdLineArgs(int argc, char* argv[]) override; + InputFlags& GetInputFlags() override { return inflags; } + + int GetandSetData() override; + + int AllocateBuffersAndCopy() override; + + int RunForwardGPU() override; + int RunForwardCPU(); + + int RunBackwardGPU() override; + + Tref GetTolerance(); + int VerifyBackward() override; + int VerifyForward() override; + ~NLLLossDriver() override + { + miopenDestroyTensorDescriptor(inputDesc); + miopenDestroyTensorDescriptor(targetDesc); + miopenDestroyTensorDescriptor(weightDesc); + miopenDestroyTensorDescriptor(outputDesc); + } + +private: + InputFlags inflags; + + int forw; + + miopenTensorDescriptor_t inputDesc; + miopenTensorDescriptor_t targetDesc; + miopenTensorDescriptor_t weightDesc; + miopenTensorDescriptor_t outputDesc; + + std::unique_ptr in_dev; + std::unique_ptr target_dev; + std::unique_ptr weight_dev; + std::unique_ptr out_dev; + + std::vector in; + std::vector target; + std::vector weight; + std::vector out; + std::vector out_host; + + size_t N; + size_t C; + size_t D1; + size_t D2; + int ignore_index; +}; + +template +int NLLLossDriver::ParseCmdLineArgs(int argc, char* argv[]) +{ + inflags.Parse(argc, argv); + + if(inflags.GetValueInt("time") == 1) + { + miopenEnableProfiling(GetHandle(), true); + } + return miopenStatusSuccess; +} + +template +int NLLLossDriver::GetandSetData() +{ + N = inflags.GetValueInt("batchsize"); + C = inflags.GetValueInt("numclasses"); + D1 = inflags.GetValueInt("D1"); + D2 = inflags.GetValueInt("D2"); + ignore_index = static_cast(inflags.GetValueInt("ignore_index")); + + if(N <= 0 || C <= 0 || D1 <= 0 || D2 <= 0) + { + MIOPEN_THROW("Error Input Tensor Lengths"); + } + + std::vector in_len = {N, C, D1, D2}; + std::vector target_len = {N, D1, D2}; + std::vector weight_len = {C}; + std::vector out_len = {N, D1, D2}; + + SetTensorNd(inputDesc, in_len, data_type); + SetTensorNd(targetDesc, target_len, data_type); + SetTensorNd(weightDesc, weight_len, data_type); + SetTensorNd(outputDesc, out_len, data_type); + + return 0; +} + +template +int NLLLossDriver::AddCmdLineArgs() +{ + inflags.AddInputFlag("forw", 'F', "1", "Run only Forward NLLLoss (Default=1)", "int"); + inflags.AddInputFlag("batchsize", 'N', "5", "Batch size", "int"); + inflags.AddInputFlag("numclasses", 'C', "10", "Number of classes", "int"); + inflags.AddInputFlag("D1", 'd', "17", "Size D1", "int"); + inflags.AddInputFlag("D2", 'D', "19", "Size D2", "int"); + inflags.AddInputFlag("ignore_index", 'g', "-1", "Ignore index", "int"); + + inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); + inflags.AddInputFlag("verify", 'V', "1", "Verify (Default=1)", "int"); + inflags.AddInputFlag("time", 't', "1", "Time (Default=0)", "int"); + inflags.AddInputFlag( + "wall", 'w', "1", "Wall-clock Time, Requires time == 1 (Default=0)", "int"); + + return miopenStatusSuccess; +} + +template +int NLLLossDriver::AllocateBuffersAndCopy() +{ + size_t in_sz = GetTensorSize(inputDesc); + size_t target_sz = GetTensorSize(targetDesc); + size_t weight_sz = GetTensorSize(weightDesc); + size_t out_sz = GetTensorSize(outputDesc); + + uint32_t ctx = 0; + + in_dev = std::unique_ptr(new GPUMem(ctx, in_sz, sizeof(Tgpu))); + target_dev = std::unique_ptr(new GPUMem(ctx, target_sz, sizeof(int))); + weight_dev = std::unique_ptr(new GPUMem(ctx, weight_sz, sizeof(Tgpu))); + out_dev = std::unique_ptr(new GPUMem(ctx, out_sz, sizeof(Tgpu))); + + in = std::vector(in_sz, static_cast(0)); + target = std::vector(target_sz, static_cast(0)); + weight = std::vector(weight_sz, static_cast(1)); + out = std::vector(out_sz, static_cast(0)); + out_host = std::vector(out_sz, static_cast(0)); + + int status; + + for(int i = 0; i < in_sz; i++) + { + in[i] = prng::gen_A_to_B(static_cast(-10.0), static_cast(-(1e-2))); + } + status = in_dev->ToGPU(q, in.data()); + + for(int i = 0; i < target_sz; i++) + { + target[i] = prng::gen_A_to_B(static_cast(0), static_cast(C - 1)); + } + status |= target_dev->ToGPU(q, target.data()); + + for(int i = 0; i < weight_sz; i++) + { + weight[i] = prng::gen_A_to_B(static_cast(-10.0), static_cast(10.0)); + } + status |= weight_dev->ToGPU(q, weight.data()); + + status |= out_dev->ToGPU(q, out.data()); + + if(status != 0) + std::cout << "Error copying data to GPU\n" << std::endl; + + return miopenStatusSuccess; +} + +template +int NLLLossDriver::RunForwardGPU() +{ + float kernel_total_time = 0.0; + float kernel_first_time = 0.0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + miopenNLLLossForward(GetHandle(), + inputDesc, + in_dev->GetMem(), + targetDesc, + target_dev->GetMem(), + weightDesc, + weight_dev->GetMem(), + outputDesc, + out_dev->GetMem(), + ignore_index); + + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + printf("Wall-clock Time Forward NLLLoss Elapsed: %f ms\n", t.gettime_ms() / iter); + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + printf("GPU Kernel Time Forward NLLLoss Elapsed: %f ms\n", kernel_average_time); + } + + out_dev->FromGPU(GetStream(), out.data()); + + return miopenStatusSuccess; +} + +template +int NLLLossDriver::RunForwardCPU() +{ + mloNLLLossForwardRunHost( + inputDesc, in.data(), target.data(), weight.data(), out_host.data(), ignore_index); + + return miopenStatusSuccess; +} + +template +int NLLLossDriver::RunBackwardGPU() +{ + return miopenStatusSuccess; +} + +template +Tref NLLLossDriver::GetTolerance() +{ + // Computation error of fp16 is ~2^13 (=8192) bigger than + // the one of fp32 because mantissa is shorter by 13 bits. + auto tolerance = std::is_same::value ? 1.5e-6 : 8.2e-3; + + // bf16 mantissa has 7 bits, by 3 bits shorter than fp16. + if(std::is_same::value) + tolerance *= 8.0; + return tolerance; +} + +template +int NLLLossDriver::VerifyForward() +{ + RunForwardCPU(); + const Tref tolerance = GetTolerance(); + auto error = miopen::rms_range(out_host, out); + + if(!std::isfinite(error) || error > tolerance) + { + std::cout << "Forward NLLLoss FAILED: " << error << std::endl; + return EC_VerifyFwd; + } + else + { + printf("Forward NLLLoss Verifies on CPU and GPU (err=%f)\n", error); + } + + return miopenStatusSuccess; +} + +template +int NLLLossDriver::VerifyBackward() +{ + return miopenStatusSuccess; +} + +#endif // GUARD_MIOPEN_NLLLOSS_DRIVER_HPP diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index e768c7b349..f50619c6f2 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -6582,6 +6582,40 @@ MIOPEN_EXPORT miopenStatus_t miopenBackendInitialize(miopenBackendDescriptor_t d // CLOSEOUT BackendAPI DOXYGEN GROUP #endif // MIOPEN_BETA_API +#ifdef MIOPEN_BETA_API +// NLLLoss APIs +/** @addtogroup nllloss + * + * @{ + */ +/*! @brief Execute a nllloss forward layer + * + * @param handle MIOpen handle (input) + * @param inputDesc Tensor descriptor for data input tensor input (input) + * @param input Data tensor input (input) + * @param targetDesc Tensor descriptor for data input tensor target (input) + * @param target Data tensor target (input) + * @param weightDesc Tensor descriptor for data input tensor weight (input) + * @param weight Data tensor weight (input) + * @param outputDesc Tensor descriptor for output data tensor y (input) + * @param output Data tensor y (output) + * @param ignore_index Class index to ignore (input) + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenNLLLossForward(miopenHandle_t handle, + const miopenTensorDescriptor_t inputDesc, + const void* input, + const miopenTensorDescriptor_t targetDesc, + const void* target, + const miopenTensorDescriptor_t weightDesc, + const void* weight, + const miopenTensorDescriptor_t outputDesc, + void* output, + int ignore_index); +/** @} */ +// CLOSEOUT nllloss DOXYGEN GROUP +#endif // MIOPEN_BETA_API + #ifdef __cplusplus } #endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9671eed03c..a7b9f0b380 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -142,6 +142,8 @@ set( MIOpen_Source lrn_api.cpp mha/mha_descriptor.cpp mha/problem_description.cpp + nllloss_api.cpp + nllloss/problem_description.cpp op_args.cpp operator.cpp performance_config.cpp @@ -264,6 +266,7 @@ set( MIOpen_Source solver/layernorm/forward_layernorm2d_ck.cpp solver/layernorm/forward_layernorm4d_ck.cpp solver/mha/mha_solver.cpp + solver/nllloss/forward_nllloss.cpp solver/pooling/forward2d.cpp solver/pooling/forwardNaive.cpp solver/pooling/forwardNd.cpp @@ -459,6 +462,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/MIOpenLRNBwd.cl kernels/MIOpenLRNFwd.cl kernels/MIOpenNeuron.cl + kernels/MIOpenNLLLoss.cpp kernels/MIOpenPooling.cl kernels/MIOpenPoolingBwd.cl kernels/MIOpenPoolingBwdND.cl @@ -583,6 +587,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN layer_norm.cpp lrn.cpp mlo_dir_conv.cpp + nllloss.cpp exec_utils.cpp ocl/activ_ocl.cpp ocl/batchnormocl.cpp diff --git a/src/include/miopen/nllloss.hpp b/src/include/miopen/nllloss.hpp new file mode 100644 index 0000000000..a29a5cf3fc --- /dev/null +++ b/src/include/miopen/nllloss.hpp @@ -0,0 +1,49 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#ifndef MIOPEN_NLLLOSS_HPP_ +#define MIOPEN_NLLLOSS_HPP_ + +#include + +namespace miopen { + +struct Handle; +struct TensorDescriptor; + +miopenStatus_t NLLLossForward(Handle& handle, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& targetDesc, + ConstData_t target, + const TensorDescriptor& weightDesc, + ConstData_t weight, + const TensorDescriptor& outputDesc, + Data_t output, + int ignore_index); + +} // namespace miopen +#endif // _MIOPEN_NLLLOSS_HPP_ diff --git a/src/include/miopen/nllloss/invoke_params.hpp b/src/include/miopen/nllloss/invoke_params.hpp new file mode 100644 index 0000000000..94ee1d816a --- /dev/null +++ b/src/include/miopen/nllloss/invoke_params.hpp @@ -0,0 +1,56 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include "miopen/common.hpp" +#include +#include + +namespace miopen { +namespace nllloss { + +struct InvokeParams : public miopen::InvokeParams +{ + + InvokeParams() = default; + + const TensorDescriptor* inputDesc = nullptr; + const TensorDescriptor* outputDesc = nullptr; + + ConstData_t input = nullptr; + ConstData_t target = nullptr; + ConstData_t weight = nullptr; + Data_t output = nullptr; + + int ignore_index = -1; + + std::size_t GetWorkspaceSize() const { return 0; } + Data_t GetWorkspace() const { return nullptr; } +}; + +} // namespace nllloss +} // namespace miopen diff --git a/src/include/miopen/nllloss/problem_description.hpp b/src/include/miopen/nllloss/problem_description.hpp new file mode 100644 index 0000000000..aedc78616c --- /dev/null +++ b/src/include/miopen/nllloss/problem_description.hpp @@ -0,0 +1,135 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +namespace miopen { + +struct NetworkConfig; + +namespace nllloss { + +struct ProblemDescription : ProblemDescriptionBase +{ + ProblemDescription(const TensorDescriptor& inputDesc_, + const TensorDescriptor& targetDesc_, + const TensorDescriptor& weightDesc_, + const TensorDescriptor& outputDesc_, + int ignore_index_) + : inputDesc(inputDesc_), + targetDesc(targetDesc_), + weightDesc(weightDesc_), + outputDesc(outputDesc_), + ignore_index(ignore_index_), + N_total(outputDesc_.GetElementSize()), + N(inputDesc_.GetLengths()[0]), + C(inputDesc_.GetLengths()[1]), + D1(inputDesc_.GetLengths()[2]), + D2(inputDesc_.GetLengths()[3]) + { + } + + const TensorDescriptor& GetInputDesc() const { return inputDesc; } + const TensorDescriptor& GetTargetDesc() const { return targetDesc; } + const TensorDescriptor& GetWeightDesc() const { return weightDesc; } + const TensorDescriptor& GetOutputDesc() const { return outputDesc; } + int GetIgnoreIndex() const { return ignore_index; } + size_t GetNtotal() const { return N_total; } + size_t GetC() const { return C; } + size_t GetD1() const { return D1; } + size_t GetD2() const { return D2; } + + bool IsRightDim() const + { + if(outputDesc.GetLengths()[0] != N || outputDesc.GetLengths()[1] != D1 || + outputDesc.GetLengths()[2] != D2 || targetDesc.GetLengths()[0] != N || + targetDesc.GetLengths()[1] != D1 || targetDesc.GetLengths()[2] != D2 || + weightDesc.GetLengths()[0] != C) + { +#if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG + MIOPEN_THROW(miopenStatusBadParm, "NLLLoss: Tensors dimension do not match."); +#else + return false; +#endif + } + return true; + } + + bool IsSameType() const + { + if(inputDesc.GetType() != weightDesc.GetType()) + { +#if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG + MIOPEN_THROW(miopenStatusBadParm, + "NLLLoss: Tensor types of Input and Weight do not match."); +#else + return false; +#endif + } + return true; + } + + bool IsAllPacked() const + { + if(!(inputDesc.IsPacked() && targetDesc.IsPacked() && weightDesc.IsPacked() && + outputDesc.IsPacked())) + { +#if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG + MIOPEN_THROW(miopenStatusBadParm, "NLLLoss: Unpacked tensors not supported."); +#else + return false; +#endif + } + return true; + } + + NetworkConfig MakeNetworkConfig() const override; + +private: + TensorDescriptor inputDesc; + TensorDescriptor targetDesc; + TensorDescriptor weightDesc; + TensorDescriptor outputDesc; + + int ignore_index; + size_t N_total; + size_t N; + size_t C; + size_t D1; + size_t D2; + + NetworkConfig MakeForwardNetworkConfig() const; +}; + +} // namespace nllloss + +} // namespace miopen diff --git a/src/include/miopen/nllloss/solvers.hpp b/src/include/miopen/nllloss/solvers.hpp new file mode 100644 index 0000000000..579286c212 --- /dev/null +++ b/src/include/miopen/nllloss/solvers.hpp @@ -0,0 +1,60 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include "miopen/conv_solution.hpp" +#include "miopen/execution_context.hpp" +#include +#include + +#include + +namespace miopen { + +namespace solver { + +namespace nllloss { + +using NormalizationSolver = + NonTunableSolverBase; + +struct NLLLossForward final : NormalizationSolver +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::nllloss::ProblemDescription& problem) const override; + + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::nllloss::ProblemDescription& problem) const override; +}; + +} // namespace nllloss + +} // namespace solver + +} // namespace miopen diff --git a/src/include/miopen/solver_id.hpp b/src/include/miopen/solver_id.hpp index c52dc020ac..9c50c8b782 100644 --- a/src/include/miopen/solver_id.hpp +++ b/src/include/miopen/solver_id.hpp @@ -56,7 +56,8 @@ enum class Primitive Reduce, Cat, Mha, - Softmax + Softmax, + NLLLossForward }; struct MIOPEN_EXPORT Id diff --git a/src/kernels/MIOpenNLLLoss.cpp b/src/kernels/MIOpenNLLLoss.cpp new file mode 100644 index 0000000000..50c9476f1d --- /dev/null +++ b/src/kernels/MIOpenNLLLoss.cpp @@ -0,0 +1,83 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#include +#endif + +#include "float_types.h" + +template +__device__ void nlllossUnreducedForward4dContiguous(const TI* __restrict__ input, + const int32_t* __restrict__ target, + const TI* weight, + TO* __restrict__ output, + int32_t ignore_index, + size_t N_total, + size_t C, + size_t D1, + size_t D2) +{ + uint64_t gid = threadIdx.x + blockIdx.x * blockDim.x; + + if(gid >= N_total) + return; + + size_t NWH[3]; + NWH[2] = (gid) % D2; + size_t nc = (gid) / D2; + NWH[1] = nc % D1; + NWH[0] = nc / D1; + + int32_t t = target[gid]; + if(t < 0 || t == ignore_index || t >= C) + { + output[gid] = static_cast(0); + return; + } + + FLOAT_ACCUM w = weight != nullptr ? CVT_FLOAT2ACCUM(weight[t]) : CVT_FP32_2ACCUM(1.0f); + + uint32_t input_offset = (NWH[0] * C + t) * D1 * D2 + NWH[1] * D2 + NWH[2]; + FLOAT_ACCUM input_value = CVT_FLOAT2ACCUM(input[input_offset]); + + FLOAT_ACCUM val = CVT_FP32_2ACCUM(-1.0f) * w * input_value; + output[gid] = CVT_ACCUM2FLOAT(val); +} + +extern "C" __global__ void NLLLossUnreducedForward4dContiguous(const INPUT_TYPE* __restrict__ input, + const int32_t* __restrict__ target, + const INPUT_TYPE* weight, + OUTPUT_TYPE* __restrict__ output, + uint64_t ignore_index, + size_t N_total, + size_t C, + size_t D1, + size_t D2) +{ + nlllossUnreducedForward4dContiguous( + input, target, weight, output, ignore_index, N_total, C, D1, D2); +} diff --git a/src/nllloss.cpp b/src/nllloss.cpp new file mode 100644 index 0000000000..b7888cb358 --- /dev/null +++ b/src/nllloss.cpp @@ -0,0 +1,71 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include +#include +#include +#include +#include + +namespace miopen { + +miopenStatus_t NLLLossForward(Handle& handle, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& targetDesc, + ConstData_t target, + const TensorDescriptor& weightDesc, + ConstData_t weight, + const TensorDescriptor& outputDesc, + Data_t output, + int32_t ignore_index) +{ + const auto problem = + nllloss::ProblemDescription{inputDesc, targetDesc, weightDesc, outputDesc, ignore_index}; + + const auto invoke_params = [&]() { + auto tmp = nllloss::InvokeParams{}; + tmp.inputDesc = &inputDesc; + tmp.outputDesc = &outputDesc; + + tmp.input = input; + tmp.target = target; + tmp.weight = weight; + tmp.output = output; + tmp.ignore_index = ignore_index; + return tmp; + }(); + + const auto algo = AlgorithmName{"NLLLossForward"}; + const auto solvers = solver::SolverContainer{}; + + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +} // namespace miopen diff --git a/src/nllloss/problem_description.cpp b/src/nllloss/problem_description.cpp new file mode 100644 index 0000000000..d152a3924c --- /dev/null +++ b/src/nllloss/problem_description.cpp @@ -0,0 +1,61 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include + +#include + +namespace miopen { + +namespace nllloss { + +NetworkConfig ProblemDescription::MakeNetworkConfig() const +{ + auto dims = inputDesc.GetLengths(); + size_t numel = outputDesc.GetElementSize(); + size_t num_batches = dims[0]; + size_t num_classes = dims[1]; + + auto input_dtype = inputDesc.GetType(); + auto output_dtype = outputDesc.GetType(); + + std::ostringstream ss; + + ss << "input_dtype" << input_dtype; + ss << "output_dtype" << output_dtype; + ss << "numel" << numel; + ss << "num_batches" << num_batches; + ss << "num_classes" << num_classes; + ss << "D1" << dims[2]; + ss << "D2" << dims[3]; + + return NetworkConfig{ss.str()}; +} + +} // namespace nllloss + +} // namespace miopen diff --git a/src/nllloss_api.cpp b/src/nllloss_api.cpp new file mode 100644 index 0000000000..f10a98d5ec --- /dev/null +++ b/src/nllloss_api.cpp @@ -0,0 +1,100 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include + +static void LogCmdNLLLoss(const miopenTensorDescriptor_t xDesc, bool is_fwd) +{ + if(miopen::IsLoggingCmd()) + { + std::stringstream ss; + auto dtype = miopen::deref(xDesc).GetType(); + if(dtype == miopenHalf) + { + ss << "nlllossfp16"; + } + else if(dtype == miopenFloat) + { + ss << "nllloss"; + } + else if(dtype == miopenBFloat16) + { + ss << "nlllossbfp16"; + } + + int32_t size = {0}; + miopenGetTensorDescriptorSize(xDesc, &size); + ss << " -N " << miopen::deref(xDesc).GetLengths()[0]; + ss << " -C " << miopen::deref(xDesc).GetLengths()[1] << " -d " + << miopen::deref(xDesc).GetLengths()[2] << " -D " + << miopen::deref(xDesc).GetLengths()[3]; + + ss << " -F " << ((is_fwd) ? "1" : "2"); + + MIOPEN_LOG_DRIVER_CMD(ss.str()); + } +} + +extern "C" miopenStatus_t miopenNLLLossForward(miopenHandle_t handle, + const miopenTensorDescriptor_t inputDesc, + const void* input, + const miopenTensorDescriptor_t targetDesc, + const void* target, + const miopenTensorDescriptor_t weightDesc, + const void* weight, + const miopenTensorDescriptor_t outputDesc, + void* output, + int ignore_index) +{ + MIOPEN_LOG_FUNCTION(handle, + inputDesc, + input, + targetDesc, + target, + weightDesc, + weight, + outputDesc, + output, + ignore_index); + + LogCmdNLLLoss(inputDesc, true); + return miopen::try_([&] { + miopen::NLLLossForward(miopen::deref(handle), + miopen::deref(inputDesc), + DataCast(input), + miopen::deref(targetDesc), + DataCast(target), + miopen::deref(weightDesc), + DataCast(weight), + miopen::deref(outputDesc), + DataCast(output), + ignore_index); + }); +} diff --git a/src/solver/nllloss/forward_nllloss.cpp b/src/solver/nllloss/forward_nllloss.cpp new file mode 100644 index 0000000000..cf137c17aa --- /dev/null +++ b/src/solver/nllloss/forward_nllloss.cpp @@ -0,0 +1,139 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "miopen/conv_solution.hpp" +#include "miopen/execution_context.hpp" +#include "miopen/invoke_params.hpp" +#include "miopen/kernel_info.hpp" +#include +#include + +#include +#include +#include +#include +#include + +#define LOCAL_SIZE 1024 + +namespace miopen { + +namespace solver { + +namespace nllloss { + +bool NLLLossForward::IsApplicable(const ExecutionContext&, + const miopen::nllloss::ProblemDescription& problem) const +{ + if(!problem.IsSameType()) + return false; + if(!problem.IsAllPacked()) + return false; + if(!problem.IsRightDim()) + return false; + return true; +} + +ConvSolution NLLLossForward::GetSolution(const ExecutionContext& context, + const miopen::nllloss::ProblemDescription& problem) const +{ + std::ignore = context; + + auto result = ConvSolution{miopenStatusSuccess}; + auto input_dtype = miopen::GetDataType(problem.GetInputDesc().GetType()); + auto output_dtype = miopen::GetDataType(problem.GetOutputDesc().GetType()); + + { + auto dtype = problem.GetInputDesc().GetType(); + size_t N_total = problem.GetNtotal(); + + size_t xlocalsize = LOCAL_SIZE; + // size_t xgridsize = (N_total + LOCAL_SIZE - 1) / LOCAL_SIZE * LOCAL_SIZE; + size_t xgridsize = AlignUp(N_total, xlocalsize); + + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + + kernel.kernel_file = "MIOpenNLLLoss.cpp"; + kernel.kernel_name = "NLLLossUnreducedForward4dContiguous"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"INPUT_TYPE", input_dtype == "bfloat16" ? "ushort" : input_dtype}, + {"OUTPUT_TYPE", output_dtype == "bfloat16" ? "ushort" : output_dtype}, + }; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + } + + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + + size_t N_total = params.outputDesc->GetElementSize(); + auto dims = params.inputDesc->GetLengths(); + size_t C = dims[1]; + size_t D1 = dims[2]; + size_t D2 = dims[3]; + + kernel(params.input, + params.target, + params.weight, + params.output, + params.ignore_index, + N_total, + C, + D1, + D2); + }; + }; + + return result; +} + +} // namespace nllloss + +} // namespace solver + +} // namespace miopen diff --git a/test/cpu_nllloss.hpp b/test/cpu_nllloss.hpp new file mode 100644 index 0000000000..e4bae54b90 --- /dev/null +++ b/test/cpu_nllloss.hpp @@ -0,0 +1,69 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_CPU_NLLLOSS_HPP +#define GUARD_CPU_NLLLOSS_HPP + +#include "tensor_holder.hpp" + +template +void cpu_nllloss_forward_4d(tensor input, + tensor target, + tensor weight, + tensor& output, + int32_t ignore_index) +{ + auto dims = input.desc.GetLengths(); + size_t N = dims[0]; + size_t C = dims[1]; + size_t D1 = dims[2]; + size_t D2 = dims[3]; + + for(size_t n = 0; n < N; n++) + { + for(size_t d1 = 0; d1 < D1; d1++) + { + for(size_t d2 = 0; d2 < D2; d2++) + { + size_t target_index = n * D1 * D2 + d1 * D2 + d2; + int32_t t = target[target_index]; + size_t input_index = (n * C + t) * D1 * D2 + d1 * D2 + d2; + size_t weight_index = t; + size_t output_index = target_index; + + if(t < 0 || t == ignore_index || t >= C) + { + output[output_index] = static_cast(0); + } + else + { + output[output_index] = + static_cast(-1.0f) * weight[weight_index] * input[input_index]; + } + } + } + } +} +#endif diff --git a/test/gtest/nllloss.cpp b/test/gtest/nllloss.cpp new file mode 100644 index 0000000000..ee5d11f666 --- /dev/null +++ b/test/gtest/nllloss.cpp @@ -0,0 +1,103 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include "nllloss.hpp" + +MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_TEST_FLOAT_ARG) +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_ALL) + +namespace nllloss { + +std::string GetFloatArg() +{ + const auto& tmp = miopen::GetStringEnv(ENV(MIOPEN_TEST_FLOAT_ARG)); + if(tmp.empty()) + { + return ""; + } + return tmp; +} + +struct NLLLossTestFloat : NLLLossTest +{ +}; + +struct NLLLossTestHalf : NLLLossTest +{ +}; + +struct NLLLossTestBFloat16 : NLLLossTest +{ +}; + +} // namespace nllloss +using namespace nllloss; + +TEST_P(NLLLossTestFloat, NLLLossTestFw) +{ + if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--float")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(NLLLossTestHalf, NLLLossTestFw) +{ + if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--half")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(NLLLossTestBFloat16, NLLLossTestFw) +{ + if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--bfloat16")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +INSTANTIATE_TEST_SUITE_P(NLLLossTestSet, NLLLossTestFloat, testing::ValuesIn(NLLLossTestConfigs())); +INSTANTIATE_TEST_SUITE_P(NLLLossTestSet, NLLLossTestHalf, testing::ValuesIn(NLLLossTestConfigs())); +INSTANTIATE_TEST_SUITE_P(NLLLossTestSet, + NLLLossTestBFloat16, + testing::ValuesIn(NLLLossTestConfigs())); diff --git a/test/gtest/nllloss.hpp b/test/gtest/nllloss.hpp new file mode 100644 index 0000000000..fec6892297 --- /dev/null +++ b/test/gtest/nllloss.hpp @@ -0,0 +1,163 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "../driver/tensor_driver.hpp" +#include "cpu_nllloss.hpp" +#include "get_handle.hpp" +#include "random.hpp" +#include "tensor_holder.hpp" +#include "verify.hpp" +#include +#include +#include + +struct NLLLossTestCase +{ + size_t N = 0; + size_t C = 0; + size_t D1 = 0; + size_t D2 = 0; + bool weight_mode = false; + int32_t ignore_index = -1; + + std::vector input = {N, C, D1, D2}; + friend std::ostream& operator<<(std::ostream& os, const NLLLossTestCase& tc) + { + return os << " N:" << tc.N << " C:" << tc.C << " D1:" << tc.D1 << " D2:" << tc.D2 + << " weight_mode:" << tc.weight_mode << " ignore_index:" << tc.ignore_index; + } + + std::vector GetInput() const { return input; } +}; + +inline std::vector NLLLossTestConfigs() +{ // dim, dims + // clang-format off + return {{1, 2, 2, 2, false, -100}, + {2,10,128,128, false, 255}, + {5,13,17,11,true, 5}, + {8, 12, 256, 256, true, -1}, + {8, 16, 512, 512, true, 10}, + {16, 21,512,512,false, 255}}; + // clang-format on +} + +template +struct NLLLossTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + nllloss_config = GetParam(); + + ignore_index = nllloss_config.ignore_index; + weight_mode = nllloss_config.weight_mode; + + auto in_dim = nllloss_config.GetInput(); + std::vector target_dim = {in_dim[0], in_dim[2], in_dim[3]}; + std::vector weight_dim = {in_dim[1]}; + std::vector out_dim = {in_dim[0], in_dim[2], in_dim[3]}; + + auto gen_input_value = [](auto...) { + return prng::gen_A_to_B(static_cast(-100.0f), static_cast(-1e-2)); + }; + size_t numclass_C = in_dim[1]; + auto gen_target_value = [numclass_C](auto...) { + return prng::gen_A_to_B(0, numclass_C - 1); + }; + auto gen_weight_value = [](auto...) { + return prng::gen_A_to_B(static_cast(-10), static_cast(10)); + }; + auto gen_weight_one = [](auto...) { return static_cast(1); }; + + input = tensor{in_dim}.generate(gen_input_value); + + target = tensor{target_dim}.generate(gen_target_value); + + if(!weight_mode) + weight = tensor{weight_dim}.generate(gen_weight_one); + else + weight = tensor{weight_dim}.generate(gen_weight_value); + + output = tensor{out_dim}; + std::fill(output.begin(), output.end(), std::numeric_limits::quiet_NaN()); + + ref_output = tensor{out_dim}; + std::fill(ref_output.begin(), ref_output.end(), std::numeric_limits::quiet_NaN()); + + input_dev = handle.Write(input.data); + target_dev = handle.Write(target.data); + weight_dev = handle.Write(weight.data); + output_dev = handle.Write(output.data); + } + + void RunTest() + { + auto&& handle = get_handle(); + cpu_nllloss_forward_4d(input, target, weight, ref_output, ignore_index); + + miopenStatus_t status = miopen::NLLLossForward(handle, + input.desc, + input_dev.get(), + target.desc, + target_dev.get(), + weight.desc, + weight_dev.get(), + output.desc, + output_dev.get(), + ignore_index); + fflush(stdout); + + EXPECT_EQ(status, miopenStatusSuccess); + + output.data = handle.Read(output_dev, output.data.size()); + } + + void Verify() + { + double threshold = std::numeric_limits::epsilon(); + auto error = miopen::rms_range(ref_output, output); + + EXPECT_TRUE(miopen::range_distance(ref_output) == miopen::range_distance(output)); + EXPECT_TRUE(error < threshold * 10) << "Error output beyond tolerance Error:" << error + << ", Thresholdx10: " << threshold * 10; + } + NLLLossTestCase nllloss_config; + + tensor input; + tensor target; + tensor weight; + tensor output; + tensor ref_output; + + bool weight_mode; + int32_t ignore_index; + + miopen::Allocator::ManageDataPtr input_dev; + miopen::Allocator::ManageDataPtr target_dev; + miopen::Allocator::ManageDataPtr weight_dev; + miopen::Allocator::ManageDataPtr output_dev; +};