From 2cb2df001ae6a9fc1615287d4cc876f885e2a5d8 Mon Sep 17 00:00:00 2001 From: ubospica Date: Tue, 21 Feb 2023 04:17:44 +0000 Subject: [PATCH] finished --- include/tvm/topi/nn.h | 28 ++++++++++++++++++++++ tests/python/topi/python/test_topi_loss.py | 11 +++++++-- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 90c1c09a070b..27c1043dde7c 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -660,6 +660,32 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights, std::string reduction = "mean", int ignore_index = -100, const std::string name = "nll_loss", const std::string tag = kBroadcast) { + if (predictions.ndim() == 1) { + // corner case: no batch in shape + // prediction->shape = (C,), targets->shape = (), weights->shape = (C,) + auto T = tvm::te::compute( + {}, + [&](const tvm::Array& target_indices) { + auto c = targets(); + return tvm::tir::Select(c != ignore_index, -predictions(c) * weights(c), + tvm::tir::make_const(predictions->dtype, 0)); + }, + name, tag); + if (reduction == "mean") { + auto W = tvm::te::compute( + {}, + [&](const tvm::Array& target_indices) { + auto c = targets(); + return tvm::tir::Select(c != ignore_index, weights(c), + tvm::tir::make_const(predictions->dtype, 0)); + }, + name, tag); + return topi::divide(T, W); + } else { + return T; + } + } + auto T = tvm::te::compute( targets->shape, [&](const tvm::Array& target_indices) { @@ -674,6 +700,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T tvm::tir::make_const(predictions->dtype, 0)); }, name, tag); + ICHECK(T->shape.size() != 0); if (reduction == "mean") { auto W = tvm::te::compute( targets->shape, @@ -690,6 +717,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T return T; } } + } // namespace topi } // namespace tvm #endif // TVM_TOPI_NN_H_ diff --git a/tests/python/topi/python/test_topi_loss.py b/tests/python/topi/python/test_topi_loss.py index 53960139dd2e..969beb7d28f7 100644 --- a/tests/python/topi/python/test_topi_loss.py +++ b/tests/python/topi/python/test_topi_loss.py @@ -32,12 +32,19 @@ ((10, 5), "none", -100, "float32"), ((10, 5), "mean", 3, "float32"), ((10, 5), "mean", -100, "float64"), + ((5,), "mean", -100, "float32"), + ((5,), "mean", 3, "float32"), + ((5,), "none", -100, "float32"), ) def test_nll_loss(target, dev, prediction_shape, reduction, ignore_index, dtype): - C = prediction_shape[1] - target_shape = prediction_shape[:1] + prediction_shape[2:] + if len(prediction_shape) == 1: + C = prediction_shape[0] + target_shape = [] + else: + C = prediction_shape[1] + target_shape = prediction_shape[:1] + prediction_shape[2:] predictions = te.placeholder(shape=prediction_shape, name="predictions", dtype=dtype) targets = te.placeholder(shape=target_shape, name="targets", dtype="int32") weights = te.placeholder(shape=(C,), name="weights", dtype=dtype)