Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2398,7 +2398,6 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder&
if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) {
auto diag_dtype = [&](const TensorStructInfoNode* sinfo, ffi::String name) {
if (sinfo->IsUnknownDtype()) {
// TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While you're cleaning up this section, I noticed a potential bug in the usage of the diag_dtype lambda. On line 2406 (in the full file), it seems data_sinfo is used for both "data" and "updates", but it should probably be updates_sinfo for "updates". This looks like a copy-paste error. It should be diag_dtype(updates_sinfo, "updates"); to correctly check the updates tensor.

LOG(WARNING) << "Data type of " << name
<< " has not been specified. Assume it has an integer type.";
}
Expand All @@ -2415,7 +2414,6 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder&
}

if (indices_sinfo->IsUnknownDtype()) {
// TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While removing this TODO, I noticed a small typo in the following log message on line 2419. It says "indice" but it should probably be "indices" to be consistent with the variable name indices_sinfo and other similar log messages in the codebase (e.g., in InferStructInfoScatterND).

LOG(WARNING) << "Data type of indice has not been specified. Assume it has an integer type.";
} else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) {
ctx->ReportFatal(
Expand Down
14 changes: 14 additions & 0 deletions tests/python/relax/test_op_manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3417,6 +3417,20 @@ def test_scatter_elements_infer_struct_info():
relax.op.scatter_elements(d2, i3, u0, 0, "updates"),
relax.TensorStructInfo(dtype="float32", ndim=-1),
)
# Test with unknown dtype for data
d_unknown = relax.Var("data", R.Tensor((4, 4)))
_check_inference(
bb,
relax.op.scatter_elements(d_unknown, i0, u0, 0, "updates"),
relax.TensorStructInfo((4, 4), dtype=""),
)
# Test with unknown dtype for updates
u_unknown = relax.Var("updates", R.Tensor((2, 2)))
_check_inference(
bb,
relax.op.scatter_elements(d0, i0, u_unknown, 0, "updates"),
relax.TensorStructInfo((4, 4), dtype="float32"),
)
Comment on lines +3420 to +3433

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These are good additions for testing unknown dtypes for data and updates. To improve test coverage further, could you also add a test case for when indices has an unknown dtype? This would cover the warning log path for indices_sinfo->IsUnknownDtype() in InferStructInfoScatterElements.

    # Test with unknown dtype for data
    d_unknown = relax.Var("data", R.Tensor((4, 4)))
    _check_inference(
        bb,
        relax.op.scatter_elements(d_unknown, i0, u0, 0, "updates"),
        relax.TensorStructInfo((4, 4), dtype=""),
    )
    # Test with unknown dtype for updates
    u_unknown = relax.Var("updates", R.Tensor((2, 2)))
    _check_inference(
        bb,
        relax.op.scatter_elements(d0, i0, u_unknown, 0, "updates"),
        relax.TensorStructInfo((4, 4), dtype="float32"),
    )
    # Test with unknown dtype for indices
    i_unknown = relax.Var("indices", R.Tensor((2, 2)))
    _check_inference(
        bb,
        relax.op.scatter_elements(d0, i_unknown, u0, 0, "updates"),
        relax.TensorStructInfo((4, 4), dtype="float32"),
    )



def test_scatter_elements_infer_struct_info_symbolic_shape():
Expand Down
Loading