@@ -2095,6 +2095,109 @@ TVM_REGISTER_OP("relax.index_put")
20952095 .set_attr<FInferStructInfo>(" FInferStructInfo" , InferStructInfoIndexPut)
20962096 .set_attr<Bool>(" FPurity" , Bool(true ));
20972097
2098+ /* relax.meshgrid */
2099+ TVM_REGISTER_NODE_TYPE (MeshgridAttrs);
2100+
2101+ Expr meshgrid (Expr tensors, Optional<String> indexing) {
2102+ ObjectPtr<MeshgridAttrs> attrs = make_object<MeshgridAttrs>();
2103+ attrs->indexing = indexing;
2104+ static const Op& op = Op::Get (" relax.meshgrid" );
2105+ return Call (op, {std::move (tensors)}, Attrs (attrs), {});
2106+ }
2107+
2108+ TVM_REGISTER_GLOBAL (" relax.op.meshgrid" ).set_body_typed(meshgrid);
2109+
2110+ StructInfo InferStructInfoMeshgrid (const Call& call, const BlockBuilder& ctx) {
2111+ if (call->args .size () != 1 ) {
2112+ ctx->ReportFatal (Diagnostic::Error (call) << " meshgrid op expects 1 Tuple input argument." );
2113+ }
2114+ Array<TensorStructInfo> input_sinfo = GetTensorStructInfoFromTuple (call, ctx, call->args [0 ]);
2115+
2116+ int n_inputs = input_sinfo.size ();
2117+
2118+ if (n_inputs == 0 ) {
2119+ ctx->ReportFatal (Diagnostic::Error (call)
2120+ << " meshgrid expects at least one 1D tensor in the input Tuple." );
2121+ }
2122+
2123+ std::vector<PrimExpr> lengths;
2124+ DataType common_dtype = DataType::Void ();
2125+ bool shape_unknown = false ;
2126+ Optional<VDevice> vdev = NullOpt;
2127+ bool vdevice_unknown = false ;
2128+
2129+ for (int i = 0 ; i < n_inputs; ++i) {
2130+ const TensorStructInfo& sinfo = input_sinfo[i];
2131+
2132+ if (sinfo->ndim != 1 ) {
2133+ ctx->ReportFatal (Diagnostic::Error (call)
2134+ << " meshgrid expects each input tensor to be 1D. Got ndim = " << sinfo->ndim
2135+ << " at index " << i);
2136+ }
2137+
2138+ if (sinfo->dtype .is_void ()) {
2139+ continue ;
2140+ } else if (common_dtype.is_void ()) {
2141+ common_dtype = sinfo->dtype ;
2142+ } else if (sinfo->dtype != common_dtype) {
2143+ ctx->ReportFatal (Diagnostic::Error (call)
2144+ << " meshgrid expects all input tensors to have the same dtype. Found "
2145+ << sinfo->dtype << " and " << common_dtype);
2146+ }
2147+
2148+ const auto * shape_expr = sinfo->shape .as <ShapeExprNode>();
2149+ if (shape_expr && shape_expr->values .size () == 1 ) {
2150+ lengths.push_back (shape_expr->values [0 ]);
2151+ } else {
2152+ shape_unknown = true ;
2153+ }
2154+
2155+ if (!vdevice_unknown) {
2156+ if (sinfo->vdevice .defined ()) {
2157+ if (!vdev.defined ()) {
2158+ vdev = sinfo->vdevice .value ();
2159+ } else if (sinfo->vdevice .value () != vdev) {
2160+ vdevice_unknown = true ;
2161+ }
2162+ }
2163+ }
2164+ }
2165+
2166+ Array<PrimExpr> out_shape;
2167+ if (!shape_unknown && lengths.size () == static_cast <size_t >(n_inputs)) {
2168+ for (const PrimExpr& dim : lengths) {
2169+ out_shape.push_back (dim);
2170+ }
2171+ }
2172+
2173+ Array<StructInfo> out_fields;
2174+ for (int i = 0 ; i < n_inputs; ++i) {
2175+ if (!out_shape.empty ()) {
2176+ if (!vdevice_unknown) {
2177+ out_fields.push_back (TensorStructInfo (ShapeExpr (out_shape), common_dtype, vdev));
2178+ } else {
2179+ out_fields.push_back (TensorStructInfo (ShapeExpr (out_shape), common_dtype));
2180+ }
2181+ } else {
2182+ if (!vdevice_unknown) {
2183+ out_fields.push_back (TensorStructInfo (common_dtype, n_inputs, vdev));
2184+ } else {
2185+ out_fields.push_back (TensorStructInfo (common_dtype, n_inputs));
2186+ }
2187+ }
2188+ }
2189+
2190+ return TupleStructInfo (out_fields);
2191+ }
2192+
2193+ TVM_REGISTER_OP (" relax.meshgrid" )
2194+ .set_attrs_type<MeshgridAttrs>()
2195+ .set_num_inputs(1 )
2196+ .add_argument(" tensors" , " Tuple of Tensors" , " The input list of tensors." )
2197+ .set_attr<FInferStructInfo>(" FInferStructInfo" , InferStructInfoMeshgrid)
2198+ .set_attr<TMixedPrecisionPolicy>(" TMixedPrecisionPolicy" , MixedPrecisionPolicyKind::kFollow )
2199+ .set_attr<Bool>(" FPurity" , Bool(true ));
2200+
20982201/* relax.scatter_elements */
20992202TVM_REGISTER_NODE_TYPE (ScatterElementsAttrs);
21002203
0 commit comments