Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,31 @@ def square(self, inputs, input_types):
(dtype,) = input_types
return _op.power(inputs[0], _expr.const(2, dtype))

def tril(self, inputs, input_types):
data = inputs[0]
if len(inputs) == 2:
k_value = inputs[1]
else:
k_value = 0
input_shape = self.infer_shape(data)
k1, k2 = input_shape[-2:]
k1 = k_value + 1
diag_input = _op.zeros(input_shape, dtype=input_types[0])
return _op.matrix_set_diag(data, diag_input, k=(k1, k2))

def triu(self, inputs, input_types):
data = inputs[0]
if len(inputs) == 2:
k_value = inputs[1]
else:
k_value = 0
input_shape = self.infer_shape(data)
k1, k2 = input_shape[-2:]
k1 = (k1 * -1) - 1
k2 = k_value - 1
diag_input = _op.zeros(input_shape, dtype=input_types[0])
return _op.matrix_set_diag(data, diag_input, k=(k1, k2))

def arange(self, inputs, input_types):
def _get_value(val, dtype):
# dtype is a tvm dtype
Expand Down Expand Up @@ -3328,6 +3353,8 @@ def create_convert_map(self):
"aten::sqrt": self.make_unary("sqrt"),
"aten::rsqrt": self.make_unary("rsqrt"),
"aten::square": self.square,
"aten::tril": self.tril,
"aten::triu": self.triu,
"aten::ceil": self.make_unary("ceil"),
"aten::floor": self.make_unary("floor"),
"aten::round": self.make_unary("round"),
Expand Down
81 changes: 74 additions & 7 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,21 @@ def visit(op):
torch.cuda.empty_cache()


def verify_model_with_input(test_func, input_data, input_dict={}):
def verify_model_with_input(
test_func,
input_data,
*,
input_dict={},
custom_convert_map={},
rtol=1e-5,
atol=1e-5,
assert_shape_only=False,
):
baseline_outputs = test_func(*input_data)
trace = torch.jit.trace(test_func, [input.clone() for input in input_data])
input_names = ["input{}".format(idx) for idx, inp in enumerate(input_data)]
input_shapes = list(zip(input_names, [inp.shape for inp in input_data]))
mod, params = relay.frontend.from_pytorch(trace, input_shapes, {})
mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map)
with tvm.transform.PassContext(opt_level=3):
for target in ["llvm", "cuda"]:
if not tvm.runtime.enabled(target):
Expand All @@ -218,7 +227,8 @@ def verify_model_with_input(test_func, input_data, input_dict={}):

compiled_output = relay_model.get_output(0).numpy()
assert_shapes_match(baseline_outputs, compiled_output)
tvm.testing.assert_allclose(baseline_outputs, compiled_output, rtol=1e-5, atol=1e-5)
if assert_shape_only == False:
tvm.testing.assert_allclose(baseline_outputs, compiled_output, rtol=rtol, atol=atol)


# Single operator tests
Expand Down Expand Up @@ -1304,7 +1314,7 @@ def test_func(input_tensor, other_tensor):

input_data = [torch.rand([2, 1, 10, 1, 10]), torch.rand([2, 1, 10, 10])]

verify_model_with_input(test_func, input_data, {"input0": input_data[0]})
verify_model_with_input(test_func, input_data, input_dict={"input0": input_data[0]})


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -3423,6 +3433,64 @@ def forward(self, *args):
verify_model(Neg1().float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_tril():
torch.set_grad_enabled(False)

def test_func(input_data):
return torch.tril(input_data)

input_data = torch.rand([3, 3]).float()
verify_model(test_func, input_data=input_data)
input_data = torch.rand([1, 3, 10, 10]).float()
verify_model(test_func, input_data=input_data)

def test_func1(input_data):
return torch.tril(input_data, 1)

input_data = torch.rand([3, 3]).float()
verify_model(test_func1, input_data=input_data)
input_data = torch.rand([1, 3, 10, 10]).float()
verify_model(test_func1, input_data=input_data)

def test_func2(input_data):
return torch.tril(input_data, -1)

input_data = torch.rand([3, 3]).float()
verify_model(test_func2, input_data=input_data)
input_data = torch.rand([1, 3, 10, 10]).float()
verify_model(test_func2, input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_triu():
torch.set_grad_enabled(False)

def test_func(input_data):
return torch.triu(input_data)

input_data = torch.rand([3, 3]).float()
verify_model(test_func, input_data=input_data)
input_data = torch.rand([1, 3, 10, 10]).float()
verify_model(test_func, input_data=input_data)

def test_func1(input_data):
return torch.triu(input_data, 1)

input_data = torch.rand([3, 3]).float()
verify_model(test_func1, input_data=input_data)
input_data = torch.rand([1, 3, 10, 10]).float()
verify_model(test_func1, input_data=input_data)

def test_func2(input_data):
return torch.triu(input_data, -1)

input_data = torch.rand([3, 3]).float()
verify_model(test_func2, input_data=input_data)
input_data = torch.rand([1, 3, 10, 10]).float()
verify_model(test_func2, input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_where():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -3817,15 +3885,14 @@ def test_empty():
def test_func():
return torch.empty([1, 3, 10, 10])

verify_model_with_input(test_func, [])
verify_model_with_input(test_func, [], assert_shape_only=True)


@pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/11967")
def test_empty_like():
def test_func(data):
return torch.empty_like(data)

verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()])
verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()], assert_shape_only=True)


def test_forward_pretrained_bert_base_uncased():
Expand Down