Skip to content

[Bug][Pytorch] Torch divide default behaviour not obeyed #10720

@AleksKnezevic

Description

@AleksKnezevic

I believe that data type promotion is happening incorrectly for aten::div. As you can see here: https://pytorch.org/docs/stable/generated/torch.div.html#torch.div when both inputs are integers, they should get promoted to default scalar type, to allow true division.

Expected behavior

Running the following code should produce a float divide:

inv_freq = torch.arange(0, 64, 2) / 64

Actual behavior

It produces an integer divide:

%2 = divide(%0, 64 /* ty=int64 /, Tensor[(32), int64], int64) / ty=Tensor[(32), int64] */;

Steps to reproduce

Run the following through pytorch front end:

inv_freq = torch.arange(0, 64, 2) / 64

Code proposal

You didn't ask for it in the bug report, but I already fixed it locally, so I'll share :)


diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 1e621ff31..b8a6776a8 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -255,6 +255,20 @@ class PyTorchOpConverter:
     # Operator implementations
     def make_elemwise(self, name):
         def elemwise(inputs, input_types):
+            if name == "divide":
+                # https://pytorch.org/docs/stable/generated/torch.div.html#torch.div
+                # None - default behavior. Performs no rounding and, if both input and 
+                # other are integer types, promotes the inputs to the default scalar type.
+                if all(["int" in input_type for input_type in input_types[:2]]):
+                    input_types[:2] = ["float32"] * 2
+                    cast_inputs = []
+                    for inp in inputs[:2]:
+                        if np.isscalar(inp):
+                            cast_inputs.append(_expr.const(inp, dtype="float32"))
+                        else:
+                            cast_inputs.append(_op.cast(inp, "float32"))
+                    inputs[:2] = cast_inputs
+
             data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2])
             return get_relay_op(name)(data0, data1)
 
 

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions