Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,51 @@ def convert(node: fx.Node):

########## Manipulation ##########

def _cat(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
return self.block_builder.emit(relax.op.concat(args[0], axis=axis))

def _cumsum(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]

dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
if "dtype" in node.kwargs:
dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
else:
dtype = None
if "out" in node.kwargs:
raise ValueError("specifying out for cumsum is not supported yet")

return self.block_builder.emit(relax.op.cumsum(x, dim, dtype))

def _expand(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
sizes = args[1:] if len(args) > 2 else args[1]
broadcast_shape, in_shape = [], self.shape_of(args[0])
for idx, i in enumerate(sizes):
if isinstance(i, int) and i == -1:
broadcast_shape.append(in_shape[idx])
else:
broadcast_shape.append(i)
return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape))

def _permute(self, node: fx.Node) -> relax.Var:
import torch # type: ignore

args = self.retrieve_args(node)
x = args[0]
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:]
return self.block_builder.emit(relax.op.permute_dims(x, dims))

def _repeat(self, node: fx.Node) -> relax.Var:
import torch # type: ignore

args = self.retrieve_args(node)
x = args[0]
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:]
return self.block_builder.emit(relax.op.tile(x, dims))

def _reshape(self, node: fx.Node) -> relax.Var:
import torch # type: ignore

Expand All @@ -738,6 +783,122 @@ def _reshape(self, node: fx.Node) -> relax.Var:
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:]
return self.block_builder.emit(relax.op.reshape(x, dims))

def _split(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
split_size = node.args[1]
dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
if isinstance(split_size, (list, tuple)):
n_section = []
for s in split_size[:-1]:
cum_sum = 0 if not n_section else n_section[-1]
n_section.append(s + cum_sum)
else:
n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size
return self.block_builder.emit(relax.op.split(x, n_section, dim))

def _squeeze(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
return self.block_builder.emit(relax.op.squeeze(x, dim))

def _tile(self, node: fx.Node) -> relax.Var:
import torch # type: ignore

args = self.retrieve_args(node)
x = args[0]
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:]
return self.block_builder.emit(relax.op.tile(x, dims))

def _transpose(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
full_idx = list(range(len(self.shape_of(args[0]))))
full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]]
return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx))

########## Creation ##########

def _to_copy(self, node: fx.Node) -> relax.Var:
import torch # type: ignore

x = self.env[node.args[0]]
if len(node.args) == 2:
if isinstance(node.args[1], torch.dtype):
dtype = self._convert_data_type(node.args[1], self.env)
return self.block_builder.emit(relax.op.astype(x, dtype))
elif "dtype" in node.kwargs:
dtype = self._convert_data_type(node.kwargs["dtype"], self.env)
return self.block_builder.emit(relax.op.astype(x, dtype))
return x

def _arange(self, node: fx.Node) -> relax.Var:
import torch # type: ignore

start_end_step = [None, None, None]
if "start" in node.kwargs:
start_end_step[0] = node.kwargs["start"]
if "end" in node.kwargs:
start_end_step[1] = node.kwargs["end"]
if "step" in node.kwargs:
start_end_step[2] = node.kwargs["step"]

if len(node.args) == 1:
assert start_end_step[1] is None
start_end_step[1] = node.args[0]
elif len(node.args) == 2:
assert start_end_step[0] is None
assert start_end_step[1] is None
start_end_step[0] = node.args[0]
start_end_step[1] = node.args[1]
elif len(node.args) == 3:
assert start_end_step[0] is None
assert start_end_step[1] is None
assert start_end_step[2] is None
start_end_step[0] = node.args[0]
start_end_step[1] = node.args[1]
start_end_step[2] = node.args[2]

if start_end_step[0] is None:
start_end_step[0] = 0
if start_end_step[2] is None:
start_end_step[2] = 1

if "dtype" in node.kwargs:
dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
elif any([isinstance(x, float) for x in start_end_step]):
dtype = self._convert_data_type(torch.get_default_dtype())
else:
dtype = "int64"
start_end_step = [
self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step
]
return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype))

def _empty(self, node: fx.Node) -> relax.Var:
dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
return self.block_builder.emit(relax.op.zeros(node.args[0], dtype))

def _fill(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
dtype = x.struct_info.dtype
value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype)
return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype))

def _new_ones(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
self_var = args[0]
size = args[1] if isinstance(args[1], (list, tuple)) else args[1:]
if not isinstance(size, (list, tuple)):
size = (size,)
size = relax.ShapeExpr(size)
return self.block_builder.emit(
relax.op.full(
size,
relax.const(1, self_var.struct_info.dtype),
self_var.struct_info.dtype,
)
)

########## Others ##########

def _getitem(self, node: fx.Node) -> relax.Var:
Expand Down
39 changes: 39 additions & 0 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,22 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var:
scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", None)
return self._upsample_impl(x, size, align_corners, scale_factor, "nearest_neighbor")

########## Manipulation ##########

def _select(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1]
index = relax.const(node.args[2], "int64")
return self.block_builder.emit(relax.op.take(x, index, dim))

def _slice(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
axes = [node.args[1]]
begin = [node.args[2]]
end = [node.args[3]]
stride = [node.args[4] if len(node.args) > 4 else 1]
return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride))

def create_convert_map(
self,
) -> Dict[str, Callable[[fx.Node], relax.Var]]:
Expand Down Expand Up @@ -249,7 +265,30 @@ def create_convert_map(
"argmax.default": self._argmax_argmin(relax.op.argmax),
"argmin.default": self._argmax_argmin(relax.op.argmin),
# tensor manipulation
"cat.default": self._cat,
"concat.default": self._cat,
"cumsum.default": self._cumsum,
"expand.default": self._expand,
"permute.default": self._permute,
"repeat.default": self._repeat,
"select.int": self._select,
"slice.Tensor": self._slice,
"split.Tensor": self._split,
"squeeze.default": self._squeeze,
"squeeze.dim": self._squeeze,
"tile.default": self._tile,
"transpose.int": self._transpose,
"unsqueeze.default": lambda node: self.block_builder.emit(
relax.op.expand_dims(self.env[node.args[0]], node.args[1])
),
"view.default": self._reshape,
# tensor creation
"_to_copy.default": self._to_copy,
"arange.start": self._arange,
"clone.default": lambda node: self.env[node.args[0]],
"empty.memory_format": self._empty,
"fill.Scalar": self._fill,
"new_ones.default": self._new_ones,
# other
"getitem": self._getitem,
}
Expand Down
Loading