Skip to content

Commit d6cad0c

Browse files
committed
Responding to the reviews
1 parent 01cbf26 commit d6cad0c

File tree

5 files changed

+16
-17
lines changed

5 files changed

+16
-17
lines changed

python/tvm/relay/backend/contrib/ethosu/codegen.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,12 @@ def __init__(self):
4646
def create_op_with_lut(self, call):
4747
"""Extract the parameters and attributes from the NPU operator and create
4848
a new operator with LUT.
49+
50+
Parameters
4951
----------
5052
call : tvm.relay.expr.Call
5153
The current call node being visited.
54+
5255
Returns
5356
-------
5457
tvm.relay.expr.Call
@@ -63,8 +66,7 @@ def create_op_with_lut(self, call):
6366
new_attrs["activation"] = activation
6467

6568
# Assume that LUT is always the last argument
66-
new_args = [ethosu_op.args[n] for n in range(len(ethosu_op.args) - 1)]
67-
new_args.append(lut)
69+
new_args = ethosu_op.args[:-1] + [lut]
6870
assert ethosu_op.op.name in self.lut_ops.keys()
6971

7072
return self.lut_ops[ethosu_op.op.name](*new_args, **new_attrs)
@@ -73,10 +75,12 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
7375
"""Recursively visit call nodes in the input graph and if an ethosu.identity
7476
operator with LUT is found and the preceding operator has a LUT attribute, create
7577
a new NPU operator.
78+
7679
Parameters
7780
----------
7881
call : tvm.relay.expr.Call
7982
The current call node being visited.
83+
8084
Returns
8185
-------
8286
tvm.relay.expr.Call
@@ -104,24 +108,26 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
104108
return new_call
105109

106110

107-
@relay.transform.function_pass(opt_level=1, name="LutOptimizer")
111+
@relay.transform.function_pass(opt_level=1, name="LUTsOptimizer")
108112
class LUTsOptimizer(Pass):
109-
"""Register LutOptimizer as a relay pass."""
113+
"""Register LUTsOptimizer as a relay pass."""
110114

111115
def transform_function(
112116
self, func: tvm.relay.function.Function, mod: tvm.IRModule, _
113117
) -> tvm.IRModule:
114118
"""Visit relay nodes in the given module.
119+
115120
Parameters
116121
----------
117122
func : tvm.relay.function.Function
118-
The function to apply the layout optimization pass to.
123+
The function to apply the optimization pass for multiple LUTs to.
119124
mod : tvm.IRModule
120-
The module to apply the layout optimization pass to.
125+
The module to apply the optimization pass for multiple LUTs to.
126+
121127
Returns
122128
-------
123129
mod : tvm.IRModule
124-
New module with augmented layouts.
130+
New module with optimized LUTs.
125131
"""
126132
assert len(mod.functions.items()) == 1, "Module can only contain one function."
127133
return OptimizeLUTs().visit(func)

python/tvm/relay/backend/contrib/ethosu/legalize.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from tvm.relay.backend.contrib.ethosu import op as ethosu_ops # type: ignore
3333
from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout # type: ignore
3434
from tvm.relay.backend.contrib.ethosu import vela_api
35+
from tvm.relay.backend.contrib.ethosu import util
3536
from tvm.relay.op.contrib import ethosu as ethosu_patterns # type: ignore
3637

3738

@@ -124,11 +125,6 @@ def __call__(self, *args, **kwargs):
124125
pass
125126

126127

127-
def round_away_zero(f):
128-
r = -0.5 if (f < 0) else 0.5
129-
return np.trunc(f + r)
130-
131-
132128
def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
133129
"""Method to calculate the values of the tanh lookup table"""
134130
lut_values = list()
@@ -138,7 +134,7 @@ def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
138134
for x in range(qmin, qmax + 1):
139135
x_real = ifm_scale * (x - ifm_zp)
140136
out_real = math.tanh(x_real)
141-
lut_result = int(round_away_zero(ofm_zp + out_real / ofm_scale))
137+
lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale))
142138
lut_result = min(qmax, max(qmin, lut_result))
143139
lut_values.append(lut_result)
144140

tests/python/contrib/test_ethosu/test_codegen.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,6 @@ def create_graph_single(input_tensor_name, input_tensor_shape, input_tensor_dtyp
970970

971971

972972
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
973-
974973
def test_tflite_tanh(accel_type):
975974
dtype = "int8"
976975
ifm_shape = [1, 115, 32, 7]

tests/python/contrib/test_ethosu/test_legalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ def representative_dataset():
10361036
converter.inference_output_type = tf.int8
10371037
tflite_model = converter.convert()
10381038
return tflite_model
1039-
1039+
10401040
tflite_graph = create_tflite_graph()
10411041
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
10421042

tests/python/contrib/test_ethosu/test_lookup_table.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ def test_tflite_lut_activations(accel_type):
4040
ifm_shape = (1, 55, 55, 3)
4141

4242
def create_tflite_graph():
43-
tf.config.run_functions_eagerly(True)
44-
4543
class Model(tf.Module):
4644
@tf.function
4745
def tf_func(self, x):

0 commit comments

Comments
 (0)