Skip to content

Commit eff0949

Browse files
AleksKnezevicpfk-beta
authored andcommitted
Fix default pytorch divide behaviour (apache#10727)
Co-authored-by: Aleks Knezevic <[email protected]>
1 parent d3d7f87 commit eff0949

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,20 @@ def is_quantized_tensor(self, data):
254254
# Operator implementations
255255
def make_elemwise(self, name):
256256
def elemwise(inputs, input_types):
257+
if name == "divide":
258+
# https://pytorch.org/docs/stable/generated/torch.div.html#torch.div
259+
# None - default behavior. Performs no rounding and, if both input and
260+
# other are integer types, promotes the inputs to the default scalar type.
261+
if all(["int" in input_type for input_type in input_types[:2]]):
262+
input_types[:2] = ["float32"] * 2
263+
cast_inputs = []
264+
for inp in inputs[:2]:
265+
if np.isscalar(inp):
266+
cast_inputs.append(_expr.const(inp, dtype="float32"))
267+
else:
268+
cast_inputs.append(_op.cast(inp, "float32"))
269+
inputs[:2] = cast_inputs
270+
257271
data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2])
258272
return get_relay_op(name)(data0, data1)
259273

0 commit comments

Comments
 (0)