diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index 9ed40f85c3bc..378b0c38ff64 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -19,4 +19,4 @@ # transformation passes from .transform import * from .recast import recast -from . import fake_quantization_to_integer +from . import fake_quantization_to_integer, mixed_precision diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 6f8ecb970221..1e982a0f18a4 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -18,7 +18,6 @@ """Default behavior for ops in mixed_precision pass. Import this file to use.""" from typing import List -from tvm import relay from tvm.relay.op import register_mixed_precision_conversion # MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory @@ -141,7 +140,7 @@ def decorator(func): return decorator -def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]: +def get_generic_out_dtypes(call_node: "relay.Call", mixed_precision_type: str) -> List[str]: """A function which returns output dtypes in a way which works for most ops. Parameters @@ -174,15 +173,15 @@ def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> # Take in CallNodes and a DType and returns a conversion type, # an accumulation dtype, and an output_dtype. @register_func_to_op_list(list_ops=DEFAULT_ALWAYS_LIST) -def generic_always_op(call_node: relay.Call, mixed_precision_type: str) -> List: +def generic_always_op(call_node: "relay.Call", mixed_precision_type: str) -> List: return [MIXED_PRECISION_ALWAYS] + get_generic_out_dtypes(call_node, mixed_precision_type) @register_func_to_op_list(list_ops=DEFAULT_FOLLOW_LIST) -def generic_follow_op(call_node: relay.Call, mixed_precision_type: str) -> List: +def generic_follow_op(call_node: "relay.Call", mixed_precision_type: str) -> List: return [MIXED_PRECISION_FOLLOW] + get_generic_out_dtypes(call_node, mixed_precision_type) @register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST) -def generic_never_op(call_node: relay.Call, mixed_precision_type: str) -> List: +def generic_never_op(call_node: "relay.Call", mixed_precision_type: str) -> List: return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type)