Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,15 @@ struct IndexPutAttrs : public tvm::AttrsNode<IndexPutAttrs> {
}
}; // struct IndexPutAttrs

/*! \brief Attribute used in meshgrid operator */
struct MeshgridAttrs : public tvm::AttrsNode<MeshgridAttrs> {
Optional<String> indexing;

TVM_DECLARE_ATTRS(MeshgridAttrs, "relax.attrs.MeshgridAttrs") {
TVM_ATTR_FIELD(indexing).describe("Specifies how the grid dimensions are ordered.");
}
};

/*! \brief Attributes used in scatter_elements operators */
struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
Integer axis;
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,26 @@ def _index_tensor(self, node: fx.Node) -> relax.Var:
indices = args[1]
return self.block_builder.emit(relax.op.index_tensor(args[0], indices))

def _meshgrid(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
indexing = args[1] if len(node.args) > 1 else node.kwargs.get("indexing", "ij")
input_list = args[0]

# Single input: return as-is, meshgrid not applicable.
if len(input_list) == 1:
return input_list
new_inputs = []
for i, item in enumerate(input_list):
if item.struct_info.ndim == 1:
new_inputs.append(item)
elif item.struct_info.ndim == 0: # Change scalar value into 1D
const_tensor = relax.op.reshape(item, (1,))
new_inputs.append(const_tensor)
else:
raise TypeError(f"Unsupported meshgrid input type at index {i}: {type(item)}")

return self.block_builder.emit(relax.op.meshgrid(new_inputs, indexing=indexing))

def _permute(self, node: fx.Node) -> relax.Var:
import torch # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ def create_convert_map(
"gather.default": self._gather,
"index.Tensor": self._index_tensor,
"index_put_.default": self._index_put,
"meshgrid.indexing": self._meshgrid,
"meshgrid.default": self._meshgrid,
"narrow.default": self._narrow,
"permute.default": self._permute,
"repeat.default": self._repeat,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ def create_convert_map(
"flip": self._flip,
"gather": self._gather,
"index_put_": self._index_put,
"meshgrid": self._meshgrid,
"narrow": self._narrow,
"numel": self._numel,
"permute": self._permute,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
gather_nd,
index_put,
index_tensor,
meshgrid,
layout_transform,
one_hot,
permute_dims,
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,29 @@ def index_put(
return _ffi_api.index_put(data, indices, values, accumulate) # type: ignore


def meshgrid(tensors: Union[Expr, List[Expr]], indexing: Optional[str] = "ij") -> Expr:
"""Generate coordinate grids from input tensors.

Parameters
----------
tensors : Union[relax.Expr, List[relax.Expr]]
An Expr in Tuple type, containing 1D tensors (or scalars promoted to 1D)
to generate coordinate grids from, or a list of such tensors.

indexing : Optional[str]
The indexing mode, either "ij" (matrix indexing) or "xy" (Cartesian indexing).
Defaults to "ij".

Returns
-------
result : relax.Expr
A Tuple of tensors representing the coordinate grids.
"""
if isinstance(tensors, (list, tuple)):
tensors = RxTuple(tensors)
return _ffi_api.meshgrid(tensors, indexing)


def scatter_elements(
data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update"
):
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,25 @@ def _index_put(bb: BlockBuilder, call: Call) -> Expr:
)


@register_legalize("relax.meshgrid")
def _meshgrid(bb: BlockBuilder, call: Call) -> Expr:
t = call.args[0]
n_field = len(t.struct_info.fields)
while isinstance(t, Var):
binding = bb.lookup_binding(t)
if not isinstance(binding, (Tuple, Var)):
break
t = binding

assert isinstance(t, (Tuple, Var))
fields = (
t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)]
)
return bb.call_te(
topi.meshgrid, fields, "ij" if call.attrs.indexing is None else call.attrs.indexing
)


@register_legalize("relax.scatter_elements")
def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
maximum,
mean,
memory,
meshgrid,
min,
minimum,
mod,
Expand Down Expand Up @@ -811,6 +812,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"maximum",
"mean",
"memory",
"meshgrid",
"metal",
"min",
"minimum",
Expand Down
103 changes: 103 additions & 0 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2095,6 +2095,109 @@ TVM_REGISTER_OP("relax.index_put")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoIndexPut)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.meshgrid */
TVM_REGISTER_NODE_TYPE(MeshgridAttrs);

Expr meshgrid(Expr tensors, Optional<String> indexing) {
ObjectPtr<MeshgridAttrs> attrs = make_object<MeshgridAttrs>();
attrs->indexing = indexing;
static const Op& op = Op::Get("relax.meshgrid");
return Call(op, {std::move(tensors)}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.meshgrid").set_body_typed(meshgrid);

StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) {
if (call->args.size() != 1) {
ctx->ReportFatal(Diagnostic::Error(call) << "meshgrid op expects 1 Tuple input argument.");
}
Array<TensorStructInfo> input_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]);

int n_inputs = input_sinfo.size();

if (n_inputs == 0) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "meshgrid expects at least one 1D tensor in the input Tuple.");
}

std::vector<PrimExpr> lengths;
DataType common_dtype = DataType::Void();
bool shape_unknown = false;
Optional<VDevice> vdev = NullOpt;
bool vdevice_unknown = false;

for (int i = 0; i < n_inputs; ++i) {
const TensorStructInfo& sinfo = input_sinfo[i];

if (sinfo->ndim != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "meshgrid expects each input tensor to be 1D. Got ndim = " << sinfo->ndim
<< " at index " << i);
}

if (sinfo->dtype.is_void()) {
continue;
} else if (common_dtype.is_void()) {
common_dtype = sinfo->dtype;
} else if (sinfo->dtype != common_dtype) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "meshgrid expects all input tensors to have the same dtype. Found "
<< sinfo->dtype << " and " << common_dtype);
}

const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
if (shape_expr && shape_expr->values.size() == 1) {
lengths.push_back(shape_expr->values[0]);
} else {
shape_unknown = true;
}

if (!vdevice_unknown) {
if (sinfo->vdevice.defined()) {
if (!vdev.defined()) {
vdev = sinfo->vdevice.value();
} else if (sinfo->vdevice.value() != vdev) {
vdevice_unknown = true;
}
}
}
}

Array<PrimExpr> out_shape;
if (!shape_unknown && lengths.size() == static_cast<size_t>(n_inputs)) {
for (const PrimExpr& dim : lengths) {
out_shape.push_back(dim);
}
}

Array<StructInfo> out_fields;
for (int i = 0; i < n_inputs; ++i) {
if (!out_shape.empty()) {
if (!vdevice_unknown) {
out_fields.push_back(TensorStructInfo(ShapeExpr(out_shape), common_dtype, vdev));
} else {
out_fields.push_back(TensorStructInfo(ShapeExpr(out_shape), common_dtype));
}
} else {
if (!vdevice_unknown) {
out_fields.push_back(TensorStructInfo(common_dtype, n_inputs, vdev));
} else {
out_fields.push_back(TensorStructInfo(common_dtype, n_inputs));
}
}
}

return TupleStructInfo(out_fields);
}

TVM_REGISTER_OP("relax.meshgrid")
.set_attrs_type<MeshgridAttrs>()
.set_num_inputs(1)
.add_argument("tensors", "Tuple of Tensors", "The input list of tensors.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMeshgrid)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.scatter_elements */
TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs);

Expand Down
8 changes: 8 additions & 0 deletions src/relax/op/tensor/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ Expr index_tensor(Expr data, Expr indices);
*/
Expr index_put(Expr data, Expr indices, Expr values, bool accumulate = false);

/*!
* \brief Generate coordinate grids from input 1D tensors.
* \param tensors A tuple of 1D tensors representing coordinate vectors.
* \param indexing Indexing mode, either "ij" (matrix indexing) or "xy" (Cartesian indexing).
* \return A tuple of tensors representing the coordinate grids.
*/
Expr meshgrid(Expr tensors, Optional<String> indexing = String("ij"));

/*!
* \brief Scatter updates into an array according to indices.
* \param data The input tensor.
Expand Down
53 changes: 53 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -2883,6 +2883,59 @@ def main(
verify_model(Flatten(), example_args, {}, expected1)


def test_meshgrid():
class Meshgrid1(Module):
def forward(self, input1, input2):
return torch.meshgrid((input1, input2), indexing="ij")

class Meshgrid2(Module):
def forward(self, input1, input2):
return torch.meshgrid((input1, input2), indexing="xy")

@tvm.script.ir_module
class expected1:
@R.function
def main(
input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,), dtype="float32")
) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")):
with R.dataflow():
lv: R.Tuple(
R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")
) = R.meshgrid((input1, input2), indexing="ij")
lv1: R.Tensor((3, 3), dtype="float32") = lv[0]
lv2: R.Tensor((3, 3), dtype="float32") = lv[1]
gv: R.Tuple(
R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")
) = (lv1, lv2)
R.output(gv)
return gv

@tvm.script.ir_module
class expected2:
@R.function
def main(
input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,), dtype="float32")
) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")):
with R.dataflow():
lv: R.Tuple(
R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")
) = R.meshgrid((input1, input2), indexing="xy")
lv1: R.Tensor((3, 3), dtype="float32") = lv[0]
lv2: R.Tensor((3, 3), dtype="float32") = lv[1]
gv: R.Tuple(
R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")
) = (lv1, lv2)
R.output(gv)
return gv

example_args = (
torch.randn(3, dtype=torch.float32),
torch.randn(3, dtype=torch.float32),
)
verify_model(Meshgrid1(), example_args, {}, expected1)
verify_model(Meshgrid2(), example_args, {}, expected2)


def test_permute():
class Permute1(Module):
def forward(self, x):
Expand Down
60 changes: 60 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3441,6 +3441,66 @@ def forward(self, x):
verify_model(AsType(), input_info, {}, expected1)


def test_meshgrid():
input_infos = [
(
[
3,
],
"float32",
),
(
[
3,
],
"float32",
),
]

class Meshgrid1(Module):
def forward(self, input1, input2):
return torch.meshgrid((input1, input2), indexing="ij")

class Meshgrid2(Module):
def forward(self, input1, input2):
return torch.meshgrid((input1, input2), indexing="xy")

@tvm.script.ir_module
class expected1:
@R.function
def main(
inp_0: R.Tensor((3,), dtype="float32"), inp_1: R.Tensor((3,), dtype="float32")
) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")):
with R.dataflow():
lv: R.Tuple(
R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")
) = R.meshgrid((inp_0, inp_1), indexing="ij")
gv: R.Tuple(
R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")
) = lv
R.output(gv)
return gv

@tvm.script.ir_module
class expected2:
@R.function
def main(
inp_0: R.Tensor((3,), dtype="float32"), inp_1: R.Tensor((3,), dtype="float32")
) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")):
with R.dataflow():
lv: R.Tuple(
R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")
) = R.meshgrid((inp_0, inp_1), indexing="xy")
gv: R.Tuple(
R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")
) = lv
R.output(gv)
return gv

verify_model(Meshgrid1(), input_infos, {}, expected1)
verify_model(Meshgrid2(), input_infos, {}, expected2)


def test_permute():
input_info = [([1, 2, 3, 4], "float32")]

Expand Down
Loading