diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 4b4e0680e2..ae102ad0eb 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -184,6 +184,19 @@ struct DropoutAttrs : public tvm::AttrsNode { } }; // struct DropoutAttrs +/*! \brief Attributes used in nll_loss operator */ +struct NLLLossAttrs : public tvm::AttrsNode { + String reduction; + int ignore_index; + + TVM_DECLARE_ATTRS(NLLLossAttrs, "relax.attrs.NLLLossAttrs") { + TVM_ATTR_FIELD(reduction).set_default("mean").describe( + "The reduction method to apply to the output. Can be" + "'none', 'mean' or 'sum'."); + TVM_ATTR_FIELD(ignore_index).describe("The target value to ignore."); + } +}; // struct NLLLossAttrs + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index a62cc7f997..fe3d27b183 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -327,7 +327,9 @@ def silu(data: Expr) -> Expr: def softmax(data: Expr, axis: int = -1) -> Expr: r"""Computes softmax. - .. math:: text{softmax}(x)_i = frac{exp(x_i)}{\sum_j exp(x_j)} + .. math:: + + \text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} Parameters ---------- @@ -351,6 +353,34 @@ def softmax(data: Expr, axis: int = -1) -> Expr: return _ffi_api.softmax(data, axis) # type: ignore +def log_softmax(data: Expr, axis: int = -1) -> Expr: + r"""Computes log softmax. + + .. math:: + + \text{log\_softmax}(x_i) = \log\left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}\right) + + .. note:: + This operator can be optimized away for inference. + + Parameters + ---------- + data: relax.Expr + The input data to the operator. + + axis: int + The axis to sum over when computing log softmax. + If not specified, it is by default the last axis of the input tensor. + Supports negative indexing. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.log_softmax(data, axis) # type: ignore + + def batch_norm( data: Expr, gamma: Expr, @@ -525,3 +555,98 @@ def dropout(data: Expr, rate: float = 0.5) -> Expr: mask tensor (1.0 where element not dropped, 0.0 where dropped) """ return _ffi_api.dropout(data, rate) # type: ignore + + +def cross_entropy_without_logits(predictions: Expr, labels: Expr) -> Expr: + r"""CrossEntropy without logits between the predictions and labels. + + The shape of predictions and labels must be the same. And when ndim >= 2, + the first dimension is regarded as the batch_size N. In this case the + computed result will divide by N to perform a mean reduction. + + .. math:: + + \text{cross\_entropy\_without\_logits}(x_i, y_i) = \frac{\sum_i -y_i \log x_i}{N} + + Parameters + ---------- + predictions : relax.Expr + The predictions. + + labels : relax.Expr + The labels (the ground truth values). + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.cross_entropy_without_logits(predictions, labels) # type: ignore + + +def cross_entropy_with_logits(predictions: Expr, labels: Expr) -> Expr: + r"""CrossEntropy with logits between the predictions and labels. + + The shape issue is the same with cross_entropy_without_logits. + + .. math:: + + \text{cross\_entropy\_with\_logits}(x_i, y_i) = \frac{\sum_i -x_i \cdot y_i}{N} + + Parameters + ---------- + predictions : relax.Expr + The predictions. + + labels : relax.Expr + The labels (the ground truth values). + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.cross_entropy_with_logits(predictions, labels) # type: ignore + + +def nll_loss( + predictions: Expr, + targets: Expr, + weights: Optional[Expr] = None, + reduction: str = "mean", + ignore_index: int = -100, +) -> Expr: + """Negative log likelihood loss. + + `output[n, i_1, i_2, ..., i_k] = -p * w`, where + - `p = predictions[n, t, i_1, i_2, i_k]`, + - `t = targets[n, i_1, i_2, ..., i_k]`, + - `w = weights[n, i_1, i_2, ..., i_k] if t != ignore_index else 0` + + result = reduction(output) + + Parameters + ---------- + predictions : relax.Expr + The predictions. + + targets : relax.Expr + The target value of each prediction. + + weights : Optional[relax.Expr] + The weight of each target value. + If not specified, it is treated as if having all ones. + + reduction : string + The reduction method to apply to the output. + Possible values are "mean", "sum" and "none". + + ignore_index : int + The target value to ignore. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.nll_loss(predictions, targets, weights, reduction, ignore_index) # type: ignore diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 1fa70dbdb2..fe9c678480 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -71,6 +71,22 @@ TVM_REGISTER_OP("relax.nn.softmax") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoSoftmax); +/* relax.nn.log_softmax */ +Expr log_softmax(Expr data, int axis) { + auto attrs = make_object(); + attrs->axis = axis; + static const Op& op = Op::Get("relax.nn.log_softmax"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.log_softmax").set_body_typed(log_softmax); + +TVM_REGISTER_OP("relax.nn.log_softmax") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoSoftmax); + bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, const Array& input_sinfo, Array axes) { Op op = Downcast(call->op); @@ -246,5 +262,304 @@ TVM_REGISTER_OP("relax.nn.dropout") .set_attr("FInferStructInfo", InferStructInfoDropout) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); +// infer structinfo for CrossEntropyWithoutLogits and CrossEntropyWithLogits +StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo pred_sinfo = input_sinfo[0]; + TensorStructInfo label_sinfo = input_sinfo[1]; + + // infer dtype + DataType dtype = InferBinaryArithOpOutDtype(call, ctx, pred_sinfo, label_sinfo); + + // infer ndim + if (!pred_sinfo->IsUnknownNdim() && !label_sinfo->IsUnknownNdim() && + pred_sinfo->ndim != label_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "CrossEntropy requires predictions and labels to have the same ndim. " + "However, the ndim of predictions is " + << pred_sinfo->ndim << " while the ndim of labels is " << label_sinfo->ndim); + } + + Optional> pred_shape_value; + if (pred_sinfo->shape.defined()) { + pred_shape_value = GetStructInfoAs(pred_sinfo->shape.value())->values; + } + + Optional> label_shape_value; + if (label_sinfo->shape.defined()) { + label_shape_value = GetStructInfoAs(label_sinfo->shape.value())->values; + } + + if (pred_shape_value.defined() && label_shape_value.defined()) { + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + for (size_t i = 0; i < pred_shape_value.value().size(); ++i) { + if (analyzer->CanProve(pred_shape_value.value()[i] != label_shape_value.value()[i])) { + ctx->ReportFatal(Diagnostic::Error(call) + << "CrossEntropy requires the predictions and labels to have " + "the same shape. However, the shape of predictions at dim " + << i << " is" << pred_shape_value.value()[i] + << " while the shape of labels at this dim is " + << label_shape_value.value()[i]); + } + } + } + return TensorStructInfo(ShapeExpr(Array()), dtype); +} + +/* relax.nn.cross_entropy_without_logits */ +Expr cross_entropy_without_logits(Expr predictions, Expr labels) { + static const Op& op = Op::Get("relax.nn.cross_entropy_without_logits"); + return Call(op, {std::move(predictions), std::move(labels)}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.cross_entropy_without_logits") + .set_body_typed(cross_entropy_without_logits); + +TVM_REGISTER_OP("relax.nn.cross_entropy_without_logits") + .set_num_inputs(2) + .add_argument("predictions", "Tensor", "The predictions.") + .add_argument("labels", "Tensor", "The labels.") + .set_attr("FInferStructInfo", InferStructInfoCrossEntropy); + +/* relax.nn.cross_entropy_with_logits */ +Expr cross_entropy_with_logits(Expr predictions, Expr labels) { + static const Op& op = Op::Get("relax.nn.cross_entropy_with_logits"); + return Call(op, {std::move(predictions), std::move(labels)}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.cross_entropy_with_logits") + .set_body_typed(cross_entropy_with_logits); + +TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits") + .set_num_inputs(2) + .add_argument("predictions", "Tensor", "The predictions.") + .add_argument("labels", "Tensor", "The labels.") + .set_attr("FInferStructInfo", InferStructInfoCrossEntropy); + +/* relax.nn.nll_loss */ +TVM_REGISTER_NODE_TYPE(NLLLossAttrs); + +Expr nll_loss(Expr predictions, Expr targets, Optional weights, String reduction, + int ignore_index) { + ObjectPtr attrs = make_object(); + + ICHECK(reduction == "none" || reduction == "sum" || reduction == "mean") + << "The argument reduction of NLLLoss should be one of the following " + "values: none, mean, sum. However, the given value is " + << reduction; + + attrs->reduction = std::move(reduction); + attrs->ignore_index = ignore_index; + + static const Op& op = Op::Get("relax.nn.nll_loss"); + if (weights.defined()) { + return Call(op, {std::move(predictions), std::move(targets), std::move(weights.value())}, + Attrs{attrs}, {}); + } else { + return Call(op, {std::move(predictions), std::move(targets)}, Attrs{attrs}, {}); + } +} + +TVM_REGISTER_GLOBAL("relax.op.nn.nll_loss").set_body_typed(nll_loss); + +StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() < 2 || call->args.size() > 3) { + ctx->ReportFatal(Diagnostic::Error(call) << "NLLLoss op should take 2 or 3 arguments"); + } + + const auto* pred_sinfo = GetStructInfoAs(call->args[0]); + const auto* tgt_sinfo = GetStructInfoAs(call->args[1]); + const TensorStructInfoNode* wgt_sinfo = nullptr; + if (call->args.size() == 3) { + wgt_sinfo = GetStructInfoAs(call->args[2]); + if (wgt_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "NLLLoss requires the argument weights to be Tensor. However, the given one is " + << call->args[1]->struct_info_->GetTypeKey()); + } + } + + if (pred_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "NLLLoss requires the argument preditions to be Tensor. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (tgt_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "NLLLoss requires the argument targets to be Tensor. However, the given one is " + << call->args[2]->struct_info_->GetTypeKey()); + } + + // infer dtype + DataType output_dtype; + if (wgt_sinfo != nullptr) { + output_dtype = InferBinaryArithOpOutDtype(call, ctx, GetRef(pred_sinfo), + GetRef(wgt_sinfo)); + } else { + output_dtype = pred_sinfo->dtype; + } + + // the type of targets must be int/uint. + if (!tgt_sinfo->IsUnknownDtype() && !tgt_sinfo->dtype.is_int() && !tgt_sinfo->dtype.is_uint()) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "NLLLoss expects the dtype of targets to be int/uint. However, the dtype of targets is " + << tgt_sinfo->dtype); + } + + // infer ndim + int K = kUnknownNDim; // k dim + if (!pred_sinfo->IsUnknownNdim()) { + if (pred_sinfo->ndim < 1) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "NLLLoss expects the ndim of predictions >= 1. However, the ndim of predictions is " + << pred_sinfo->ndim); + } + K = pred_sinfo->ndim <= 2 ? 0 : pred_sinfo->ndim - 2; + } + if (!tgt_sinfo->IsUnknownNdim()) { + int K_tgt = tgt_sinfo->ndim <= 1 ? 0 : tgt_sinfo->ndim - 1; + if (K != kUnknownNDim && K != K_tgt) { + ctx->ReportFatal(Diagnostic::Error(call) + << "NLLLoss expects number of dimensions K inferred from different " + "arguments to be equal. However, K from predictions is " + << K << " while K from targets is " << K_tgt); + } + } + if (wgt_sinfo != nullptr && !wgt_sinfo->IsUnknownNdim() && wgt_sinfo->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "NLLLoss expects the ndim of weights == 1. However, the ndim of weights is " + << wgt_sinfo->ndim); + } + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + Optional N; + Optional C; + Array output_shape; // N, d1, d2, ..., dk + + Optional> pred_shape_value; + if (pred_sinfo->shape.defined()) { + pred_shape_value = GetStructInfoAs(pred_sinfo->shape.value())->values; + } + if (pred_shape_value.defined()) { + if (pred_shape_value.value().size() == 1) { + // (C,) + ICHECK(pred_sinfo->ndim == 1); + C = pred_shape_value.value()[0]; + } else { + // (N, C, d1, d2, ..., dk) + ICHECK(pred_shape_value.value().size() >= 2); + ICHECK(pred_sinfo->ndim == static_cast(pred_shape_value.value().size())); + N = pred_shape_value.value()[0]; + C = pred_shape_value.value()[1]; + output_shape = Array(); + output_shape.push_back(N.value()); + for (size_t i = 2; i < pred_shape_value.value().size(); ++i) { + output_shape.push_back(pred_shape_value.value()[i]); + } + } + } + + Optional> tgt_shape_value; + if (tgt_sinfo->shape.defined()) { + tgt_shape_value = GetStructInfoAs(tgt_sinfo->shape.value())->values; + } + if (tgt_shape_value.defined()) { + if (tgt_shape_value.value().empty()) { + // () + ICHECK(tgt_sinfo->ndim == 0); + if (N.defined()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Shape mismatch for NLLLoss. Predictions shape is " + "(N, C, ...) while targets is a scalar"); + } + } else { + // (N,) or (N, d1, d2, ..., dk) + // check N + const PrimExpr& N_tgt = tgt_shape_value.value()[0]; + if (N.defined() && analyzer->CanProve(N.value() != N_tgt)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "NLLLoss expects minibatch size N inferred from different " + "arguments to be equal. However, N from predictions is " + << N << " while N from targets is " << N_tgt); + } + // only C case + if (!N.defined() && C.defined()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Shape mismatch for NLLLoss. Predictions shape is " + "(C,) while targets is not a scalar"); + } + + if (tgt_shape_value.value().size() == 1) { + // (N,) + ICHECK(tgt_sinfo->IsUnknownNdim() || tgt_sinfo->ndim == 1); + } else { + // (N, d1, d2, ..., dk) + ICHECK(tgt_shape_value.value().size() >= 2); + ICHECK(tgt_sinfo->IsUnknownNdim() || + tgt_sinfo->ndim == static_cast(tgt_shape_value.value().size())); + + if (pred_shape_value.defined()) { + // check (d1, d2, ..., dk) + for (size_t i = 1; i < tgt_shape_value.value().size(); ++i) { + if (analyzer->CanProve(output_shape[i] != tgt_shape_value.value()[i])) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Shape mismatch for NLLLoss. The prediction shape at this dim is " + << output_shape[i] << " while the target shape at this dim is " + << tgt_shape_value.value()[i]); + } + } + } + } + } + } + + if (wgt_sinfo != nullptr) { + Optional> wgt_shape_value; + if (wgt_sinfo->shape.defined()) { + wgt_shape_value = GetStructInfoAs(wgt_sinfo->shape.value())->values; + } + if (wgt_shape_value.defined()) { + ICHECK(wgt_shape_value.value().size() == 1); + ICHECK(wgt_sinfo->IsUnknownNdim() || wgt_sinfo->ndim == 1); + const PrimExpr& C_wgt = wgt_shape_value.value()[0]; + if (C.defined() && analyzer->CanProve(C.value() != C_wgt)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "NLLLoss expects number of classes C inferred from different " + "arguments to be equal. However, C from predictions is " + << C << " while C from weights is " << C_wgt); + } + } + } + + const auto* attrs = call->attrs.as(); + String reduction = attrs->reduction; + + if (reduction == "none") { + // () or (N,) or (N, d1, d2, ..., dk) + if (pred_sinfo->shape.as()) { + return TensorStructInfo(ShapeExpr(output_shape), output_dtype); + } else { + int output_ndim = pred_sinfo->ndim == kUnknownNDim ? kUnknownNDim : pred_sinfo->ndim - 1; + return TensorStructInfo(output_dtype, /*ndim=*/output_ndim); + } + } else { + // sum or mean. output is scalar + return TensorStructInfo(/*shape=*/ShapeExpr(Array()), output_dtype); + } +} + +TVM_REGISTER_OP("relax.nn.nll_loss") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("predictions", "Tensor", "The prediction tensor.") + .add_argument("targets", "Tensor", "The target tensor.") + .add_argument("weights", "Optional", "The weight of each target values.") + .set_attr("FInferStructInfo", InferStructInfoNLLLoss); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index 39f26a7237..ba34f5bb1f 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -44,6 +44,9 @@ Expr silu(Expr data); /*! \brief Softmax function. */ Expr softmax(Expr data, int axis); +/*! \brief LogSoftmax function. */ +Expr log_softmax(Expr data, int axis); + /*! \brief Compute batch normalization. */ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // int axis, double epsilon, bool center, bool scale); @@ -62,6 +65,16 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double ep */ Expr dropout(Expr data, double rate); +/*! \brief CrossEntropy without logits. */ +Expr cross_entropy_without_logits(Expr predictions, Expr labels); + +/*! \brief CrossEntropy with logits. */ +Expr cross_entropy_with_logits(Expr predictions, Expr labels); + +/*! \brief Negative log likelihood loss. */ +Expr nll_loss(Expr predictions, Expr targets, Optional weights, String reduction, + int ignore_index); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index c952a5395d..7974839ad2 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -33,6 +33,11 @@ def test_op_correctness(): assert relax.op.reshape(x, (4, 5, 3)).op == Op.get("relax.reshape") assert relax.op.split(x, indices_or_sections=1).op == Op.get("relax.split") assert relax.op.squeeze(x).op == Op.get("relax.squeeze") + assert relax.op.broadcast_to(x, (3, 3, 4, 5)).op == Op.get("relax.broadcast_to") + assert relax.op.collapse_sum_to(x, (4, 5)).op == Op.get("relax.collapse_sum_to") + + y = relax.Var("x", R.Tensor((4, 5), "float32")) + assert relax.op.collapse_sum_like(x, y).op == Op.get("relax.collapse_sum_like") def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): @@ -2286,12 +2291,12 @@ def test_collapse_sum_like_infer_struct_info_shape_symbolic(): def test_collapse_sum_like_infer_struct_info_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s2 = relax.Var("s", relax.ShapeStructInfo()) - s3 = relax.Var("s", relax.ShapeStructInfo((3, 4))) - s4 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) - s5 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s2", relax.ShapeStructInfo()) + s3 = relax.Var("s3", relax.ShapeStructInfo((3, 4))) + s4 = relax.Var("s4", relax.ShapeStructInfo(ndim=2)) + s5 = relax.Var("s5", relax.ShapeStructInfo()) x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) @@ -2317,7 +2322,7 @@ def test_collapse_sum_like_infer_struct_info_more_input_dtype(): _check_inference(bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo((3, 4), "int8")) -def test_collapse_sum_like_wrong_input_type(): +def test_collapse_sum_like_infer_struct_info_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) x1 = relax.Var("x", relax.ShapeStructInfo((4, 5))) @@ -2330,7 +2335,7 @@ def test_collapse_sum_like_wrong_input_type(): bb.normalize(relax.op.collapse_sum_like(x2, x0)) -def test_collapse_sum_like_check_shape_failure(): +def test_collapse_sum_like_infer_struct_info_shape_mismatch(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) y0 = relax.Var("y", R.Tensor((3, 6, 5), "float32")) @@ -2339,13 +2344,13 @@ def test_collapse_sum_like_check_shape_failure(): x1 = relax.Var("z", R.Tensor((3, a, 5), "float32")) y1 = relax.Var("w", R.Tensor((3, b, 5), "float32")) - s0 = relax.Var("s", relax.ShapeStructInfo((3, 4, 5))) - s1 = relax.Var("s", relax.ShapeStructInfo((3, 6, 5))) + s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5))) + s1 = relax.Var("s1", relax.ShapeStructInfo((3, 6, 5))) x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) y2 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) - s2 = relax.Var("s", relax.ShapeStructInfo((3, a, 5))) - s3 = relax.Var("s", relax.ShapeStructInfo((3, b, 5))) + s2 = relax.Var("s2", relax.ShapeStructInfo((3, a, 5))) + s3 = relax.Var("s3", relax.ShapeStructInfo((3, b, 5))) x3 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) y3 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) @@ -2402,9 +2407,9 @@ def test_collapse_sum_to_infer_struct_info_shape_symbolic(): def test_collapse_sum_to_infer_struct_info_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s2 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s2", relax.ShapeStructInfo()) x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) @@ -2432,7 +2437,7 @@ def test_collapse_sum_to_infer_struct_info_more_input_dtype(): ) -def test_collapse_sum_to_wrong_input_type(): +def test_collapse_sum_to_infer_struct_info_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) x1 = relax.Var("x", relax.ShapeStructInfo((4, 5))) @@ -2448,17 +2453,17 @@ def test_collapse_sum_to_wrong_input_type(): bb.normalize(relax.op.collapse_sum_to(x1, x1)) -def test_collapse_sum_to_check_shape_failure(): +def test_collapse_sum_to_infer_struct_info_shape_mismatch(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) a = tir.Var("a", "int64") b = tir.Var("b", "int64") x1 = relax.Var("x", R.Tensor((3, a, 5), "float32")) - s0 = relax.Var("s", relax.ShapeStructInfo((3, 4, 5))) + s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5))) x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - s1 = relax.Var("s", relax.ShapeStructInfo((3, a, 5))) + s1 = relax.Var("s1", relax.ShapeStructInfo((3, a, 5))) x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) with pytest.raises(TVMError): @@ -2474,24 +2479,24 @@ def test_collapse_sum_to_check_shape_failure(): bb.normalize(relax.op.collapse_sum_to(x3, (3, b, 5))) -def test_collapse_sum_to_struct_info_tgt_shape_var(): +def test_collapse_sum_to_infer_struct_info_struct_info_tgt_shape_var(): bb = relax.BlockBuilder() a = tir.Var("a", "int64") b = tir.Var("b", "int64") c = tir.Var("c", "int64") d = tir.Var("d", "int64") - s0 = relax.Var("s", relax.ShapeStructInfo((3, a, b))) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) - s2 = relax.Var("s", relax.ShapeStructInfo()) + s0 = relax.Var("s0", relax.ShapeStructInfo((3, a, b))) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s2", relax.ShapeStructInfo()) x0 = relax.Var("x", R.Tensor((3, a, b), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) x2 = relax.Var("x", R.Tensor("")) x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - stgt0 = relax.Var("stgt", relax.ShapeStructInfo((a, b))) - stgt1 = relax.Var("stgt", relax.ShapeStructInfo(ndim=2)) - stgt2 = relax.Var("stgt", relax.ShapeStructInfo()) + stgt0 = relax.Var("stgt0", relax.ShapeStructInfo((a, b))) + stgt1 = relax.Var("stgt1", relax.ShapeStructInfo(ndim=2)) + stgt2 = relax.Var("stgt2", relax.ShapeStructInfo()) _check_inference( bb, relax.op.collapse_sum_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32") diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index d047448309..bcd720d1c6 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -29,6 +29,7 @@ def test_op_correctness(): assert relax.op.nn.gelu(x).op == Op.get("relax.nn.gelu") assert relax.op.nn.silu(x).op == Op.get("relax.nn.silu") assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax") + assert relax.op.nn.log_softmax(x).op == Op.get("relax.nn.log_softmax") assert relax.op.nn.dropout(x).op == Op.get("relax.nn.dropout") x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) @@ -41,6 +42,20 @@ def test_op_correctness(): ) assert relax.op.nn.layer_norm(x, gamma, beta, axes=1).op == Op.get("relax.nn.layer_norm") + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + assert relax.op.nn.cross_entropy_without_logits(x, y).op == Op.get( + "relax.nn.cross_entropy_without_logits" + ) + assert relax.op.nn.cross_entropy_with_logits(x, y).op == Op.get( + "relax.nn.cross_entropy_with_logits" + ) + + x = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + y = relax.Var("y", R.Tensor((3, 10, 10), "int64")) + w = relax.Var("w", R.Tensor((5,), "float32")) + assert relax.op.nn.nll_loss(x, y, w).op == Op.get("relax.nn.nll_loss") + def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): ret = bb.normalize(call) @@ -117,7 +132,7 @@ def test_linear_unit_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.silu(x1)) -def test_softmax_infer_struct_info(): +def test_softmax_log_softmax_infer_struct_info(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) @@ -133,8 +148,20 @@ def test_softmax_infer_struct_info(): _check_inference(bb, relax.op.nn.softmax(x3, axis=-1), relax.TensorStructInfo((2, 3), dtype="")) _check_inference(bb, relax.op.nn.softmax(x4, axis=-2), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.nn.log_softmax(x1, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.nn.log_softmax(x2, axis=1), relax.TensorStructInfo(dtype="float32") + ) + _check_inference( + bb, relax.op.nn.log_softmax(x3, axis=-1), relax.TensorStructInfo((2, 3), dtype="") + ) + _check_inference(bb, relax.op.nn.log_softmax(x4, axis=-2), relax.TensorStructInfo(dtype="")) + -def test_softmax_infer_struct_info_shape_symbolic(): +def test_softmax_log_softmax_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() m = tir.Var("m", "int64") n = tir.Var("n", "int64") @@ -144,8 +171,13 @@ def test_softmax_infer_struct_info_shape_symbolic(): _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((m, n), "float32")) _check_inference(bb, relax.op.nn.softmax(x1, axis=0), relax.TensorStructInfo((4, n), "float32")) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference( + bb, relax.op.nn.log_softmax(x1, axis=0), relax.TensorStructInfo((4, n), "float32") + ) + -def test_softmax_infer_struct_info_shape_var(): +def test_softmax_log_softmax_infer_struct_info_shape_var(): bb = relax.BlockBuilder() s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) s1 = relax.Var("s", relax.ShapeStructInfo()) @@ -155,8 +187,11 @@ def test_softmax_infer_struct_info_shape_var(): _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo(s0, "float32")) _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.nn.log_softmax(x1), relax.TensorStructInfo(s1, "float32")) -def test_softmax_infer_struct_info_more_input_dtype(): + +def test_softmax_log_softmax_infer_struct_info_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float16")) x1 = relax.Var("x", R.Tensor((2, 3), "float64")) @@ -164,8 +199,11 @@ def test_softmax_infer_struct_info_more_input_dtype(): _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 3), "float16")) _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.nn.log_softmax(x1), relax.TensorStructInfo((2, 3), "float64")) + -def test_softmax_infer_struct_info_invalid_input_dtype(): +def test_softmax_log_softmax_infer_struct_info_invalid_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "int8")) x1 = relax.Var("x", R.Tensor((2, 3), "int64")) @@ -174,26 +212,40 @@ def test_softmax_infer_struct_info_invalid_input_dtype(): bb.normalize(relax.op.nn.softmax(x0)) with pytest.raises(TVMError): bb.normalize(relax.op.nn.softmax(x1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x1)) -def test_softmax_infer_struct_info_axis_out_of_range(): +def test_softmax_log_softmax_infer_struct_info_axis_out_of_range(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + with pytest.raises(TVMError): bb.normalize(relax.op.nn.softmax(x, axis=3)) with pytest.raises(TVMError): bb.normalize(relax.op.nn.softmax(x, axis=-4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x, axis=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x, axis=-4)) -def test_softmax_wrong_with_multiple_axes(): +def test_softmax_log_softmax_wrong_with_multiple_axes(): x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + with pytest.raises(TVMError): relax.op.nn.softmax(x, axis=[1, 2]) with pytest.raises(TVMError): relax.op.nn.softmax(x, axis=[-1, -2, -3]) + with pytest.raises(TVMError): + relax.op.nn.log_softmax(x, axis=[1, 2]) + with pytest.raises(TVMError): + relax.op.nn.log_softmax(x, axis=[-1, -2, -3]) -def test_softmax_infer_struct_info_wrong_input_type(): +def test_softmax_log_softmax_infer_struct_info_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) @@ -202,6 +254,10 @@ def test_softmax_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.softmax(x0)) with pytest.raises(TVMError): bb.normalize(relax.op.nn.softmax(x1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x1)) def test_batch_norm_infer_struct_info(): @@ -925,5 +981,576 @@ def test_dropout_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.dropout(x1)) +def test_cross_entropy_infer_struct_info(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=2)) + y2 = relax.Var("y", R.Tensor((2, 3))) + y3 = relax.Var("y", R.Tensor(ndim=2)) + + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x, y0), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, + relax.op.nn.cross_entropy_without_logits(x, y1), + relax.TensorStructInfo((), dtype="float32"), + ) + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x, y2), relax.TensorStructInfo((), dtype="") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x, y3), relax.TensorStructInfo((), dtype="") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y0), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, + relax.op.nn.cross_entropy_with_logits(x, y1), + relax.TensorStructInfo((), dtype="float32"), + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y2), relax.TensorStructInfo((), dtype="") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y3), relax.TensorStructInfo((), dtype="") + ) + + +def test_cross_entropy_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m0 = tir.Var("m", "int64") + m1 = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m0, n), "float32")) + x1 = relax.Var("x", R.Tensor((m1, n), "float32")) + y = relax.Var("y", R.Tensor((m0, n), "float32")) + + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x0, y), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x1, y), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x0, y), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x1, y), relax.TensorStructInfo((), "float32") + ) + + +def test_cross_entropy_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + x = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + y0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + y1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x, y0), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x, y1), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y0), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y1), relax.TensorStructInfo((), "float32") + ) + + +def test_cross_entropy_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float16")) + y0 = relax.Var("y", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + y1 = relax.Var("y", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int32")) + y2 = relax.Var("y", R.Tensor((2, 3), "int32")) + + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x0, y0), relax.TensorStructInfo((), "float16") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_without_logits(x1, y1), relax.TensorStructInfo((), "int8") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x0, y0), relax.TensorStructInfo((), "float16") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x1, y1), relax.TensorStructInfo((), "int8") + ) + + +def test_cross_entropy_infer_struct_info_wrong_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x2 = relax.Var("x", R.Tensor((2,), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=4)) + y2 = relax.Var("y", R.Tensor("float32", ndim=-1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_without_logits(x1, y0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_without_logits(x0, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x1, y0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x0, y1)) + + +def test_cross_entropy_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor((m, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 4), "float32")) + y1 = relax.Var("y", R.Tensor((m + 2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_without_logits(x0, y0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x0, y0)) + + +def test_cross_entropy_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_without_logits(x0, y)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_without_logits(x1, y)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x0, y)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x1, y)) + + +def test_nll_loss_infer_struct_info(): + bb = relax.BlockBuilder() + + x0 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((3, 5, 10, 10))) + x4 = relax.Var("x", R.Tensor((3, 5), "float32")) # (N, C) + x5 = relax.Var("x", R.Tensor((5,), "float32")) # (C,) + + y0 = relax.Var("y", R.Tensor((3, 10, 10), "int64")) + y1 = relax.Var("y", R.Tensor("int64", ndim=3)) + y2 = relax.Var("y", R.Tensor("int64")) + y3 = relax.Var("y", R.Tensor((3, 10, 10))) + y4 = relax.Var("y", R.Tensor((3,))) # (N,) + y5 = relax.Var("y", R.Tensor(())) # () + + w0 = relax.Var("w", R.Tensor((5,), "float32")) + w1 = relax.Var("w", R.Tensor("float32", ndim=1)) + w2 = relax.Var("w", R.Tensor("float32")) + w3 = relax.Var("w", R.Tensor((5,))) + + # reduction = mean + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x1, y0, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x2, y0, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x3, y0, w0, reduction="mean"), + relax.TensorStructInfo((), ""), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y1, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y2, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y3, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w1, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w2, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w3, reduction="mean"), + relax.TensorStructInfo((), ""), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x4, y4, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x5, y5, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + + # reduction=sum is totally the same as mean. Just need one test to ensure they behave the same + _check_inference( + bb, relax.op.nn.nll_loss(x0, y0, w0, reduction="sum"), relax.TensorStructInfo((), "float32") + ) + + # reduction=none + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w0, reduction="none"), + relax.TensorStructInfo((3, 10, 10), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x1, y0, w0, reduction="none"), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x2, y0, w0, reduction="none"), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x3, y0, w0, reduction="none"), + relax.TensorStructInfo((3, 10, 10), ""), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y1, w0, reduction="none"), + relax.TensorStructInfo((3, 10, 10), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y2, w0, reduction="none"), + relax.TensorStructInfo((3, 10, 10), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y3, w0, reduction="none"), + relax.TensorStructInfo((3, 10, 10), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w1, reduction="none"), + relax.TensorStructInfo((3, 10, 10), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w2, reduction="none"), + relax.TensorStructInfo((3, 10, 10), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w3, reduction="none"), + relax.TensorStructInfo((3, 10, 10), ""), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x4, y4, w0, reduction="none"), + relax.TensorStructInfo((3,), "float32"), # (N,) + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x5, y5, w0, reduction="none"), + relax.TensorStructInfo((), "float32"), # () + ) + + +def test_nll_loss_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + N = tir.Var("N", "int64") + C = tir.Var("C", "int64") + d1 = tir.Var("d", "int64") + d2 = tir.Var("d", "int64") + x0 = relax.Var("x", R.Tensor((N, C, d1, d2), "float32")) + x1 = relax.Var("x", R.Tensor((N, C), "float32")) + x2 = relax.Var("x", R.Tensor((C,), "float32")) + x3 = relax.Var("x", R.Tensor((3, C, d1, 2), "float32")) + y0 = relax.Var("y", R.Tensor((N, d1, d2), "int64")) + y1 = relax.Var("y", R.Tensor((N,), "int64")) + y2 = relax.Var("y", R.Tensor((), "int64")) + y3 = relax.Var("y", R.Tensor((3, d1, 2), "int64")) + w0 = relax.Var("w", R.Tensor((C,), "float32")) + w1 = relax.Var("w", R.Tensor((5,), "float32")) + + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w0, reduction="none"), + relax.TensorStructInfo((N, d1, d2), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x1, y1, w0, reduction="none"), + relax.TensorStructInfo((N,), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x2, y2, w0, reduction="none"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x3, y3, w0, reduction="none"), + relax.TensorStructInfo((3, d1, 2), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x3, y3, w1, reduction="none"), + relax.TensorStructInfo((3, d1, 2), "float32"), + ) + + +def test_nll_loss_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + + s0 = relax.Var("s0", relax.ShapeStructInfo((3, 5, 10, 10))) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s2", relax.ShapeStructInfo()) + s3 = relax.Var("s3", relax.ShapeStructInfo((3, 10, 10))) + s4 = relax.Var("s4", relax.ShapeStructInfo(ndim=3)) + s5 = relax.Var("s5", relax.ShapeStructInfo((5,))) + s6 = relax.Var("s6", relax.ShapeStructInfo(ndim=1)) + + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(s3, "int64")) + y1 = relax.Var("y", relax.TensorStructInfo(s4, "int64")) + w0 = relax.Var("w", relax.TensorStructInfo(s5, "float32")) + w1 = relax.Var("w", relax.TensorStructInfo(s6, "float32")) + + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w0, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w0, reduction="none"), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x1, y0, w0, reduction="none"), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x2, y0, w0, reduction="none"), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y1, w0, reduction="none"), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w1, reduction="none"), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + +def test_nll_loss_infer_struct_info_no_weights(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + y = relax.Var("x", R.Tensor((3, 10, 10), "int64")) + + _check_inference( + bb, + relax.op.nn.nll_loss(x, y, reduction="mean"), + relax.TensorStructInfo((), "float32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x, y, reduction="none"), + relax.TensorStructInfo((3, 10, 10), "float32"), + ) + + +def test_nll_loss_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + y0 = relax.Var("y", R.Tensor((3, 10, 10), "int64")) + y1 = relax.Var("y", relax.ShapeStructInfo((2, 3))) + y2 = relax.Var("y", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + w0 = relax.Var("w", R.Tensor((5,), "float32")) + w1 = relax.Var("w", relax.ShapeStructInfo((2, 3))) + w2 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x1, y0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x2, y0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y1, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y2, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y0, w2)) + + +def test_nll_loss_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float16")) + x1 = relax.Var("x", R.Tensor((3, 5, 10, 10), "int8")) + x2 = relax.Var("x", R.Tensor((3, 5, 10, 10), "int32")) + x3 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float64")) + y0 = relax.Var("y", R.Tensor((3, 10, 10), "int8")) + w0 = relax.Var("y", R.Tensor((5,), "float16")) + w1 = relax.Var("y", R.Tensor((5,), "int8")) + w2 = relax.Var("y", R.Tensor((5,), "int32")) + w3 = relax.Var("y", R.Tensor((5,), "float64")) + + _check_inference( + bb, + relax.op.nn.nll_loss(x0, y0, w0, reduction="mean"), + relax.TensorStructInfo((), "float16"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x1, y0, w1, reduction="mean"), + relax.TensorStructInfo((), "int8"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x2, y0, w2, reduction="mean"), + relax.TensorStructInfo((), "int32"), + ) + _check_inference( + bb, + relax.op.nn.nll_loss(x3, y0, w3, reduction="mean"), + relax.TensorStructInfo((), "float64"), + ) + + +def test_nll_loss_infer_struct_info_targets_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + w = relax.Var("w", R.Tensor((5,), "float32")) + targets0 = relax.Var("targets", R.Tensor((3, 10, 10), "float32")) + targets1 = relax.Var("targets", R.Tensor((3, 10, 10), "float64")) + targets2 = relax.Var("targets", R.Tensor((3, 10, 10), "bool")) + targets3 = relax.Var("targets", R.Tensor((3, 10, 10), "int32")) + targets4 = relax.Var("targets", R.Tensor((3, 10, 10), "int64")) + targets5 = relax.Var("targets", R.Tensor((3, 10, 10), "uint32")) + targets6 = relax.Var("targets", R.Tensor((3, 10, 10), "")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x, targets0, w)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x, targets1, w)) + + # correct cases + bb.normalize(relax.op.nn.nll_loss(x, targets2, w)) # bool is uint1 + bb.normalize(relax.op.nn.nll_loss(x, targets3, w)) + bb.normalize(relax.op.nn.nll_loss(x, targets4, w)) + bb.normalize(relax.op.nn.nll_loss(x, targets5, w)) + bb.normalize(relax.op.nn.nll_loss(x, targets6, w)) # unknwon dtype + + +def test_nll_loss_infer_struct_info_ndim_mismatch(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + x1 = relax.Var("x", R.Tensor((3, 5, 10, 10, 10), "float32")) + x2 = relax.Var("x", R.Tensor((3, 5, 10), "float32")) + y0 = relax.Var("x", R.Tensor((3, 10, 10), "int64")) + y1 = relax.Var("x", R.Tensor((3, 10, 10, 10), "int64")) + y2 = relax.Var("x", R.Tensor((3, 10), "int64")) + w0 = relax.Var("w", R.Tensor((5,), "float32")) + w1 = relax.Var("w", R.Tensor((5, 5), "float32")) + w2 = relax.Var("w", R.Tensor((), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x1, y0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x2, y0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y1, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y2, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y0, w2)) + + +def test_nll_loss_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + x1 = relax.Var("x", R.Tensor((3, 6, 10, 10), "float32")) + x2 = relax.Var("x", R.Tensor((4, 5, 10, 10), "float32")) + x3 = relax.Var("x", R.Tensor((3, 5, 11, 10), "float32")) + y0 = relax.Var("x", R.Tensor((3, 10, 10), "int64")) + y1 = relax.Var("x", R.Tensor((4, 10, 10), "int64")) + y2 = relax.Var("x", R.Tensor((3, 11, 10), "int64")) + w0 = relax.Var("w", R.Tensor((5,), "float32")) + w1 = relax.Var("w", R.Tensor((4,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x1, y0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x2, y0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x3, y0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y1, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y2, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x0, y0, w1)) + + +def test_nll_loss_infer_struct_info_wrong_reduction(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32")) + y = relax.Var("x", R.Tensor((3, 10, 10), "int64")) + w = relax.Var("w", R.Tensor((5,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.nll_loss(x, y, w, reduction="foo")) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py b/tests/python/relax/test_tvmscript_parser_op_nn.py index 6114eb04f3..e0e8bfee9d 100644 --- a/tests/python/relax/test_tvmscript_parser_op_nn.py +++ b/tests/python/relax/test_tvmscript_parser_op_nn.py @@ -114,6 +114,21 @@ def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): _check(foo, bb.get()["foo"]) +def test_log_softmax(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.log_softmax(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.log_softmax(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + def test_batch_norm(): @R.function def foo( @@ -188,5 +203,85 @@ def foo( _check(foo, bb.get()["foo"]) +def test_cross_entropy_without_logits(): + @R.function + def foo( + predictions: R.Tensor((2, 3), "float32"), labels: R.Tensor((2, 3), "float32") + ) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.nn.cross_entropy_without_logits(predictions, labels) + return gv + + predictions = relax.Var("predictions", R.Tensor((2, 3), "float32")) + labels = relax.Var("labels", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [predictions, labels]): + gv = bb.emit(relax.op.nn.cross_entropy_without_logits(predictions, labels)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_cross_entropy_with_logits(): + @R.function + def foo( + predictions: R.Tensor((2, 3), "float32"), labels: R.Tensor((2, 3), "float32") + ) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.nn.cross_entropy_with_logits(predictions, labels) + return gv + + predictions = relax.Var("predictions", R.Tensor((2, 3), "float32")) + labels = relax.Var("labels", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [predictions, labels]): + gv = bb.emit(relax.op.nn.cross_entropy_with_logits(predictions, labels)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_nll_loss(): + @R.function + def foo( + predictions: R.Tensor((3, 5, 10, 10), dtype="float32"), + targets: R.Tensor((3, 10, 10), dtype="int64"), + weights: R.Tensor((5,), dtype="float32"), + ) -> R.Tensor((), dtype="float32"): + gv: R.Tensor((), dtype="float32") = R.nn.nll_loss(predictions, targets, weights, "mean", -1) + return gv + + predictions = relax.Var("predictions", R.Tensor((3, 5, 10, 10), "float32")) + targets = relax.Var("targets", R.Tensor((3, 10, 10), "int64")) + weights = relax.Var("weights", R.Tensor((5,), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [predictions, targets, weights]): + gv = bb.emit( + relax.op.nn.nll_loss(predictions, targets, weights, reduction="mean", ignore_index=-1) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_nll_loss_no_weights(): + @R.function + def foo( + predictions: R.Tensor((3, 5, 10, 10), dtype="float32"), + targets: R.Tensor((3, 10, 10), dtype="int64"), + ) -> R.Tensor((), dtype="float32"): + gv: R.Tensor((), dtype="float32") = R.nn.nll_loss( + predictions, targets, reduction="mean", ignore_index=-1 + ) + return gv + + predictions = relax.Var("predictions", R.Tensor((3, 5, 10, 10), "float32")) + targets = relax.Var("targets", R.Tensor((3, 10, 10), "int64")) + bb = relax.BlockBuilder() + with bb.function("foo", [predictions, targets]): + gv = bb.emit(relax.op.nn.nll_loss(predictions, targets, reduction="mean", ignore_index=-1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + if __name__ == "__main__": tvm.testing.main()