Skip to content

Commit fa26a05

Browse files
Deivanayaki-Sdeivanayakisankaralingam
andauthored
[Relax][PyTorch] Add Meshgrid Op Support for Exported Program and FX graph (#17904)
* add torch.meshgrid op support into torch frontends * remove trailing whitespaces * fix lint issues * fix space issue in test script * fix func definition issue * set relax var shape to fix the unity issue * fix format issue in input declaration * fix lint issue * fix cpp lints * ix cpp lint issue in manipulate file * fix wrong input in struct info test script * add one more mapping for meshgrid in exported program --------- Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki.>
1 parent 4ef582a commit fa26a05

File tree

13 files changed

+346
-0
lines changed

13 files changed

+346
-0
lines changed

include/tvm/relax/attrs/manipulate.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,15 @@ struct IndexPutAttrs : public tvm::AttrsNode<IndexPutAttrs> {
196196
}
197197
}; // struct IndexPutAttrs
198198

199+
/*! \brief Attribute used in meshgrid operator */
200+
struct MeshgridAttrs : public tvm::AttrsNode<MeshgridAttrs> {
201+
Optional<String> indexing;
202+
203+
TVM_DECLARE_ATTRS(MeshgridAttrs, "relax.attrs.MeshgridAttrs") {
204+
TVM_ATTR_FIELD(indexing).describe("Specifies how the grid dimensions are ordered.");
205+
}
206+
};
207+
199208
/*! \brief Attributes used in scatter_elements operators */
200209
struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
201210
Integer axis;

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,6 +1179,26 @@ def _index_tensor(self, node: fx.Node) -> relax.Var:
11791179
indices = args[1]
11801180
return self.block_builder.emit(relax.op.index_tensor(args[0], indices))
11811181

1182+
def _meshgrid(self, node: fx.Node) -> relax.Var:
1183+
args = self.retrieve_args(node)
1184+
indexing = args[1] if len(node.args) > 1 else node.kwargs.get("indexing", "ij")
1185+
input_list = args[0]
1186+
1187+
# Single input: return as-is, meshgrid not applicable.
1188+
if len(input_list) == 1:
1189+
return input_list
1190+
new_inputs = []
1191+
for i, item in enumerate(input_list):
1192+
if item.struct_info.ndim == 1:
1193+
new_inputs.append(item)
1194+
elif item.struct_info.ndim == 0: # Change scalar value into 1D
1195+
const_tensor = relax.op.reshape(item, (1,))
1196+
new_inputs.append(const_tensor)
1197+
else:
1198+
raise TypeError(f"Unsupported meshgrid input type at index {i}: {type(item)}")
1199+
1200+
return self.block_builder.emit(relax.op.meshgrid(new_inputs, indexing=indexing))
1201+
11821202
def _permute(self, node: fx.Node) -> relax.Var:
11831203
import torch # type: ignore
11841204

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,8 @@ def create_convert_map(
439439
"gather.default": self._gather,
440440
"index.Tensor": self._index_tensor,
441441
"index_put_.default": self._index_put,
442+
"meshgrid.indexing": self._meshgrid,
443+
"meshgrid.default": self._meshgrid,
442444
"narrow.default": self._narrow,
443445
"permute.default": self._permute,
444446
"repeat.default": self._repeat,

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,7 @@ def create_convert_map(
804804
"flip": self._flip,
805805
"gather": self._gather,
806806
"index_put_": self._index_put,
807+
"meshgrid": self._meshgrid,
807808
"narrow": self._narrow,
808809
"numel": self._numel,
809810
"permute": self._permute,

python/tvm/relax/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
gather_nd,
9898
index_put,
9999
index_tensor,
100+
meshgrid,
100101
layout_transform,
101102
one_hot,
102103
permute_dims,

python/tvm/relax/op/manipulate.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,29 @@ def index_put(
646646
return _ffi_api.index_put(data, indices, values, accumulate) # type: ignore
647647

648648

649+
def meshgrid(tensors: Union[Expr, List[Expr]], indexing: Optional[str] = "ij") -> Expr:
650+
"""Generate coordinate grids from input tensors.
651+
652+
Parameters
653+
----------
654+
tensors : Union[relax.Expr, List[relax.Expr]]
655+
An Expr in Tuple type, containing 1D tensors (or scalars promoted to 1D)
656+
to generate coordinate grids from, or a list of such tensors.
657+
658+
indexing : Optional[str]
659+
The indexing mode, either "ij" (matrix indexing) or "xy" (Cartesian indexing).
660+
Defaults to "ij".
661+
662+
Returns
663+
-------
664+
result : relax.Expr
665+
A Tuple of tensors representing the coordinate grids.
666+
"""
667+
if isinstance(tensors, (list, tuple)):
668+
tensors = RxTuple(tensors)
669+
return _ffi_api.meshgrid(tensors, indexing)
670+
671+
649672
def scatter_elements(
650673
data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update"
651674
):

python/tvm/relax/transform/legalize_ops/manipulate.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,25 @@ def _index_put(bb: BlockBuilder, call: Call) -> Expr:
215215
)
216216

217217

218+
@register_legalize("relax.meshgrid")
219+
def _meshgrid(bb: BlockBuilder, call: Call) -> Expr:
220+
t = call.args[0]
221+
n_field = len(t.struct_info.fields)
222+
while isinstance(t, Var):
223+
binding = bb.lookup_binding(t)
224+
if not isinstance(binding, (Tuple, Var)):
225+
break
226+
t = binding
227+
228+
assert isinstance(t, (Tuple, Var))
229+
fields = (
230+
t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)]
231+
)
232+
return bb.call_te(
233+
topi.meshgrid, fields, "ij" if call.attrs.indexing is None else call.attrs.indexing
234+
)
235+
236+
218237
@register_legalize("relax.scatter_elements")
219238
def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
220239
return bb.call_te(

python/tvm/script/ir_builder/relax/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
maximum,
126126
mean,
127127
memory,
128+
meshgrid,
128129
min,
129130
minimum,
130131
mod,
@@ -811,6 +812,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
811812
"maximum",
812813
"mean",
813814
"memory",
815+
"meshgrid",
814816
"metal",
815817
"min",
816818
"minimum",

src/relax/op/tensor/manipulate.cc

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2095,6 +2095,109 @@ TVM_REGISTER_OP("relax.index_put")
20952095
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoIndexPut)
20962096
.set_attr<Bool>("FPurity", Bool(true));
20972097

2098+
/* relax.meshgrid */
2099+
TVM_REGISTER_NODE_TYPE(MeshgridAttrs);
2100+
2101+
Expr meshgrid(Expr tensors, Optional<String> indexing) {
2102+
ObjectPtr<MeshgridAttrs> attrs = make_object<MeshgridAttrs>();
2103+
attrs->indexing = indexing;
2104+
static const Op& op = Op::Get("relax.meshgrid");
2105+
return Call(op, {std::move(tensors)}, Attrs(attrs), {});
2106+
}
2107+
2108+
TVM_REGISTER_GLOBAL("relax.op.meshgrid").set_body_typed(meshgrid);
2109+
2110+
StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) {
2111+
if (call->args.size() != 1) {
2112+
ctx->ReportFatal(Diagnostic::Error(call) << "meshgrid op expects 1 Tuple input argument.");
2113+
}
2114+
Array<TensorStructInfo> input_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]);
2115+
2116+
int n_inputs = input_sinfo.size();
2117+
2118+
if (n_inputs == 0) {
2119+
ctx->ReportFatal(Diagnostic::Error(call)
2120+
<< "meshgrid expects at least one 1D tensor in the input Tuple.");
2121+
}
2122+
2123+
std::vector<PrimExpr> lengths;
2124+
DataType common_dtype = DataType::Void();
2125+
bool shape_unknown = false;
2126+
Optional<VDevice> vdev = NullOpt;
2127+
bool vdevice_unknown = false;
2128+
2129+
for (int i = 0; i < n_inputs; ++i) {
2130+
const TensorStructInfo& sinfo = input_sinfo[i];
2131+
2132+
if (sinfo->ndim != 1) {
2133+
ctx->ReportFatal(Diagnostic::Error(call)
2134+
<< "meshgrid expects each input tensor to be 1D. Got ndim = " << sinfo->ndim
2135+
<< " at index " << i);
2136+
}
2137+
2138+
if (sinfo->dtype.is_void()) {
2139+
continue;
2140+
} else if (common_dtype.is_void()) {
2141+
common_dtype = sinfo->dtype;
2142+
} else if (sinfo->dtype != common_dtype) {
2143+
ctx->ReportFatal(Diagnostic::Error(call)
2144+
<< "meshgrid expects all input tensors to have the same dtype. Found "
2145+
<< sinfo->dtype << " and " << common_dtype);
2146+
}
2147+
2148+
const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
2149+
if (shape_expr && shape_expr->values.size() == 1) {
2150+
lengths.push_back(shape_expr->values[0]);
2151+
} else {
2152+
shape_unknown = true;
2153+
}
2154+
2155+
if (!vdevice_unknown) {
2156+
if (sinfo->vdevice.defined()) {
2157+
if (!vdev.defined()) {
2158+
vdev = sinfo->vdevice.value();
2159+
} else if (sinfo->vdevice.value() != vdev) {
2160+
vdevice_unknown = true;
2161+
}
2162+
}
2163+
}
2164+
}
2165+
2166+
Array<PrimExpr> out_shape;
2167+
if (!shape_unknown && lengths.size() == static_cast<size_t>(n_inputs)) {
2168+
for (const PrimExpr& dim : lengths) {
2169+
out_shape.push_back(dim);
2170+
}
2171+
}
2172+
2173+
Array<StructInfo> out_fields;
2174+
for (int i = 0; i < n_inputs; ++i) {
2175+
if (!out_shape.empty()) {
2176+
if (!vdevice_unknown) {
2177+
out_fields.push_back(TensorStructInfo(ShapeExpr(out_shape), common_dtype, vdev));
2178+
} else {
2179+
out_fields.push_back(TensorStructInfo(ShapeExpr(out_shape), common_dtype));
2180+
}
2181+
} else {
2182+
if (!vdevice_unknown) {
2183+
out_fields.push_back(TensorStructInfo(common_dtype, n_inputs, vdev));
2184+
} else {
2185+
out_fields.push_back(TensorStructInfo(common_dtype, n_inputs));
2186+
}
2187+
}
2188+
}
2189+
2190+
return TupleStructInfo(out_fields);
2191+
}
2192+
2193+
TVM_REGISTER_OP("relax.meshgrid")
2194+
.set_attrs_type<MeshgridAttrs>()
2195+
.set_num_inputs(1)
2196+
.add_argument("tensors", "Tuple of Tensors", "The input list of tensors.")
2197+
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMeshgrid)
2198+
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
2199+
.set_attr<Bool>("FPurity", Bool(true));
2200+
20982201
/* relax.scatter_elements */
20992202
TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs);
21002203

src/relax/op/tensor/manipulate.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,14 @@ Expr index_tensor(Expr data, Expr indices);
231231
*/
232232
Expr index_put(Expr data, Expr indices, Expr values, bool accumulate = false);
233233

234+
/*!
235+
* \brief Generate coordinate grids from input 1D tensors.
236+
* \param tensors A tuple of 1D tensors representing coordinate vectors.
237+
* \param indexing Indexing mode, either "ij" (matrix indexing) or "xy" (Cartesian indexing).
238+
* \return A tuple of tensors representing the coordinate grids.
239+
*/
240+
Expr meshgrid(Expr tensors, Optional<String> indexing = String("ij"));
241+
234242
/*!
235243
* \brief Scatter updates into an array according to indices.
236244
* \param data The input tensor.

0 commit comments

Comments
 (0)