Skip to content

Commit fab67a9

Browse files
authored
[Relax][PyTorch] Support tensor manipulation and creation ops for ExportedProgram importer (#17429)
* support cat and concat * support cumsum * support expand * support permute * support squeeze * support tile * support transpose * support unsqueeze * add test for flatten * support repeat * add test for reshape * support select and slice * support arange * support empty * support fill * support new_ones * support _to_copy * support split * add test for unbind * support clone
1 parent 4f94890 commit fab67a9

File tree

4 files changed

+981
-139
lines changed

4 files changed

+981
-139
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,51 @@ def convert(node: fx.Node):
730730

731731
########## Manipulation ##########
732732

733+
def _cat(self, node: fx.Node) -> relax.Var:
734+
args = self.retrieve_args(node)
735+
axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
736+
return self.block_builder.emit(relax.op.concat(args[0], axis=axis))
737+
738+
def _cumsum(self, node: fx.Node) -> relax.Var:
739+
x = self.env[node.args[0]]
740+
741+
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
742+
if "dtype" in node.kwargs:
743+
dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
744+
else:
745+
dtype = None
746+
if "out" in node.kwargs:
747+
raise ValueError("specifying out for cumsum is not supported yet")
748+
749+
return self.block_builder.emit(relax.op.cumsum(x, dim, dtype))
750+
751+
def _expand(self, node: fx.Node) -> relax.Var:
752+
args = self.retrieve_args(node)
753+
sizes = args[1:] if len(args) > 2 else args[1]
754+
broadcast_shape, in_shape = [], self.shape_of(args[0])
755+
for idx, i in enumerate(sizes):
756+
if isinstance(i, int) and i == -1:
757+
broadcast_shape.append(in_shape[idx])
758+
else:
759+
broadcast_shape.append(i)
760+
return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape))
761+
762+
def _permute(self, node: fx.Node) -> relax.Var:
763+
import torch # type: ignore
764+
765+
args = self.retrieve_args(node)
766+
x = args[0]
767+
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:]
768+
return self.block_builder.emit(relax.op.permute_dims(x, dims))
769+
770+
def _repeat(self, node: fx.Node) -> relax.Var:
771+
import torch # type: ignore
772+
773+
args = self.retrieve_args(node)
774+
x = args[0]
775+
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:]
776+
return self.block_builder.emit(relax.op.tile(x, dims))
777+
733778
def _reshape(self, node: fx.Node) -> relax.Var:
734779
import torch # type: ignore
735780

@@ -738,6 +783,122 @@ def _reshape(self, node: fx.Node) -> relax.Var:
738783
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:]
739784
return self.block_builder.emit(relax.op.reshape(x, dims))
740785

786+
def _split(self, node: fx.Node) -> relax.Var:
787+
x = self.env[node.args[0]]
788+
split_size = node.args[1]
789+
dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
790+
if isinstance(split_size, (list, tuple)):
791+
n_section = []
792+
for s in split_size[:-1]:
793+
cum_sum = 0 if not n_section else n_section[-1]
794+
n_section.append(s + cum_sum)
795+
else:
796+
n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size
797+
return self.block_builder.emit(relax.op.split(x, n_section, dim))
798+
799+
def _squeeze(self, node: fx.Node) -> relax.Var:
800+
x = self.env[node.args[0]]
801+
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
802+
return self.block_builder.emit(relax.op.squeeze(x, dim))
803+
804+
def _tile(self, node: fx.Node) -> relax.Var:
805+
import torch # type: ignore
806+
807+
args = self.retrieve_args(node)
808+
x = args[0]
809+
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:]
810+
return self.block_builder.emit(relax.op.tile(x, dims))
811+
812+
def _transpose(self, node: fx.Node) -> relax.Var:
813+
args = self.retrieve_args(node)
814+
full_idx = list(range(len(self.shape_of(args[0]))))
815+
full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]]
816+
return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx))
817+
818+
########## Creation ##########
819+
820+
def _to_copy(self, node: fx.Node) -> relax.Var:
821+
import torch # type: ignore
822+
823+
x = self.env[node.args[0]]
824+
if len(node.args) == 2:
825+
if isinstance(node.args[1], torch.dtype):
826+
dtype = self._convert_data_type(node.args[1], self.env)
827+
return self.block_builder.emit(relax.op.astype(x, dtype))
828+
elif "dtype" in node.kwargs:
829+
dtype = self._convert_data_type(node.kwargs["dtype"], self.env)
830+
return self.block_builder.emit(relax.op.astype(x, dtype))
831+
return x
832+
833+
def _arange(self, node: fx.Node) -> relax.Var:
834+
import torch # type: ignore
835+
836+
start_end_step = [None, None, None]
837+
if "start" in node.kwargs:
838+
start_end_step[0] = node.kwargs["start"]
839+
if "end" in node.kwargs:
840+
start_end_step[1] = node.kwargs["end"]
841+
if "step" in node.kwargs:
842+
start_end_step[2] = node.kwargs["step"]
843+
844+
if len(node.args) == 1:
845+
assert start_end_step[1] is None
846+
start_end_step[1] = node.args[0]
847+
elif len(node.args) == 2:
848+
assert start_end_step[0] is None
849+
assert start_end_step[1] is None
850+
start_end_step[0] = node.args[0]
851+
start_end_step[1] = node.args[1]
852+
elif len(node.args) == 3:
853+
assert start_end_step[0] is None
854+
assert start_end_step[1] is None
855+
assert start_end_step[2] is None
856+
start_end_step[0] = node.args[0]
857+
start_end_step[1] = node.args[1]
858+
start_end_step[2] = node.args[2]
859+
860+
if start_end_step[0] is None:
861+
start_end_step[0] = 0
862+
if start_end_step[2] is None:
863+
start_end_step[2] = 1
864+
865+
if "dtype" in node.kwargs:
866+
dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
867+
elif any([isinstance(x, float) for x in start_end_step]):
868+
dtype = self._convert_data_type(torch.get_default_dtype())
869+
else:
870+
dtype = "int64"
871+
start_end_step = [
872+
self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step
873+
]
874+
return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype))
875+
876+
def _empty(self, node: fx.Node) -> relax.Var:
877+
dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
878+
return self.block_builder.emit(relax.op.zeros(node.args[0], dtype))
879+
880+
def _fill(self, node: fx.Node) -> relax.Var:
881+
args = self.retrieve_args(node)
882+
x = args[0]
883+
dtype = x.struct_info.dtype
884+
value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype)
885+
return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype))
886+
887+
def _new_ones(self, node: fx.Node) -> relax.Var:
888+
args = self.retrieve_args(node)
889+
self_var = args[0]
890+
size = args[1] if isinstance(args[1], (list, tuple)) else args[1:]
891+
if not isinstance(size, (list, tuple)):
892+
size = (size,)
893+
size = relax.ShapeExpr(size)
894+
return self.block_builder.emit(
895+
relax.op.full(
896+
size,
897+
relax.const(1, self_var.struct_info.dtype),
898+
self_var.struct_info.dtype,
899+
)
900+
)
901+
741902
########## Others ##########
742903

743904
def _getitem(self, node: fx.Node) -> relax.Var:

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,22 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var:
162162
scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", None)
163163
return self._upsample_impl(x, size, align_corners, scale_factor, "nearest_neighbor")
164164

165+
########## Manipulation ##########
166+
167+
def _select(self, node: fx.Node) -> relax.Var:
168+
x = self.env[node.args[0]]
169+
dim = node.args[1]
170+
index = relax.const(node.args[2], "int64")
171+
return self.block_builder.emit(relax.op.take(x, index, dim))
172+
173+
def _slice(self, node: fx.Node) -> relax.Var:
174+
x = self.env[node.args[0]]
175+
axes = [node.args[1]]
176+
begin = [node.args[2]]
177+
end = [node.args[3]]
178+
stride = [node.args[4] if len(node.args) > 4 else 1]
179+
return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride))
180+
165181
def create_convert_map(
166182
self,
167183
) -> Dict[str, Callable[[fx.Node], relax.Var]]:
@@ -249,7 +265,30 @@ def create_convert_map(
249265
"argmax.default": self._argmax_argmin(relax.op.argmax),
250266
"argmin.default": self._argmax_argmin(relax.op.argmin),
251267
# tensor manipulation
268+
"cat.default": self._cat,
269+
"concat.default": self._cat,
270+
"cumsum.default": self._cumsum,
271+
"expand.default": self._expand,
272+
"permute.default": self._permute,
273+
"repeat.default": self._repeat,
274+
"select.int": self._select,
275+
"slice.Tensor": self._slice,
276+
"split.Tensor": self._split,
277+
"squeeze.default": self._squeeze,
278+
"squeeze.dim": self._squeeze,
279+
"tile.default": self._tile,
280+
"transpose.int": self._transpose,
281+
"unsqueeze.default": lambda node: self.block_builder.emit(
282+
relax.op.expand_dims(self.env[node.args[0]], node.args[1])
283+
),
252284
"view.default": self._reshape,
285+
# tensor creation
286+
"_to_copy.default": self._to_copy,
287+
"arange.start": self._arange,
288+
"clone.default": lambda node: self.env[node.args[0]],
289+
"empty.memory_format": self._empty,
290+
"fill.Scalar": self._fill,
291+
"new_ones.default": self._new_ones,
253292
# other
254293
"getitem": self._getitem,
255294
}

0 commit comments

Comments
 (0)