From 742a67d77d5f9e997e1063a09bf36f67fbec9e36 Mon Sep 17 00:00:00 2001 From: hieule88 Date: Thu, 11 Apr 2024 02:40:24 +0000 Subject: [PATCH 01/23] add nllloss_api with gtest --- .gitignore | 9 + include/miopen/miopen.h | 38 ++++ src/CMakeLists.txt | 5 + src/include/miopen/nllloss.hpp | 49 +++++ src/include/miopen/nllloss/invoke_params.hpp | 56 ++++++ .../miopen/nllloss/problem_description.hpp | 91 ++++++++++ src/include/miopen/nllloss/solvers.hpp | 59 ++++++ src/include/miopen/solver_id.hpp | 3 +- src/kernels/MIOpenNLLLoss.cpp | 78 ++++++++ src/nllloss.cpp | 71 ++++++++ src/nllloss/problem_description.cpp | 57 ++++++ src/nllloss_api.cpp | 70 +++++++ src/solver/nllloss/forward_nllloss.cpp | 124 +++++++++++++ test/cpu_nllloss.hpp | 68 +++++++ test/gtest/nllloss.cpp | 65 +++++++ test/gtest/nllloss.hpp | 171 ++++++++++++++++++ 16 files changed, 1013 insertions(+), 1 deletion(-) create mode 100644 .gitignore create mode 100644 src/include/miopen/nllloss.hpp create mode 100644 src/include/miopen/nllloss/invoke_params.hpp create mode 100644 src/include/miopen/nllloss/problem_description.hpp create mode 100644 src/include/miopen/nllloss/solvers.hpp create mode 100644 src/kernels/MIOpenNLLLoss.cpp create mode 100644 src/nllloss.cpp create mode 100644 src/nllloss/problem_description.cpp create mode 100644 src/nllloss_api.cpp create mode 100644 src/solver/nllloss/forward_nllloss.cpp create mode 100644 test/cpu_nllloss.hpp create mode 100644 test/gtest/nllloss.cpp create mode 100644 test/gtest/nllloss.hpp diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..6bee76431a --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ + +.cache/ +.devcontainer/ +.vscode/ + +install_dir/ +build/ + +.clangd \ No newline at end of file diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index e768c7b349..9e9917a17d 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -6582,6 +6582,44 @@ MIOPEN_EXPORT miopenStatus_t miopenBackendInitialize(miopenBackendDescriptor_t d // CLOSEOUT BackendAPI DOXYGEN GROUP #endif // MIOPEN_BETA_API +#ifdef MIOPEN_BETA_API +//-------------------------------------------------------------------------------------------------// +// NLLLoss APIs +/** @addtogroup nllloss + * + * @{ + */ +/*! @brief Execute a nllloss forward layer + * + * @param handle MIOpen handle (input) + * @param inputDesc Tensor descriptor for data input tensor x (input) + * @param input Data tensor x (input) + * @param targetDesc Tensor descriptor for data input tensor target (input) + * @param target Data tensor target (input) + * @param weightDesc Tensor descriptor for data input tensor weight (input) + * @param weight Data tensor weight (input) + * @param outputDesc Tensor descriptor for output data tensor y (input) + * @param output Data tensor y (output) + * @param ignore_index Index to ignore (input) + * @param N Number of elements in the output + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenNLLLossForward(miopenHandle_t handle, + const miopenTensorDescriptor_t inputDesc, + const void* input, + const miopenTensorDescriptor_t targetDesc, + const void* target, + const miopenTensorDescriptor_t weightDesc, + const void* weight, + const miopenTensorDescriptor_t outputDesc, + void* output, + long ignore_index); +/** @} */ +// CLOSEOUT nllloss DOXYGEN GROUP +#endif // MIOPEN_BETA_API + + + #ifdef __cplusplus } #endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1c7f8f7a8e..423deb2580 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -128,6 +128,8 @@ set( MIOpen_Source graphapi/tensor.cpp groupnorm_api.cpp groupnorm/problem_description.cpp + nllloss_api.cpp + nllloss/problem_description.cpp handle_api.cpp invoker_cache.cpp kernel_build_params.cpp @@ -258,6 +260,7 @@ set( MIOpen_Source solver/gemm_bwd.cpp solver/gemm_wrw.cpp solver/groupnorm/forward_groupnorm.cpp + solver/nllloss/forward_nllloss.cpp solver/layernorm/forward_layernorm.cpp solver/layernorm/forward_layernorm2d_ck.cpp solver/layernorm/forward_layernorm4d_ck.cpp @@ -453,6 +456,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/MIOpenConvDirBatchNormActiv.cl kernels/MIOpenConvDirGenFwd.cl kernels/MIOpenGroupNorm.cpp + kernels/MIOpenNLLLoss.cpp kernels/MIOpenLayerNorm.cpp kernels/MIOpenLRNBwd.cl kernels/MIOpenLRNFwd.cl @@ -577,6 +581,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN argmax.cpp cat.cpp groupnorm.cpp + nllloss.cpp kernel_cache.cpp layer_norm.cpp lrn.cpp diff --git a/src/include/miopen/nllloss.hpp b/src/include/miopen/nllloss.hpp new file mode 100644 index 0000000000..74a0d92e2d --- /dev/null +++ b/src/include/miopen/nllloss.hpp @@ -0,0 +1,49 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#ifndef MIOPEN_NLLLOSS_HPP_ +#define MIOPEN_NLLLOSS_HPP_ + +#include + +namespace miopen { + +struct Handle; +struct TensorDescriptor; + +miopenStatus_t NLLLossForward(Handle& handle, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& targetDesc, + ConstData_t target, + const TensorDescriptor& weightDesc, + ConstData_t weight, + const TensorDescriptor& outputDesc, + Data_t output, + long ignore_index); + +} // namespace miopen +#endif // _MIOPEN_NLLLOSS_HPP_ diff --git a/src/include/miopen/nllloss/invoke_params.hpp b/src/include/miopen/nllloss/invoke_params.hpp new file mode 100644 index 0000000000..b3084685b7 --- /dev/null +++ b/src/include/miopen/nllloss/invoke_params.hpp @@ -0,0 +1,56 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include "miopen/common.hpp" +#include +#include + +namespace miopen { +namespace nllloss { + +struct InvokeParams : public miopen::InvokeParams +{ + + InvokeParams() = default; + + const TensorDescriptor* inputDesc = nullptr; + const TensorDescriptor* outputDesc = nullptr; + + ConstData_t input = nullptr; + ConstData_t target = nullptr; + ConstData_t weight = nullptr; + Data_t output = nullptr; + + long ignore_index = -1; + + std::size_t GetWorkspaceSize() const { return 0; } + Data_t GetWorkspace() const { return nullptr; } +}; + +} // namespace nllloss +} // namespace miopen diff --git a/src/include/miopen/nllloss/problem_description.hpp b/src/include/miopen/nllloss/problem_description.hpp new file mode 100644 index 0000000000..9b4fb67d74 --- /dev/null +++ b/src/include/miopen/nllloss/problem_description.hpp @@ -0,0 +1,91 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace miopen { + +struct NetworkConfig; + +namespace nllloss { + +struct ProblemDescription : ProblemDescriptionBase +{ + ProblemDescription(const TensorDescriptor& inputDesc_, + const TensorDescriptor& targetDesc_, + const TensorDescriptor& weightDesc_, + const TensorDescriptor& outputDesc_, + long ignore_index_) + : inputDesc(inputDesc_), + targetDesc(targetDesc_), + weightDesc(weightDesc_), + outputDesc(outputDesc_), + ignore_index(ignore_index_), + N_total(outputDesc_.GetElementSize()), + C(inputDesc_.GetLengths()[1]), + D1(inputDesc_.GetLengths()[2]), + D2(inputDesc_.GetLengths()[3]) + { + + } + + const TensorDescriptor& GetInputDesc() const { return inputDesc; } + const TensorDescriptor& GetTargetDesc() const { return targetDesc; } + const TensorDescriptor& GetWeightDesc() const { return weightDesc; } + const TensorDescriptor& GetOutputDesc() const { return outputDesc; } + long GetIgnoreIndex() const { return ignore_index; } + long GetNtotal() const { return N_total; } + long GetC() const { return C; } + long GetD1() const { return D1; } + long GetD2() const { return D2; } + + NetworkConfig MakeNetworkConfig() const override; + +private: + TensorDescriptor inputDesc; + TensorDescriptor targetDesc; + TensorDescriptor weightDesc; + TensorDescriptor outputDesc; + + long ignore_index; + long N_total; + long C; + long D1; + long D2; + + NetworkConfig MakeForwardNetworkConfig() const; +}; + +} // namespace nllloss + +} // namespace miopen diff --git a/src/include/miopen/nllloss/solvers.hpp b/src/include/miopen/nllloss/solvers.hpp new file mode 100644 index 0000000000..18c166a9da --- /dev/null +++ b/src/include/miopen/nllloss/solvers.hpp @@ -0,0 +1,59 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include "miopen/conv_solution.hpp" +#include "miopen/execution_context.hpp" +#include +#include + +#include + +namespace miopen { + +namespace solver { + +namespace nllloss { + +using NormalizationSolver = + NonTunableSolverBase; + +struct NLLLossForward final : NormalizationSolver { + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::nllloss::ProblemDescription& problem) const override; + + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::nllloss::ProblemDescription& problem) const override; +}; + +} // namespace nllloss + +} // namespace solver + +} // namespace miopen diff --git a/src/include/miopen/solver_id.hpp b/src/include/miopen/solver_id.hpp index c52dc020ac..9c50c8b782 100644 --- a/src/include/miopen/solver_id.hpp +++ b/src/include/miopen/solver_id.hpp @@ -56,7 +56,8 @@ enum class Primitive Reduce, Cat, Mha, - Softmax + Softmax, + NLLLossForward }; struct MIOPEN_EXPORT Id diff --git a/src/kernels/MIOpenNLLLoss.cpp b/src/kernels/MIOpenNLLLoss.cpp new file mode 100644 index 0000000000..bb83a4699a --- /dev/null +++ b/src/kernels/MIOpenNLLLoss.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#endif + +#include "float_types.h" + +#if MIOPEN_USE_BFP16 == 1 +#define CVT_FLOAT2ACCUM(x) (bfloat16_to_float(x)) +#define CVT_ACCUM2FLOAT(x) (float_to_bfloat16(x)) +#define CVT_INTEGRAL2ACCUM(x) ((_FLOAT_ACCUM)(x)) +#define CVT_FP32_2FLOAT(x) (CVT_ACCUM2FLOAT(x)) +#define CVT_FP32_2ACCUM(x) (x) +#endif + +/* input(input): [N, C, D1, D2], target(target): [N, D1, D2], + * weight(weight): [C], output(output): [N, D1, D2] */ +/* Each thread computes one output: output[n0][n1][n2] */ +extern "C" __global__ void NLLLossUnreducedForward4dContiguous(const FLOAT_ACCUM* __restrict__ input, + const FLOAT_ACCUM* __restrict__ target, + const FLOAT_ACCUM* weight, + FLOAT_ACCUM* __restrict__ output, + long ignore_index, + size_t N_total, + size_t C, + size_t D1, + size_t D2) +{ + size_t gid = threadIdx.x + blockIdx.x * blockDim.x; + + if (gid >= N_total) return; + + size_t NWH[3]; + NWH[2] = (gid) % D2; + size_t nc = (gid) / D2; + NWH[1] = nc % D1; + NWH[0] = nc / D1; + + long t = static_cast(target[gid]); + // t: Class index + if (t < 0 || t == ignore_index || t >= C) { + output[gid] = 0; + return; + } + + FLOAT_ACCUM w = weight != nullptr ? weight[t] : 1.0f; + + // fix this + // FLOAT_ACCUM input_value = input[NWH[0]][t][NWH[1]][NWH[2]]; + FLOAT_ACCUM input_value = input[(NWH[0] * C + t) * D1 * D2 + NWH[1] * D2 + NWH[2]]; + + output[gid] = -1.0f * w * input_value; +} diff --git a/src/nllloss.cpp b/src/nllloss.cpp new file mode 100644 index 0000000000..c847f64017 --- /dev/null +++ b/src/nllloss.cpp @@ -0,0 +1,71 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include +#include +#include +#include +#include + +namespace miopen { + +miopenStatus_t NLLLossForward(Handle& handle, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& targetDesc, + ConstData_t target, + const TensorDescriptor& weightDesc, + ConstData_t weight, + const TensorDescriptor& outputDesc, + Data_t output, + long ignore_index) +{ + const auto problem = nllloss::ProblemDescription{ + inputDesc, targetDesc, weightDesc, outputDesc, ignore_index}; + + const auto invoke_params = [&]() { + auto tmp = nllloss::InvokeParams{}; + tmp.inputDesc = &inputDesc; + tmp.outputDesc = &outputDesc; + + tmp.input = input; + tmp.target = target; + tmp.weight = weight; + tmp.output = output; + tmp.ignore_index = ignore_index; + return tmp; + }(); + + const auto algo = AlgorithmName{"NLLLossForward"}; + const auto solvers = solver::SolverContainer{}; + + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +} // namespace miopen diff --git a/src/nllloss/problem_description.cpp b/src/nllloss/problem_description.cpp new file mode 100644 index 0000000000..8190345e8d --- /dev/null +++ b/src/nllloss/problem_description.cpp @@ -0,0 +1,57 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include + +#include + +namespace miopen { + +namespace nllloss { + +NetworkConfig ProblemDescription::MakeNetworkConfig() const +{ + auto dims = inputDesc.GetLengths(); + size_t numel = outputDesc.GetElementSize(); + size_t num_batches = dims[0]; + size_t num_classes = dims[1]; + + auto dtype = inputDesc.GetType(); + + std::ostringstream ss; + + ss << "dtype" << dtype; + ss << "numel" << numel; + ss << "num_batches" << num_batches; + ss << "num_classes" << num_classes; + + return NetworkConfig{ss.str()}; +} + +} // namespace nllloss + +} // namespace miopen diff --git a/src/nllloss_api.cpp b/src/nllloss_api.cpp new file mode 100644 index 0000000000..56f1549e42 --- /dev/null +++ b/src/nllloss_api.cpp @@ -0,0 +1,70 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include + +extern "C" miopenStatus_t miopenNLLLossForward(miopenHandle_t handle, + const miopenTensorDescriptor_t inputDesc, + const void* input, + const miopenTensorDescriptor_t targetDesc, + const void* target, + const miopenTensorDescriptor_t weightDesc, + const void* weight, + const miopenTensorDescriptor_t outputDesc, + void* output, + long ignore_index) +{ + MIOPEN_LOG_FUNCTION( + handle, + inputDesc, + input, + targetDesc, + target, + weightDesc, + weight, + outputDesc, + output, + ignore_index); + + // LogCmdNLLLossForward(inputDesc, targetDesc, weightDesc, outputDesc, ignore_index, N, C); + return miopen::try_([&] { + miopen::NLLLossForward( + miopen::deref(handle), + miopen::deref(inputDesc), + DataCast(input), + miopen::deref(targetDesc), + DataCast(target), + miopen::deref(weightDesc), + DataCast(weight), + miopen::deref(outputDesc), + DataCast(output), + ignore_index); + }); +} diff --git a/src/solver/nllloss/forward_nllloss.cpp b/src/solver/nllloss/forward_nllloss.cpp new file mode 100644 index 0000000000..27916be83d --- /dev/null +++ b/src/solver/nllloss/forward_nllloss.cpp @@ -0,0 +1,124 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "miopen/conv_solution.hpp" +#include "miopen/execution_context.hpp" +#include "miopen/invoke_params.hpp" +#include "miopen/kernel_info.hpp" +#include +#include + +#include +#include +#include +#include +#include + +#define LOCAL_SIZE 1024 + +namespace miopen { + +namespace solver { + +namespace nllloss { + +bool NLLLossForward::IsApplicable(const ExecutionContext&, + const miopen::nllloss::ProblemDescription& problem) const +{ + return true; +} + +ConvSolution +NLLLossForward::GetSolution(const ExecutionContext& context, + const miopen::nllloss::ProblemDescription& problem) const +{ + std::ignore = context; + + auto result = ConvSolution{miopenStatusSuccess}; + + { + auto dtype = problem.GetInputDesc().GetType(); + long N_total = problem.GetNtotal(); + + size_t xlocalsize = LOCAL_SIZE; + size_t xgridsize = N_total; + + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + + kernel.kernel_file = "MIOpenNLLLoss.cpp"; + kernel.kernel_name = "NLLLossUnreducedForward4dContiguous"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"LOCAL_SIZE", LOCAL_SIZE}, + }; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + } + + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + + size_t N_total = params.outputDesc->GetElementSize(); + auto dims = params.inputDesc->GetLengths(); + size_t C = dims[1]; + size_t D1 = dims[2]; + size_t D2 = dims[3]; + + kernel(params.input, params.target, params.weight, params.output, + params.ignore_index, + N_total, C, D1, D2); + }; + }; + + return result; +} + +} // namespace nllloss + +} // namespace solver + +} // namespace miopen diff --git a/test/cpu_nllloss.hpp b/test/cpu_nllloss.hpp new file mode 100644 index 0000000000..a8ae65c0d3 --- /dev/null +++ b/test/cpu_nllloss.hpp @@ -0,0 +1,68 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_CPU_NLLLOSS_HPP +#define GUARD_CPU_NLLLOSS_HPP + +#include "tensor_holder.hpp" + +template +void cpu_nllloss_forward_4d(tensor input, + tensor target, + tensor weight, + tensor& output, + long ignore_index) +{ + auto dims = input.desc.GetLengths(); + size_t N = dims[0]; + size_t C = dims[1]; + size_t D1 = dims[2]; + size_t D2 = dims[3]; + + for (size_t n = 0; n < N; n++) + { + for (size_t d1 = 0; d1 < D1; d1++) + { + for (size_t d2 = 0; d2 < D2; d2++) + { + size_t target_index = n * D1 * D2 + d1 * D2 + d2; + int t = target[target_index]; + size_t input_index = (n * C + t) * D1 * D2 + d1 * D2 + d2; + size_t weight_index = t; + size_t output_index = target_index; + + if (t < 0 || t == ignore_index || t >= C) + { + output[output_index] = 0; + } + else + { + output[output_index] = -1.0f * weight[weight_index] * input[input_index]; + } + } + } + } +} +#endif diff --git a/test/gtest/nllloss.cpp b/test/gtest/nllloss.cpp new file mode 100644 index 0000000000..89f6be2815 --- /dev/null +++ b/test/gtest/nllloss.cpp @@ -0,0 +1,65 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include "nllloss.hpp" + +MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_TEST_FLOAT_ARG) +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_ALL) + +namespace nllloss { + +std::string GetFloatArg() +{ + const auto& tmp = miopen::GetStringEnv(ENV(MIOPEN_TEST_FLOAT_ARG)); + if(tmp.empty()) + { + return ""; + } + return tmp; +} + +struct NLLLossTestFloat : NLLLossTest +{ +}; + +} // namespace nllloss +using namespace nllloss; + +TEST_P(NLLLossTestFloat, NLLLossTestFw) +{ + // if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--float")) + // { + RunTest(); + Verify(); + // } + // else + // { + // GTEST_SKIP(); + // } +}; + +INSTANTIATE_TEST_SUITE_P(NLLLossTestSet, NLLLossTestFloat, testing::ValuesIn(NLLLossTestConfigs())); diff --git a/test/gtest/nllloss.hpp b/test/gtest/nllloss.hpp new file mode 100644 index 0000000000..027f5eb49a --- /dev/null +++ b/test/gtest/nllloss.hpp @@ -0,0 +1,171 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#define MIOPEN_BETA_API 1 +#include "../driver/tensor_driver.hpp" +#include "cpu_nllloss.hpp" +#include "get_handle.hpp" +#include "random.hpp" +#include "tensor_holder.hpp" +#include "verify.hpp" +#include +#include +#include + +struct NLLLossTestCase +{ + size_t N=0; + size_t C=0; + size_t D1=0; + size_t D2=0; + bool weight_mode=false; + long ignore_index=-1; + + std::vector input = {N, C, D1, D2}; + friend std::ostream& operator<<(std::ostream& os, const NLLLossTestCase& tc) + { + return os << " N:" << tc.N << " C:" << tc.C << " D1:" << tc.D1 << " D2:" << tc.D2 + << " weight_mode:" << tc.weight_mode << " ignore_index:" << tc.ignore_index; + } + + std::vector GetInput() const { return input; } +}; + +inline std::vector NLLLossTestConfigs() +{ // dim, dims + // clang-format off + return {{8, 1 ,6,5,false,-1}, + {1, 2, 2, 2, false, -1}, + {2,10,128,128, false, -1}, + {8, 12, 256, 256, true, -1}, + {8, 16, 512, 512, true, 10}}; + // clang-format on +} + +template +struct NLLLossTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + nllloss_config = GetParam(); + + // input < 0 + // 0 <= target < C + // weight = 1 + + /* input(input) : [N, C, D1, D2], + * target(target): [N, D1, D2], + * weight(weight): [C], + * output(output): [N, D1, D2] */ + + ignore_index = nllloss_config.ignore_index; + weight_mode = nllloss_config.weight_mode; + + auto in_dim = nllloss_config.GetInput(); + std::vector target_dim = {in_dim[0], in_dim[2], in_dim[3]}; + std::vector weight_dim = {in_dim[1]}; + std::vector out_dim = {in_dim[0], in_dim[2], in_dim[3]}; + + auto gen_input_value = [](auto...) -> T { static std::random_device rd; + static std::mt19937 gen(rd()); + std::uniform_real_distribution dis(-100, -1e-2); + return dis(gen); + }; + size_t numclass_C = in_dim[1]; + auto gen_target_value = [numclass_C](auto...) -> int { static std::random_device rd; + static std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, numclass_C-1); + return dis(gen); + }; + auto gen_weight_value = [](auto...) { return prng::gen_descreet_uniform_sign(1e-2, 10); }; + auto gen_weight_one = [](auto...) -> T { return 1; }; + + input = tensor{in_dim}.generate(gen_input_value); + + target = tensor{target_dim}.generate(gen_target_value); + + if (!weight_mode) + weight = tensor{weight_dim}.generate(gen_weight_one); + else + weight = tensor{weight_dim}.generate(gen_weight_value); + + output = tensor{out_dim}; + std::fill(output.begin(), output.end(), std::numeric_limits::quiet_NaN()); + + ref_output = tensor{out_dim}; + std::fill(ref_output.begin(), ref_output.end(), std::numeric_limits::quiet_NaN()); + + input_dev = handle.Write(input.data); + target_dev = handle.Write(target.data); + weight_dev = handle.Write(weight.data); + output_dev = handle.Write(output.data); + } + + void RunTest() + { + auto&& handle = get_handle(); + cpu_nllloss_forward_4d(input, target, weight, ref_output, ignore_index); + + miopenStatus_t status = miopen::NLLLossForward(handle, + input.desc, + input_dev.get(), + target.desc, + target_dev.get(), + weight.desc, + weight_dev.get(), + output.desc, + output_dev.get(), + ignore_index); + fflush(stdout); + + EXPECT_EQ(status, miopenStatusSuccess); + + output.data = handle.Read(output_dev, output.data.size()); + } + + void Verify() + { + auto error = miopen::rms_range(ref_output, output); + EXPECT_TRUE(miopen::range_distance(ref_output) == miopen::range_distance(output)); + EXPECT_TRUE(error == 0) << "Outputs do not match each other. Error:" << error; + } + NLLLossTestCase nllloss_config; + + tensor input; + tensor target; + tensor weight; + tensor output; + tensor ref_output; + + bool weight_mode; + long ignore_index; + + miopen::Allocator::ManageDataPtr input_dev; + miopen::Allocator::ManageDataPtr target_dev; + miopen::Allocator::ManageDataPtr weight_dev; + miopen::Allocator::ManageDataPtr output_dev; +}; From 0002973eaf165b87de34d4688b72d105a8330f24 Mon Sep 17 00:00:00 2001 From: hieule88 Date: Thu, 11 Apr 2024 09:48:05 +0000 Subject: [PATCH 02/23] add MIOpenDriver for NLLLoss --- driver/driver.hpp | 5 +- driver/main.cpp | 11 ++ driver/mloNLLLossHost.hpp | 74 ++++++++ driver/nllloss_driver.hpp | 331 ++++++++++++++++++++++++++++++++++ src/kernels/MIOpenNLLLoss.cpp | 18 +- src/nllloss_api.cpp | 34 ++++ 6 files changed, 462 insertions(+), 11 deletions(-) create mode 100644 driver/mloNLLLossHost.hpp create mode 100644 driver/nllloss_driver.hpp diff --git a/driver/driver.hpp b/driver/driver.hpp index 4cfc2b544e..3e9969f40b 100644 --- a/driver/driver.hpp +++ b/driver/driver.hpp @@ -151,7 +151,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz) "pool[fp16], lrn[fp16], " "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], " "tensorop[fp16], reduce[fp16|fp64], layernorm[bfp16|fp16], sum[bfp16|fp16], " - "argmax[bfp16|fp16], groupnorm[bfp16|fp16], cat[bfp16|fp16]\n"); + "argmax[bfp16|fp16], groupnorm[bfp16|fp16], cat[bfp16|fp16], nllloss[bfp16|fp16]\n"); exit(0); // NOLINT (concurrency-mt-unsafe) } @@ -176,7 +176,8 @@ inline std::string ParseBaseArg(int argc, char* argv[]) arg != "layernormfp16" && arg != "layernormbfp16" && arg != "sum" && arg != "sumfp16" && arg != "sumbfp16" && arg != "argmax" && arg != "argmaxfp16" && arg != "argmaxbfp16" && arg != "groupnorm" && arg != "groupnormfp16" && arg != "groupnormbfp16" && arg != "cat" && - arg != "catfp16" && arg != "catbfp16" && arg != "--version") + arg != "catfp16" && arg != "catbfp16" && arg != "nllloss" && arg != "nlllossfp16" && + arg != "nlllossbfp16" && arg != "--version") { printf("FAILED: Invalid Base Input Argument\n"); Usage(); diff --git a/driver/main.cpp b/driver/main.cpp index e1c5a62d1d..dc7abc09b6 100644 --- a/driver/main.cpp +++ b/driver/main.cpp @@ -46,6 +46,8 @@ #include "sum_driver.hpp" #include "argmax_driver.hpp" #include "cat_driver.hpp" +#include "nllloss_driver.hpp" + #include #include @@ -260,6 +262,15 @@ int main(int argc, char* argv[]) { drv = new CatDriver(); } + else if(base_arg == "nllloss") { + drv = new NLLLossDriver(); + } + else if(base_arg == "nlllossfp16") { + drv = new NLLLossDriver(); + } + else if(base_arg == "nlllossbfp16") { + drv = new NLLLossDriver(); + } else { printf("Incorrect BaseArg\n"); diff --git a/driver/mloNLLLossHost.hpp b/driver/mloNLLLossHost.hpp new file mode 100644 index 0000000000..70198e7cc3 --- /dev/null +++ b/driver/mloNLLLossHost.hpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MLO_NLLLOSSHOST_H_ +#define MLO_NLLLOSSHOST_H_ + +//////////////////////////////////////////////////////////// +// +/////////////////////////////////////////////////////////// + +template +int32_t mloNLLLossForwardRunHost(miopenTensorDescriptor_t inputDesc, + Tgpu* input, + Tgpu* target, + Tgpu* weight, + Tcheck* outputhost, + long ignore_index) +{ + auto dims = miopen::deref(inputDesc).GetLengths(); + + size_t N = dims[0]; + size_t C = dims[1]; + size_t D1 = dims[2]; + size_t D2 = dims[3]; + + for (size_t n = 0; n < N; n++) + { + for (size_t d1 = 0; d1 < D1; d1++) + { + for (size_t d2 = 0; d2 < D2; d2++) + { + size_t target_index = n * D1 * D2 + d1 * D2 + d2; + long t = static_cast(target[target_index]); + size_t input_index = (n * C + t) * D1 * D2 + d1 * D2 + d2; + size_t weight_index = t; + size_t output_index = target_index; + + if (t < 0 || t == ignore_index || t >= C) + { + outputhost[output_index] = static_cast(0.0); + } + else + { + outputhost[output_index] = static_cast(-1.0f) * weight[weight_index] * input[input_index]; + } + } + } + } + + return 0; +} +#endif // MLO_NLLLOSSHOST_H_ diff --git a/driver/nllloss_driver.hpp b/driver/nllloss_driver.hpp new file mode 100644 index 0000000000..3abf3fcbe0 --- /dev/null +++ b/driver/nllloss_driver.hpp @@ -0,0 +1,331 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#ifndef GUARD_MIOPEN_NLLLOSS_DRIVER_HPP +#define GUARD_MIOPEN_NLLLOSS_DRIVER_HPP + +#include "InputFlags.hpp" +#include "driver.hpp" +#include "mloNLLLossHost.hpp" +#include "tensor_driver.hpp" +#include "timer.hpp" +#include <../test/verify.hpp> +#include +#include +#include +#include +#include +#include +#include +#include <../test/tensor_holder.hpp> +#include "random.hpp" + +template +class NLLLossDriver : public Driver +{ +public: + NLLLossDriver() : Driver() + { + miopenCreateTensorDescriptor(&inputDesc); + miopenCreateTensorDescriptor(&targetDesc); + miopenCreateTensorDescriptor(&weightDesc); + miopenCreateTensorDescriptor(&outputDesc); + + data_type = miopen_type{}; + } + + int AddCmdLineArgs() override; + int ParseCmdLineArgs(int argc, char* argv[]) override; + InputFlags& GetInputFlags() override { return inflags; } + + int GetandSetData() override; + std::vector GetInputTensorLengthsFromCmdLine(); + + int AllocateBuffersAndCopy() override; + + int RunForwardGPU() override; + int RunForwardCPU(); + + int RunBackwardGPU() override; + + Tref GetTolerance(); + int VerifyBackward() override; + int VerifyForward() override; + ~NLLLossDriver() override + { + miopenDestroyTensorDescriptor(inputDesc); + miopenDestroyTensorDescriptor(targetDesc); + miopenDestroyTensorDescriptor(weightDesc); + miopenDestroyTensorDescriptor(outputDesc); + } + +private: + InputFlags inflags; + + miopenTensorDescriptor_t inputDesc; + miopenTensorDescriptor_t targetDesc; + miopenTensorDescriptor_t weightDesc; + miopenTensorDescriptor_t outputDesc; + + std::unique_ptr in_dev; + std::unique_ptr target_dev; + std::unique_ptr weight_dev; + std::unique_ptr out_dev; + + std::vector in; + std::vector target; + std::vector weight; + std::vector out; + std::vector out_host; + + size_t N; + size_t C; + size_t D1; + size_t D2; + long ignore_index; +}; + +template +int NLLLossDriver::ParseCmdLineArgs(int argc, char* argv[]) +{ + inflags.Parse(argc, argv); + + if(inflags.GetValueInt("time") == 1) + { + miopenEnableProfiling(GetHandle(), true); + } + return miopenStatusSuccess; +} + +template +int NLLLossDriver::GetandSetData() +{ + N = inflags.GetValueInt("batchsize"); + C = inflags.GetValueInt("numclasses"); + D1 = inflags.GetValueInt("D1"); + D2 = inflags.GetValueInt("D2"); + ignore_index = static_cast(inflags.GetValueInt("ignore_index")); + + if (N<=0 || C<=0 || D1<=0 || D2<=0) + { + MIOPEN_THROW("Error Input Tensor Lengths"); + } + + std::vector in_len = {N, C, D1, D2}; + std::vector target_len = {N, D1, D2}; + std::vector weight_len = {C}; + std::vector out_len = {N, D1, D2}; + + SetTensorNd(inputDesc, in_len, data_type); + SetTensorNd(targetDesc, target_len, data_type); + SetTensorNd(weightDesc, weight_len, data_type); + SetTensorNd(outputDesc, out_len, data_type); + + return 0; +} + +template +int NLLLossDriver::AddCmdLineArgs() +{ + inflags.AddInputFlag("forw", 'F', "1", "Run only Forward NLLLoss (Default=1)", "int"); + inflags.AddInputFlag("batchsize", 'N', "1", "Batch size", "int"); + inflags.AddInputFlag("numclasses", 'C', "2", "Number of classes", "int"); + inflags.AddInputFlag("D1", 'd', "1", "Size D1", "int"); + inflags.AddInputFlag("D2", 'D', "1", "Size D2", "int"); + inflags.AddInputFlag("ignore_index", 'g', "-1", "Ignore index", "int"); + + inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); + inflags.AddInputFlag("verify", 'V', "1", "Verify (Default=1)", "int"); + inflags.AddInputFlag("time", 't', "1", "Time (Default=0)", "int"); + inflags.AddInputFlag( + "wall", 'w', "1", "Wall-clock Time, Requires time == 1 (Default=0)", "int"); + + return miopenStatusSuccess; +} + +template +int NLLLossDriver::AllocateBuffersAndCopy() +{ + size_t in_sz = GetTensorSize(inputDesc); + size_t target_sz = GetTensorSize(targetDesc); + size_t weight_sz = GetTensorSize(weightDesc); + size_t out_sz = GetTensorSize(outputDesc); + + uint32_t ctx = 0; + + in_dev = std::unique_ptr(new GPUMem(ctx, in_sz, sizeof(Tgpu))); + target_dev = std::unique_ptr(new GPUMem(ctx, target_sz, sizeof(Tgpu))); + weight_dev = std::unique_ptr(new GPUMem(ctx, weight_sz, sizeof(Tgpu))); + out_dev = std::unique_ptr(new GPUMem(ctx, out_sz, sizeof(Tgpu))); + + in = std::vector(in_sz, static_cast(0)); + target = std::vector(target_sz, static_cast(0)); + weight = std::vector(weight_sz, static_cast(0)); + out = std::vector(out_sz, static_cast(0)); + out_host = std::vector(out_sz, static_cast(0)); + + int status; + + for (int i = 0; i < in_sz; i++) + { + in[i] = prng::gen_A_to_B(static_cast(-10.0), static_cast(-(1e-2))); + } + status = in_dev->ToGPU(q, in.data()); + + for (int i = 0; i < target_sz; i++) + { + target[i] = prng::gen_A_to_B(static_cast(0), static_cast(C-1)); + } + status |= target_dev->ToGPU(q, target.data()); + + for (int i = 0; i < weight_sz; i++) + { + weight[i] = prng::gen_A_to_B(static_cast(-10.0), static_cast(10.0)); + } + status |= weight_dev->ToGPU(q, weight.data()); + + status |= out_dev->ToGPU(q, out.data()); + + if(status != 0) + std::cout << "Error copying data to GPU\n" << std::endl; + + return miopenStatusSuccess; +} + +template +int NLLLossDriver::RunForwardGPU() +{ + float kernel_total_time = 0.0; + float kernel_first_time = 0.0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + miopenNLLLossForward(GetHandle(), + inputDesc, + in_dev->GetMem(), + targetDesc, + target_dev->GetMem(), + weightDesc, + weight_dev->GetMem(), + outputDesc, + out_dev->GetMem(), + ignore_index); + + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + printf("Wall-clock Time Forward NLLLoss Elapsed: %f ms\n", t.gettime_ms() / iter); + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + printf("GPU Kernel Time Forward NLLLoss Elapsed: %f ms\n", kernel_average_time); + } + + out_dev->FromGPU(GetStream(), out.data()); + + return miopenStatusSuccess; +} + +template +int NLLLossDriver::RunForwardCPU() +{ + mloNLLLossForwardRunHost(inputDesc, + in.data(), + target.data(), + weight.data(), + out_host.data(), + ignore_index); + + return miopenStatusSuccess; +} + +template +int NLLLossDriver::RunBackwardGPU() +{ + return miopenStatusSuccess; +} + +template +Tref NLLLossDriver::GetTolerance() +{ + if(data_type == miopenHalf) + { + return 1e-3; + } + else if(data_type == miopenFloat) + { + return 5e-5; + } + else if(data_type == miopenDouble) + { + return 1e-10; + } + else if(data_type == miopenBFloat16) + { + return 5e-3; + } + return 0; +} + +template +int NLLLossDriver::VerifyForward() +{ + RunForwardCPU(); + const Tref tolerance = GetTolerance(); + auto error = miopen::rms_range(out_host, out); + + if(!std::isfinite(error) || error > tolerance) + { + std::cout << "Forward NLLLoss FAILED: " << error << std::endl; + return EC_VerifyFwd; + } + else + { + printf("Forward NLLLoss Verifies on CPU and GPU (err=%f)\n", error); + } + + return miopenStatusSuccess; +} + +template +int NLLLossDriver::VerifyBackward() +{ + return miopenStatusSuccess; +} + +#endif // GUARD_MIOPEN_NLLLOSS_DRIVER_HPP diff --git a/src/kernels/MIOpenNLLLoss.cpp b/src/kernels/MIOpenNLLLoss.cpp index bb83a4699a..222c0aefa0 100644 --- a/src/kernels/MIOpenNLLLoss.cpp +++ b/src/kernels/MIOpenNLLLoss.cpp @@ -41,10 +41,10 @@ /* input(input): [N, C, D1, D2], target(target): [N, D1, D2], * weight(weight): [C], output(output): [N, D1, D2] */ /* Each thread computes one output: output[n0][n1][n2] */ -extern "C" __global__ void NLLLossUnreducedForward4dContiguous(const FLOAT_ACCUM* __restrict__ input, - const FLOAT_ACCUM* __restrict__ target, - const FLOAT_ACCUM* weight, - FLOAT_ACCUM* __restrict__ output, +extern "C" __global__ void NLLLossUnreducedForward4dContiguous(const FLOAT* __restrict__ input, + const FLOAT* __restrict__ target, + const FLOAT* weight, + FLOAT* __restrict__ output, long ignore_index, size_t N_total, size_t C, @@ -64,15 +64,15 @@ extern "C" __global__ void NLLLossUnreducedForward4dContiguous(const FLOAT_ACCUM long t = static_cast(target[gid]); // t: Class index if (t < 0 || t == ignore_index || t >= C) { - output[gid] = 0; + output[gid] = static_cast(0.0); return; } - FLOAT_ACCUM w = weight != nullptr ? weight[t] : 1.0f; + FLOAT w = weight != nullptr ? weight[t] : static_cast(1.0); // fix this - // FLOAT_ACCUM input_value = input[NWH[0]][t][NWH[1]][NWH[2]]; - FLOAT_ACCUM input_value = input[(NWH[0] * C + t) * D1 * D2 + NWH[1] * D2 + NWH[2]]; + // FLOAT input_value = input[N][t][D1][D2]; + FLOAT input_value = input[(NWH[0] * C + t) * D1 * D2 + NWH[1] * D2 + NWH[2]]; - output[gid] = -1.0f * w * input_value; + output[gid] = static_cast(-1.0) * w * input_value; } diff --git a/src/nllloss_api.cpp b/src/nllloss_api.cpp index 56f1549e42..c153ea9a1b 100644 --- a/src/nllloss_api.cpp +++ b/src/nllloss_api.cpp @@ -30,6 +30,39 @@ #include #include +static void LogCmdNLLLoss(const miopenTensorDescriptor_t xDesc, bool is_fwd) +{ + if(miopen::IsLoggingCmd()) + { + std::stringstream ss; + auto dtype = miopen::deref(xDesc).GetType(); + if(dtype == miopenHalf) + { + ss << "nlllossfp16"; + } + else if(dtype == miopenFloat) + { + ss << "nlllossfp32"; + } + else if(dtype == miopenBFloat16) + { + ss << "nlllossbfp16"; + } + + int32_t size = {0}; + miopenGetTensorDescriptorSize(xDesc, &size); + ss << " -N " << miopen::deref(xDesc).GetLengths()[0]; + ss << " -C " << miopen::deref(xDesc).GetLengths()[1] + << " -d " << miopen::deref(xDesc).GetLengths()[2] + << " -D " << miopen::deref(xDesc).GetLengths()[3]; + + ss << " -F " << ((is_fwd) ? "1" : "2"); + + MIOPEN_LOG_DRIVER_CMD(ss.str()); + } +} + + extern "C" miopenStatus_t miopenNLLLossForward(miopenHandle_t handle, const miopenTensorDescriptor_t inputDesc, const void* input, @@ -53,6 +86,7 @@ extern "C" miopenStatus_t miopenNLLLossForward(miopenHandle_t handle, output, ignore_index); + LogCmdNLLLoss(inputDesc, true); // LogCmdNLLLossForward(inputDesc, targetDesc, weightDesc, outputDesc, ignore_index, N, C); return miopen::try_([&] { miopen::NLLLossForward( From d5ca312c565819e81d17153b1c2107b268c5efb8 Mon Sep 17 00:00:00 2001 From: hieule88 Date: Fri, 12 Apr 2024 08:02:48 +0000 Subject: [PATCH 03/23] fix bfp16 --- driver/mloNLLLossHost.hpp | 12 +++++++----- driver/nllloss_driver.hpp | 14 ++++++++------ src/kernels/MIOpenNLLLoss.cpp | 19 +++++++++---------- test/cpu_nllloss.hpp | 4 ++-- test/gtest/nllloss.hpp | 16 ++++++++-------- 5 files changed, 34 insertions(+), 31 deletions(-) diff --git a/driver/mloNLLLossHost.hpp b/driver/mloNLLLossHost.hpp index 70198e7cc3..d2e1091de7 100644 --- a/driver/mloNLLLossHost.hpp +++ b/driver/mloNLLLossHost.hpp @@ -33,10 +33,10 @@ template int32_t mloNLLLossForwardRunHost(miopenTensorDescriptor_t inputDesc, Tgpu* input, - Tgpu* target, + int* target, Tgpu* weight, Tcheck* outputhost, - long ignore_index) + int ignore_index) { auto dims = miopen::deref(inputDesc).GetLengths(); @@ -52,18 +52,20 @@ int32_t mloNLLLossForwardRunHost(miopenTensorDescriptor_t inputDesc, for (size_t d2 = 0; d2 < D2; d2++) { size_t target_index = n * D1 * D2 + d1 * D2 + d2; - long t = static_cast(target[target_index]); + int t = target[target_index]; size_t input_index = (n * C + t) * D1 * D2 + d1 * D2 + d2; size_t weight_index = t; size_t output_index = target_index; if (t < 0 || t == ignore_index || t >= C) { - outputhost[output_index] = static_cast(0.0); + outputhost[output_index] = static_cast(0); } else { - outputhost[output_index] = static_cast(-1.0f) * weight[weight_index] * input[input_index]; + outputhost[output_index] = static_cast(-1) + * static_cast(weight[weight_index]) + * static_cast(input[input_index]); } } } diff --git a/driver/nllloss_driver.hpp b/driver/nllloss_driver.hpp index 3abf3fcbe0..70d1932592 100644 --- a/driver/nllloss_driver.hpp +++ b/driver/nllloss_driver.hpp @@ -85,6 +85,8 @@ class NLLLossDriver : public Driver private: InputFlags inflags; + int forw; + miopenTensorDescriptor_t inputDesc; miopenTensorDescriptor_t targetDesc; miopenTensorDescriptor_t weightDesc; @@ -96,7 +98,7 @@ class NLLLossDriver : public Driver std::unique_ptr out_dev; std::vector in; - std::vector target; + std::vector target; std::vector weight; std::vector out; std::vector out_host; @@ -105,7 +107,7 @@ class NLLLossDriver : public Driver size_t C; size_t D1; size_t D2; - long ignore_index; + int ignore_index; }; template @@ -127,7 +129,7 @@ int NLLLossDriver::GetandSetData() C = inflags.GetValueInt("numclasses"); D1 = inflags.GetValueInt("D1"); D2 = inflags.GetValueInt("D2"); - ignore_index = static_cast(inflags.GetValueInt("ignore_index")); + ignore_index = static_cast(inflags.GetValueInt("ignore_index")); if (N<=0 || C<=0 || D1<=0 || D2<=0) { @@ -177,12 +179,12 @@ int NLLLossDriver::AllocateBuffersAndCopy() uint32_t ctx = 0; in_dev = std::unique_ptr(new GPUMem(ctx, in_sz, sizeof(Tgpu))); - target_dev = std::unique_ptr(new GPUMem(ctx, target_sz, sizeof(Tgpu))); + target_dev = std::unique_ptr(new GPUMem(ctx, target_sz, sizeof(int))); weight_dev = std::unique_ptr(new GPUMem(ctx, weight_sz, sizeof(Tgpu))); out_dev = std::unique_ptr(new GPUMem(ctx, out_sz, sizeof(Tgpu))); in = std::vector(in_sz, static_cast(0)); - target = std::vector(target_sz, static_cast(0)); + target = std::vector(target_sz, static_cast(0)); weight = std::vector(weight_sz, static_cast(0)); out = std::vector(out_sz, static_cast(0)); out_host = std::vector(out_sz, static_cast(0)); @@ -197,7 +199,7 @@ int NLLLossDriver::AllocateBuffersAndCopy() for (int i = 0; i < target_sz; i++) { - target[i] = prng::gen_A_to_B(static_cast(0), static_cast(C-1)); + target[i] = prng::gen_A_to_B(static_cast(0), static_cast(C-1)); } status |= target_dev->ToGPU(q, target.data()); diff --git a/src/kernels/MIOpenNLLLoss.cpp b/src/kernels/MIOpenNLLLoss.cpp index 222c0aefa0..bbefd90c20 100644 --- a/src/kernels/MIOpenNLLLoss.cpp +++ b/src/kernels/MIOpenNLLLoss.cpp @@ -23,7 +23,6 @@ * SOFTWARE. * *******************************************************************************/ -#include #ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS #include #endif @@ -42,16 +41,16 @@ * weight(weight): [C], output(output): [N, D1, D2] */ /* Each thread computes one output: output[n0][n1][n2] */ extern "C" __global__ void NLLLossUnreducedForward4dContiguous(const FLOAT* __restrict__ input, - const FLOAT* __restrict__ target, + const int* __restrict__ target, const FLOAT* weight, FLOAT* __restrict__ output, - long ignore_index, + int ignore_index, size_t N_total, size_t C, size_t D1, size_t D2) { - size_t gid = threadIdx.x + blockIdx.x * blockDim.x; + uint64_t gid = threadIdx.x + blockIdx.x * blockDim.x; if (gid >= N_total) return; @@ -61,18 +60,18 @@ extern "C" __global__ void NLLLossUnreducedForward4dContiguous(const FLOAT* __re NWH[1] = nc % D1; NWH[0] = nc / D1; - long t = static_cast(target[gid]); + int t = target[gid]; // t: Class index if (t < 0 || t == ignore_index || t >= C) { - output[gid] = static_cast(0.0); + output[gid] = static_cast(0); return; } - FLOAT w = weight != nullptr ? weight[t] : static_cast(1.0); + FLOAT_ACCUM w = weight != nullptr ? CVT_FLOAT2ACCUM(weight[t]) : CVT_FP32_2ACCUM(1.0f); - // fix this // FLOAT input_value = input[N][t][D1][D2]; - FLOAT input_value = input[(NWH[0] * C + t) * D1 * D2 + NWH[1] * D2 + NWH[2]]; + FLOAT_ACCUM input_value = CVT_FLOAT2ACCUM(input[(NWH[0] * C + t) * D1 * D2 + NWH[1] * D2 + NWH[2]]); - output[gid] = static_cast(-1.0) * w * input_value; + FLOAT_ACCUM val = CVT_FP32_2ACCUM(-1.0f) * w * input_value; + output[gid] = CVT_ACCUM2FLOAT(val); } diff --git a/test/cpu_nllloss.hpp b/test/cpu_nllloss.hpp index a8ae65c0d3..38a58ab2e8 100644 --- a/test/cpu_nllloss.hpp +++ b/test/cpu_nllloss.hpp @@ -30,10 +30,10 @@ template void cpu_nllloss_forward_4d(tensor input, - tensor target, + tensor target, tensor weight, tensor& output, - long ignore_index) + int ignore_index) { auto dims = input.desc.GetLengths(); size_t N = dims[0]; diff --git a/test/gtest/nllloss.hpp b/test/gtest/nllloss.hpp index 027f5eb49a..e666b62500 100644 --- a/test/gtest/nllloss.hpp +++ b/test/gtest/nllloss.hpp @@ -41,7 +41,7 @@ struct NLLLossTestCase size_t D1=0; size_t D2=0; bool weight_mode=false; - long ignore_index=-1; + int ignore_index=-1; std::vector input = {N, C, D1, D2}; friend std::ostream& operator<<(std::ostream& os, const NLLLossTestCase& tc) @@ -56,11 +56,11 @@ struct NLLLossTestCase inline std::vector NLLLossTestConfigs() { // dim, dims // clang-format off - return {{8, 1 ,6,5,false,-1}, - {1, 2, 2, 2, false, -1}, - {2,10,128,128, false, -1}, + return {{1, 2, 2, 2, false, -100}, + {2,10,128,128, false, 255}, {8, 12, 256, 256, true, -1}, - {8, 16, 512, 512, true, 10}}; + {8, 16, 512, 512, true, 10}, + {16, 21,512,512,false, 255}}; // clang-format on } @@ -106,7 +106,7 @@ struct NLLLossTest : public ::testing::TestWithParam input = tensor{in_dim}.generate(gen_input_value); - target = tensor{target_dim}.generate(gen_target_value); + target = tensor{target_dim}.generate(gen_target_value); if (!weight_mode) weight = tensor{weight_dim}.generate(gen_weight_one); @@ -156,13 +156,13 @@ struct NLLLossTest : public ::testing::TestWithParam NLLLossTestCase nllloss_config; tensor input; - tensor target; + tensor target; tensor weight; tensor output; tensor ref_output; bool weight_mode; - long ignore_index; + int ignore_index; miopen::Allocator::ManageDataPtr input_dev; miopen::Allocator::ManageDataPtr target_dev; From 79db9b3e6abc5b1f27041cab54124f94181af5a4 Mon Sep 17 00:00:00 2001 From: hieule88 Date: Fri, 12 Apr 2024 11:22:05 +0000 Subject: [PATCH 04/23] add problem desc functions --- include/miopen/miopen.h | 2 +- src/include/miopen/nllloss.hpp | 2 +- src/include/miopen/nllloss/invoke_params.hpp | 2 +- .../miopen/nllloss/problem_description.hpp | 79 ++++++++++++++++--- src/nllloss.cpp | 2 +- src/nllloss_api.cpp | 5 +- src/solver/nllloss/forward_nllloss.cpp | 6 ++ 7 files changed, 78 insertions(+), 20 deletions(-) diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 9e9917a17d..2b7f307815 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -6613,7 +6613,7 @@ MIOPEN_EXPORT miopenStatus_t miopenNLLLossForward(miopenHandle_t handle, const void* weight, const miopenTensorDescriptor_t outputDesc, void* output, - long ignore_index); + int ignore_index); /** @} */ // CLOSEOUT nllloss DOXYGEN GROUP #endif // MIOPEN_BETA_API diff --git a/src/include/miopen/nllloss.hpp b/src/include/miopen/nllloss.hpp index 74a0d92e2d..ea575b8f56 100644 --- a/src/include/miopen/nllloss.hpp +++ b/src/include/miopen/nllloss.hpp @@ -43,7 +43,7 @@ miopenStatus_t NLLLossForward(Handle& handle, ConstData_t weight, const TensorDescriptor& outputDesc, Data_t output, - long ignore_index); + int ignore_index); } // namespace miopen #endif // _MIOPEN_NLLLOSS_HPP_ diff --git a/src/include/miopen/nllloss/invoke_params.hpp b/src/include/miopen/nllloss/invoke_params.hpp index b3084685b7..830ba54762 100644 --- a/src/include/miopen/nllloss/invoke_params.hpp +++ b/src/include/miopen/nllloss/invoke_params.hpp @@ -46,7 +46,7 @@ struct InvokeParams : public miopen::InvokeParams ConstData_t weight = nullptr; Data_t output = nullptr; - long ignore_index = -1; + int ignore_index = -1; std::size_t GetWorkspaceSize() const { return 0; } Data_t GetWorkspace() const { return nullptr; } diff --git a/src/include/miopen/nllloss/problem_description.hpp b/src/include/miopen/nllloss/problem_description.hpp index 9b4fb67d74..f720c89a46 100644 --- a/src/include/miopen/nllloss/problem_description.hpp +++ b/src/include/miopen/nllloss/problem_description.hpp @@ -29,7 +29,6 @@ #include #include #include - #include #include @@ -45,29 +44,82 @@ struct ProblemDescription : ProblemDescriptionBase const TensorDescriptor& targetDesc_, const TensorDescriptor& weightDesc_, const TensorDescriptor& outputDesc_, - long ignore_index_) + int ignore_index_) : inputDesc(inputDesc_), targetDesc(targetDesc_), weightDesc(weightDesc_), outputDesc(outputDesc_), ignore_index(ignore_index_), N_total(outputDesc_.GetElementSize()), + N(inputDesc_.GetLengths()[0]), C(inputDesc_.GetLengths()[1]), D1(inputDesc_.GetLengths()[2]), D2(inputDesc_.GetLengths()[3]) { - } const TensorDescriptor& GetInputDesc() const { return inputDesc; } const TensorDescriptor& GetTargetDesc() const { return targetDesc; } const TensorDescriptor& GetWeightDesc() const { return weightDesc; } const TensorDescriptor& GetOutputDesc() const { return outputDesc; } - long GetIgnoreIndex() const { return ignore_index; } - long GetNtotal() const { return N_total; } - long GetC() const { return C; } - long GetD1() const { return D1; } - long GetD2() const { return D2; } + int GetIgnoreIndex() const { return ignore_index; } + size_t GetNtotal() const { return N_total; } + size_t GetC() const { return C; } + size_t GetD1() const { return D1; } + size_t GetD2() const { return D2; } + + /* input(input): [N, C, D1, D2], target(target): [N, D1, D2], + * weight(weight): [C], output(output): [N, D1, D2] */ + bool IsRightDim() const + { + if (outputDesc.GetLengths()[0] != N || outputDesc.GetLengths()[1] != D1 || outputDesc.GetLengths()[2] != D2 || + targetDesc.GetLengths()[0] != N || targetDesc.GetLengths()[1] != D1 || targetDesc.GetLengths()[2] != D2 || + weightDesc.GetLengths()[0] != C) + { +#if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG + MIOPEN_THROW( + miopenStatusBadParm, + "NLLLoss: Tensors dimension do not match."); +#else + return false; +#endif + } + return true; + } + + bool IsSameType() const + { + if(inputDesc.GetType() != outputDesc.GetType()) + { +#if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG + MIOPEN_THROW(miopenStatusBadParm, "NLLLoss: Tensor types of Input and Output do not match."); +#else + return false; +#endif + } + if(outputDesc.GetType() != weightDesc.GetType()) + { +#if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG + MIOPEN_THROW(miopenStatusBadParm, "NLLLoss: Tensor types of Output and Weight do not match."); +#else + return false; +#endif + } + return true; + } + + bool IsAllPacked() const + { + if(!(inputDesc.IsPacked() && targetDesc.IsPacked() && weightDesc.IsPacked() && outputDesc.IsPacked())) + { +#if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG + MIOPEN_THROW(miopenStatusBadParm, "NLLLoss: Unpacked tensors not supported."); +#else + return false; +#endif + } + return true; + } NetworkConfig MakeNetworkConfig() const override; @@ -77,11 +129,12 @@ struct ProblemDescription : ProblemDescriptionBase TensorDescriptor weightDesc; TensorDescriptor outputDesc; - long ignore_index; - long N_total; - long C; - long D1; - long D2; + int ignore_index; + size_t N_total; + size_t N; + size_t C; + size_t D1; + size_t D2; NetworkConfig MakeForwardNetworkConfig() const; }; diff --git a/src/nllloss.cpp b/src/nllloss.cpp index c847f64017..be1fe4950d 100644 --- a/src/nllloss.cpp +++ b/src/nllloss.cpp @@ -42,7 +42,7 @@ miopenStatus_t NLLLossForward(Handle& handle, ConstData_t weight, const TensorDescriptor& outputDesc, Data_t output, - long ignore_index) + int ignore_index) { const auto problem = nllloss::ProblemDescription{ inputDesc, targetDesc, weightDesc, outputDesc, ignore_index}; diff --git a/src/nllloss_api.cpp b/src/nllloss_api.cpp index c153ea9a1b..32dfe3d9c9 100644 --- a/src/nllloss_api.cpp +++ b/src/nllloss_api.cpp @@ -42,7 +42,7 @@ static void LogCmdNLLLoss(const miopenTensorDescriptor_t xDesc, bool is_fwd) } else if(dtype == miopenFloat) { - ss << "nlllossfp32"; + ss << "nllloss"; } else if(dtype == miopenBFloat16) { @@ -72,7 +72,7 @@ extern "C" miopenStatus_t miopenNLLLossForward(miopenHandle_t handle, const void* weight, const miopenTensorDescriptor_t outputDesc, void* output, - long ignore_index) + int ignore_index) { MIOPEN_LOG_FUNCTION( handle, @@ -87,7 +87,6 @@ extern "C" miopenStatus_t miopenNLLLossForward(miopenHandle_t handle, ignore_index); LogCmdNLLLoss(inputDesc, true); - // LogCmdNLLLossForward(inputDesc, targetDesc, weightDesc, outputDesc, ignore_index, N, C); return miopen::try_([&] { miopen::NLLLossForward( miopen::deref(handle), diff --git a/src/solver/nllloss/forward_nllloss.cpp b/src/solver/nllloss/forward_nllloss.cpp index 27916be83d..9cd0d3f8cb 100644 --- a/src/solver/nllloss/forward_nllloss.cpp +++ b/src/solver/nllloss/forward_nllloss.cpp @@ -48,6 +48,12 @@ namespace nllloss { bool NLLLossForward::IsApplicable(const ExecutionContext&, const miopen::nllloss::ProblemDescription& problem) const { + if(!problem.IsSameType()) + return false; + if(!problem.IsAllPacked()) + return false; + if(!problem.IsRightDim()) + return false; return true; } From cd7505c77b13163a7304692bad6b06386d2730ca Mon Sep 17 00:00:00 2001 From: hieule88 Date: Fri, 12 Apr 2024 11:24:00 +0000 Subject: [PATCH 05/23] rm gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 6bee76431a..0027c51c3f 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ install_dir/ build/ -.clangd \ No newline at end of file +.clangd +.gitignore \ No newline at end of file From f567e01a78ddf4afe77ddc196aa8d39c2b8a0f6f Mon Sep 17 00:00:00 2001 From: Hieu Le <54091202+hieule88@users.noreply.github.com> Date: Fri, 12 Apr 2024 18:25:42 +0700 Subject: [PATCH 06/23] Delete .gitignore --- .gitignore | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 .gitignore diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 0027c51c3f..0000000000 --- a/.gitignore +++ /dev/null @@ -1,10 +0,0 @@ - -.cache/ -.devcontainer/ -.vscode/ - -install_dir/ -build/ - -.clangd -.gitignore \ No newline at end of file From 2a060522f467787eca22d70a69fd6a80a351f8c3 Mon Sep 17 00:00:00 2001 From: hieule88 Date: Fri, 12 Apr 2024 11:34:15 +0000 Subject: [PATCH 07/23] test --- src/include/miopen/nllloss/problem_description.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/miopen/nllloss/problem_description.hpp b/src/include/miopen/nllloss/problem_description.hpp index f720c89a46..05ee852a8f 100644 --- a/src/include/miopen/nllloss/problem_description.hpp +++ b/src/include/miopen/nllloss/problem_description.hpp @@ -84,7 +84,7 @@ struct ProblemDescription : ProblemDescriptionBase return false; #endif } - return true; + return true; } bool IsSameType() const From 5a44d4b5507518a8e0ce2a8fb1f6d4f3deb5ac4f Mon Sep 17 00:00:00 2001 From: hieule88 Date: Fri, 12 Apr 2024 11:45:29 +0000 Subject: [PATCH 08/23] check exist gitignore --- src/include/miopen/nllloss/problem_description.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/miopen/nllloss/problem_description.hpp b/src/include/miopen/nllloss/problem_description.hpp index 05ee852a8f..f720c89a46 100644 --- a/src/include/miopen/nllloss/problem_description.hpp +++ b/src/include/miopen/nllloss/problem_description.hpp @@ -84,7 +84,7 @@ struct ProblemDescription : ProblemDescriptionBase return false; #endif } - return true; + return true; } bool IsSameType() const From 5a25d231f0c7c56fedaf0574f85fec440cf4c6c1 Mon Sep 17 00:00:00 2001 From: hieule88 Date: Sun, 14 Apr 2024 13:53:39 +0000 Subject: [PATCH 09/23] include ordering --- driver/nllloss_driver.hpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/driver/nllloss_driver.hpp b/driver/nllloss_driver.hpp index 70d1932592..e7fd4bc514 100644 --- a/driver/nllloss_driver.hpp +++ b/driver/nllloss_driver.hpp @@ -23,25 +23,33 @@ * SOFTWARE. * *******************************************************************************/ -#include #ifndef GUARD_MIOPEN_NLLLOSS_DRIVER_HPP #define GUARD_MIOPEN_NLLLOSS_DRIVER_HPP #include "InputFlags.hpp" #include "driver.hpp" #include "mloNLLLossHost.hpp" +#include "random.hpp" #include "tensor_driver.hpp" #include "timer.hpp" +#include "util_driver.hpp" + #include <../test/verify.hpp> + +#include +#include +#include +#include + #include +#include +#include #include #include #include -#include #include #include #include <../test/tensor_holder.hpp> -#include "random.hpp" template class NLLLossDriver : public Driver From 02376f4bc3b3a0757f85d9fbfbc32ca5e808c136 Mon Sep 17 00:00:00 2001 From: hieule88 Date: Mon, 15 Apr 2024 03:37:43 +0000 Subject: [PATCH 10/23] change driver format --- docs/reference/index.rst | 1 + driver/CMakeLists.txt | 1 + driver/dm_nllloss.cpp | 40 +++++++++++++++++++++++++++++++++++++++ driver/mloNLLLossHost.hpp | 2 ++ driver/nllloss_driver.hpp | 1 - include/miopen/miopen.h | 7 +++---- test/gtest/nllloss.hpp | 1 + 7 files changed, 48 insertions(+), 5 deletions(-) create mode 100644 driver/dm_nllloss.cpp diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 02bcb88622..3e56c89a9d 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -32,3 +32,4 @@ The MIOpen API library is structured as follows: * :doc:`GroupNorm <../doxygen/html/group__groupnorm>` (experimental) * :doc:`Cat <../doxygen/html/group__cat>` (experimental) * :doc:`Argmax<./argmax>` (experimental) + * :doc:`NLLLoss<../doxygen/html/group__nllloss>` (experimental) diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 224e550fed..5771a260fe 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -51,6 +51,7 @@ add_executable(MIOpenDriver dm_softmax.cpp dm_sum.cpp dm_tensorop.cpp + dm_nllloss.cpp main.cpp registry_driver_maker.cpp rocrand_wrapper.cpp) diff --git a/driver/dm_nllloss.cpp b/driver/dm_nllloss.cpp new file mode 100644 index 0000000000..1d4fa4c0dc --- /dev/null +++ b/driver/dm_nllloss.cpp @@ -0,0 +1,40 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "registry_driver_maker.hpp" +#include "nllloss_driver.hpp" + +static Driver* makeDriver(const std::string& base_arg) +{ + if(base_arg == "nllloss") + return new NLLLossDriver(); + if(base_arg == "nlllossfp16") + return new NLLLossDriver(); + if(base_arg == "nlllossbfp16") + return new NLLLossDriver(); + return nullptr; +} + +REGISTER_DRIVER_MAKER(makeDriver); diff --git a/driver/mloNLLLossHost.hpp b/driver/mloNLLLossHost.hpp index d2e1091de7..92920c072e 100644 --- a/driver/mloNLLLossHost.hpp +++ b/driver/mloNLLLossHost.hpp @@ -26,6 +26,8 @@ #ifndef MLO_NLLLOSSHOST_H_ #define MLO_NLLLOSSHOST_H_ +#include + //////////////////////////////////////////////////////////// // /////////////////////////////////////////////////////////// diff --git a/driver/nllloss_driver.hpp b/driver/nllloss_driver.hpp index e7fd4bc514..edf17f1344 100644 --- a/driver/nllloss_driver.hpp +++ b/driver/nllloss_driver.hpp @@ -70,7 +70,6 @@ class NLLLossDriver : public Driver InputFlags& GetInputFlags() override { return inflags; } int GetandSetData() override; - std::vector GetInputTensorLengthsFromCmdLine(); int AllocateBuffersAndCopy() override; diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 2b7f307815..fcf32b8705 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -6592,16 +6592,15 @@ MIOPEN_EXPORT miopenStatus_t miopenBackendInitialize(miopenBackendDescriptor_t d /*! @brief Execute a nllloss forward layer * * @param handle MIOpen handle (input) - * @param inputDesc Tensor descriptor for data input tensor x (input) - * @param input Data tensor x (input) + * @param inputDesc Tensor descriptor for data input tensor input (input) + * @param input Data tensor input (input) * @param targetDesc Tensor descriptor for data input tensor target (input) * @param target Data tensor target (input) * @param weightDesc Tensor descriptor for data input tensor weight (input) * @param weight Data tensor weight (input) * @param outputDesc Tensor descriptor for output data tensor y (input) * @param output Data tensor y (output) - * @param ignore_index Index to ignore (input) - * @param N Number of elements in the output + * @param ignore_index Class index to ignore (input) * @return miopenStatus_t */ MIOPEN_EXPORT miopenStatus_t miopenNLLLossForward(miopenHandle_t handle, diff --git a/test/gtest/nllloss.hpp b/test/gtest/nllloss.hpp index e666b62500..828e53e5d9 100644 --- a/test/gtest/nllloss.hpp +++ b/test/gtest/nllloss.hpp @@ -58,6 +58,7 @@ inline std::vector NLLLossTestConfigs() // clang-format off return {{1, 2, 2, 2, false, -100}, {2,10,128,128, false, 255}, + {5,13,17,11,true, 5}, {8, 12, 256, 256, true, -1}, {8, 16, 512, 512, true, 10}, {16, 21,512,512,false, 255}}; From 2b72956757e160f06587479ee5797197bae97943 Mon Sep 17 00:00:00 2001 From: hieule88 Date: Mon, 15 Apr 2024 04:18:26 +0000 Subject: [PATCH 11/23] change HIP kernel format --- src/kernels/MIOpenNLLLoss.cpp | 35 ++++++++++++++++++-------- src/solver/nllloss/forward_nllloss.cpp | 4 +++ 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/kernels/MIOpenNLLLoss.cpp b/src/kernels/MIOpenNLLLoss.cpp index bbefd90c20..31218d97d3 100644 --- a/src/kernels/MIOpenNLLLoss.cpp +++ b/src/kernels/MIOpenNLLLoss.cpp @@ -24,6 +24,7 @@ * *******************************************************************************/ #ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include #include #endif @@ -40,15 +41,16 @@ /* input(input): [N, C, D1, D2], target(target): [N, D1, D2], * weight(weight): [C], output(output): [N, D1, D2] */ /* Each thread computes one output: output[n0][n1][n2] */ -extern "C" __global__ void NLLLossUnreducedForward4dContiguous(const FLOAT* __restrict__ input, - const int* __restrict__ target, - const FLOAT* weight, - FLOAT* __restrict__ output, - int ignore_index, - size_t N_total, - size_t C, - size_t D1, - size_t D2) +template +__device__ void nlllossUnreducedForward4dContiguous(const TI* __restrict__ input, + const int* __restrict__ target, + const TI* weight, + TO* __restrict__ output, + int ignore_index, + size_t N_total, + size_t C, + size_t D1, + size_t D2) { uint64_t gid = threadIdx.x + blockIdx.x * blockDim.x; @@ -63,7 +65,7 @@ extern "C" __global__ void NLLLossUnreducedForward4dContiguous(const FLOAT* __re int t = target[gid]; // t: Class index if (t < 0 || t == ignore_index || t >= C) { - output[gid] = static_cast(0); + output[gid] = static_cast(0); return; } @@ -75,3 +77,16 @@ extern "C" __global__ void NLLLossUnreducedForward4dContiguous(const FLOAT* __re FLOAT_ACCUM val = CVT_FP32_2ACCUM(-1.0f) * w * input_value; output[gid] = CVT_ACCUM2FLOAT(val); } + +extern "C" __global__ void NLLLossUnreducedForward4dContiguous(const INPUT_TYPE* __restrict__ input, + const int* __restrict__ target, + const INPUT_TYPE* weight, + OUTPUT_TYPE* __restrict__ output, + int ignore_index, + size_t N_total, + size_t C, + size_t D1, + size_t D2) +{ + nlllossUnreducedForward4dContiguous(input, target, weight, output, ignore_index, N_total, C, D1, D2); +} diff --git a/src/solver/nllloss/forward_nllloss.cpp b/src/solver/nllloss/forward_nllloss.cpp index 9cd0d3f8cb..7a54c15524 100644 --- a/src/solver/nllloss/forward_nllloss.cpp +++ b/src/solver/nllloss/forward_nllloss.cpp @@ -64,6 +64,8 @@ NLLLossForward::GetSolution(const ExecutionContext& context, std::ignore = context; auto result = ConvSolution{miopenStatusSuccess}; + auto input_dtype = miopen::GetDataType(problem.GetInputDesc().GetType()); + auto output_dtype = miopen::GetDataType(problem.GetOutputDesc().GetType()); { auto dtype = problem.GetInputDesc().GetType(); @@ -87,6 +89,8 @@ NLLLossForward::GetSolution(const ExecutionContext& context, {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"INPUT_TYPE", input_dtype}, + {"OUTPUT_TYPE", output_dtype}, {"LOCAL_SIZE", LOCAL_SIZE}, }; From 8cda6790fe9e47167a3370594fd4f2ee6d7c0498 Mon Sep 17 00:00:00 2001 From: hieule88 Date: Mon, 15 Apr 2024 04:23:13 +0000 Subject: [PATCH 12/23] fix INPUT OUTPUT TYPE --- src/solver/nllloss/forward_nllloss.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/solver/nllloss/forward_nllloss.cpp b/src/solver/nllloss/forward_nllloss.cpp index 7a54c15524..16b4db901c 100644 --- a/src/solver/nllloss/forward_nllloss.cpp +++ b/src/solver/nllloss/forward_nllloss.cpp @@ -89,8 +89,8 @@ NLLLossForward::GetSolution(const ExecutionContext& context, {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, - {"INPUT_TYPE", input_dtype}, - {"OUTPUT_TYPE", output_dtype}, + {"INPUT_TYPE", input_dtype == "bfloat16" ? "ushort" : input_dtype}, + {"OUTPUT_TYPE", output_dtype == "bfloat16" ? "ushort" : output_dtype}, {"LOCAL_SIZE", LOCAL_SIZE}, }; From fcbcc17bcfbc6ab52c1bd6cc47d7d9866931644b Mon Sep 17 00:00:00 2001 From: hieule88 Date: Mon, 15 Apr 2024 06:31:05 +0000 Subject: [PATCH 13/23] resolve license and / --- driver/mloNLLLossHost.hpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/driver/mloNLLLossHost.hpp b/driver/mloNLLLossHost.hpp index 92920c072e..ed68b03a88 100644 --- a/driver/mloNLLLossHost.hpp +++ b/driver/mloNLLLossHost.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2024 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,10 +28,6 @@ #include -//////////////////////////////////////////////////////////// -// -/////////////////////////////////////////////////////////// - template int32_t mloNLLLossForwardRunHost(miopenTensorDescriptor_t inputDesc, Tgpu* input, From 0e208c1f76962887980237cf8b1803db5a92800d Mon Sep 17 00:00:00 2001 From: hieule88 Date: Mon, 15 Apr 2024 09:41:14 +0000 Subject: [PATCH 14/23] apply review comments --- driver/mloNLLLossHost.hpp | 32 +++---- driver/nllloss_driver.hpp | 89 ++++++++----------- src/CMakeLists.txt | 10 +-- src/include/miopen/nllloss/invoke_params.hpp | 8 +- .../miopen/nllloss/problem_description.hpp | 30 ++++--- src/include/miopen/nllloss/solvers.hpp | 5 +- src/kernels/MIOpenNLLLoss.cpp | 48 +++++----- src/nllloss.cpp | 22 ++--- src/nllloss/problem_description.cpp | 6 +- src/solver/nllloss/forward_nllloss.cpp | 41 +++++---- test/cpu_nllloss.hpp | 27 +++--- test/gtest/nllloss.cpp | 54 +++++++++-- test/gtest/nllloss.hpp | 84 ++++++++--------- 13 files changed, 243 insertions(+), 213 deletions(-) diff --git a/driver/mloNLLLossHost.hpp b/driver/mloNLLLossHost.hpp index ed68b03a88..df2eb3251d 100644 --- a/driver/mloNLLLossHost.hpp +++ b/driver/mloNLLLossHost.hpp @@ -30,40 +30,40 @@ template int32_t mloNLLLossForwardRunHost(miopenTensorDescriptor_t inputDesc, - Tgpu* input, - int* target, - Tgpu* weight, - Tcheck* outputhost, - int ignore_index) + Tgpu* input, + int32_t* target, + Tgpu* weight, + Tcheck* outputhost, + int32_t ignore_index) { auto dims = miopen::deref(inputDesc).GetLengths(); - size_t N = dims[0]; - size_t C = dims[1]; + size_t N = dims[0]; + size_t C = dims[1]; size_t D1 = dims[2]; size_t D2 = dims[3]; - for (size_t n = 0; n < N; n++) + for(size_t n = 0; n < N; n++) { - for (size_t d1 = 0; d1 < D1; d1++) + for(size_t d1 = 0; d1 < D1; d1++) { - for (size_t d2 = 0; d2 < D2; d2++) + for(size_t d2 = 0; d2 < D2; d2++) { size_t target_index = n * D1 * D2 + d1 * D2 + d2; - int t = target[target_index]; - size_t input_index = (n * C + t) * D1 * D2 + d1 * D2 + d2; + int32_t t = target[target_index]; + size_t input_index = (n * C + t) * D1 * D2 + d1 * D2 + d2; size_t weight_index = t; size_t output_index = target_index; - if (t < 0 || t == ignore_index || t >= C) + if(t < 0 || t == ignore_index || t >= C) { outputhost[output_index] = static_cast(0); } else { - outputhost[output_index] = static_cast(-1) - * static_cast(weight[weight_index]) - * static_cast(input[input_index]); + outputhost[output_index] = static_cast(-1) * + static_cast(weight[weight_index]) * + static_cast(input[input_index]); } } } diff --git a/driver/nllloss_driver.hpp b/driver/nllloss_driver.hpp index edf17f1344..dde8c8d73b 100644 --- a/driver/nllloss_driver.hpp +++ b/driver/nllloss_driver.hpp @@ -64,7 +64,7 @@ class NLLLossDriver : public Driver data_type = miopen_type{}; } - + int AddCmdLineArgs() override; int ParseCmdLineArgs(int argc, char* argv[]) override; InputFlags& GetInputFlags() override { return inflags; } @@ -88,12 +88,12 @@ class NLLLossDriver : public Driver miopenDestroyTensorDescriptor(weightDesc); miopenDestroyTensorDescriptor(outputDesc); } - + private: InputFlags inflags; int forw; - + miopenTensorDescriptor_t inputDesc; miopenTensorDescriptor_t targetDesc; miopenTensorDescriptor_t weightDesc; @@ -132,21 +132,21 @@ int NLLLossDriver::ParseCmdLineArgs(int argc, char* argv[]) template int NLLLossDriver::GetandSetData() { - N = inflags.GetValueInt("batchsize"); - C = inflags.GetValueInt("numclasses"); - D1 = inflags.GetValueInt("D1"); - D2 = inflags.GetValueInt("D2"); + N = inflags.GetValueInt("batchsize"); + C = inflags.GetValueInt("numclasses"); + D1 = inflags.GetValueInt("D1"); + D2 = inflags.GetValueInt("D2"); ignore_index = static_cast(inflags.GetValueInt("ignore_index")); - if (N<=0 || C<=0 || D1<=0 || D2<=0) + if(N <= 0 || C <= 0 || D1 <= 0 || D2 <= 0) { MIOPEN_THROW("Error Input Tensor Lengths"); } - std::vector in_len = {N, C, D1, D2}; + std::vector in_len = {N, C, D1, D2}; std::vector target_len = {N, D1, D2}; std::vector weight_len = {C}; - std::vector out_len = {N, D1, D2}; + std::vector out_len = {N, D1, D2}; SetTensorNd(inputDesc, in_len, data_type); SetTensorNd(targetDesc, target_len, data_type); @@ -159,16 +159,16 @@ int NLLLossDriver::GetandSetData() template int NLLLossDriver::AddCmdLineArgs() { - inflags.AddInputFlag("forw", 'F', "1", "Run only Forward NLLLoss (Default=1)", "int"); - inflags.AddInputFlag("batchsize", 'N', "1", "Batch size", "int"); - inflags.AddInputFlag("numclasses", 'C', "2", "Number of classes", "int"); - inflags.AddInputFlag("D1", 'd', "1", "Size D1", "int"); - inflags.AddInputFlag("D2", 'D', "1", "Size D2", "int"); + inflags.AddInputFlag("forw", 'F', "1", "Run only Forward NLLLoss (Default=1)", "int"); + inflags.AddInputFlag("batchsize", 'N', "5", "Batch size", "int"); + inflags.AddInputFlag("numclasses", 'C', "10", "Number of classes", "int"); + inflags.AddInputFlag("D1", 'd', "17", "Size D1", "int"); + inflags.AddInputFlag("D2", 'D', "19", "Size D2", "int"); inflags.AddInputFlag("ignore_index", 'g', "-1", "Ignore index", "int"); - inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); - inflags.AddInputFlag("verify", 'V', "1", "Verify (Default=1)", "int"); - inflags.AddInputFlag("time", 't', "1", "Time (Default=0)", "int"); + inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); + inflags.AddInputFlag("verify", 'V', "1", "Verify (Default=1)", "int"); + inflags.AddInputFlag("time", 't', "1", "Time (Default=0)", "int"); inflags.AddInputFlag( "wall", 'w', "1", "Wall-clock Time, Requires time == 1 (Default=0)", "int"); @@ -178,10 +178,10 @@ int NLLLossDriver::AddCmdLineArgs() template int NLLLossDriver::AllocateBuffersAndCopy() { - size_t in_sz = GetTensorSize(inputDesc); + size_t in_sz = GetTensorSize(inputDesc); size_t target_sz = GetTensorSize(targetDesc); size_t weight_sz = GetTensorSize(weightDesc); - size_t out_sz = GetTensorSize(outputDesc); + size_t out_sz = GetTensorSize(outputDesc); uint32_t ctx = 0; @@ -190,27 +190,27 @@ int NLLLossDriver::AllocateBuffersAndCopy() weight_dev = std::unique_ptr(new GPUMem(ctx, weight_sz, sizeof(Tgpu))); out_dev = std::unique_ptr(new GPUMem(ctx, out_sz, sizeof(Tgpu))); - in = std::vector(in_sz, static_cast(0)); - target = std::vector(target_sz, static_cast(0)); - weight = std::vector(weight_sz, static_cast(0)); - out = std::vector(out_sz, static_cast(0)); + in = std::vector(in_sz, static_cast(0)); + target = std::vector(target_sz, static_cast(0)); + weight = std::vector(weight_sz, static_cast(1)); + out = std::vector(out_sz, static_cast(0)); out_host = std::vector(out_sz, static_cast(0)); int status; - for (int i = 0; i < in_sz; i++) + for(int i = 0; i < in_sz; i++) { in[i] = prng::gen_A_to_B(static_cast(-10.0), static_cast(-(1e-2))); } status = in_dev->ToGPU(q, in.data()); - for (int i = 0; i < target_sz; i++) + for(int i = 0; i < target_sz; i++) { - target[i] = prng::gen_A_to_B(static_cast(0), static_cast(C-1)); + target[i] = prng::gen_A_to_B(static_cast(0), static_cast(C - 1)); } status |= target_dev->ToGPU(q, target.data()); - for (int i = 0; i < weight_sz; i++) + for(int i = 0; i < weight_sz; i++) { weight[i] = prng::gen_A_to_B(static_cast(-10.0), static_cast(10.0)); } @@ -273,12 +273,8 @@ int NLLLossDriver::RunForwardGPU() template int NLLLossDriver::RunForwardCPU() { - mloNLLLossForwardRunHost(inputDesc, - in.data(), - target.data(), - weight.data(), - out_host.data(), - ignore_index); + mloNLLLossForwardRunHost( + inputDesc, in.data(), target.data(), weight.data(), out_host.data(), ignore_index); return miopenStatusSuccess; } @@ -292,23 +288,14 @@ int NLLLossDriver::RunBackwardGPU() template Tref NLLLossDriver::GetTolerance() { - if(data_type == miopenHalf) - { - return 1e-3; - } - else if(data_type == miopenFloat) - { - return 5e-5; - } - else if(data_type == miopenDouble) - { - return 1e-10; - } - else if(data_type == miopenBFloat16) - { - return 5e-3; - } - return 0; + // Computation error of fp16 is ~2^13 (=8192) bigger than + // the one of fp32 because mantissa is shorter by 13 bits. + auto tolerance = std::is_same::value ? 1.5e-6 : 8.2e-3; + + // bf16 mantissa has 7 bits, by 3 bits shorter than fp16. + if(std::is_same::value) + tolerance *= 8.0; + return tolerance; } template diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index beecff0c14..a7b9f0b380 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -130,8 +130,6 @@ set( MIOpen_Source graphapi/variant_pack.cpp groupnorm_api.cpp groupnorm/problem_description.cpp - nllloss_api.cpp - nllloss/problem_description.cpp handle_api.cpp invoker_cache.cpp kernel_build_params.cpp @@ -144,6 +142,8 @@ set( MIOpen_Source lrn_api.cpp mha/mha_descriptor.cpp mha/problem_description.cpp + nllloss_api.cpp + nllloss/problem_description.cpp op_args.cpp operator.cpp performance_config.cpp @@ -262,11 +262,11 @@ set( MIOpen_Source solver/gemm_bwd.cpp solver/gemm_wrw.cpp solver/groupnorm/forward_groupnorm.cpp - solver/nllloss/forward_nllloss.cpp solver/layernorm/forward_layernorm.cpp solver/layernorm/forward_layernorm2d_ck.cpp solver/layernorm/forward_layernorm4d_ck.cpp solver/mha/mha_solver.cpp + solver/nllloss/forward_nllloss.cpp solver/pooling/forward2d.cpp solver/pooling/forwardNaive.cpp solver/pooling/forwardNd.cpp @@ -458,11 +458,11 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/MIOpenConvDirBatchNormActiv.cl kernels/MIOpenConvDirGenFwd.cl kernels/MIOpenGroupNorm.cpp - kernels/MIOpenNLLLoss.cpp kernels/MIOpenLayerNorm.cpp kernels/MIOpenLRNBwd.cl kernels/MIOpenLRNFwd.cl kernels/MIOpenNeuron.cl + kernels/MIOpenNLLLoss.cpp kernels/MIOpenPooling.cl kernels/MIOpenPoolingBwd.cl kernels/MIOpenPoolingBwdND.cl @@ -583,11 +583,11 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN argmax.cpp cat.cpp groupnorm.cpp - nllloss.cpp kernel_cache.cpp layer_norm.cpp lrn.cpp mlo_dir_conv.cpp + nllloss.cpp exec_utils.cpp ocl/activ_ocl.cpp ocl/batchnormocl.cpp diff --git a/src/include/miopen/nllloss/invoke_params.hpp b/src/include/miopen/nllloss/invoke_params.hpp index 830ba54762..cb05074288 100644 --- a/src/include/miopen/nllloss/invoke_params.hpp +++ b/src/include/miopen/nllloss/invoke_params.hpp @@ -38,13 +38,13 @@ struct InvokeParams : public miopen::InvokeParams InvokeParams() = default; - const TensorDescriptor* inputDesc = nullptr; + const TensorDescriptor* inputDesc = nullptr; const TensorDescriptor* outputDesc = nullptr; - - ConstData_t input = nullptr; + + ConstData_t input = nullptr; ConstData_t target = nullptr; ConstData_t weight = nullptr; - Data_t output = nullptr; + Data_t output = nullptr; int ignore_index = -1; diff --git a/src/include/miopen/nllloss/problem_description.hpp b/src/include/miopen/nllloss/problem_description.hpp index f720c89a46..3e0be451ee 100644 --- a/src/include/miopen/nllloss/problem_description.hpp +++ b/src/include/miopen/nllloss/problem_description.hpp @@ -53,7 +53,7 @@ struct ProblemDescription : ProblemDescriptionBase N_total(outputDesc_.GetElementSize()), N(inputDesc_.GetLengths()[0]), C(inputDesc_.GetLengths()[1]), - D1(inputDesc_.GetLengths()[2]), + D1(inputDesc_.GetLengths()[2]), D2(inputDesc_.GetLengths()[3]) { } @@ -69,30 +69,30 @@ struct ProblemDescription : ProblemDescriptionBase size_t GetD2() const { return D2; } /* input(input): [N, C, D1, D2], target(target): [N, D1, D2], - * weight(weight): [C], output(output): [N, D1, D2] */ + * weight(weight): [C], output(output): [N, D1, D2] */ bool IsRightDim() const - { - if (outputDesc.GetLengths()[0] != N || outputDesc.GetLengths()[1] != D1 || outputDesc.GetLengths()[2] != D2 || - targetDesc.GetLengths()[0] != N || targetDesc.GetLengths()[1] != D1 || targetDesc.GetLengths()[2] != D2 || - weightDesc.GetLengths()[0] != C) + { + if(outputDesc.GetLengths()[0] != N || outputDesc.GetLengths()[1] != D1 || + outputDesc.GetLengths()[2] != D2 || targetDesc.GetLengths()[0] != N || + targetDesc.GetLengths()[1] != D1 || targetDesc.GetLengths()[2] != D2 || + weightDesc.GetLengths()[0] != C) { #if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG - MIOPEN_THROW( - miopenStatusBadParm, - "NLLLoss: Tensors dimension do not match."); + MIOPEN_THROW(miopenStatusBadParm, "NLLLoss: Tensors dimension do not match."); #else return false; #endif } return true; } - + bool IsSameType() const { if(inputDesc.GetType() != outputDesc.GetType()) { #if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG - MIOPEN_THROW(miopenStatusBadParm, "NLLLoss: Tensor types of Input and Output do not match."); + MIOPEN_THROW(miopenStatusBadParm, + "NLLLoss: Tensor types of Input and Output do not match."); #else return false; #endif @@ -100,7 +100,8 @@ struct ProblemDescription : ProblemDescriptionBase if(outputDesc.GetType() != weightDesc.GetType()) { #if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG - MIOPEN_THROW(miopenStatusBadParm, "NLLLoss: Tensor types of Output and Weight do not match."); + MIOPEN_THROW(miopenStatusBadParm, + "NLLLoss: Tensor types of Output and Weight do not match."); #else return false; #endif @@ -110,7 +111,8 @@ struct ProblemDescription : ProblemDescriptionBase bool IsAllPacked() const { - if(!(inputDesc.IsPacked() && targetDesc.IsPacked() && weightDesc.IsPacked() && outputDesc.IsPacked())) + if(!(inputDesc.IsPacked() && targetDesc.IsPacked() && weightDesc.IsPacked() && + outputDesc.IsPacked())) { #if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG MIOPEN_THROW(miopenStatusBadParm, "NLLLoss: Unpacked tensors not supported."); @@ -128,7 +130,7 @@ struct ProblemDescription : ProblemDescriptionBase TensorDescriptor targetDesc; TensorDescriptor weightDesc; TensorDescriptor outputDesc; - + int ignore_index; size_t N_total; size_t N; diff --git a/src/include/miopen/nllloss/solvers.hpp b/src/include/miopen/nllloss/solvers.hpp index 18c166a9da..3c7c2d8735 100644 --- a/src/include/miopen/nllloss/solvers.hpp +++ b/src/include/miopen/nllloss/solvers.hpp @@ -42,12 +42,13 @@ namespace nllloss { using NormalizationSolver = NonTunableSolverBase; -struct NLLLossForward final : NormalizationSolver { +struct NLLLossForward final : NormalizationSolver +{ const std::string& SolverDbId() const override { return GetSolverDbId(); } bool IsApplicable(const ExecutionContext& context, const miopen::nllloss::ProblemDescription& problem) const override; - + ConvSolution GetSolution(const ExecutionContext& context, const miopen::nllloss::ProblemDescription& problem) const override; }; diff --git a/src/kernels/MIOpenNLLLoss.cpp b/src/kernels/MIOpenNLLLoss.cpp index 31218d97d3..bbb8f72a40 100644 --- a/src/kernels/MIOpenNLLLoss.cpp +++ b/src/kernels/MIOpenNLLLoss.cpp @@ -30,41 +30,35 @@ #include "float_types.h" -#if MIOPEN_USE_BFP16 == 1 -#define CVT_FLOAT2ACCUM(x) (bfloat16_to_float(x)) -#define CVT_ACCUM2FLOAT(x) (float_to_bfloat16(x)) -#define CVT_INTEGRAL2ACCUM(x) ((_FLOAT_ACCUM)(x)) -#define CVT_FP32_2FLOAT(x) (CVT_ACCUM2FLOAT(x)) -#define CVT_FP32_2ACCUM(x) (x) -#endif - /* input(input): [N, C, D1, D2], target(target): [N, D1, D2], * weight(weight): [C], output(output): [N, D1, D2] */ /* Each thread computes one output: output[n0][n1][n2] */ template -__device__ void nlllossUnreducedForward4dContiguous(const TI* __restrict__ input, - const int* __restrict__ target, +__device__ void nlllossUnreducedForward4dContiguous(const TI* __restrict__ input, + const int32_t* __restrict__ target, const TI* weight, - TO* __restrict__ output, - int ignore_index, + TO* __restrict__ output, + int32_t ignore_index, size_t N_total, size_t C, size_t D1, size_t D2) { uint64_t gid = threadIdx.x + blockIdx.x * blockDim.x; - - if (gid >= N_total) return; + + if(gid >= N_total) + return; size_t NWH[3]; - NWH[2] = (gid) % D2; + NWH[2] = (gid) % D2; size_t nc = (gid) / D2; - NWH[1] = nc % D1; - NWH[0] = nc / D1; + NWH[1] = nc % D1; + NWH[0] = nc / D1; - int t = target[gid]; + int32_t t = target[gid]; // t: Class index - if (t < 0 || t == ignore_index || t >= C) { + if(t < 0 || t == ignore_index || t >= C) + { output[gid] = static_cast(0); return; } @@ -72,21 +66,23 @@ __device__ void nlllossUnreducedForward4dContiguous(const TI* __restrict__ input FLOAT_ACCUM w = weight != nullptr ? CVT_FLOAT2ACCUM(weight[t]) : CVT_FP32_2ACCUM(1.0f); // FLOAT input_value = input[N][t][D1][D2]; - FLOAT_ACCUM input_value = CVT_FLOAT2ACCUM(input[(NWH[0] * C + t) * D1 * D2 + NWH[1] * D2 + NWH[2]]); + FLOAT_ACCUM input_value = + CVT_FLOAT2ACCUM(input[(NWH[0] * C + t) * D1 * D2 + NWH[1] * D2 + NWH[2]]); FLOAT_ACCUM val = CVT_FP32_2ACCUM(-1.0f) * w * input_value; - output[gid] = CVT_ACCUM2FLOAT(val); + output[gid] = CVT_ACCUM2FLOAT(val); } -extern "C" __global__ void NLLLossUnreducedForward4dContiguous(const INPUT_TYPE* __restrict__ input, - const int* __restrict__ target, +extern "C" __global__ void NLLLossUnreducedForward4dContiguous(const INPUT_TYPE* __restrict__ input, + const int32_t* __restrict__ target, const INPUT_TYPE* weight, - OUTPUT_TYPE* __restrict__ output, - int ignore_index, + OUTPUT_TYPE* __restrict__ output, + uint64_t ignore_index, size_t N_total, size_t C, size_t D1, size_t D2) { - nlllossUnreducedForward4dContiguous(input, target, weight, output, ignore_index, N_total, C, D1, D2); + nlllossUnreducedForward4dContiguous( + input, target, weight, output, ignore_index, N_total, C, D1, D2); } diff --git a/src/nllloss.cpp b/src/nllloss.cpp index be1fe4950d..77331b0621 100644 --- a/src/nllloss.cpp +++ b/src/nllloss.cpp @@ -42,25 +42,25 @@ miopenStatus_t NLLLossForward(Handle& handle, ConstData_t weight, const TensorDescriptor& outputDesc, Data_t output, - int ignore_index) + int32_t ignore_index) { - const auto problem = nllloss::ProblemDescription{ - inputDesc, targetDesc, weightDesc, outputDesc, ignore_index}; + const auto problem = + nllloss::ProblemDescription{inputDesc, targetDesc, weightDesc, outputDesc, ignore_index}; const auto invoke_params = [&]() { - auto tmp = nllloss::InvokeParams{}; - tmp.inputDesc = &inputDesc; + auto tmp = nllloss::InvokeParams{}; + tmp.inputDesc = &inputDesc; tmp.outputDesc = &outputDesc; - tmp.input = input; - tmp.target = target; - tmp.weight = weight; - tmp.output = output; + tmp.input = input; + tmp.target = target; + tmp.weight = weight; + tmp.output = output; tmp.ignore_index = ignore_index; return tmp; - }(); + }(); - const auto algo = AlgorithmName{"NLLLossForward"}; + const auto algo = AlgorithmName{"NLLLossForward"}; const auto solvers = solver::SolverContainer{}; solvers.ExecutePrimitive(handle, problem, algo, invoke_params); diff --git a/src/nllloss/problem_description.cpp b/src/nllloss/problem_description.cpp index 8190345e8d..fc5fcf091c 100644 --- a/src/nllloss/problem_description.cpp +++ b/src/nllloss/problem_description.cpp @@ -35,9 +35,9 @@ namespace nllloss { NetworkConfig ProblemDescription::MakeNetworkConfig() const { - auto dims = inputDesc.GetLengths(); - size_t numel = outputDesc.GetElementSize(); - size_t num_batches = dims[0]; + auto dims = inputDesc.GetLengths(); + size_t numel = outputDesc.GetElementSize(); + size_t num_batches = dims[0]; size_t num_classes = dims[1]; auto dtype = inputDesc.GetType(); diff --git a/src/solver/nllloss/forward_nllloss.cpp b/src/solver/nllloss/forward_nllloss.cpp index 16b4db901c..dc6c7374e0 100644 --- a/src/solver/nllloss/forward_nllloss.cpp +++ b/src/solver/nllloss/forward_nllloss.cpp @@ -44,7 +44,7 @@ namespace miopen { namespace solver { namespace nllloss { - + bool NLLLossForward::IsApplicable(const ExecutionContext&, const miopen::nllloss::ProblemDescription& problem) const { @@ -57,27 +57,26 @@ bool NLLLossForward::IsApplicable(const ExecutionContext&, return true; } -ConvSolution -NLLLossForward::GetSolution(const ExecutionContext& context, - const miopen::nllloss::ProblemDescription& problem) const +ConvSolution NLLLossForward::GetSolution(const ExecutionContext& context, + const miopen::nllloss::ProblemDescription& problem) const { std::ignore = context; - auto result = ConvSolution{miopenStatusSuccess}; + auto result = ConvSolution{miopenStatusSuccess}; auto input_dtype = miopen::GetDataType(problem.GetInputDesc().GetType()); auto output_dtype = miopen::GetDataType(problem.GetOutputDesc().GetType()); { - auto dtype = problem.GetInputDesc().GetType(); + auto dtype = problem.GetInputDesc().GetType(); long N_total = problem.GetNtotal(); size_t xlocalsize = LOCAL_SIZE; - size_t xgridsize = N_total; - + size_t xgridsize = N_total + LOCAL_SIZE - 1; + size_t ylocalsize = 1; - size_t ygridsize = 1; + size_t ygridsize = 1; size_t zlocalsize = 1; - size_t zgridsize = 1; + size_t zgridsize = 1; auto kernel = KernelInfo{}; @@ -113,14 +112,20 @@ NLLLossForward::GetSolution(const ExecutionContext& context, decltype(auto) params = raw_params.CastTo(); size_t N_total = params.outputDesc->GetElementSize(); - auto dims = params.inputDesc->GetLengths(); - size_t C = dims[1]; - size_t D1 = dims[2]; - size_t D2 = dims[3]; - - kernel(params.input, params.target, params.weight, params.output, - params.ignore_index, - N_total, C, D1, D2); + auto dims = params.inputDesc->GetLengths(); + size_t C = dims[1]; + size_t D1 = dims[2]; + size_t D2 = dims[3]; + + kernel(params.input, + params.target, + params.weight, + params.output, + params.ignore_index, + N_total, + C, + D1, + D2); }; }; diff --git a/test/cpu_nllloss.hpp b/test/cpu_nllloss.hpp index 38a58ab2e8..3dfc9297e0 100644 --- a/test/cpu_nllloss.hpp +++ b/test/cpu_nllloss.hpp @@ -29,37 +29,38 @@ #include "tensor_holder.hpp" template -void cpu_nllloss_forward_4d(tensor input, - tensor target, +void cpu_nllloss_forward_4d(tensor input, + tensor target, tensor weight, tensor& output, - int ignore_index) + int32_t ignore_index) { auto dims = input.desc.GetLengths(); - size_t N = dims[0]; - size_t C = dims[1]; + size_t N = dims[0]; + size_t C = dims[1]; size_t D1 = dims[2]; size_t D2 = dims[3]; - for (size_t n = 0; n < N; n++) + for(size_t n = 0; n < N; n++) { - for (size_t d1 = 0; d1 < D1; d1++) + for(size_t d1 = 0; d1 < D1; d1++) { - for (size_t d2 = 0; d2 < D2; d2++) + for(size_t d2 = 0; d2 < D2; d2++) { size_t target_index = n * D1 * D2 + d1 * D2 + d2; - int t = target[target_index]; - size_t input_index = (n * C + t) * D1 * D2 + d1 * D2 + d2; + int32_t t = target[target_index]; + size_t input_index = (n * C + t) * D1 * D2 + d1 * D2 + d2; size_t weight_index = t; size_t output_index = target_index; - if (t < 0 || t == ignore_index || t >= C) + if(t < 0 || t == ignore_index || t >= C) { - output[output_index] = 0; + output[output_index] = static_cast(0); } else { - output[output_index] = -1.0f * weight[weight_index] * input[input_index]; + output[output_index] = + static_cast(-1.0f) * weight[weight_index] * input[input_index]; } } } diff --git a/test/gtest/nllloss.cpp b/test/gtest/nllloss.cpp index 89f6be2815..9ac654ba60 100644 --- a/test/gtest/nllloss.cpp +++ b/test/gtest/nllloss.cpp @@ -46,20 +46,58 @@ struct NLLLossTestFloat : NLLLossTest { }; -} // namespace nllloss +struct NLLLossTestHalf : NLLLossTest +{ +}; + +struct NLLLossTestBFloat16 : NLLLossTest +{ +}; + +} // namespace nllloss using namespace nllloss; TEST_P(NLLLossTestFloat, NLLLossTestFw) { - // if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--float")) - // { + if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) || (GetFloatArg() == "--float")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(NLLLossTestHalf, NLLLossTestFw) +{ + if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) || (GetFloatArg() == "--half")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(NLLLossTestBFloat16, NLLLossTestFw) +{ + if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) || (GetFloatArg() == "--bfloat16")) + { RunTest(); Verify(); - // } - // else - // { - // GTEST_SKIP(); - // } + } + else + { + GTEST_SKIP(); + } }; INSTANTIATE_TEST_SUITE_P(NLLLossTestSet, NLLLossTestFloat, testing::ValuesIn(NLLLossTestConfigs())); +INSTANTIATE_TEST_SUITE_P(NLLLossTestSet, NLLLossTestHalf, testing::ValuesIn(NLLLossTestConfigs())); +INSTANTIATE_TEST_SUITE_P(NLLLossTestSet, + NLLLossTestBFloat16, + testing::ValuesIn(NLLLossTestConfigs())); diff --git a/test/gtest/nllloss.hpp b/test/gtest/nllloss.hpp index 828e53e5d9..4a014340d4 100644 --- a/test/gtest/nllloss.hpp +++ b/test/gtest/nllloss.hpp @@ -23,7 +23,6 @@ * SOFTWARE. * *******************************************************************************/ -#define MIOPEN_BETA_API 1 #include "../driver/tensor_driver.hpp" #include "cpu_nllloss.hpp" #include "get_handle.hpp" @@ -36,21 +35,21 @@ struct NLLLossTestCase { - size_t N=0; - size_t C=0; - size_t D1=0; - size_t D2=0; - bool weight_mode=false; - int ignore_index=-1; + size_t N = 0; + size_t C = 0; + size_t D1 = 0; + size_t D2 = 0; + bool weight_mode = false; + int32_t ignore_index = -1; std::vector input = {N, C, D1, D2}; friend std::ostream& operator<<(std::ostream& os, const NLLLossTestCase& tc) { - return os << " N:" << tc.N << " C:" << tc.C << " D1:" << tc.D1 << " D2:" << tc.D2 - << " weight_mode:" << tc.weight_mode << " ignore_index:" << tc.ignore_index; + return os << " N:" << tc.N << " C:" << tc.C << " D1:" << tc.D1 << " D2:" << tc.D2 + << " weight_mode:" << tc.weight_mode << " ignore_index:" << tc.ignore_index; } - std::vector GetInput() const { return input; } + std::vector GetInput() const { return input; } }; inline std::vector NLLLossTestConfigs() @@ -78,38 +77,36 @@ struct NLLLossTest : public ::testing::TestWithParam // 0 <= target < C // weight = 1 - /* input(input) : [N, C, D1, D2], - * target(target): [N, D1, D2], - * weight(weight): [C], - * output(output): [N, D1, D2] */ - + /* input(input) : [N, C, D1, D2], + * target(target): [N, D1, D2], + * weight(weight): [C], + * output(output): [N, D1, D2] */ + ignore_index = nllloss_config.ignore_index; - weight_mode = nllloss_config.weight_mode; + weight_mode = nllloss_config.weight_mode; - auto in_dim = nllloss_config.GetInput(); + auto in_dim = nllloss_config.GetInput(); std::vector target_dim = {in_dim[0], in_dim[2], in_dim[3]}; std::vector weight_dim = {in_dim[1]}; - std::vector out_dim = {in_dim[0], in_dim[2], in_dim[3]}; - - auto gen_input_value = [](auto...) -> T { static std::random_device rd; - static std::mt19937 gen(rd()); - std::uniform_real_distribution dis(-100, -1e-2); - return dis(gen); - }; - size_t numclass_C = in_dim[1]; - auto gen_target_value = [numclass_C](auto...) -> int { static std::random_device rd; - static std::mt19937 gen(rd()); - std::uniform_int_distribution dis(0, numclass_C-1); - return dis(gen); - }; - auto gen_weight_value = [](auto...) { return prng::gen_descreet_uniform_sign(1e-2, 10); }; - auto gen_weight_one = [](auto...) -> T { return 1; }; - + std::vector out_dim = {in_dim[0], in_dim[2], in_dim[3]}; + + auto gen_input_value = [](auto...) { + return prng::gen_A_to_B(static_cast(-100.0f), static_cast(-1e-2)); + }; + size_t numclass_C = in_dim[1]; + auto gen_target_value = [numclass_C](auto...) { + return prng::gen_A_to_B(0, numclass_C - 1); + }; + auto gen_weight_value = [](auto...) { + return prng::gen_A_to_B(static_cast(-10), static_cast(10)); + }; + auto gen_weight_one = [](auto...) { return static_cast(1); }; + input = tensor{in_dim}.generate(gen_input_value); - target = tensor{target_dim}.generate(gen_target_value); + target = tensor{target_dim}.generate(gen_target_value); - if (!weight_mode) + if(!weight_mode) weight = tensor{weight_dim}.generate(gen_weight_one); else weight = tensor{weight_dim}.generate(gen_weight_value); @@ -120,7 +117,7 @@ struct NLLLossTest : public ::testing::TestWithParam ref_output = tensor{out_dim}; std::fill(ref_output.begin(), ref_output.end(), std::numeric_limits::quiet_NaN()); - input_dev = handle.Write(input.data); + input_dev = handle.Write(input.data); target_dev = handle.Write(target.data); weight_dev = handle.Write(weight.data); output_dev = handle.Write(output.data); @@ -128,7 +125,7 @@ struct NLLLossTest : public ::testing::TestWithParam void RunTest() { - auto&& handle = get_handle(); + auto&& handle = get_handle(); cpu_nllloss_forward_4d(input, target, weight, ref_output, ignore_index); miopenStatus_t status = miopen::NLLLossForward(handle, @@ -142,7 +139,7 @@ struct NLLLossTest : public ::testing::TestWithParam output_dev.get(), ignore_index); fflush(stdout); - + EXPECT_EQ(status, miopenStatusSuccess); output.data = handle.Read(output_dev, output.data.size()); @@ -150,20 +147,23 @@ struct NLLLossTest : public ::testing::TestWithParam void Verify() { - auto error = miopen::rms_range(ref_output, output); + double threshold = std::numeric_limits::epsilon(); + auto error = miopen::rms_range(ref_output, output); + EXPECT_TRUE(miopen::range_distance(ref_output) == miopen::range_distance(output)); - EXPECT_TRUE(error == 0) << "Outputs do not match each other. Error:" << error; + EXPECT_TRUE(error < threshold * 10) << "Error output beyond tolerance Error:" << error + << ", Thresholdx10: " << threshold * 10; } NLLLossTestCase nllloss_config; tensor input; - tensor target; + tensor target; tensor weight; tensor output; tensor ref_output; bool weight_mode; - int ignore_index; + int32_t ignore_index; miopen::Allocator::ManageDataPtr input_dev; miopen::Allocator::ManageDataPtr target_dev; From 0cbacf37a1f26bf01ab1a0f9ec7d6355a5d1e1ae Mon Sep 17 00:00:00 2001 From: hieule88 Date: Mon, 15 Apr 2024 10:02:48 +0000 Subject: [PATCH 15/23] 2023->2024 --- driver/nllloss_driver.hpp | 2 +- src/include/miopen/nllloss.hpp | 2 +- src/include/miopen/nllloss/invoke_params.hpp | 2 +- .../miopen/nllloss/problem_description.hpp | 2 +- src/include/miopen/nllloss/solvers.hpp | 2 +- src/kernels/MIOpenNLLLoss.cpp | 2 +- src/nllloss.cpp | 2 +- src/nllloss/problem_description.cpp | 2 +- src/nllloss_api.cpp | 71 +++++++++---------- src/solver/nllloss/forward_nllloss.cpp | 2 +- test/cpu_nllloss.hpp | 2 +- test/gtest/nllloss.cpp | 2 +- test/gtest/nllloss.hpp | 2 +- 13 files changed, 46 insertions(+), 49 deletions(-) diff --git a/driver/nllloss_driver.hpp b/driver/nllloss_driver.hpp index dde8c8d73b..58b1970828 100644 --- a/driver/nllloss_driver.hpp +++ b/driver/nllloss_driver.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2024 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/include/miopen/nllloss.hpp b/src/include/miopen/nllloss.hpp index ea575b8f56..a29a5cf3fc 100644 --- a/src/include/miopen/nllloss.hpp +++ b/src/include/miopen/nllloss.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2024 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/include/miopen/nllloss/invoke_params.hpp b/src/include/miopen/nllloss/invoke_params.hpp index cb05074288..94ee1d816a 100644 --- a/src/include/miopen/nllloss/invoke_params.hpp +++ b/src/include/miopen/nllloss/invoke_params.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2024 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/include/miopen/nllloss/problem_description.hpp b/src/include/miopen/nllloss/problem_description.hpp index 3e0be451ee..1461cf66f1 100644 --- a/src/include/miopen/nllloss/problem_description.hpp +++ b/src/include/miopen/nllloss/problem_description.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2024 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/include/miopen/nllloss/solvers.hpp b/src/include/miopen/nllloss/solvers.hpp index 3c7c2d8735..579286c212 100644 --- a/src/include/miopen/nllloss/solvers.hpp +++ b/src/include/miopen/nllloss/solvers.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2024 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/kernels/MIOpenNLLLoss.cpp b/src/kernels/MIOpenNLLLoss.cpp index bbb8f72a40..31875902df 100644 --- a/src/kernels/MIOpenNLLLoss.cpp +++ b/src/kernels/MIOpenNLLLoss.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2024 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/nllloss.cpp b/src/nllloss.cpp index 77331b0621..b7888cb358 100644 --- a/src/nllloss.cpp +++ b/src/nllloss.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2024 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/nllloss/problem_description.cpp b/src/nllloss/problem_description.cpp index fc5fcf091c..2f53fe2c8b 100644 --- a/src/nllloss/problem_description.cpp +++ b/src/nllloss/problem_description.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2024 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/nllloss_api.cpp b/src/nllloss_api.cpp index 32dfe3d9c9..f10a98d5ec 100644 --- a/src/nllloss_api.cpp +++ b/src/nllloss_api.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2024 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -52,52 +52,49 @@ static void LogCmdNLLLoss(const miopenTensorDescriptor_t xDesc, bool is_fwd) int32_t size = {0}; miopenGetTensorDescriptorSize(xDesc, &size); ss << " -N " << miopen::deref(xDesc).GetLengths()[0]; - ss << " -C " << miopen::deref(xDesc).GetLengths()[1] - << " -d " << miopen::deref(xDesc).GetLengths()[2] - << " -D " << miopen::deref(xDesc).GetLengths()[3]; - + ss << " -C " << miopen::deref(xDesc).GetLengths()[1] << " -d " + << miopen::deref(xDesc).GetLengths()[2] << " -D " + << miopen::deref(xDesc).GetLengths()[3]; + ss << " -F " << ((is_fwd) ? "1" : "2"); MIOPEN_LOG_DRIVER_CMD(ss.str()); } } - extern "C" miopenStatus_t miopenNLLLossForward(miopenHandle_t handle, - const miopenTensorDescriptor_t inputDesc, - const void* input, - const miopenTensorDescriptor_t targetDesc, - const void* target, - const miopenTensorDescriptor_t weightDesc, - const void* weight, - const miopenTensorDescriptor_t outputDesc, - void* output, - int ignore_index) + const miopenTensorDescriptor_t inputDesc, + const void* input, + const miopenTensorDescriptor_t targetDesc, + const void* target, + const miopenTensorDescriptor_t weightDesc, + const void* weight, + const miopenTensorDescriptor_t outputDesc, + void* output, + int ignore_index) { - MIOPEN_LOG_FUNCTION( - handle, - inputDesc, - input, - targetDesc, - target, - weightDesc, - weight, - outputDesc, - output, - ignore_index); + MIOPEN_LOG_FUNCTION(handle, + inputDesc, + input, + targetDesc, + target, + weightDesc, + weight, + outputDesc, + output, + ignore_index); LogCmdNLLLoss(inputDesc, true); return miopen::try_([&] { - miopen::NLLLossForward( - miopen::deref(handle), - miopen::deref(inputDesc), - DataCast(input), - miopen::deref(targetDesc), - DataCast(target), - miopen::deref(weightDesc), - DataCast(weight), - miopen::deref(outputDesc), - DataCast(output), - ignore_index); + miopen::NLLLossForward(miopen::deref(handle), + miopen::deref(inputDesc), + DataCast(input), + miopen::deref(targetDesc), + DataCast(target), + miopen::deref(weightDesc), + DataCast(weight), + miopen::deref(outputDesc), + DataCast(output), + ignore_index); }); } diff --git a/src/solver/nllloss/forward_nllloss.cpp b/src/solver/nllloss/forward_nllloss.cpp index dc6c7374e0..f16d31544e 100644 --- a/src/solver/nllloss/forward_nllloss.cpp +++ b/src/solver/nllloss/forward_nllloss.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2024 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/cpu_nllloss.hpp b/test/cpu_nllloss.hpp index 3dfc9297e0..e4bae54b90 100644 --- a/test/cpu_nllloss.hpp +++ b/test/cpu_nllloss.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2024 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/gtest/nllloss.cpp b/test/gtest/nllloss.cpp index 9ac654ba60..023e420c0f 100644 --- a/test/gtest/nllloss.cpp +++ b/test/gtest/nllloss.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2024 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/gtest/nllloss.hpp b/test/gtest/nllloss.hpp index 4a014340d4..292967d46b 100644 --- a/test/gtest/nllloss.hpp +++ b/test/gtest/nllloss.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2024 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 762bd6ed38f8e3b3bfa9157c7957eb796e269515 Mon Sep 17 00:00:00 2001 From: hieule88 Date: Tue, 16 Apr 2024 07:23:08 +0000 Subject: [PATCH 16/23] fix condition gtest --- src/include/miopen/nllloss/problem_description.hpp | 13 ++----------- test/gtest/nllloss.cpp | 6 +++--- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/src/include/miopen/nllloss/problem_description.hpp b/src/include/miopen/nllloss/problem_description.hpp index 1461cf66f1..0a9574d977 100644 --- a/src/include/miopen/nllloss/problem_description.hpp +++ b/src/include/miopen/nllloss/problem_description.hpp @@ -88,20 +88,11 @@ struct ProblemDescription : ProblemDescriptionBase bool IsSameType() const { - if(inputDesc.GetType() != outputDesc.GetType()) + if(inputDesc.GetType() != weightDesc.GetType()) { #if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG MIOPEN_THROW(miopenStatusBadParm, - "NLLLoss: Tensor types of Input and Output do not match."); -#else - return false; -#endif - } - if(outputDesc.GetType() != weightDesc.GetType()) - { -#if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG - MIOPEN_THROW(miopenStatusBadParm, - "NLLLoss: Tensor types of Output and Weight do not match."); + "NLLLoss: Tensor types of Input and Weight do not match."); #else return false; #endif diff --git a/test/gtest/nllloss.cpp b/test/gtest/nllloss.cpp index 023e420c0f..ee5d11f666 100644 --- a/test/gtest/nllloss.cpp +++ b/test/gtest/nllloss.cpp @@ -59,7 +59,7 @@ using namespace nllloss; TEST_P(NLLLossTestFloat, NLLLossTestFw) { - if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) || (GetFloatArg() == "--float")) + if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--float")) { RunTest(); Verify(); @@ -72,7 +72,7 @@ TEST_P(NLLLossTestFloat, NLLLossTestFw) TEST_P(NLLLossTestHalf, NLLLossTestFw) { - if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) || (GetFloatArg() == "--half")) + if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--half")) { RunTest(); Verify(); @@ -85,7 +85,7 @@ TEST_P(NLLLossTestHalf, NLLLossTestFw) TEST_P(NLLLossTestBFloat16, NLLLossTestFw) { - if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) || (GetFloatArg() == "--bfloat16")) + if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--bfloat16")) { RunTest(); Verify(); From 07f03a69f9832e01dd4bddce997b64efd72c9e2d Mon Sep 17 00:00:00 2001 From: hieule88 Date: Tue, 16 Apr 2024 07:34:55 +0000 Subject: [PATCH 17/23] rm unused env params --- .githooks/pre-commit | 1 - src/solver/nllloss/forward_nllloss.cpp | 1 - 2 files changed, 2 deletions(-) diff --git a/.githooks/pre-commit b/.githooks/pre-commit index e166dadd03..d91036e8fc 100755 --- a/.githooks/pre-commit +++ b/.githooks/pre-commit @@ -40,4 +40,3 @@ do "$format" -i -style=file "$file" fi done - diff --git a/src/solver/nllloss/forward_nllloss.cpp b/src/solver/nllloss/forward_nllloss.cpp index f16d31544e..8ec45ce37a 100644 --- a/src/solver/nllloss/forward_nllloss.cpp +++ b/src/solver/nllloss/forward_nllloss.cpp @@ -90,7 +90,6 @@ ConvSolution NLLLossForward::GetSolution(const ExecutionContext& context, {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, {"INPUT_TYPE", input_dtype == "bfloat16" ? "ushort" : input_dtype}, {"OUTPUT_TYPE", output_dtype == "bfloat16" ? "ushort" : output_dtype}, - {"LOCAL_SIZE", LOCAL_SIZE}, }; kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); From a34106d6a9e0720a8bea8b0244efb8284065b941 Mon Sep 17 00:00:00 2001 From: hieule88 Date: Wed, 17 Apr 2024 02:27:12 +0000 Subject: [PATCH 18/23] rm unused comment lines and re-ordering --- driver/CMakeLists.txt | 2 +- driver/nllloss_driver.hpp | 2 +- include/miopen/miopen.h | 1 - src/kernels/MIOpenNLLLoss.cpp | 4 ++-- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 5771a260fe..13566589d8 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -45,13 +45,13 @@ add_executable(MIOpenDriver dm_groupnorm.cpp dm_layernorm.cpp dm_lrn.cpp + dm_nllloss.cpp dm_pool.cpp dm_reduce.cpp dm_rnn.cpp dm_softmax.cpp dm_sum.cpp dm_tensorop.cpp - dm_nllloss.cpp main.cpp registry_driver_maker.cpp rocrand_wrapper.cpp) diff --git a/driver/nllloss_driver.hpp b/driver/nllloss_driver.hpp index 58b1970828..742d8f995e 100644 --- a/driver/nllloss_driver.hpp +++ b/driver/nllloss_driver.hpp @@ -34,6 +34,7 @@ #include "timer.hpp" #include "util_driver.hpp" +#include <../test/tensor_holder.hpp> #include <../test/verify.hpp> #include @@ -49,7 +50,6 @@ #include #include #include -#include <../test/tensor_holder.hpp> template class NLLLossDriver : public Driver diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index fcf32b8705..0fc89c2a19 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -6583,7 +6583,6 @@ MIOPEN_EXPORT miopenStatus_t miopenBackendInitialize(miopenBackendDescriptor_t d #endif // MIOPEN_BETA_API #ifdef MIOPEN_BETA_API -//-------------------------------------------------------------------------------------------------// // NLLLoss APIs /** @addtogroup nllloss * diff --git a/src/kernels/MIOpenNLLLoss.cpp b/src/kernels/MIOpenNLLLoss.cpp index 31875902df..a9f30f0fcc 100644 --- a/src/kernels/MIOpenNLLLoss.cpp +++ b/src/kernels/MIOpenNLLLoss.cpp @@ -66,8 +66,8 @@ __device__ void nlllossUnreducedForward4dContiguous(const TI* __restrict__ input FLOAT_ACCUM w = weight != nullptr ? CVT_FLOAT2ACCUM(weight[t]) : CVT_FP32_2ACCUM(1.0f); // FLOAT input_value = input[N][t][D1][D2]; - FLOAT_ACCUM input_value = - CVT_FLOAT2ACCUM(input[(NWH[0] * C + t) * D1 * D2 + NWH[1] * D2 + NWH[2]]); + uint32_t input_offset = (NWH[0] * C + t) * D1 * D2 + NWH[1] * D2 + NWH[2]; + FLOAT_ACCUM input_value = CVT_FLOAT2ACCUM(input[input_offset]); FLOAT_ACCUM val = CVT_FP32_2ACCUM(-1.0f) * w * input_value; output[gid] = CVT_ACCUM2FLOAT(val); From 08e69d1ef588a9a524da6152fd527260a98bfe21 Mon Sep 17 00:00:00 2001 From: hieule88 Date: Wed, 17 Apr 2024 02:38:05 +0000 Subject: [PATCH 19/23] align up --- include/miopen/miopen.h | 2 -- src/kernels/MIOpenNLLLoss.cpp | 2 +- src/solver/nllloss/forward_nllloss.cpp | 4 ++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 0fc89c2a19..f50619c6f2 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -6616,8 +6616,6 @@ MIOPEN_EXPORT miopenStatus_t miopenNLLLossForward(miopenHandle_t handle, // CLOSEOUT nllloss DOXYGEN GROUP #endif // MIOPEN_BETA_API - - #ifdef __cplusplus } #endif diff --git a/src/kernels/MIOpenNLLLoss.cpp b/src/kernels/MIOpenNLLLoss.cpp index a9f30f0fcc..047a042e8b 100644 --- a/src/kernels/MIOpenNLLLoss.cpp +++ b/src/kernels/MIOpenNLLLoss.cpp @@ -66,7 +66,7 @@ __device__ void nlllossUnreducedForward4dContiguous(const TI* __restrict__ input FLOAT_ACCUM w = weight != nullptr ? CVT_FLOAT2ACCUM(weight[t]) : CVT_FP32_2ACCUM(1.0f); // FLOAT input_value = input[N][t][D1][D2]; - uint32_t input_offset = (NWH[0] * C + t) * D1 * D2 + NWH[1] * D2 + NWH[2]; + uint32_t input_offset = (NWH[0] * C + t) * D1 * D2 + NWH[1] * D2 + NWH[2]; FLOAT_ACCUM input_value = CVT_FLOAT2ACCUM(input[input_offset]); FLOAT_ACCUM val = CVT_FP32_2ACCUM(-1.0f) * w * input_value; diff --git a/src/solver/nllloss/forward_nllloss.cpp b/src/solver/nllloss/forward_nllloss.cpp index 8ec45ce37a..a700fd61da 100644 --- a/src/solver/nllloss/forward_nllloss.cpp +++ b/src/solver/nllloss/forward_nllloss.cpp @@ -68,10 +68,10 @@ ConvSolution NLLLossForward::GetSolution(const ExecutionContext& context, { auto dtype = problem.GetInputDesc().GetType(); - long N_total = problem.GetNtotal(); + size_t N_total = problem.GetNtotal(); size_t xlocalsize = LOCAL_SIZE; - size_t xgridsize = N_total + LOCAL_SIZE - 1; + size_t xgridsize = (N_total + LOCAL_SIZE - 1) / LOCAL_SIZE * LOCAL_SIZE; size_t ylocalsize = 1; size_t ygridsize = 1; From b3493cff137f998148eab43657fbf8732b3a9588 Mon Sep 17 00:00:00 2001 From: hieule88 Date: Wed, 17 Apr 2024 03:09:38 +0000 Subject: [PATCH 20/23] use func align up --- src/solver/nllloss/forward_nllloss.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/solver/nllloss/forward_nllloss.cpp b/src/solver/nllloss/forward_nllloss.cpp index a700fd61da..cf137c17aa 100644 --- a/src/solver/nllloss/forward_nllloss.cpp +++ b/src/solver/nllloss/forward_nllloss.cpp @@ -67,11 +67,12 @@ ConvSolution NLLLossForward::GetSolution(const ExecutionContext& context, auto output_dtype = miopen::GetDataType(problem.GetOutputDesc().GetType()); { - auto dtype = problem.GetInputDesc().GetType(); + auto dtype = problem.GetInputDesc().GetType(); size_t N_total = problem.GetNtotal(); size_t xlocalsize = LOCAL_SIZE; - size_t xgridsize = (N_total + LOCAL_SIZE - 1) / LOCAL_SIZE * LOCAL_SIZE; + // size_t xgridsize = (N_total + LOCAL_SIZE - 1) / LOCAL_SIZE * LOCAL_SIZE; + size_t xgridsize = AlignUp(N_total, xlocalsize); size_t ylocalsize = 1; size_t ygridsize = 1; From f5468c543b8086b383c8d45efda505d4fd82ea9a Mon Sep 17 00:00:00 2001 From: hieule88 Date: Wed, 17 Apr 2024 04:04:47 +0000 Subject: [PATCH 21/23] change problem descript --- src/nllloss/problem_description.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/nllloss/problem_description.cpp b/src/nllloss/problem_description.cpp index 2f53fe2c8b..9704f268d2 100644 --- a/src/nllloss/problem_description.cpp +++ b/src/nllloss/problem_description.cpp @@ -40,14 +40,18 @@ NetworkConfig ProblemDescription::MakeNetworkConfig() const size_t num_batches = dims[0]; size_t num_classes = dims[1]; - auto dtype = inputDesc.GetType(); + auto input_dtype = inputDesc.GetType(); + auto output_dtype = outputDesc.GetType(); std::ostringstream ss; - ss << "dtype" << dtype; + ss << "input_dtype" << input_dtype; + ss << "output_dtype" << output_dtype; ss << "numel" << numel; ss << "num_batches" << num_batches; ss << "num_classes" << num_classes; + ss << "D1" << dims[2]; + ss << "D2" << dims[3]; return NetworkConfig{ss.str()}; } From ee7166af6c73ddbd6dda85578c99860bbcbcdfb7 Mon Sep 17 00:00:00 2001 From: hieule88 Date: Wed, 17 Apr 2024 15:18:08 +0000 Subject: [PATCH 22/23] rm comments --- src/kernels/MIOpenNLLLoss.cpp | 2 -- src/nllloss/problem_description.cpp | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/kernels/MIOpenNLLLoss.cpp b/src/kernels/MIOpenNLLLoss.cpp index 047a042e8b..e732cd7b92 100644 --- a/src/kernels/MIOpenNLLLoss.cpp +++ b/src/kernels/MIOpenNLLLoss.cpp @@ -56,7 +56,6 @@ __device__ void nlllossUnreducedForward4dContiguous(const TI* __restrict__ input NWH[0] = nc / D1; int32_t t = target[gid]; - // t: Class index if(t < 0 || t == ignore_index || t >= C) { output[gid] = static_cast(0); @@ -65,7 +64,6 @@ __device__ void nlllossUnreducedForward4dContiguous(const TI* __restrict__ input FLOAT_ACCUM w = weight != nullptr ? CVT_FLOAT2ACCUM(weight[t]) : CVT_FP32_2ACCUM(1.0f); - // FLOAT input_value = input[N][t][D1][D2]; uint32_t input_offset = (NWH[0] * C + t) * D1 * D2 + NWH[1] * D2 + NWH[2]; FLOAT_ACCUM input_value = CVT_FLOAT2ACCUM(input[input_offset]); diff --git a/src/nllloss/problem_description.cpp b/src/nllloss/problem_description.cpp index 9704f268d2..d152a3924c 100644 --- a/src/nllloss/problem_description.cpp +++ b/src/nllloss/problem_description.cpp @@ -40,7 +40,7 @@ NetworkConfig ProblemDescription::MakeNetworkConfig() const size_t num_batches = dims[0]; size_t num_classes = dims[1]; - auto input_dtype = inputDesc.GetType(); + auto input_dtype = inputDesc.GetType(); auto output_dtype = outputDesc.GetType(); std::ostringstream ss; From e471c837c4e44902514d38038e7936da0f4e1e1b Mon Sep 17 00:00:00 2001 From: hieule88 Date: Wed, 17 Apr 2024 15:21:34 +0000 Subject: [PATCH 23/23] rm comments --- src/include/miopen/nllloss/problem_description.hpp | 2 -- src/kernels/MIOpenNLLLoss.cpp | 3 --- test/gtest/nllloss.hpp | 9 --------- 3 files changed, 14 deletions(-) diff --git a/src/include/miopen/nllloss/problem_description.hpp b/src/include/miopen/nllloss/problem_description.hpp index 0a9574d977..aedc78616c 100644 --- a/src/include/miopen/nllloss/problem_description.hpp +++ b/src/include/miopen/nllloss/problem_description.hpp @@ -68,8 +68,6 @@ struct ProblemDescription : ProblemDescriptionBase size_t GetD1() const { return D1; } size_t GetD2() const { return D2; } - /* input(input): [N, C, D1, D2], target(target): [N, D1, D2], - * weight(weight): [C], output(output): [N, D1, D2] */ bool IsRightDim() const { if(outputDesc.GetLengths()[0] != N || outputDesc.GetLengths()[1] != D1 || diff --git a/src/kernels/MIOpenNLLLoss.cpp b/src/kernels/MIOpenNLLLoss.cpp index e732cd7b92..50c9476f1d 100644 --- a/src/kernels/MIOpenNLLLoss.cpp +++ b/src/kernels/MIOpenNLLLoss.cpp @@ -30,9 +30,6 @@ #include "float_types.h" -/* input(input): [N, C, D1, D2], target(target): [N, D1, D2], - * weight(weight): [C], output(output): [N, D1, D2] */ -/* Each thread computes one output: output[n0][n1][n2] */ template __device__ void nlllossUnreducedForward4dContiguous(const TI* __restrict__ input, const int32_t* __restrict__ target, diff --git a/test/gtest/nllloss.hpp b/test/gtest/nllloss.hpp index 292967d46b..fec6892297 100644 --- a/test/gtest/nllloss.hpp +++ b/test/gtest/nllloss.hpp @@ -73,15 +73,6 @@ struct NLLLossTest : public ::testing::TestWithParam auto&& handle = get_handle(); nllloss_config = GetParam(); - // input < 0 - // 0 <= target < C - // weight = 1 - - /* input(input) : [N, C, D1, D2], - * target(target): [N, D1, D2], - * weight(weight): [C], - * output(output): [N, D1, D2] */ - ignore_index = nllloss_config.ignore_index; weight_mode = nllloss_config.weight_mode;