Skip to content

Commit 43eb1a1

Browse files
committed
Created Look Up Table for unary ops such that the values are computed during compile time and take op is used to access the values at runtime
1 parent 2c9af0f commit 43eb1a1

File tree

3 files changed

+576
-0
lines changed

3 files changed

+576
-0
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-docstring, invalid-name, unnecessary-comprehension, unused-argument
18+
19+
import tvm
20+
import tvm.testing
21+
from tvm import relax
22+
from tvm.contrib.hexagon import hexagon_unary_ops
23+
24+
25+
def op_replace(call_node):
26+
def is_op(op_name: str, call_node: relax.Call) -> bool:
27+
if not isinstance(call_node, relax.Call):
28+
return False
29+
call_tir_op = tvm.ir.Op.get("relax.call_tir")
30+
if call_node.op != call_tir_op:
31+
return False
32+
global_var = call_node.args[0]
33+
return op_name in global_var.name_hint
34+
35+
ops = ["tanh", "sqrt", "rsqrt", "exp", "erf", "sigmoid", "hardswish", "log", "abs"]
36+
for op in ops:
37+
if is_op(op, call_node):
38+
return True
39+
return False
40+
41+
42+
@relax.expr_functor.mutator
43+
class Tanh2TakeReplace(tvm.relax.PyExprMutator):
44+
def __init__(self, mod: tvm.IRModule) -> None:
45+
super().__init__(mod)
46+
self.mod_ = mod
47+
48+
def transform(self) -> tvm.IRModule:
49+
# Iterate over all the nodes to check for the node replaceable
50+
for global_var, func in self.mod_.functions.items():
51+
# Skip non-relax functions
52+
if not isinstance(func, relax.Function):
53+
continue
54+
updated_func = self.visit_expr(func)
55+
self.builder_.normalize(updated_func)
56+
self.builder_.update_func(global_var, updated_func)
57+
# At the end of the transformation we return the updated IRModule from the BlockBuilder.
58+
return self.builder_.get()
59+
60+
def visit_call_(self, call_node: relax.Call) -> relax.Call:
61+
if call_node.args[1][0].struct_info.dtype == "uint8":
62+
if op_replace(call_node):
63+
inp, inp_scale, inp_zp, out_scale, out_zp = [x for x in call_node.args[1]]
64+
# LUT node creation
65+
LUT = hexagon_unary_ops.LUT_generation(
66+
inp_scale, inp_zp, out_scale, out_zp, call_node.args[0].name_hint
67+
)
68+
# Take operation node creation
69+
take_func = hexagon_unary_ops.generate_take_primfunc(inp, call_node.struct_info)
70+
take_func = take_func.without_attr("global_symbol")
71+
take_func_gv = self.builder_.add_func(take_func, "take")
72+
take_node = relax.call_tir(
73+
take_func_gv,
74+
relax.expr.Tuple(
75+
[call_node.args[1][0], relax.expr.Constant(tvm.nd.array(LUT))]
76+
),
77+
call_node.struct_info,
78+
)
79+
return take_node
80+
return call_node
81+
82+
83+
@tvm.ir.transform.module_pass(opt_level=2, name="replace_tanh_take")
84+
class PassReplaceWithTakeOpPrimFuncs:
85+
def transform_module(self, mod, ctx):
86+
return Tanh2TakeReplace(mod).transform()
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-docstring, invalid-name
18+
import logging
19+
import numpy as np
20+
from scipy import special
21+
from tvm import te
22+
23+
logger = logging.getLogger(__name__)
24+
25+
######################################################################
26+
#################### PRIMFUNC FOR LUT and Take Op ####################
27+
######################################################################
28+
29+
30+
def saturate(x: te.Tensor, dtype: str):
31+
"""Saturate value for the specified data type"""
32+
return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype)))
33+
34+
35+
def hardswish_func(x):
36+
x_2 = np.add(x, 3.0)
37+
x_2 = np.clip(x_2, 0.0, 6.0)
38+
return x * x_2 / 6.0
39+
40+
41+
def LUT_generation(inp_scale, inp_zp, out_scale, out_zp, op_name) -> None:
42+
LUT = []
43+
for i in range(256):
44+
i = np.int32(i)
45+
# converting the constants to the numpy value
46+
if inp_zp.data.shape == ():
47+
i_zp = inp_zp.data.numpy()[()]
48+
if inp_scale.data.shape == ():
49+
i_scale = inp_scale.data.numpy()[()]
50+
if out_zp.data.shape == ():
51+
o_zp = out_zp.data.numpy()[()]
52+
if out_scale.data.shape == ():
53+
o_scale = out_scale.data.numpy()[()]
54+
# Dequantization followed by computing the op value
55+
dequant = (i - i_zp) * i_scale
56+
if "tanh" in op_name:
57+
op_val = np.tanh(dequant)
58+
elif "rsqrt" in op_name:
59+
op_val = 1 / np.sqrt(dequant)
60+
elif "sqrt" in op_name:
61+
op_val = np.sqrt(dequant)
62+
elif "exp" in op_name:
63+
op_val = np.exp(dequant)
64+
elif "erf" in op_name:
65+
op_val = special.erf(dequant)
66+
elif "sigmoid" in op_name:
67+
op_val = 1 / (1 + np.exp(np.negative(dequant)))
68+
elif "hardswish" in op_name:
69+
op_val = hardswish_func(dequant)
70+
elif "log" in op_name:
71+
op_val = np.log(dequant)
72+
elif "abs" in op_name:
73+
op_val = np.abs(dequant)
74+
else :
75+
logger.error("Error op is other than unary op")
76+
77+
# Quantizing the value generated and appending in the Look Up Table)
78+
quant = np.round((op_val) / o_scale) + o_zp
79+
val = np.maximum(0, np.minimum(quant, 255)).astype(np.uint8)
80+
LUT.append(val)
81+
return LUT
82+
83+
84+
def generate_take_primfunc(inp, struct_info):
85+
# Generating the take op
86+
N, H, W, C = inp.struct_info.shape
87+
data = te.placeholder((N, H, W, C), dtype=struct_info.dtype, name="data")
88+
LUT_func = te.placeholder((256,), dtype="uint8", name="LUT")
89+
take = te.compute(
90+
struct_info.shape,
91+
lambda *indices: saturate(
92+
(LUT_func[data[indices].astype("uint8")]), struct_info.dtype
93+
).astype(struct_info.dtype),
94+
name="take_op",
95+
)
96+
mod = te.create_prim_func([data, LUT_func, take])
97+
return mod

0 commit comments

Comments
 (0)