Skip to content

Commit e6af874

Browse files
authored
[PyTorch] Fix rsub type (#10090)
* [PyTorch] Fix rsub type * fix
1 parent fa317ed commit e6af874

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1972,10 +1972,7 @@ def stack(self, inputs, input_types):
19721972
return self.tensor_array_stack(inputs, input_types)
19731973

19741974
def rsub(self, inputs, input_types):
1975-
data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2])
1976-
1977-
# TODO (t-vi): should this also be part of the type promotion?
1978-
alpha = _expr.const(float(inputs[2]))
1975+
data0, data1, alpha = self.pytorch_promote_types(inputs, input_types)
19791976

19801977
# note: rsub means data0 and data1 swap places
19811978
return get_relay_op("subtract")(data1, alpha * data0)

tests/python/frontend/pytorch/test_forward.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2691,6 +2691,13 @@ def forward(self, *args):
26912691
verify_model(Rsub2().float().eval(), input_data=[d1, d2])
26922692
verify_model(Rsub2().float().eval(), input_data=[d1, d3])
26932693

2694+
d1 = torch.rand([1, 3]).half()
2695+
d2 = torch.rand([1, 3]).half()
2696+
verify_model(Rsub1().half().eval(), input_data=[d1, d2])
2697+
verify_model(Rsub1().half().eval(), input_data=[d1, d3])
2698+
verify_model(Rsub2().half().eval(), input_data=[d1, d2])
2699+
verify_model(Rsub2().half().eval(), input_data=[d1, d3])
2700+
26942701

26952702
@tvm.testing.uses_gpu
26962703
def test_forward_embedding():

0 commit comments

Comments
 (0)