Skip to content

Commit 90ac839

Browse files
committed
Rebase and fix the review commits
1 parent f3da9af commit 90ac839

File tree

3 files changed

+37
-159
lines changed

3 files changed

+37
-159
lines changed

python/tvm/relay/backend/contrib/ethosu/legalize.py

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,14 @@ def get_lut_from_func(ifm_scale, ifm_zp, ofm_scale, ofm_zp, func):
141141
return lut_values
142142

143143

144-
class TanhRewriter(DFPatternCallback):
145-
"""This pass adds tanh as a LUT to the identity operator"""
144+
class LutActivationRewriter(DFPatternCallback):
145+
"""A class to create an identity operator with the LUT"""
146146

147-
def __init__(self):
147+
def __init__(self, params_class, activation_type, calc_func):
148148
super().__init__(require_type=True, rewrite_once=True)
149-
self.pattern = (
150-
wildcard().has_attr({"Composite": ethosu_patterns.TanhParams.composite_name})
151-
)(wildcard())
149+
self.pattern = (wildcard().has_attr({"Composite": params_class.composite_name}))(wildcard())
150+
self.activation_type = activation_type
151+
self.calc_func = calc_func
152152

153153
def callback(self, pre, post, node_map):
154154
id_input = post.args[0]
@@ -161,7 +161,9 @@ def callback(self, pre, post, node_map):
161161
input_scale = float(dequantize_args[1].data.asnumpy())
162162
input_zp = int(dequantize_args[2].data.asnumpy())
163163

164-
lut_values = get_lut_from_func(input_scale, input_zp, output_scale, output_zp, math.tanh)
164+
lut_values = get_lut_from_func(
165+
input_scale, input_zp, output_scale, output_zp, self.calc_func
166+
)
165167
lut = relay.const(lut_values, dtype="uint8")
166168

167169
# We baked the requantization into the LUT, so we don't requantize the identity operator
@@ -172,12 +174,21 @@ def callback(self, pre, post, node_map):
172174
ifm_zero_point=input_zp,
173175
ofm_scale=input_scale,
174176
ofm_zero_point=input_zp,
175-
activation="TANH",
177+
activation=self.activation_type,
176178
)
177179

178180
return identity
179181

180182

183+
class TanhRewriter(LutActivationRewriter):
184+
"""This pass adds tanh as a LUT to the identity operator"""
185+
186+
def __init__(self):
187+
super().__init__(
188+
params_class=ethosu_patterns.TanhParams, activation_type="TANH", calc_func=math.tanh
189+
)
190+
191+
181192
@ir.transform.module_pass(opt_level=1)
182193
class LegalizeTanh:
183194
"""This is the pass that wraps TanhRewriter"""
@@ -209,44 +220,16 @@ def sigmoid_calc_func(x):
209220
return y
210221

211222

212-
class SigmoidRewriter(DFPatternCallback):
223+
class SigmoidRewriter(LutActivationRewriter):
213224
"""This pass adds sigmoid as a LUT for identity op"""
214225

215226
def __init__(self):
216-
super().__init__(require_type=True, rewrite_once=True)
217-
self.pattern = (
218-
wildcard().has_attr({"Composite": ethosu_patterns.SigmoidParams.composite_name})
219-
)(wildcard())
220-
221-
def callback(self, pre, post, node_map):
222-
inp = post.args[0]
223-
224-
quantize_args = post.op.body.args
225-
output_scale = float(quantize_args[1].data.asnumpy())
226-
output_zp = int(quantize_args[2].data.asnumpy())
227-
228-
dequantize_args = quantize_args[0].args[0].args
229-
input_scale = float(dequantize_args[1].data.asnumpy())
230-
input_zp = int(dequantize_args[2].data.asnumpy())
231-
232-
lut_values = get_lut_from_func(
233-
input_scale, input_zp, output_scale, output_zp, sigmoid_calc_func
234-
)
235-
lut = relay.const(lut_values, dtype="uint8")
236-
237-
# We baked the requantization into the LUT, so we don't requantize the identity operator
238-
identity = ethosu_ops.ethosu_identity(
239-
ifm=inp,
240-
lut=lut,
241-
ifm_scale=input_scale,
242-
ifm_zero_point=input_zp,
243-
ofm_scale=input_scale,
244-
ofm_zero_point=input_zp,
245-
activation="SIGMOID",
227+
super().__init__(
228+
params_class=ethosu_patterns.SigmoidParams,
229+
activation_type="SIGMOID",
230+
calc_func=sigmoid_calc_func,
246231
)
247232

248-
return identity
249-
250233

251234
@ir.transform.module_pass(opt_level=1)
252235
class LegalizeSigmoid:

python/tvm/relay/op/contrib/ethosu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1049,7 +1049,7 @@ def __init__(self, func_body: Call):
10491049

10501050
def is_valid(self):
10511051
"""
1052-
This function checks whether reshape has compatible attributes with the NPU
1052+
This function checks whether sigmoid has compatible attributes with the NPU
10531053
"""
10541054
if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]):
10551055
return False

tests/python/contrib/test_ethosu/test_codegen.py

Lines changed: 12 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -826,131 +826,26 @@ def clz_comp(n):
826826

827827
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
828828
def test_tflite_tanh(accel_type):
829-
dtype = "int8"
830829
ifm_shape = [1, 115, 32, 7]
831830

832-
def create_tflite_graph():
833-
class Model(tf.Module):
834-
@tf.function
835-
def tanh_function(self, x):
836-
op = tf.nn.tanh(x)
837-
return op
838-
839-
model = Model()
840-
concrete_func = model.tanh_function.get_concrete_function(
841-
tf.TensorSpec(ifm_shape, dtype=tf.float32)
842-
)
843-
844-
# Convert the model
845-
def representative_dataset():
846-
for _ in range(100):
847-
data = np.random.rand(*tuple(ifm_shape))
848-
yield [data.astype(np.float32)]
849-
850-
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
851-
converter.optimizations = [tf.lite.Optimize.DEFAULT]
852-
converter.representative_dataset = representative_dataset
853-
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
854-
converter.inference_input_type = tf.int8
855-
converter.inference_output_type = tf.int8
856-
tflite_model = converter.convert()
857-
return tflite_model
858-
859-
tflite_graph = create_tflite_graph()
860-
861-
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
862-
863-
relay_module, params = relay.frontend.from_tflite(
864-
tflite_model,
865-
shape_dict={"input": ifm_shape},
866-
dtype_dict={"input": dtype},
867-
)
868-
mod = partition_for_ethosu(relay_module, params)
869-
870-
# Generate reference data
871-
input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
872-
873-
compiled_models = infra.build_source(
874-
mod,
875-
input_data,
876-
output_data,
877-
accel_type,
878-
)
879-
880-
# Assumes only two runtime.Modules are created -- i.e. single offload module
881-
ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
831+
@tf.function
832+
def tanh_func(x):
833+
op = tf.nn.tanh(x)
834+
return op
882835

883-
# Verify generated C source
884-
get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
885-
compilation_artifacts = get_artifacts(ethosu_module)
886-
cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
887-
infra.print_payload(cmms)
888-
infra.verify_source(compiled_models, accel_type)
836+
_compare_tvm_with_tflite(tanh_func, [ifm_shape], accel_type)
889837

890838

891839
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
892-
@pytest.mark.parametrize("ifm_shape", [[1, 115, 32, 7], [1, 4, 5, 2]])
893-
def test_tflite_sigmoid(accel_type, ifm_shape):
894-
dtype = "int8"
895-
896-
def create_tflite_graph():
897-
tf.config.run_functions_eagerly(True)
898-
899-
class Model(tf.Module):
900-
@tf.function
901-
def tanh_function(self, x):
902-
op = tf.nn.sigmoid(x)
903-
return op
904-
905-
model = Model()
906-
concrete_func = model.tanh_function.get_concrete_function(
907-
tf.TensorSpec(ifm_shape, dtype=tf.float32)
908-
)
909-
910-
# Convert the model
911-
def representative_dataset():
912-
for _ in range(100):
913-
data = np.random.rand(*tuple(ifm_shape))
914-
yield [data.astype(np.float32)]
915-
916-
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
917-
converter.optimizations = [tf.lite.Optimize.DEFAULT]
918-
converter.representative_dataset = representative_dataset
919-
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
920-
converter.inference_input_type = tf.int8
921-
converter.inference_output_type = tf.int8
922-
tflite_model = converter.convert()
923-
return tflite_model
924-
925-
tflite_graph = create_tflite_graph()
926-
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
927-
928-
relay_module, params = relay.frontend.from_tflite(
929-
tflite_model,
930-
shape_dict={"input": ifm_shape},
931-
dtype_dict={"input": dtype},
932-
)
933-
mod = partition_for_ethosu(relay_module, params)
934-
935-
# Generate reference data
936-
input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
937-
938-
compiled_models = infra.build_source(
939-
mod,
940-
input_data,
941-
output_data,
942-
accel_type,
943-
)
840+
def test_tflite_sigmoid(accel_type):
841+
ifm_shape = [1, 135, 41, 6]
944842

945-
# Assumes only two runtime.Modules are created -- i.e. single offload module
946-
ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
843+
@tf.function
844+
def sigmoid_function(x):
845+
op = tf.nn.sigmoid(x)
846+
return op
947847

948-
# Verify generated C source
949-
get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
950-
compilation_artifacts = get_artifacts(ethosu_module)
951-
cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
952-
infra.print_payload(cmms)
953-
infra.verify_source(compiled_models, accel_type)
848+
_compare_tvm_with_tflite(sigmoid_function, [ifm_shape], accel_type)
954849

955850

956851
if __name__ == "__main__":

0 commit comments

Comments
 (0)