From 97bcded98f0e86f6796dd2795510e6042354ac39 Mon Sep 17 00:00:00 2001 From: Chaosfan Date: Sun, 15 Jan 2023 22:41:10 +0800 Subject: [PATCH] [Op][NN] cross_entropy, log_softmax, nll_loss (#94) After discussing about the loss, a good way is `log_softmax` + `nll_loss`. This PR introduces these two operators and tests them. As for `nll_loss`, here are some basic shape descriptions which may help review. And an important reference: https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss ``` def nll_loss( predictions: Expr, targets: Expr, weights: Optional[Expr] = None, reduction: str = "mean", ignore_index: int = -100, ) -> Expr: Notations: N: minibatch size C: number of classes K: number of input dimensions Shape: weights: (C,) (always) without minibatch: predictions: (C,) targets: () output: () with minibatch N: predictions: (N, C) targets: (N,) output: (N,) (reduction=none) output: () (reduction=mean/sum) with minibatch N and high dimension input d1, d2, ..., dk: predictions: (N, C, d1, d2, ..., dk) targets: (N, d1, d2, ..., dk) output: (N, d1, d2, ..., dk) (reduction=none) output: () (reduction=mean/sum) ``` Our inference rule is trusting `predictions`, do equal assertion if other arguments have enough information and do best effort inference. Please check the code for details. This PR also introduces cross entropy operator since it is dropped when rebasing onto tlc. Given that torch has different definitions with our cross entropy, here we use the names `cross_entropy_without_logits` and `cross_entropy_with_logits` to make it less confused and align with relay. --- include/tvm/relax/attrs/nn.h | 13 + python/tvm/relax/op/nn/nn.py | 127 +++- src/relax/op/nn/nn.cc | 315 +++++++++ src/relax/op/nn/nn.h | 13 + tests/python/relax/test_op_manipulate.py | 57 +- tests/python/relax/test_op_nn.py | 643 +++++++++++++++++- .../relax/test_tvmscript_parser_op_nn.py | 95 +++ 7 files changed, 1228 insertions(+), 35 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 4b4e0680e2..ae102ad0eb 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -184,6 +184,19 @@ struct DropoutAttrs : public tvm::AttrsNode { } }; // struct DropoutAttrs +/*! \brief Attributes used in nll_loss operator */ +struct NLLLossAttrs : public tvm::AttrsNode { + String reduction; + int ignore_index; + + TVM_DECLARE_ATTRS(NLLLossAttrs, "relax.attrs.NLLLossAttrs") { + TVM_ATTR_FIELD(reduction).set_default("mean").describe( + "The reduction method to apply to the output. Can be" + "'none', 'mean' or 'sum'."); + TVM_ATTR_FIELD(ignore_index).describe("The target value to ignore."); + } +}; // struct NLLLossAttrs + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index a62cc7f997..fe3d27b183 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -327,7 +327,9 @@ def silu(data: Expr) -> Expr: def softmax(data: Expr, axis: int = -1) -> Expr: r"""Computes softmax. - .. math:: text{softmax}(x)_i = frac{exp(x_i)}{\sum_j exp(x_j)} + .. math:: + + \text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} Parameters ---------- @@ -351,6 +353,34 @@ def softmax(data: Expr, axis: int = -1) -> Expr: return _ffi_api.softmax(data, axis) # type: ignore +def log_softmax(data: Expr, axis: int = -1) -> Expr: + r"""Computes log softmax. + + .. math:: + + \text{log\_softmax}(x_i) = \log\left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}\right) + + .. note:: + This operator can be optimized away for inference. + + Parameters + ---------- + data: relax.Expr + The input data to the operator. + + axis: int + The axis to sum over when computing log softmax. + If not specified, it is by default the last axis of the input tensor. + Supports negative indexing. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.log_softmax(data, axis) # type: ignore + + def batch_norm( data: Expr, gamma: Expr, @@ -525,3 +555,98 @@ def dropout(data: Expr, rate: float = 0.5) -> Expr: mask tensor (1.0 where element not dropped, 0.0 where dropped) """ return _ffi_api.dropout(data, rate) # type: ignore + + +def cross_entropy_without_logits(predictions: Expr, labels: Expr) -> Expr: + r"""CrossEntropy without logits between the predictions and labels. + + The shape of predictions and labels must be the same. And when ndim >= 2, + the first dimension is regarded as the batch_size N. In this case the + computed result will divide by N to perform a mean reduction. + + .. math:: + + \text{cross\_entropy\_without\_logits}(x_i, y_i) = \frac{\sum_i -y_i \log x_i}{N} + + Parameters + ---------- + predictions : relax.Expr + The predictions. + + labels : relax.Expr + The labels (the ground truth values). + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.cross_entropy_without_logits(predictions, labels) # type: ignore + + +def cross_entropy_with_logits(predictions: Expr, labels: Expr) -> Expr: + r"""CrossEntropy with logits between the predictions and labels. + + The shape issue is the same with cross_entropy_without_logits. + + .. math:: + + \text{cross\_entropy\_with\_logits}(x_i, y_i) = \frac{\sum_i -x_i \cdot y_i}{N} + + Parameters + ---------- + predictions : relax.Expr + The predictions. + + labels : relax.Expr + The labels (the ground truth values). + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.cross_entropy_with_logits(predictions, labels) # type: ignore + + +def nll_loss( + predictions: Expr, + targets: Expr, + weights: Optional[Expr] = None, + reduction: str = "mean", + ignore_index: int = -100, +) -> Expr: + """Negative log likelihood loss. + + `output[n, i_1, i_2, ..., i_k] = -p * w`, where + - `p = predictions[n, t, i_1, i_2, i_k]`, + - `t = targets[n, i_1, i_2, ..., i_k]`, + - `w = weights[n, i_1, i_2, ..., i_k] if t != ignore_index else 0` + + result = reduction(output) + + Parameters + ---------- + predictions : relax.Expr + The predictions. + + targets : relax.Expr + The target value of each prediction. + + weights : Optional[relax.Expr] + The weight of each target value. + If not specified, it is treated as if having all ones. + + reduction : string + The reduction method to apply to the output. + Possible values are "mean", "sum" and "none". + + ignore_index : int + The target value to ignore. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.nll_loss(predictions, targets, weights, reduction, ignore_index) # type: ignore diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 1fa70dbdb2..fe9c678480 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -71,6 +71,22 @@ TVM_REGISTER_OP("relax.nn.softmax") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoSoftmax); +/* relax.nn.log_softmax */ +Expr log_softmax(Expr data, int axis) { + auto attrs = make_object(); + attrs->axis = axis; + static const Op& op = Op::Get("relax.nn.log_softmax"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.log_softmax").set_body_typed(log_softmax); + +TVM_REGISTER_OP("relax.nn.log_softmax") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoSoftmax); + bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, const Array& input_sinfo, Array axes) { Op op = Downcast(call->op); @@ -246,5 +262,304 @@ TVM_REGISTER_OP("relax.nn.dropout") .set_attr("FInferStructInfo", InferStructInfoDropout) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); +// infer structinfo for CrossEntropyWithoutLogits and CrossEntropyWithLogits +StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo pred_sinfo = input_sinfo[0]; + TensorStructInfo label_sinfo = input_sinfo[1]; + + // infer dtype + DataType dtype = InferBinaryArithOpOutDtype(call, ctx, pred_sinfo, label_sinfo); + + // infer ndim + if (!pred_sinfo->IsUnknownNdim() && !label_sinfo->IsUnknownNdim() && + pred_sinfo->ndim != label_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "CrossEntropy requires predictions and labels to have the same ndim. " + "However, the ndim of predictions is " + << pred_sinfo->ndim << " while the ndim of labels is " << label_sinfo->ndim); + } + + Optional> pred_shape_value; + if (pred_sinfo->shape.defined()) { + pred_shape_value = GetStructInfoAs(pred_sinfo->shape.value())->values; + } + + Optional> label_shape_value; + if (label_sinfo->shape.defined()) { + label_shape_value = GetStructInfoAs(label_sinfo->shape.value())->values; + } + + if (pred_shape_value.defined() && label_shape_value.defined()) { + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + for (size_t i = 0; i < pred_shape_value.value().size(); ++i) { + if (analyzer->CanProve(pred_shape_value.value()[i] != label_shape_value.value()[i])) { + ctx->ReportFatal(Diagnostic::Error(call) + << "CrossEntropy requires the predictions and labels to have " + "the same shape. However, the shape of predictions at dim " + << i << " is" << pred_shape_value.value()[i] + << " while the shape of labels at this dim is " + << label_shape_value.value()[i]); + } + } + } + return TensorStructInfo(ShapeExpr(Array()), dtype); +} + +/* relax.nn.cross_entropy_without_logits */ +Expr cross_entropy_without_logits(Expr predictions, Expr labels) { + static const Op& op = Op::Get("relax.nn.cross_entropy_without_logits"); + return Call(op, {std::move(predictions), std::move(labels)}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.cross_entropy_without_logits") + .set_body_typed(cross_entropy_without_logits); + +TVM_REGISTER_OP("relax.nn.cross_entropy_without_logits") + .set_num_inputs(2) + .add_argument("predictions", "Tensor", "The predictions.") + .add_argument("labels", "Tensor", "The labels.") + .set_attr("FInferStructInfo", InferStructInfoCrossEntropy); + +/* relax.nn.cross_entropy_with_logits */ +Expr cross_entropy_with_logits(Expr predictions, Expr labels) { + static const Op& op = Op::Get("relax.nn.cross_entropy_with_logits"); + return Call(op, {std::move(predictions), std::move(labels)}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.cross_entropy_with_logits") + .set_body_typed(cross_entropy_with_logits); + +TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits") + .set_num_inputs(2) + .add_argument("predictions", "Tensor", "The predictions.") + .add_argument("labels", "Tensor", "The labels.") + .set_attr("FInferStructInfo", InferStructInfoCrossEntropy); + +/* relax.nn.nll_loss */ +TVM_REGISTER_NODE_TYPE(NLLLossAttrs); + +Expr nll_loss(Expr predictions, Expr targets, Optional weights, String reduction, + int ignore_index) { + ObjectPtr attrs = make_object(); + + ICHECK(reduction == "none" || reduction == "sum" || reduction == "mean") + << "The argument reduction of NLLLoss should be one of the following " + "values: none, mean, sum. However, the given value is " + << reduction; + + attrs->reduction = std::move(reduction); + attrs->ignore_index = ignore_index; + + static const Op& op = Op::Get("relax.nn.nll_loss"); + if (weights.defined()) { + return Call(op, {std::move(predictions), std::move(targets), std::move(weights.value())}, + Attrs{attrs}, {}); + } else { + return Call(op, {std::move(predictions), std::move(targets)}, Attrs{attrs}, {}); + } +} + +TVM_REGISTER_GLOBAL("relax.op.nn.nll_loss").set_body_typed(nll_loss); + +StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() < 2 || call->args.size() > 3) { + ctx->ReportFatal(Diagnostic::Error(call) << "NLLLoss op should take 2 or 3 arguments"); + } + + const auto* pred_sinfo = GetStructInfoAs(call->args[0]); + const auto* tgt_sinfo = GetStructInfoAs(call->args[1]); + const TensorStructInfoNode* wgt_sinfo = nullptr; + if (call->args.size() == 3) { + wgt_sinfo = GetStructInfoAs(call->args[2]); + if (wgt_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "NLLLoss requires the argument weights to be Tensor. However, the given one is " + << call->args[1]->struct_info_->GetTypeKey()); + } + } + + if (pred_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "NLLLoss requires the argument preditions to be Tensor. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (tgt_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "NLLLoss requires the argument targets to be Tensor. However, the given one is " + << call->args[2]->struct_info_->GetTypeKey()); + } + + // infer dtype + DataType output_dtype; + if (wgt_sinfo != nullptr) { + output_dtype = InferBinaryArithOpOutDtype(call, ctx, GetRef(pred_sinfo), + GetRef(wgt_sinfo)); + } else { + output_dtype = pred_sinfo->dtype; + } + + // the type of targets must be int/uint. + if (!tgt_sinfo->IsUnknownDtype() && !tgt_sinfo->dtype.is_int() && !tgt_sinfo->dtype.is_uint()) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "NLLLoss expects the dtype of targets to be int/uint. However, the dtype of targets is " + << tgt_sinfo->dtype); + } + + // infer ndim + int K = kUnknownNDim; // k dim + if (!pred_sinfo->IsUnknownNdim()) { + if (pred_sinfo->ndim < 1) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "NLLLoss expects the ndim of predictions >= 1. However, the ndim of predictions is " + << pred_sinfo->ndim); + } + K = pred_sinfo->ndim <= 2 ? 0 : pred_sinfo->ndim - 2; + } + if (!tgt_sinfo->IsUnknownNdim()) { + int K_tgt = tgt_sinfo->ndim <= 1 ? 0 : tgt_sinfo->ndim - 1; + if (K != kUnknownNDim && K != K_tgt) { + ctx->ReportFatal(Diagnostic::Error(call) + << "NLLLoss expects number of dimensions K inferred from different " + "arguments to be equal. However, K from predictions is " + << K << " while K from targets is " << K_tgt); + } + } + if (wgt_sinfo != nullptr && !wgt_sinfo->IsUnknownNdim() && wgt_sinfo->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "NLLLoss expects the ndim of weights == 1. However, the ndim of weights is " + << wgt_sinfo->ndim); + } + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + Optional N; + Optional C; + Array output_shape; // N, d1, d2, ..., dk + + Optional> pred_shape_value; + if (pred_sinfo->shape.defined()) { + pred_shape_value = GetStructInfoAs(pred_sinfo->shape.value())->values; + } + if (pred_shape_value.defined()) { + if (pred_shape_value.value().size() == 1) { + // (C,) + ICHECK(pred_sinfo->ndim == 1); + C = pred_shape_value.value()[0]; + } else { + // (N, C, d1, d2, ..., dk) + ICHECK(pred_shape_value.value().size() >= 2); + ICHECK(pred_sinfo->ndim == static_cast(pred_shape_value.value().size())); + N = pred_shape_value.value()[0]; + C = pred_shape_value.value()[1]; + output_shape = Array(); + output_shape.push_back(N.value()); + for (size_t i = 2; i < pred_shape_value.value().size(); ++i) { + output_shape.push_back(pred_shape_value.value()[i]); + } + } + } + + Optional> tgt_shape_value; + if (tgt_sinfo->shape.defined()) { + tgt_shape_value = GetStructInfoAs(tgt_sinfo->shape.value())->values; + } + if (tgt_shape_value.defined()) { + if (tgt_shape_value.value().empty()) { + // () + ICHECK(tgt_sinfo->ndim == 0); + if (N.defined()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Shape mismatch for NLLLoss. Predictions shape is " + "(N, C, ...) while targets is a scalar"); + } + } else { + // (N,) or (N, d1, d2, ..., dk) + // check N + const PrimExpr& N_tgt = tgt_shape_value.value()[0]; + if (N.defined() && analyzer->CanProve(N.value() != N_tgt)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "NLLLoss expects minibatch size N inferred from different " + "arguments to be equal. However, N from predictions is " + << N << " while N from targets is " << N_tgt); + } + // only C case + if (!N.defined() && C.defined()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Shape mismatch for NLLLoss. Predictions shape is " + "(C,) while targets is not a scalar"); + } + + if (tgt_shape_value.value().size() == 1) { + // (N,) + ICHECK(tgt_sinfo->IsUnknownNdim() || tgt_sinfo->ndim == 1); + } else { + // (N, d1, d2, ..., dk) + ICHECK(tgt_shape_value.value().size() >= 2); + ICHECK(tgt_sinfo->IsUnknownNdim() || + tgt_sinfo->ndim == static_cast(tgt_shape_value.value().size())); + + if (pred_shape_value.defined()) { + // check (d1, d2, ..., dk) + for (size_t i = 1; i < tgt_shape_value.value().size(); ++i) { + if (analyzer->CanProve(output_shape[i] != tgt_shape_value.value()[i])) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Shape mismatch for NLLLoss. The prediction shape at this dim is " + << output_shape[i] << " while the target shape at this dim is " + << tgt_shape_value.value()[i]); + } + } + } + } + } + } + + if (wgt_sinfo != nullptr) { + Optional> wgt_shape_value; + if (wgt_sinfo->shape.defined()) { + wgt_shape_value = GetStructInfoAs(wgt_sinfo->shape.value())->values; + } + if (wgt_shape_value.defined()) { + ICHECK(wgt_shape_value.value().size() == 1); + ICHECK(wgt_sinfo->IsUnknownNdim() || wgt_sinfo->ndim == 1); + const PrimExpr& C_wgt = wgt_shape_value.value()[0]; + if (C.defined() && analyzer->CanProve(C.value() != C_wgt)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "NLLLoss expects number of classes C inferred from different " + "arguments to be equal. However, C from predictions is " + << C << " while C from weights is " << C_wgt); + } + } + } + + const auto* attrs = call->attrs.as(); + String reduction = attrs->reduction; + + if (reduction == "none") { + // () or (N,) or (N, d1, d2, ..., dk) + if (pred_sinfo->shape.as()) { + return TensorStructInfo(ShapeExpr(output_shape), output_dtype); + } else { + int output_ndim = pred_sinfo->ndim == kUnknownNDim ? kUnknownNDim : pred_sinfo->ndim - 1; + return TensorStructInfo(output_dtype, /*ndim=*/output_ndim); + } + } else { + // sum or mean. output is scalar + return TensorStructInfo(/*shape=*/ShapeExpr(Array()), output_dtype); + } +} + +TVM_REGISTER_OP("relax.nn.nll_loss") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("predictions", "Tensor", "The prediction tensor.") + .add_argument("targets", "Tensor", "The target tensor.") + .add_argument("weights", "Optional", "The weight of each target values.") + .set_attr("FInferStructInfo", InferStructInfoNLLLoss); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index 39f26a7237..ba34f5bb1f 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -44,6 +44,9 @@ Expr silu(Expr data); /*! \brief Softmax function. */ Expr softmax(Expr data, int axis); +/*! \brief LogSoftmax function. */ +Expr log_softmax(Expr data, int axis); + /*! \brief Compute batch normalization. */ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // int axis, double epsilon, bool center, bool scale); @@ -62,6 +65,16 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double ep */ Expr dropout(Expr data, double rate); +/*! \brief CrossEntropy without logits. */ +Expr cross_entropy_without_logits(Expr predictions, Expr labels); + +/*! \brief CrossEntropy with logits. */ +Expr cross_entropy_with_logits(Expr predictions, Expr labels); + +/*! \brief Negative log likelihood loss. */ +Expr nll_loss(Expr predictions, Expr targets, Optional weights, String reduction, + int ignore_index); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index c952a5395d..7974839ad2 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -33,6 +33,11 @@ def test_op_correctness(): assert relax.op.reshape(x, (4, 5, 3)).op == Op.get("relax.reshape") assert relax.op.split(x, indices_or_sections=1).op == Op.get("relax.split") assert relax.op.squeeze(x).op == Op.get("relax.squeeze") + assert relax.op.broadcast_to(x, (3, 3, 4, 5)).op == Op.get("relax.broadcast_to") + assert relax.op.collapse_sum_to(x, (4, 5)).op == Op.get("relax.collapse_sum_to") + + y = relax.Var("x", R.Tensor((4, 5), "float32")) + assert relax.op.collapse_sum_like(x, y).op == Op.get("relax.collapse_sum_like") def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): @@ -2286,12 +2291,12 @@ def test_collapse_sum_like_infer_struct_info_shape_symbolic(): def test_collapse_sum_like_infer_struct_info_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - s3 = relax.Var("s", relax.ShapeStructInfo((3, 4))) - s4 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s5 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s2", relax.ShapeStructInfo()) + s3 = relax.Var("s3", relax.ShapeStructInfo((3, 4))) + s4 = relax.Var("s4", relax.ShapeStructInfo(ndim=2)) + s5 = relax.Var("s5", relax.ShapeStructInfo()) x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) @@ -2317,7 +2322,7 @@ def test_collapse_sum_like_infer_struct_info_more_input_dtype(): _check_inference(bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo((3, 4), "int8")) -def test_collapse_sum_like_wrong_input_type(): +def test_collapse_sum_like_infer_struct_info_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) x1 = relax.Var("x", relax.ShapeStructInfo((4, 5))) @@ -2330,7 +2335,7 @@ def test_collapse_sum_like_wrong_input_type(): bb.normalize(relax.op.collapse_sum_like(x2, x0)) -def test_collapse_sum_like_check_shape_failure(): +def test_collapse_sum_like_infer_struct_info_shape_mismatch(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) y0 = relax.Var("y", R.Tensor((3, 6, 5), "float32")) @@ -2339,13 +2344,13 @@ def test_collapse_sum_like_check_shape_failure(): x1 = relax.Var("z", R.Tensor((3, a, 5), "float32")) y1 = relax.Var("w", R.Tensor((3, b, 5), "float32")) - s0 = relax.Var("s", relax.ShapeStructInfo((3, 4, 5))) - s1 = relax.Var("s", relax.ShapeStructInfo((3, 6, 5))) + s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5))) + s1 = relax.Var("s1", relax.ShapeStructInfo((3, 6, 5))) x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) y2 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) - s2 = relax.Var("s", relax.ShapeStructInfo((3, a, 5))) - s3 = relax.Var("s", relax.ShapeStructInfo((3, b, 5))) + s2 = relax.Var("s2", relax.ShapeStructInfo((3, a, 5))) + s3 = relax.Var("s3", relax.ShapeStructInfo((3, b, 5))) x3 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) y3 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) @@ -2402,9 +2407,9 @@ def test_collapse_sum_to_infer_struct_info_shape_symbolic(): def test_collapse_sum_to_infer_struct_info_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s2 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s2", relax.ShapeStructInfo()) x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) @@ -2432,7 +2437,7 @@ def test_collapse_sum_to_infer_struct_info_more_input_dtype(): ) -def test_collapse_sum_to_wrong_input_type(): +def test_collapse_sum_to_infer_struct_info_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) x1 = relax.Var("x", relax.ShapeStructInfo((4, 5))) @@ -2448,17 +2453,17 @@ def test_collapse_sum_to_wrong_input_type(): bb.normalize(relax.op.collapse_sum_to(x1, x1)) -def test_collapse_sum_to_check_shape_failure(): +def test_collapse_sum_to_infer_struct_info_shape_mismatch(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) a = tir.Var("a", "int64") b = tir.Var("b", "int64") x1 = relax.Var("x", R.Tensor((3, a, 5), "float32")) - s0 = relax.Var("s", relax.ShapeStructInfo((3, 4, 5))) + s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5))) x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - s1 = relax.Var("s", relax.ShapeStructInfo((3, a, 5))) + s1 = relax.Var("s1", relax.ShapeStructInfo((3, a, 5))) x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) with pytest.raises(TVMError): @@ -2474,24 +2479,24 @@ def test_collapse_sum_to_check_shape_failure(): bb.normalize(relax.op.collapse_sum_to(x3, (3, b, 5))) -def test_collapse_sum_to_struct_info_tgt_shape_var(): +def test_collapse_sum_to_infer_struct_info_struct_info_tgt_shape_var(): bb = relax.BlockBuilder() a = tir.Var("a", "int64") b = tir.Var("b", "int64") c = tir.Var("c", "int64") d = tir.Var("d", "int64") - s0 = relax.Var("s", relax.ShapeStructInfo((3, a, b))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s2 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s0", relax.ShapeStructInfo((3, a, b))) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s2", relax.ShapeStructInfo()) x0 = relax.Var("x", R.Tensor((3, a, b), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) x2 = relax.Var("x", R.Tensor("")) x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - stgt0 = relax.Var("stgt", relax.ShapeStructInfo((a, b))) - stgt1 = relax.Var("stgt", relax.ShapeStructInfo(ndim=2)) - stgt2 = relax.Var("stgt", relax.ShapeStructInfo()) + stgt0 = relax.Var("stgt0", relax.ShapeStructInfo((a, b))) + stgt1 = relax.Var("stgt1", relax.ShapeStructInfo(ndim=2)) + stgt2 = relax.Var("stgt2", relax.ShapeStructInfo()) _check_inference( bb, relax.op.collapse_sum_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32") diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index d047448309..bcd720d1c6 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -29,6 +29,7 @@ def test_op_correctness(): assert relax.op.nn.gelu(x).op == Op.get("relax.nn.gelu") assert relax.op.nn.silu(x).op == Op.get("relax.nn.silu") assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax") + assert relax.op.nn.log_softmax(x).op == Op.get("relax.nn.log_softmax") assert relax.op.nn.dropout(x).op == Op.get("relax.nn.dropout") x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) @@ -41,6 +42,20 @@ def test_op_correctness(): ) assert relax.op.nn.layer_norm(x, gamma, beta, axes=1).op == Op.get("relax.nn.layer_norm") + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + assert relax.op.nn.cross_entropy_without_logits(x, y).op == Op.get( + "relax.nn.cross_entropy_without_logits" + ) + assert relax.op.nn.cross_entropy_with_logits(x, y).op == Op.get( + "relax.nn.cross_entropy_with_logits" + ) + + x = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + y = relax.Var("y", R.Tensor((3, 10, 10), "int64")) + w = relax.Var("w", R.Tensor((5,), "float32")) + assert relax.op.nn.nll_loss(x, y, w).op == Op.get("relax.nn.nll_loss") + def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): ret = bb.normalize(call) @@ -117,7 +132,7 @@ def test_linear_unit_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.silu(x1)) -def test_softmax_infer_struct_info(): +def test_softmax_log_softmax_infer_struct_info(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) @@ -133,8 +148,20 @@ def test_softmax_infer_struct_info(): _check_inference(bb, relax.op.nn.softmax(x3, axis=-1), relax.TensorStructInfo((2, 3), dtype="")) _check_inference(bb, relax.op.nn.softmax(x4, axis=-2), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.nn.log_softmax(x1, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.nn.log_softmax(x2, axis=1), relax.TensorStructInfo(dtype="float32") + ) + _check_inference( + bb, relax.op.nn.log_softmax(x3, axis=-1), relax.TensorStructInfo((2, 3), dtype="") + ) + _check_inference(bb, relax.op.nn.log_softmax(x4, axis=-2), relax.TensorStructInfo(dtype="")) + -def test_softmax_infer_struct_info_shape_symbolic(): +def test_softmax_log_softmax_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() m = tir.Var("m", "int64") n = tir.Var("n", "int64") @@ -144,8 +171,13 @@ def test_softmax_infer_struct_info_shape_symbolic(): _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((m, n), "float32")) _check_inference(bb, relax.op.nn.softmax(x1, axis=0), relax.TensorStructInfo((4, n), "float32")) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference( + bb, relax.op.nn.log_softmax(x1, axis=0), relax.TensorStructInfo((4, n), "float32") + ) + -def test_softmax_infer_struct_info_shape_var(): +def test_softmax_log_softmax_infer_struct_info_shape_var(): bb = relax.BlockBuilder() s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) s1 = relax.Var("s", relax.ShapeStructInfo()) @@ -155,8 +187,11 @@ def test_softmax_infer_struct_info_shape_var(): _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo(s0, "float32")) _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.nn.log_softmax(x1), relax.TensorStructInfo(s1, "float32")) -def test_softmax_infer_struct_info_more_input_dtype(): + +def test_softmax_log_softmax_infer_struct_info_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float16")) x1 = relax.Var("x", R.Tensor((2, 3), "float64")) @@ -164,8 +199,11 @@ def test_softmax_infer_struct_info_more_input_dtype(): _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 3), "float16")) _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.nn.log_softmax(x1), relax.TensorStructInfo((2, 3), "float64")) + -def test_softmax_infer_struct_info_invalid_input_dtype(): +def test_softmax_log_softmax_infer_struct_info_invalid_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "int8")) x1 = relax.Var("x", R.Tensor((2, 3), "int64")) @@ -174,26 +212,40 @@ def test_softmax_infer_struct_info_invalid_input_dtype(): bb.normalize(relax.op.nn.softmax(x0)) with pytest.raises(TVMError): bb.normalize(relax.op.nn.softmax(x1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x1)) -def test_softmax_infer_struct_info_axis_out_of_range(): +def test_softmax_log_softmax_infer_struct_info_axis_out_of_range(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + with pytest.raises(TVMError): bb.normalize(relax.op.nn.softmax(x, axis=3)) with pytest.raises(TVMError): bb.normalize(relax.op.nn.softmax(x, axis=-4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x, axis=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x, axis=-4)) -def test_softmax_wrong_with_multiple_axes(): +def test_softmax_log_softmax_wrong_with_multiple_axes(): x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + with pytest.raises(TVMError): relax.op.nn.softmax(x, axis=[1, 2]) with pytest.raises(TVMError): relax.op.nn.softmax(x, axis=[-1, -2, -3]) + with pytest.raises(TVMError): + relax.op.nn.log_softmax(x, axis=[1, 2]) + with pytest.raises(TVMError): + relax.op.nn.log_softmax(x, axis=[-1, -2, -3]) -def test_softmax_infer_struct_info_wrong_input_type(): +def test_softmax_log_softmax_infer_struct_info_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) @@ -202,6 +254,10 @@ def test_softmax_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.softmax(x0)) with pytest.raises(TVMError): bb.normalize(relax.op.nn.softmax(x1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x1)) def test_batch_norm_infer_struct_info(): @@ -925,5 +981,576 @@ def test_dropout_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.dropout(x1)) +def test_cross_entropy_infer_struct_info(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=2)) + y2 = relax.Var("y", R.Tensor((2, 3))) + y3 = relax.Var("y", R.Tensor(ndim=2)) + + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x, y0), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, + relax.op.nn.cross_entropy_without_logits(x, y1), + relax.TensorStructInfo((), dtype="float32"), + ) + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x, y2), relax.TensorStructInfo((), dtype="") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x, y3), relax.TensorStructInfo((), dtype="") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y0), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, + relax.op.nn.cross_entropy_with_logits(x, y1), + relax.TensorStructInfo((), dtype="float32"), + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y2), relax.TensorStructInfo((), dtype="") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y3), relax.TensorStructInfo((), dtype="") + ) + + +def test_cross_entropy_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m0 = tir.Var("m", "int64") + m1 = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m0, n), "float32")) + x1 = relax.Var("x", R.Tensor((m1, n), "float32")) + y = relax.Var("y", R.Tensor((m0, n), "float32")) + + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x0, y), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x1, y), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x0, y), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x1, y), relax.TensorStructInfo((), "float32") + ) + + +def test_cross_entropy_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + x = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + y0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + y1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x, y0), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x, y1), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y0), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y1), relax.TensorStructInfo((), "float32") + ) + + +def test_cross_entropy_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float16")) + y0 = relax.Var("y", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + y1 = relax.Var("y", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int32")) + y2 = relax.Var("y", R.Tensor((2, 3), "int32")) + + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x0, y0), relax.TensorStructInfo((), "float16") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x1, y1), relax.TensorStructInfo((), "int8") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x0, y0), relax.TensorStructInfo((), "float16") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x1, y1), relax.TensorStructInfo((), "int8") + ) + + +def test_cross_entropy_infer_struct_info_wrong_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x2 = relax.Var("x", R.Tensor((2,), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=4)) + y2 = relax.Var("y", R.Tensor("float32", ndim=-1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_without_logits(x1, y0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_without_logits(x0, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x1, y0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x0, y1)) + + +def test_cross_entropy_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor((m, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 4), "float32")) + y1 = relax.Var("y", R.Tensor((m + 2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_without_logits(x0, y0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x0, y0)) + + +def test_cross_entropy_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_without_logits(x0, y)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_without_logits(x1, y)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x0, y)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x1, y)) + + +def test_nll_loss_infer_struct_info(): + bb = relax.BlockBuilder() + + x0 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((3, 5, 10, 10))) + x4 = relax.Var("x", R.Tensor((3, 5), "float32")) # (N, C) + x5 = relax.Var("x", R.Tensor((5,), "float32")) # (C,) + + y0 = relax.Var("y", R.Tensor((3, 10, 10), "int64")) + y1 = relax.Var("y", R.Tensor("int64", ndim=3)) + y2 = relax.Var("y", R.Tensor("int64")) + y3 = relax.Var("y", R.Tensor((3, 10, 10))) + y4 = relax.Var("y", R.Tensor((3,))) # (N,) + y5 = relax.Var("y", R.Tensor(())) # () + + w0 = relax.Var("w", R.Tensor((5,), "float32")) + w1 = relax.Var("w", R.Tensor("float32", ndim=1)) + w2 = relax.Var("w", R.Tensor("float32")) + w3 = relax.Var("w", R.Tensor((5,))) + + # reduction = mean + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x1, y0, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x2, y0, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x3, y0, w0, reduction="mean"), + relax.TensorStructInfo((), ""), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y1, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y2, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y3, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w1, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w2, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w3, reduction="mean"), + relax.TensorStructInfo((), ""), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x4, y4, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x5, y5, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + + # reduction=sum is totally the same as mean. Just need one test to ensure they behave the same + _check_inference( + bb, relax.op.nn.nll_loss(x0, y0, w0, reduction="sum"), relax.TensorStructInfo((), "float32") + ) + + # reduction=none + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w0, reduction="none"), + relax.TensorStructInfo((3, 10, 10), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x1, y0, w0, reduction="none"), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x2, y0, w0, reduction="none"), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x3, y0, w0, reduction="none"), + relax.TensorStructInfo((3, 10, 10), ""), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y1, w0, reduction="none"), + relax.TensorStructInfo((3, 10, 10), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y2, w0, reduction="none"), + relax.TensorStructInfo((3, 10, 10), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y3, w0, reduction="none"), + relax.TensorStructInfo((3, 10, 10), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w1, reduction="none"), + relax.TensorStructInfo((3, 10, 10), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w2, reduction="none"), + relax.TensorStructInfo((3, 10, 10), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w3, reduction="none"), + relax.TensorStructInfo((3, 10, 10), ""), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x4, y4, w0, reduction="none"), + relax.TensorStructInfo((3,), "float32"), # (N,) + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x5, y5, w0, reduction="none"), + relax.TensorStructInfo((), "float32"), # () + ) + + +def test_nll_loss_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + N = tir.Var("N", "int64") + C = tir.Var("C", "int64") + d1 = tir.Var("d", "int64") + d2 = tir.Var("d", "int64") + x0 = relax.Var("x", R.Tensor((N, C, d1, d2), "float32")) + x1 = relax.Var("x", R.Tensor((N, C), "float32")) + x2 = relax.Var("x", R.Tensor((C,), "float32")) + x3 = relax.Var("x", R.Tensor((3, C, d1, 2), "float32")) + y0 = relax.Var("y", R.Tensor((N, d1, d2), "int64")) + y1 = relax.Var("y", R.Tensor((N,), "int64")) + y2 = relax.Var("y", R.Tensor((), "int64")) + y3 = relax.Var("y", R.Tensor((3, d1, 2), "int64")) + w0 = relax.Var("w", R.Tensor((C,), "float32")) + w1 = relax.Var("w", R.Tensor((5,), "float32")) + + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w0, reduction="none"), + relax.TensorStructInfo((N, d1, d2), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x1, y1, w0, reduction="none"), + relax.TensorStructInfo((N,), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x2, y2, w0, reduction="none"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x3, y3, w0, reduction="none"), + relax.TensorStructInfo((3, d1, 2), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x3, y3, w1, reduction="none"), + relax.TensorStructInfo((3, d1, 2), "float32"), + ) + + +def test_nll_loss_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + + s0 = relax.Var("s0", relax.ShapeStructInfo((3, 5, 10, 10))) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s2", relax.ShapeStructInfo()) + s3 = relax.Var("s3", relax.ShapeStructInfo((3, 10, 10))) + s4 = relax.Var("s4", relax.ShapeStructInfo(ndim=3)) + s5 = relax.Var("s5", relax.ShapeStructInfo((5,))) + s6 = relax.Var("s6", relax.ShapeStructInfo(ndim=1)) + + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(s3, "int64")) + y1 = relax.Var("y", relax.TensorStructInfo(s4, "int64")) + w0 = relax.Var("w", relax.TensorStructInfo(s5, "float32")) + w1 = relax.Var("w", relax.TensorStructInfo(s6, "float32")) + + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w0, reduction="none"), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x1, y0, w0, reduction="none"), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x2, y0, w0, reduction="none"), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y1, w0, reduction="none"), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w1, reduction="none"), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + +def test_nll_loss_infer_struct_info_no_weights(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + y = relax.Var("x", R.Tensor((3, 10, 10), "int64")) + + _check_inference( + bb, + relax.op.nn.nll_loss(x, y, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x, y, reduction="none"), + relax.TensorStructInfo((3, 10, 10), "float32"), + ) + + +def test_nll_loss_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + y0 = relax.Var("y", R.Tensor((3, 10, 10), "int64")) + y1 = relax.Var("y", relax.ShapeStructInfo((2, 3))) + y2 = relax.Var("y", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + w0 = relax.Var("w", R.Tensor((5,), "float32")) + w1 = relax.Var("w", relax.ShapeStructInfo((2, 3))) + w2 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x1, y0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x2, y0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y1, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y2, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y0, w2)) + + +def test_nll_loss_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float16")) + x1 = relax.Var("x", R.Tensor((3, 5, 10, 10), "int8")) + x2 = relax.Var("x", R.Tensor((3, 5, 10, 10), "int32")) + x3 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float64")) + y0 = relax.Var("y", R.Tensor((3, 10, 10), "int8")) + w0 = relax.Var("y", R.Tensor((5,), "float16")) + w1 = relax.Var("y", R.Tensor((5,), "int8")) + w2 = relax.Var("y", R.Tensor((5,), "int32")) + w3 = relax.Var("y", R.Tensor((5,), "float64")) + + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w0, reduction="mean"), + relax.TensorStructInfo((), "float16"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x1, y0, w1, reduction="mean"), + relax.TensorStructInfo((), "int8"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x2, y0, w2, reduction="mean"), + relax.TensorStructInfo((), "int32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x3, y0, w3, reduction="mean"), + relax.TensorStructInfo((), "float64"), + ) + + +def test_nll_loss_infer_struct_info_targets_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + w = relax.Var("w", R.Tensor((5,), "float32")) + targets0 = relax.Var("targets", R.Tensor((3, 10, 10), "float32")) + targets1 = relax.Var("targets", R.Tensor((3, 10, 10), "float64")) + targets2 = relax.Var("targets", R.Tensor((3, 10, 10), "bool")) + targets3 = relax.Var("targets", R.Tensor((3, 10, 10), "int32")) + targets4 = relax.Var("targets", R.Tensor((3, 10, 10), "int64")) + targets5 = relax.Var("targets", R.Tensor((3, 10, 10), "uint32")) + targets6 = relax.Var("targets", R.Tensor((3, 10, 10), "")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x, targets0, w)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x, targets1, w)) + + # correct cases + bb.normalize(relax.op.nn.nll_loss(x, targets2, w)) # bool is uint1 + bb.normalize(relax.op.nn.nll_loss(x, targets3, w)) + bb.normalize(relax.op.nn.nll_loss(x, targets4, w)) + bb.normalize(relax.op.nn.nll_loss(x, targets5, w)) + bb.normalize(relax.op.nn.nll_loss(x, targets6, w)) # unknwon dtype + + +def test_nll_loss_infer_struct_info_ndim_mismatch(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + x1 = relax.Var("x", R.Tensor((3, 5, 10, 10, 10), "float32")) + x2 = relax.Var("x", R.Tensor((3, 5, 10), "float32")) + y0 = relax.Var("x", R.Tensor((3, 10, 10), "int64")) + y1 = relax.Var("x", R.Tensor((3, 10, 10, 10), "int64")) + y2 = relax.Var("x", R.Tensor((3, 10), "int64")) + w0 = relax.Var("w", R.Tensor((5,), "float32")) + w1 = relax.Var("w", R.Tensor((5, 5), "float32")) + w2 = relax.Var("w", R.Tensor((), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x1, y0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x2, y0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y1, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y2, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y0, w2)) + + +def test_nll_loss_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + x1 = relax.Var("x", R.Tensor((3, 6, 10, 10), "float32")) + x2 = relax.Var("x", R.Tensor((4, 5, 10, 10), "float32")) + x3 = relax.Var("x", R.Tensor((3, 5, 11, 10), "float32")) + y0 = relax.Var("x", R.Tensor((3, 10, 10), "int64")) + y1 = relax.Var("x", R.Tensor((4, 10, 10), "int64")) + y2 = relax.Var("x", R.Tensor((3, 11, 10), "int64")) + w0 = relax.Var("w", R.Tensor((5,), "float32")) + w1 = relax.Var("w", R.Tensor((4,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x1, y0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x2, y0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x3, y0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y1, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y2, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y0, w1)) + + +def test_nll_loss_infer_struct_info_wrong_reduction(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + y = relax.Var("x", R.Tensor((3, 10, 10), "int64")) + w = relax.Var("w", R.Tensor((5,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x, y, w, reduction="foo")) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py b/tests/python/relax/test_tvmscript_parser_op_nn.py index 6114eb04f3..e0e8bfee9d 100644 --- a/tests/python/relax/test_tvmscript_parser_op_nn.py +++ b/tests/python/relax/test_tvmscript_parser_op_nn.py @@ -114,6 +114,21 @@ def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): _check(foo, bb.get()["foo"]) +def test_log_softmax(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.log_softmax(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.log_softmax(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + def test_batch_norm(): @R.function def foo( @@ -188,5 +203,85 @@ def foo( _check(foo, bb.get()["foo"]) +def test_cross_entropy_without_logits(): + @R.function + def foo( + predictions: R.Tensor((2, 3), "float32"), labels: R.Tensor((2, 3), "float32") + ) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.nn.cross_entropy_without_logits(predictions, labels) + return gv + + predictions = relax.Var("predictions", R.Tensor((2, 3), "float32")) + labels = relax.Var("labels", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [predictions, labels]): + gv = bb.emit(relax.op.nn.cross_entropy_without_logits(predictions, labels)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_cross_entropy_with_logits(): + @R.function + def foo( + predictions: R.Tensor((2, 3), "float32"), labels: R.Tensor((2, 3), "float32") + ) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.nn.cross_entropy_with_logits(predictions, labels) + return gv + + predictions = relax.Var("predictions", R.Tensor((2, 3), "float32")) + labels = relax.Var("labels", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [predictions, labels]): + gv = bb.emit(relax.op.nn.cross_entropy_with_logits(predictions, labels)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_nll_loss(): + @R.function + def foo( + predictions: R.Tensor((3, 5, 10, 10), dtype="float32"), + targets: R.Tensor((3, 10, 10), dtype="int64"), + weights: R.Tensor((5,), dtype="float32"), + ) -> R.Tensor((), dtype="float32"): + gv: R.Tensor((), dtype="float32") = R.nn.nll_loss(predictions, targets, weights, "mean", -1) + return gv + + predictions = relax.Var("predictions", R.Tensor((3, 5, 10, 10), "float32")) + targets = relax.Var("targets", R.Tensor((3, 10, 10), "int64")) + weights = relax.Var("weights", R.Tensor((5,), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [predictions, targets, weights]): + gv = bb.emit( + relax.op.nn.nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-1) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_nll_loss_no_weights(): + @R.function + def foo( + predictions: R.Tensor((3, 5, 10, 10), dtype="float32"), + targets: R.Tensor((3, 10, 10), dtype="int64"), + ) -> R.Tensor((), dtype="float32"): + gv: R.Tensor((), dtype="float32") = R.nn.nll_loss( + predictions, targets, reduction="mean", ignore_index=-1 + ) + return gv + + predictions = relax.Var("predictions", R.Tensor((3, 5, 10, 10), "float32")) + targets = relax.Var("targets", R.Tensor((3, 10, 10), "int64")) + bb = relax.BlockBuilder() + with bb.function("foo", [predictions, targets]): + gv = bb.emit(relax.op.nn.nll_loss(predictions, targets, reduction="mean", ignore_index=-1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + if __name__ == "__main__": tvm.testing.main()