From c586d27b3ab0e46b5c4dd02304166464a54fc0d1 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Fri, 28 Feb 2025 12:07:13 -0800 Subject: [PATCH 1/3] [ExecuTorch] Arm Ethos: Do not depend on `torch.testing._internal ` This can cuase issues with `disable_global_flags` and internal state of the library, this is something which is set when importing this. Differential Revision: [D70402061](https://our.internmc.facebook.com/intern/diff/D70402061/) [ghstack-poisoned] --- backends/arm/test/passes/test_rescale_pass.py | 5 ++-- backends/arm/test/runner_utils.py | 26 ++++++++++++++++--- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/backends/arm/test/passes/test_rescale_pass.py b/backends/arm/test/passes/test_rescale_pass.py index 90ad502378c..5725e1884b3 100644 --- a/backends/arm/test/passes/test_rescale_pass.py +++ b/backends/arm/test/passes/test_rescale_pass.py @@ -13,7 +13,6 @@ from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from parameterized import parameterized -from torch.testing._internal import optests def test_rescale_op(): @@ -64,7 +63,7 @@ def test_nonzero_zp_for_int32(): ), ] for sample_input in sample_inputs: - with pytest.raises(optests.generate_tests.OpCheckError): + with pytest.raises(Exception): torch.library.opcheck(torch.ops.tosa._rescale, sample_input) @@ -87,7 +86,7 @@ def test_zp_outside_range(): ), ] for sample_input in sample_inputs: - with pytest.raises(optests.generate_tests.OpCheckError): + with pytest.raises(Exception): torch.library.opcheck(torch.ops.tosa._rescale, sample_input) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 5a0bfe2c37c..0b1c5b05431 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -34,12 +34,32 @@ from torch.fx.node import Node from torch.overrides import TorchFunctionMode -from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict from tosa import TosaGraph logger = logging.getLogger(__name__) logger.setLevel(logging.CRITICAL) +# Copied from PyTorch. +# From torch/testing/_internal/common_utils.py:torch_to_numpy_dtype_dict +# To avoid a dependency on _internal stuff. +_torch_to_numpy_dtype_dict = { + torch.bool : np.bool_, + torch.uint8 : np.uint8, + torch.uint16 : np.uint16, + torch.uint32 : np.uint32, + torch.uint64 : np.uint64, + torch.int8 : np.int8, + torch.int16 : np.int16, + torch.int32 : np.int32, + torch.int64 : np.int64, + torch.float16 : np.float16, + torch.float32 : np.float32, + torch.float64 : np.float64, + torch.bfloat16 : np.float32, + torch.complex32 : np.complex64, + torch.complex64 : np.complex64, + torch.complex128: np.complex128 +} class QuantizationParams: __slots__ = ["node_name", "zp", "scale", "qmin", "qmax", "dtype"] @@ -335,7 +355,7 @@ def run_corstone( output_dtype = node.meta["val"].dtype tosa_ref_output = np.fromfile( os.path.join(intermediate_path, f"out-{i}.bin"), - torch_to_numpy_dtype_dict[output_dtype], + _torch_to_numpy_dtype_dict[output_dtype], ) output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape)) @@ -349,7 +369,7 @@ def prep_data_for_save( ): if isinstance(data, torch.Tensor): data_np = np.array(data.detach(), order="C").astype( - torch_to_numpy_dtype_dict[data.dtype] + _torch_to_numpy_dtype_dict[data.dtype] ) else: data_np = np.array(data) From 0a1a3c0499ad7657ced6a4dc6a214e3852004f8e Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Fri, 28 Feb 2025 12:08:03 -0800 Subject: [PATCH 2/3] Update on "[ExecuTorch] Arm Ethos: Do not depend on `torch.testing._internal `" This can cuase issues with `disable_global_flags` and internal state of the library, this is something which is set when importing this. Differential Revision: [D70402061](https://our.internmc.facebook.com/intern/diff/D70402061/) [ghstack-poisoned] From 8440cea2855ccead52a71378d9261794c2a14b33 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Fri, 28 Feb 2025 12:09:02 -0800 Subject: [PATCH 3/3] Update on "[ExecuTorch] Arm Ethos: Do not depend on `torch.testing._internal `" This can cuase issues with `disable_global_flags` and internal state of the library, this is something which is set when importing this. Differential Revision: [D70402061](https://our.internmc.facebook.com/intern/diff/D70402061/) [ghstack-poisoned]