-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[Relax] Clean up scatter_elements unknown dtype handling #18577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2398,7 +2398,6 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& | |
| if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) { | ||
| auto diag_dtype = [&](const TensorStructInfoNode* sinfo, ffi::String name) { | ||
| if (sinfo->IsUnknownDtype()) { | ||
| // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning? | ||
| LOG(WARNING) << "Data type of " << name | ||
| << " has not been specified. Assume it has an integer type."; | ||
| } | ||
|
|
@@ -2415,7 +2414,6 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& | |
| } | ||
|
|
||
| if (indices_sinfo->IsUnknownDtype()) { | ||
| // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning? | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| LOG(WARNING) << "Data type of indice has not been specified. Assume it has an integer type."; | ||
| } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { | ||
| ctx->ReportFatal( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3417,6 +3417,20 @@ def test_scatter_elements_infer_struct_info(): | |
| relax.op.scatter_elements(d2, i3, u0, 0, "updates"), | ||
| relax.TensorStructInfo(dtype="float32", ndim=-1), | ||
| ) | ||
| # Test with unknown dtype for data | ||
| d_unknown = relax.Var("data", R.Tensor((4, 4))) | ||
| _check_inference( | ||
| bb, | ||
| relax.op.scatter_elements(d_unknown, i0, u0, 0, "updates"), | ||
| relax.TensorStructInfo((4, 4), dtype=""), | ||
| ) | ||
| # Test with unknown dtype for updates | ||
| u_unknown = relax.Var("updates", R.Tensor((2, 2))) | ||
| _check_inference( | ||
| bb, | ||
| relax.op.scatter_elements(d0, i0, u_unknown, 0, "updates"), | ||
| relax.TensorStructInfo((4, 4), dtype="float32"), | ||
| ) | ||
|
Comment on lines
+3420
to
+3433
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are good additions for testing unknown dtypes for # Test with unknown dtype for data
d_unknown = relax.Var("data", R.Tensor((4, 4)))
_check_inference(
bb,
relax.op.scatter_elements(d_unknown, i0, u0, 0, "updates"),
relax.TensorStructInfo((4, 4), dtype=""),
)
# Test with unknown dtype for updates
u_unknown = relax.Var("updates", R.Tensor((2, 2)))
_check_inference(
bb,
relax.op.scatter_elements(d0, i0, u_unknown, 0, "updates"),
relax.TensorStructInfo((4, 4), dtype="float32"),
)
# Test with unknown dtype for indices
i_unknown = relax.Var("indices", R.Tensor((2, 2)))
_check_inference(
bb,
relax.op.scatter_elements(d0, i_unknown, u0, 0, "updates"),
relax.TensorStructInfo((4, 4), dtype="float32"),
) |
||
|
|
||
|
|
||
| def test_scatter_elements_infer_struct_info_symbolic_shape(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While you're cleaning up this section, I noticed a potential bug in the usage of the
diag_dtypelambda. On line 2406 (in the full file), it seemsdata_sinfois used for both "data" and "updates", but it should probably beupdates_sinfofor "updates". This looks like a copy-paste error. It should bediag_dtype(updates_sinfo, "updates");to correctly check theupdatestensor.