From b0cb49ca93fbd2dc66983f3e192d9a90d33a6a3a Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Fri, 3 Dec 2021 13:51:25 -0600 Subject: [PATCH] Add scalar type promotion for mul and div (#454) --- e2e_testing/torchscript/basic.py | 1 + e2e_testing/torchscript/elementwise.py | 76 ++++++++++++++++++- e2e_testing/torchscript/type_promotion.py | 2 + .../TorchToLinalg/TorchToLinalg.cpp | 32 ++++---- 4 files changed, 96 insertions(+), 15 deletions(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index ea43fc4c46db..fb41476018c4 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -784,6 +784,7 @@ def AddCDivModule_basic(module, tu: TestUtils): # ============================================================================== + class DropoutModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index d1872dd03de6..af9734299945 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -363,6 +363,8 @@ def RsubModule_noalpha_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) # ============================================================================== + + class ElementwiseMulScalarModule(torch.nn.Module): def __init__(self): super().__init__() @@ -378,7 +380,52 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseMulScalarModule()) def ElementwiseMulScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) - + + + +class ElementwiseMulTensorFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.float64, True), + ]) + def forward(self, a, b): + return torch.mul(a, b) + + +@register_test_case( + module_factory=lambda: ElementwiseMulTensorFloatModule()) +def ElementwiseMulTensorFloatModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(4), + tu.rand(4).type(torch.float64)) + +class ElementwiseMulTensorIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int32, True), + ([-1], torch.int64, True), + ]) + def forward(self, a, b): + return torch.mul(a, b) + + +@register_test_case( + module_factory=lambda: ElementwiseMulTensorIntModule()) +def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils): + module.forward( + torch.randint(10, [4]).type(torch.int32), + torch.randint(10, [4])) + + # ============================================================================== class ElementwiseLogModule(torch.nn.Module): def __init__(self): @@ -553,7 +600,32 @@ def forward(self, x): def ElementwiseDivScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + +class ElementwiseDivTensorFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.float64, True), + ]) + def forward(self, a, b): + return torch.div(a, b) + + +@register_test_case( + module_factory=lambda: ElementwiseDivTensorFloatModule()) +def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(4), + tu.rand(4).type(torch.float64)) + + # ============================================================================== + + class ElementwiseAndIntegerModule(torch.nn.Module): def __init__(self): super().__init__() @@ -573,3 +645,5 @@ def forward(self, x, y): def ElementwiseAndIntegerModule_basic(module, tu: TestUtils): module.forward(torch.randint(-10, 10, (3, 4)).to(torch.int32), torch.randint(-10, 10, (3, 4))) + + diff --git a/e2e_testing/torchscript/type_promotion.py b/e2e_testing/torchscript/type_promotion.py index 6cad4ef03005..a7a5491c5b2e 100644 --- a/e2e_testing/torchscript/type_promotion.py +++ b/e2e_testing/torchscript/type_promotion.py @@ -111,3 +111,5 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: TypePromotionAlphaWiderModule()) def TypePromotionAlphaWiderModule_basic(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand()) + + diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 1439654cac7f..4515989011fe 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1531,24 +1531,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } } if (auto mul = dyn_cast(op)) { - if (!mul.getType() - .cast() - .getDtype() - .isa()) { - mul.emitError("unimplemented: non-floating point dtype"); - return nullptr; + AtenMulTensorOp::Adaptor adaptor(operands); + Type dtype = converter->convertType(mul.getType()) + .cast() + .getElementType(); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + if (dtype.isa()) { + return b.create(loc, lhs, rhs); + } else { + return b.create(loc, lhs, rhs); } - return b.create(loc, payloadArgs[0], payloadArgs[1]); } if (auto div = dyn_cast(op)) { - if (!div.getType() - .cast() - .getDtype() - .isa()) { + AtenDivTensorOp::Adaptor adaptor(operands); + Type dtype = converter->convertType(div.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) div.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } - return b.create(loc, payloadArgs[0], payloadArgs[1]); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + return b.create(loc, lhs, rhs); } if (auto pow = dyn_cast(op)) { if (!pow.getType()