Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2759,6 +2759,20 @@ def aten_div(self: TFloat, other: TFloat) -> TFloat:
return op.Div(self, other)


@torch_op(
(
"aten::true_divide.Tensor",
"aten::true_divide.Scalar",
"_operator::truediv",
)
Comment thread
xadupre marked this conversation as resolved.
Outdated
)
def aten_div_int(self: TInt, other: TInt) -> TFloat:
Comment thread
xadupre marked this conversation as resolved.
Outdated
"""div.Tensor(Tensor self, Tensor other) -> Tensor"""

Comment thread
justinchuby marked this conversation as resolved.
Outdated
# Int inputs will be promoted to float by PyTorch
Comment thread
justinchuby marked this conversation as resolved.
Outdated
return op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype))


@torch_op(
(
"aten::div.Tensor",
Expand Down Expand Up @@ -3605,12 +3619,12 @@ def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat:


@torch_op(("aten::floor_divide", "_operator::floordiv"), traceable=True)
def aten_floor_divide_int(self: TInt, other: TInt) -> TInt:
def aten_floor_divide_int(self: TInt, other: TInt) -> FLOAT:
"""floor_divide(Tensor self, Tensor other) -> Tensor"""

# We implement floor_divide only for positive inputs (using integer division)
# because that is the usual intended case and is the most efficient.
return op.Div(self, other)
return op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype))
Comment thread
justinchuby marked this conversation as resolved.
Outdated


def aten_fmax(self: TensorType, other: TensorType) -> TensorType:
Expand Down