Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.quant_args import QuantArgs

from executorch.backends.arm.tosa.specification import get_context_spec, Tosa_1_00
from executorch.backends.arm.tosa.specification import get_context_spec
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand Down Expand Up @@ -40,9 +40,7 @@ def call_operator(self, op, args, kwargs, meta):
if args[0].data.dtype == torch.int8:
return super().call_operator(op, args, kwargs, meta)
elif args[0].data.dtype == torch.int16:
if isinstance(tosa_spec, Tosa_1_00) and not tosa_spec.support_extension(
"int16"
):
if not tosa_spec.support_extension("int16"):
raise ValueError(
"int16 activation for convolution requires TOSA int16 extension"
)
Expand Down
12 changes: 12 additions & 0 deletions backends/arm/tosa/specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ def support_float(self) -> bool:
"""Return True if floating-point operations are supported."""
raise NotImplementedError

def support_extension(self, extension: str) -> bool:
"""Return True if an extension is supported and enabled.

Args:
extension (str): Extension name (for example ``int4``, ``bf16``).

Returns:
bool: True if the extension is valid for the active profiles and selected.

"""
raise NotImplementedError

def __init__(self, version: Version, extras: List[str]):
"""Initialize the base specification.

Expand Down
Loading