diff --git a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py index d12904bbcb9..2f160474c5b 100644 --- a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py +++ b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py @@ -10,7 +10,7 @@ from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.quant_args import QuantArgs -from executorch.backends.arm.tosa.specification import get_context_spec, Tosa_1_00 +from executorch.backends.arm.tosa.specification import get_context_spec from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -40,9 +40,7 @@ def call_operator(self, op, args, kwargs, meta): if args[0].data.dtype == torch.int8: return super().call_operator(op, args, kwargs, meta) elif args[0].data.dtype == torch.int16: - if isinstance(tosa_spec, Tosa_1_00) and not tosa_spec.support_extension( - "int16" - ): + if not tosa_spec.support_extension("int16"): raise ValueError( "int16 activation for convolution requires TOSA int16 extension" ) diff --git a/backends/arm/tosa/specification.py b/backends/arm/tosa/specification.py index 6fca2163d41..c6c79f9ad9a 100644 --- a/backends/arm/tosa/specification.py +++ b/backends/arm/tosa/specification.py @@ -105,6 +105,18 @@ def support_float(self) -> bool: """Return True if floating-point operations are supported.""" raise NotImplementedError + def support_extension(self, extension: str) -> bool: + """Return True if an extension is supported and enabled. + + Args: + extension (str): Extension name (for example ``int4``, ``bf16``). + + Returns: + bool: True if the extension is valid for the active profiles and selected. + + """ + raise NotImplementedError + def __init__(self, version: Version, extras: List[str]): """Initialize the base specification.