Skip to content

Commit

Permalink
[AMP] Add default op attribute registration to __init__.py (#8460)
Browse files Browse the repository at this point in the history
* add attribute registration to init

* blackify

* remove unused improt

* jostle ci

* avoid circular import

* change order to match orig

* other things

Co-authored-by: Andrew Zhao Luo <[email protected]>
  • Loading branch information
AndrewZhaoLuo and Andrew Zhao Luo authored Jul 15, 2021
1 parent 11c5b6d commit 8a8c9b2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
# transformation passes
from .transform import *
from .recast import recast
from . import fake_quantization_to_integer
from . import fake_quantization_to_integer, mixed_precision
9 changes: 4 additions & 5 deletions python/tvm/relay/transform/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""Default behavior for ops in mixed_precision pass. Import this file to use."""
from typing import List

from tvm import relay
from tvm.relay.op import register_mixed_precision_conversion

# MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
Expand Down Expand Up @@ -141,7 +140,7 @@ def decorator(func):
return decorator


def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]:
def get_generic_out_dtypes(call_node: "relay.Call", mixed_precision_type: str) -> List[str]:
"""A function which returns output dtypes in a way which works for most ops.
Parameters
Expand Down Expand Up @@ -174,15 +173,15 @@ def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) ->
# Take in CallNodes and a DType and returns a conversion type,
# an accumulation dtype, and an output_dtype.
@register_func_to_op_list(list_ops=DEFAULT_ALWAYS_LIST)
def generic_always_op(call_node: relay.Call, mixed_precision_type: str) -> List:
def generic_always_op(call_node: "relay.Call", mixed_precision_type: str) -> List:
return [MIXED_PRECISION_ALWAYS] + get_generic_out_dtypes(call_node, mixed_precision_type)


@register_func_to_op_list(list_ops=DEFAULT_FOLLOW_LIST)
def generic_follow_op(call_node: relay.Call, mixed_precision_type: str) -> List:
def generic_follow_op(call_node: "relay.Call", mixed_precision_type: str) -> List:
return [MIXED_PRECISION_FOLLOW] + get_generic_out_dtypes(call_node, mixed_precision_type)


@register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST)
def generic_never_op(call_node: relay.Call, mixed_precision_type: str) -> List:
def generic_never_op(call_node: "relay.Call", mixed_precision_type: str) -> List:
return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type)

0 comments on commit 8a8c9b2

Please sign in to comment.