From a48b96c77e16855780e1aeabb9a3594d4d1ca269 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Wed, 28 Dec 2022 21:24:56 +0530 Subject: [PATCH 1/5] [CLML][TEST] Codegen test cases for ops Codegen verification test cases for all the ops (convolution, concat, pad, pool ..etc.) that are supported by clml BYOC path. Fix depthwise conv2d issue with layout --- python/tvm/relay/op/contrib/clml.py | 53 +-- src/relay/backend/contrib/clml/codegen.cc | 4 + src/runtime/contrib/clml/clml_runtime.cc | 4 +- .../contrib/test_clml/infrastructure.py | 60 ++- .../python/contrib/test_clml/test_network.py | 15 +- tests/python/contrib/test_clml/test_ops.py | 372 ++++++++++++++++-- 6 files changed, 432 insertions(+), 76 deletions(-) diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index 6453b8a06c9f..736ec93ec075 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -92,38 +92,36 @@ def preprocess_module(mod): preprocessed_mod : The processed module. """ - def convert_layout_conv2d(conv2d_function): - def convert_conv(attrs, inputs, tinfos, desired_layouts): - new_attrs = dict(attrs) - data_info = tinfos[0] - weight_info = tinfos[1] - desired_data_layout, desired_kernel_layout = map(str, desired_layouts) - new_attrs["data_layout"] = desired_data_layout - new_attrs["kernel_layout"] = desired_kernel_layout - - if is_depthwise_conv2d( - data_info.shape, - attrs["data_layout"], - weight_info.shape, - attrs["kernel_layout"], - attrs["groups"], - ): - dkl = desired_kernel_layout - new_attrs["kernel_layout"] = dkl[1] + dkl[0] + dkl[2] + dkl[3] - return conv2d_function(*inputs, **new_attrs) - - return convert_conv - - with OpAttrContext( - "nn.conv2d", "FTVMConvertOpLayout", convert_layout_conv2d(tvm.relay.nn.conv2d) - ): + def alter_conv(attrs, inputs, tinfos, out_type): + data, weight = inputs + new_attrs = dict(attrs) + data_info = tinfos[0] + weight_info = tinfos[1] + (desired_data_layout, desired_kernel_layout) = ("NCHW", "OIHW") + new_attrs["data_layout"] = desired_data_layout + new_attrs["kernel_layout"] = desired_kernel_layout + + if is_depthwise_conv2d( + data_info.shape, + attrs["data_layout"], + weight_info.shape, + attrs["kernel_layout"], + attrs["groups"], + ): + dkl = desired_kernel_layout + new_attrs["kernel_layout"] = dkl[1] + dkl[0] + dkl[2] + dkl[3] + return relay.nn.conv2d(*inputs, **new_attrs) + + with OpAttrContext( "nn.conv2d", "FTVMAlterOpLayout", alter_conv): seq = tvm.transform.Sequential( [ transform.ConvertLayout({"nn.conv2d": ["NCHW", "OIHW"]}), + transform.AlterOpLayout(), transform.FoldConstant(), ] ) - preprocessed_mod = seq(mod) + with tvm.transform.PassContext(opt_level=3): + preprocessed_mod = seq(mod) return preprocessed_mod @@ -275,6 +273,9 @@ def check_default_op(extract): ("clml.add", is_op("add")(wildcard(), wildcard()), check_binary_op), ("clml.subtract", is_op("subtract")(wildcard(), wildcard()), check_binary_op), ("clml.multiply", is_op("multiply")(wildcard(), wildcard()), check_binary_op), + ("clml.divide", is_op("divide")(wildcard(), wildcard()), check_binary_op), + ("clml.minimum", is_op("minimum")(wildcard(), wildcard()), check_binary_op), + ("clml.maximum", is_op("maximum")(wildcard(), wildcard()), check_binary_op), ("clml.softmax", is_op("nn.softmax")(wildcard()), check_softmax_op), ("clml.reshape", is_op("reshape")(wildcard()), check_default_op), ("clml.avg_pool2d", is_op("nn.avg_pool2d")(wildcard()), check_default_op), diff --git a/src/relay/backend/contrib/clml/codegen.cc b/src/relay/backend/contrib/clml/codegen.cc index 167c48e1baf5..0e0d2f482b07 100644 --- a/src/relay/backend/contrib/clml/codegen.cc +++ b/src/relay/backend/contrib/clml/codegen.cc @@ -332,6 +332,10 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer { bias = dense; dense = dense->args[0].as(); } + if (backend::IsOp(dense, "nn.bias_add")) { + bias = dense; + dense = dense->args[0].as(); + } ICHECK(backend::IsOp(dense, "nn.dense")); const auto* dense_op = dense->op.as(); ICHECK(dense_op); diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index a667caaafcd8..b78712e6564e 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -400,7 +400,7 @@ class CLMLRuntime : public JSONRuntimeBase { this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); this->layer_.func_outs.push_back(out); } else if ("add" == op_name || "subtract" == op_name || "multiply" == op_name || - "minimum" == op_name || "maximum" == op_name) { + "minimum" == op_name || "maximum" == op_name || "divide" == op_name) { auto out = CreateBinaryLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); this->layer_.func_outs.push_back(out); @@ -1236,6 +1236,8 @@ class CLMLRuntime : public JSONRuntimeBase { binary_op = CL_TENSOR_OP_SUB_QCOM; else if (op_name == "multiply") binary_op = CL_TENSOR_OP_MUL_QCOM; + else if (op_name == "divide") + binary_op = CL_TENSOR_OP_DIV_QCOM; else if (op_name == "minimum") binary_op = CL_TENSOR_OP_MIN_QCOM; else if (op_name == "maximum") diff --git a/tests/python/contrib/test_clml/infrastructure.py b/tests/python/contrib/test_clml/infrastructure.py index 89c22255d77d..81aebc62fca7 100644 --- a/tests/python/contrib/test_clml/infrastructure.py +++ b/tests/python/contrib/test_clml/infrastructure.py @@ -39,9 +39,9 @@ class Device: Configuration for CLML tests. Check tests/python/contrib/clml/ for the presence of an test_config.json file. - This file can be used to override the default configuration here which will attempt to run the Arm - Compute Library runtime tests locally if the runtime is available. Changing the configuration - will allow these runtime tests to be offloaded to a remote Arm device via a tracker for example. + This file can be used to override the default configuration here which will attempt to run the + Open CLML runtime tests locally if the runtime is available. Changing the configuration + will allow these runtime tests to be offloaded to a remote Snapdragon device via a tracker for example. Notes ----- @@ -101,6 +101,25 @@ def _get_remote(cls): return device +def get_cpu_op_count(mod): + """Traverse graph counting ops offloaded to TVM.""" + + class Counter(tvm.relay.ExprVisitor): + def __init__(self): + super().__init__() + self.count = 0 + + def visit_call(self, call): + if isinstance(call.op, tvm.ir.Op): + self.count += 1 + + super().visit_call(call) + + c = Counter() + c.visit(mod["main"]) + return c.count + + def skip_codegen_test(): """Skip test if it requires the CLML codegen and it's not present.""" if not tvm.get_global_func("relay.ext.clml", True): @@ -130,7 +149,6 @@ def build_and_run( try: libm = build_module(mod, device.target, device.target_host, params, enable_clml, tune_log) - clml_modules = extract_clml_modules(libm) for mod in clml_modules: source = mod.get_source("json") @@ -155,9 +173,9 @@ def build_and_run( for _ in range(no_runs): gen_module.run() out.append([gen_module.get_output(i) for i in range(outputs)]) - time_f = gen_module.module.time_evaluator("run", device.device.cl(0), number=1) - cost = time_f().mean - print("%g secs/iteration\n" % cost) + #time_f = gen_module.module.time_evaluator("run", device.device.cl(0), number=1) + #cost = time_f().mean + #print("%g secs/iteration\n" % cost) return out @@ -181,16 +199,36 @@ def extract_clml_modules(module): def verify_codegen( - module, + mod, known_good_codegen, + device, + params, num_clml_modules=1, tvm_ops=0, - target="llvm -mtriple=aarch64-linux-gnu", ): """Check clml codegen against a known good output.""" - module = build_module(module, target, tvm_ops=tvm_ops, clml_partitions=num_clml_modules) - clml_modules = extract_clml_modules(module) + if isinstance(mod, tvm.relay.expr.Call): + mod = tvm.IRModule.from_expr(mod) + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + mod = clml.partition_for_clml(mod, params) + tvm_op_count = get_cpu_op_count(mod) + assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected {}".format( + tvm_op_count, tvm_ops + ) + partition_count = 0 + for global_var in mod.get_global_vars(): + if "clml" in global_var.name_hint: + partition_count += 1 + + assert ( + num_clml_modules == partition_count + ), "Got {} Open CLML partitions, expected {}".format( + partition_count, num_clml_modules + ) + relay.backend.te_compiler.get().clear() + module = relay.build(mod, target=device.target, target_host=device.target_host, params=params) + clml_modules = extract_clml_modules(module) assert len(clml_modules) == num_clml_modules, ( f"The number of CLML modules produced ({len(clml_modules)}) does not " f"match the expected value ({num_clml_modules})." diff --git a/tests/python/contrib/test_clml/test_network.py b/tests/python/contrib/test_clml/test_network.py index 8d740d6dce4d..177359d9b18a 100644 --- a/tests/python/contrib/test_clml/test_network.py +++ b/tests/python/contrib/test_clml/test_network.py @@ -91,13 +91,8 @@ def get_model(): mod, params, inputs, input_data, device=device, atol=1e-5, rtol=1e-5 ) - # test - print("OpenCL:", outputs[0].asnumpy().shape) - print("CLML:", outputs[1].asnumpy().shape) - opencl_sort = np.argsort(outputs[1].asnumpy()).flatten() clml_sort = np.argsort(outputs[0].asnumpy()).flatten() - tvm.testing.assert_allclose(opencl_sort[:10], clml_sort[:10], rtol=1e-5, atol=1e-5) @@ -134,7 +129,6 @@ def get_model(): opencl_sort = np.argsort(outputs[1].asnumpy()).flatten() clml_sort = np.argsort(outputs[0].asnumpy()).flatten() - tvm.testing.assert_allclose(opencl_sort[:5], clml_sort[:5], rtol=1e-5, atol=1e-5) @@ -176,11 +170,10 @@ def get_model(): mod, params, inputs, input_data, device=device, atol=1e-5, rtol=1e-5 ) - # test - print("OpenCL:", outputs[0].asnumpy().shape) - print("CLML:", outputs[1].asnumpy().shape) - opencl_sort = np.argsort(outputs[1].asnumpy()).flatten() clml_sort = np.argsort(outputs[0].asnumpy()).flatten() - tvm.testing.assert_allclose(opencl_sort[:10], clml_sort[:10], rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/contrib/test_clml/test_ops.py b/tests/python/contrib/test_clml/test_ops.py index da09715fbe4c..1f8f333a4325 100644 --- a/tests/python/contrib/test_clml/test_ops.py +++ b/tests/python/contrib/test_clml/test_ops.py @@ -14,15 +14,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""CLML integration conv2d tests.""" +"""CLML integration operator tests.""" import tvm import numpy as np from tvm import relay +from tvm.relay.op.contrib import clml from tvm.relay import testing from tvm.ir import IRModule from tvm.contrib import utils -from test_clml.infrastructure import build_and_run, Device, skip_codegen_test +from test_clml.infrastructure import ( + build_and_run, + Device, + skip_codegen_test, + verify_codegen, + build_module, + get_cpu_op_count, +) import pytest @@ -54,11 +62,8 @@ def _get_conv_model( shape = (shape[0], shape[1], shape[2] + padding[0] * 2, shape[3] + padding[1] * 2) is_depthwise = shape[1] == channels == groups - weight_format = "OIHW" if is_depthwise else "OIHW" - if weight_format == "IOHW": - weight_shape = (shape[1] // groups, channels, kernel_h, kernel_w) - else: - weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w) + weight_format = "OIHW" + weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w) w = tvm.nd.array(np.random.uniform(-1, 1, weight_shape).astype(dtype)) weights = relay.const(w, dtype) @@ -77,7 +82,7 @@ def _get_conv_model( ) params = {"w": w} if has_bias: - bias_shape = weight_shape[2] if is_depthwise else weight_shape[0] + bias_shape = (weight_shape[0], ) b = tvm.nd.array(np.random.uniform(-1, 1, bias_shape).astype(dtype)) biasc = relay.const(b, dtype) out = relay.nn.bias_add(out, biasc, axis=1) @@ -86,31 +91,122 @@ def _get_conv_model( if has_activation: out = relay.nn.relu(out) - print("Out:", out) - return out, params +def _get_conv_expected_codegen( + shape, + kernel_h, + kernel_w, + padding, + strides, + dilation, + groups, + dtype, + channels, + has_bias=False, + has_activation=False, +): + if len(padding) == 2: + padding = (padding[0], padding[1], padding[0], padding[1]) + output_height = ((shape[2] - kernel_h + padding[0] + padding[2]) / strides[0]) + 1 + output_width = ((shape[3] - kernel_w + padding[1] + padding[3]) / strides[1]) + 1 + output_shape = (1, channels, int(output_height), int(output_width)) + out_dtype = dtype + is_depthwise = shape[1] == channels == groups + + weight_format = "IOHW" if is_depthwise else "OIHW" + if weight_format == "OIHW": + weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w) + else: + weight_shape = (shape[1] // groups, channels, kernel_h, kernel_w) + #weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w) + + if is_depthwise: + name = "nn.depthwise_conv2d" + else: + name = "nn.conv2d" + + node = { + "op": "kernel", + "name": name, + "inputs": [], + "attrs": { + "groups": [[str(groups)]], + "num_outputs": "1", + "data_layout": [["NCHW"]], + "kernel_layout": [[weight_format]], + "channels": [[str(channels)]], + "dilation": [[str(dilation[0]), str(dilation[1])]], + "out_layout": [[""]], + "out_dtype": [[out_dtype]], + "kernel_size": [[str(kernel_h), str(kernel_w)]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "padding": [[str(p) for p in padding]], + "strides": [[str(s) for s in strides]], + }, + } + + if has_activation: + node["attrs"]["activation_type"] = [["relu"]] + + inputs = [ + {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[str(dtype)]]}}, + { + "op": "const", + "name": "", + "attrs": {"shape": [[list(weight_shape)]], "dtype": [[str(dtype)]]}, + }, + ] + + if has_bias: + bias_dtype = dtype + inputs.append( + { + "op": "const", + "name": "", + "attrs": { + "shape": [[[1, weight_shape[1] if is_depthwise else weight_shape[0], 1, 1]]], + "dtype": [[bias_dtype]], + }, + } + ) + + input_idx = 0 + for _ in range(len(inputs)): + node["inputs"].append([input_idx, 0, 0]) + input_idx += 1 + node["attrs"]["num_inputs"] = str(len(inputs)) + inputs.append(node) + return inputs + + @pytest.mark.parametrize("dtype", ["float32"]) @tvm.testing.requires_openclml def test_conv2d(device, dtype): trials = [ # Normal convolution - [3, 3, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False)], - [2, 1, (2, 2), (1, 1), (1, 1), 7, (15, 16, 12), (True, False, True)], - [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False)], - [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, True)], - # Normal convolution - [2, 2, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False)], - [2, 1, (2, 2), (1, 1), (1, 1), 7, (16, 12, 15), (False, False, True)], - [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False)], - [3, 3, (1, 1), (1, 1), (1, 1), 16, (16, 12, 15), (False, False, False)], - [5, 5, (1, 1), (2, 2), (1, 1), 4, (14, 10, 10), (False, False, False)], - [1, 3, (1, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, True)], - [2, 2, (2, 2), (1, 1), (1, 1), 4, (20, 20, 20), (False, True, False)], - [5, 5, (1, 1), (2, 2), (1, 1), 4, (14, 10, 10), (False, False, False)], - [3, 3, (2, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, False)], - [3, 3, (1, 1), (2, 2), (1, 1), 16, (14, 10, 10), (False, True, True)], + [3, 3, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False), False], + [2, 1, (2, 2), (1, 1), (1, 1), 7, (15, 16, 12), (True, False, True), False], + [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False), False], + [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, True), False], + [2, 2, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False), False], + [2, 1, (2, 2), (1, 1), (1, 1), 7, (16, 12, 15), (False, False, True), False], + [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False), False], + [3, 3, (1, 1), (1, 1), (1, 1), 16, (16, 12, 15), (False, False, False), False], + [5, 5, (1, 1), (2, 2), (1, 1), 4, (14, 10, 10), (False, False, False), False], + [1, 3, (1, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, True), False], + [2, 2, (2, 2), (1, 1), (1, 1), 4, (20, 20, 20), (False, True, False), False], + [5, 5, (1, 1), (2, 2), (1, 1), 4, (14, 10, 10), (False, False, False), False], + [3, 3, (2, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, False), False], + [3, 3, (1, 1), (2, 2), (1, 1), 16, (14, 10, 10), (False, True, True), False], + # Depth-wise convolution + [3, 3, (1, 1), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, True), True], + [5, 5, (2, 2), (1, 1), (1, 1), 20, (20, 20, 20), (False, True, False), True], + [3, 3, (2, 2), (2, 2), (1, 1), 14, (14, 10, 10), (False, False, False), True], + [5, 5, (0, 0), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, False), True], + [3, 3, (1, 1), (2, 2), (1, 1), 14, (14, 10, 10), (False, True, True), True], ] for ( @@ -122,9 +218,13 @@ def test_conv2d(device, dtype): out_channels, shape, composite, + is_depthwise, ) in trials: shape = (1, *shape) - groups = 1 + if is_depthwise: + groups = shape[1] + else: + groups = 1 outputs = [] inputs = { "a": tvm.nd.array(np.random.uniform(-1, 1, shape).astype(dtype)), @@ -151,6 +251,11 @@ def test_conv2d(device, dtype): tvm.testing.assert_allclose( clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-5, atol=1e-5 ) + args = (shape, kernel_h, kernel_w, pad, stride, dilation, groups, dtype, out_channels) + exp_codegen = _get_conv_expected_codegen( + *args, has_bias=composite[1], has_activation=composite[2] + ) + verify_codegen(func, exp_codegen, device, params) @pytest.mark.parametrize("dtype", ["float16"]) @@ -211,11 +316,79 @@ def test_concat(device, dtype): tvm.testing.assert_allclose( clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 ) + exp_codegen = [ + { + "attrs": { + "dtype": [[dtype]], + "shape": [[list(in_shape_1)]], + }, + "name": "", + "op": "input" + }, + { + "attrs": { + "dtype": [[dtype]], + "shape": [[list(in_shape_2)]], + }, + "name": "", + "op": "input" + }, + { + "attrs": { + "axis": [["1"]], + "dtype": [[dtype]], + "num_inputs": "2", + "num_outputs": "1", + "shape": [[list(clml_out[0].shape)]], + }, + "inputs": [[0, 0, 0], [1, 0, 0]], + "name": "concatenate", + "op": "kernel" + } + ] + verify_codegen(func, exp_codegen, device, params) + + +def _get_pool_expected_codegen(input_shape, pool_size, stride, padding, pool_type, dtype): + import math + pool_height = math.floor(((input_shape[2] + padding[2] - pool_size[0]) / stride[0]) + 1) + pool_width = math.floor(((input_shape[3] + padding[3] - pool_size[1]) / stride[1]) + 1) + output_shape = [input_shape[0], input_shape[1], pool_height, pool_width] + attrs = { + "ceil_mode": [["0"]], + "dilation": [["1", "1"]], + "layout": [["NCHW"]], + "num_inputs": "1", + "num_outputs": "1", + "out_layout": [[""]], + "padding": [[str(p) for p in padding]], + "pool_size": [[str(p) for p in pool_size]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "strides": [[str(s) for s in stride]], + } + if sum(padding): + attrs["count_include_pad"] = [["0"]] + + exp_codegen = [ + { + "op": "input", + "name": "", + "attrs": {"shape": [[list(input_shape)]], "dtype": [[str(dtype)]]}, + }, + { + "op": "kernel", + "name": "nn.avg_pool2d" if pool_type == "avg" else "nn.max_pool2d", + "inputs": [[0, 0, 0]], + "attrs": attrs, + } + ] + return exp_codegen @pytest.mark.parametrize("dtype", ["float16"]) @tvm.testing.requires_openclml -def test_avgpool(device, dtype): +def test_pool(device, dtype): trials = [ # input size pool_size stride paading [(1, 64, 147, 147), (3, 3), (2, 2), (0, 0, 0, 0), "max"], @@ -251,7 +424,152 @@ def test_avgpool(device, dtype): opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0] clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0] + tvm.testing.assert_allclose( + clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 + ) + + args = (input_shape, pool_size, stride, padding, pooling_type, dtype) + exp_codegen = _get_pool_expected_codegen(*args) + verify_codegen(func, exp_codegen, device, params) + + +@pytest.mark.parametrize("dtype", ["float32"]) +@tvm.testing.requires_openclml +def test_dense(device, dtype): + def _get_model(x_shape, k_shape, has_bias=False): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) + out = relay.nn.dense(x, kernel, units=k_shape[0]) + params = {"kernel": tvm.nd.array(np.random.uniform(-1, 1, k_shape).astype(dtype))} + inputs = {"x": tvm.nd.array(np.random.uniform(-1, 1, x_shape).astype(dtype))} + exp_codegen = [ + { + "attrs": { + "dtype": [[dtype]], + "shape": [[list(x_shape)]], + }, + "name": "", + "op": "input" + }, + { + "attrs": { + "dtype": [[dtype]], + "shape": [[list(k_shape)]], + }, + "name": "", + "op": "const" + }, + ] + if has_bias: + bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) + out = relay.nn.bias_add(out, bias) + bias_node = { + "attrs": { + "dtype": [[dtype]], + "shape": [[list((1, k_shape[0]))]], + }, + "name": "", + "op": "const" + } + exp_codegen.append(bias_node) + params["bias"] = tvm.nd.array(np.random.uniform(-1, 1, (k_shape[0],)).astype(dtype)) + + dense_node = { + "attrs": { + "num_inputs": "3" if has_bias else "2", + "num_outputs": "1", + "dtype": [[dtype]], + "out_dtype": [[""]], + "shape": [[[x_shape[0], k_shape[0]]]], + "units": [[str(k_shape[0])]] + }, + "inputs": [[0, 0, 0], [1, 0, 0], [2, 0, 0]] if has_bias else [[0, 0, 0], [1, 0, 0]], + "name": "nn.dense", + "op": "kernel" + } + exp_codegen.append(dense_node) + return out, params, inputs, exp_codegen + + def _verify(out, params, inputs, exp_codegen): + mod = IRModule.from_expr(out) + opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0] + clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0] + tvm.testing.assert_allclose( + clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 + ) + verify_codegen(out, exp_codegen, device, params) + + _verify(*(_get_model((1, 16), (32, 16)))) + _verify(*(_get_model((1, 16), (32, 16), True))) + + +@pytest.mark.parametrize("dtype", ["float32"]) +@tvm.testing.requires_openclml +def test_binary_ops(device, dtype): + def _get_model(a_shape, b_shape, op): + a = relay.var("a", shape=(a_shape), dtype=dtype) + b = relay.var("b", shape=(b_shape), dtype=dtype) + out = op(a, b) + inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, a_shape).astype(dtype)), + "b": tvm.nd.array(np.random.uniform(-1, 1, b_shape).astype(dtype))} + params = {} + return out, params, inputs + + def _verify(out, params, inputs): + mod = IRModule.from_expr(out) + opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0] + clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0] + tvm.testing.assert_allclose( + clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 + ) + + # Check to make sure these ops are offloaded to CLML instead of TVM. + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + mod = clml.partition_for_clml(mod, params) + tvm_op_count = get_cpu_op_count(mod) + assert ( + tvm_op_count == 0 + ), "Got {} TVM Native Compute partitions, expected 0".format(tvm_op_count) + + + _verify(*(_get_model((1, 16), (1, 16), relay.add))) + _verify(*(_get_model((1, 16), (1, 16), relay.subtract))) + _verify(*(_get_model((1, 16), (1, 16), relay.multiply))) + _verify(*(_get_model((1, 16), (1, 16), relay.divide))) + _verify(*(_get_model((1, 16), (1, 16), relay.minimum))) + _verify(*(_get_model((1, 16), (1, 16), relay.maximum))) + +@pytest.mark.parametrize("dtype", ["float32"]) +@tvm.testing.requires_openclml +def test_unary_ops(device, dtype): + def _get_model(a_shape, op): + a = relay.var("a", shape=(a_shape), dtype=dtype) + out = op(a) + inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, a_shape).astype(dtype))} + params = {} + return out, params, inputs + + def _verify(out, params, inputs): + mod = IRModule.from_expr(out) + opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0] + clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0] tvm.testing.assert_allclose( clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 ) + + # Check to make sure these ops are offloaded to CLML instead of TVM. + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + mod = clml.partition_for_clml(mod, params) + tvm_op_count = get_cpu_op_count(mod) + assert ( + tvm_op_count == 0 + ), "Got {} TVM Native Compute partitions, expected 0".format(tvm_op_count) + + + _verify(*(_get_model((1, 16), relay.nn.softmax))) + _verify(*(_get_model((1, 16), relay.nn.relu))) + + +if __name__ == "__main__": + tvm.testing.main() From 32ce3bf3ebeab4d95983752dadb3cc3a7146b772 Mon Sep 17 00:00:00 2001 From: srk Date: Thu, 29 Dec 2022 09:46:20 +0530 Subject: [PATCH 2/5] * lint errors --- python/tvm/relay/op/contrib/clml.py | 3 +- .../contrib/test_clml/infrastructure.py | 10 +-- tests/python/contrib/test_clml/test_ops.py | 81 ++++++++++--------- 3 files changed, 46 insertions(+), 48 deletions(-) diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index 736ec93ec075..3e3c2b13a34b 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -93,7 +93,6 @@ def preprocess_module(mod): """ def alter_conv(attrs, inputs, tinfos, out_type): - data, weight = inputs new_attrs = dict(attrs) data_info = tinfos[0] weight_info = tinfos[1] @@ -112,7 +111,7 @@ def alter_conv(attrs, inputs, tinfos, out_type): new_attrs["kernel_layout"] = dkl[1] + dkl[0] + dkl[2] + dkl[3] return relay.nn.conv2d(*inputs, **new_attrs) - with OpAttrContext( "nn.conv2d", "FTVMAlterOpLayout", alter_conv): + with OpAttrContext("nn.conv2d", "FTVMAlterOpLayout", alter_conv): seq = tvm.transform.Sequential( [ transform.ConvertLayout({"nn.conv2d": ["NCHW", "OIHW"]}), diff --git a/tests/python/contrib/test_clml/infrastructure.py b/tests/python/contrib/test_clml/infrastructure.py index 81aebc62fca7..be2bbc7f8a71 100644 --- a/tests/python/contrib/test_clml/infrastructure.py +++ b/tests/python/contrib/test_clml/infrastructure.py @@ -173,9 +173,9 @@ def build_and_run( for _ in range(no_runs): gen_module.run() out.append([gen_module.get_output(i) for i in range(outputs)]) - #time_f = gen_module.module.time_evaluator("run", device.device.cl(0), number=1) - #cost = time_f().mean - #print("%g secs/iteration\n" % cost) + # time_f = gen_module.module.time_evaluator("run", device.device.cl(0), number=1) + # cost = time_f().mean + # print("%g secs/iteration\n" % cost) return out @@ -222,9 +222,7 @@ def verify_codegen( assert ( num_clml_modules == partition_count - ), "Got {} Open CLML partitions, expected {}".format( - partition_count, num_clml_modules - ) + ), "Got {} Open CLML partitions, expected {}".format(partition_count, num_clml_modules) relay.backend.te_compiler.get().clear() module = relay.build(mod, target=device.target, target_host=device.target_host, params=params) diff --git a/tests/python/contrib/test_clml/test_ops.py b/tests/python/contrib/test_clml/test_ops.py index 1f8f333a4325..1b2d586c8209 100644 --- a/tests/python/contrib/test_clml/test_ops.py +++ b/tests/python/contrib/test_clml/test_ops.py @@ -29,7 +29,7 @@ skip_codegen_test, verify_codegen, build_module, - get_cpu_op_count, + get_cpu_op_count, ) import pytest @@ -82,7 +82,7 @@ def _get_conv_model( ) params = {"w": w} if has_bias: - bias_shape = (weight_shape[0], ) + bias_shape = (weight_shape[0],) b = tvm.nd.array(np.random.uniform(-1, 1, bias_shape).astype(dtype)) biasc = relay.const(b, dtype) out = relay.nn.bias_add(out, biasc, axis=1) @@ -120,7 +120,7 @@ def _get_conv_expected_codegen( weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w) else: weight_shape = (shape[1] // groups, channels, kernel_h, kernel_w) - #weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w) + # weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w) if is_depthwise: name = "nn.depthwise_conv2d" @@ -323,7 +323,7 @@ def test_concat(device, dtype): "shape": [[list(in_shape_1)]], }, "name": "", - "op": "input" + "op": "input", }, { "attrs": { @@ -331,7 +331,7 @@ def test_concat(device, dtype): "shape": [[list(in_shape_2)]], }, "name": "", - "op": "input" + "op": "input", }, { "attrs": { @@ -341,32 +341,33 @@ def test_concat(device, dtype): "num_outputs": "1", "shape": [[list(clml_out[0].shape)]], }, - "inputs": [[0, 0, 0], [1, 0, 0]], - "name": "concatenate", - "op": "kernel" - } + "inputs": [[0, 0, 0], [1, 0, 0]], + "name": "concatenate", + "op": "kernel", + }, ] verify_codegen(func, exp_codegen, device, params) def _get_pool_expected_codegen(input_shape, pool_size, stride, padding, pool_type, dtype): import math + pool_height = math.floor(((input_shape[2] + padding[2] - pool_size[0]) / stride[0]) + 1) pool_width = math.floor(((input_shape[3] + padding[3] - pool_size[1]) / stride[1]) + 1) output_shape = [input_shape[0], input_shape[1], pool_height, pool_width] attrs = { - "ceil_mode": [["0"]], - "dilation": [["1", "1"]], - "layout": [["NCHW"]], - "num_inputs": "1", - "num_outputs": "1", - "out_layout": [[""]], - "padding": [[str(p) for p in padding]], - "pool_size": [[str(p) for p in pool_size]], - "shape": [[list(output_shape)]], - "dtype": [[dtype]], - "strides": [[str(s) for s in stride]], - } + "ceil_mode": [["0"]], + "dilation": [["1", "1"]], + "layout": [["NCHW"]], + "num_inputs": "1", + "num_outputs": "1", + "out_layout": [[""]], + "padding": [[str(p) for p in padding]], + "pool_size": [[str(p) for p in pool_size]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "strides": [[str(s) for s in stride]], + } if sum(padding): attrs["count_include_pad"] = [["0"]] @@ -381,7 +382,7 @@ def _get_pool_expected_codegen(input_shape, pool_size, stride, padding, pool_typ "name": "nn.avg_pool2d" if pool_type == "avg" else "nn.max_pool2d", "inputs": [[0, 0, 0]], "attrs": attrs, - } + }, ] return exp_codegen @@ -449,7 +450,7 @@ def _get_model(x_shape, k_shape, has_bias=False): "shape": [[list(x_shape)]], }, "name": "", - "op": "input" + "op": "input", }, { "attrs": { @@ -457,7 +458,7 @@ def _get_model(x_shape, k_shape, has_bias=False): "shape": [[list(k_shape)]], }, "name": "", - "op": "const" + "op": "const", }, ] if has_bias: @@ -469,23 +470,23 @@ def _get_model(x_shape, k_shape, has_bias=False): "shape": [[list((1, k_shape[0]))]], }, "name": "", - "op": "const" + "op": "const", } exp_codegen.append(bias_node) params["bias"] = tvm.nd.array(np.random.uniform(-1, 1, (k_shape[0],)).astype(dtype)) - dense_node = { + dense_node = { "attrs": { "num_inputs": "3" if has_bias else "2", "num_outputs": "1", "dtype": [[dtype]], "out_dtype": [[""]], "shape": [[[x_shape[0], k_shape[0]]]], - "units": [[str(k_shape[0])]] + "units": [[str(k_shape[0])]], }, - "inputs": [[0, 0, 0], [1, 0, 0], [2, 0, 0]] if has_bias else [[0, 0, 0], [1, 0, 0]], - "name": "nn.dense", - "op": "kernel" + "inputs": [[0, 0, 0], [1, 0, 0], [2, 0, 0]] if has_bias else [[0, 0, 0], [1, 0, 0]], + "name": "nn.dense", + "op": "kernel", } exp_codegen.append(dense_node) return out, params, inputs, exp_codegen @@ -510,8 +511,10 @@ def _get_model(a_shape, b_shape, op): a = relay.var("a", shape=(a_shape), dtype=dtype) b = relay.var("b", shape=(b_shape), dtype=dtype) out = op(a, b) - inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, a_shape).astype(dtype)), - "b": tvm.nd.array(np.random.uniform(-1, 1, b_shape).astype(dtype))} + inputs = { + "a": tvm.nd.array(np.random.uniform(-1, 1, a_shape).astype(dtype)), + "b": tvm.nd.array(np.random.uniform(-1, 1, b_shape).astype(dtype)), + } params = {} return out, params, inputs @@ -527,10 +530,9 @@ def _verify(out, params, inputs): with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): mod = clml.partition_for_clml(mod, params) tvm_op_count = get_cpu_op_count(mod) - assert ( - tvm_op_count == 0 - ), "Got {} TVM Native Compute partitions, expected 0".format(tvm_op_count) - + assert tvm_op_count == 0, "Got {} TVM Native Compute partitions, expected 0".format( + tvm_op_count + ) _verify(*(_get_model((1, 16), (1, 16), relay.add))) _verify(*(_get_model((1, 16), (1, 16), relay.subtract))) @@ -562,10 +564,9 @@ def _verify(out, params, inputs): with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): mod = clml.partition_for_clml(mod, params) tvm_op_count = get_cpu_op_count(mod) - assert ( - tvm_op_count == 0 - ), "Got {} TVM Native Compute partitions, expected 0".format(tvm_op_count) - + assert tvm_op_count == 0, "Got {} TVM Native Compute partitions, expected 0".format( + tvm_op_count + ) _verify(*(_get_model((1, 16), relay.nn.softmax))) _verify(*(_get_model((1, 16), relay.nn.relu))) From a85e777491b6a700527a801ae36268dba1c40b6a Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Thu, 29 Dec 2022 18:23:45 +0530 Subject: [PATCH 3/5] * version compatilibility changes. --- cmake/modules/contrib/CLML.cmake | 16 ++++++++++- python/tvm/relay/op/contrib/clml.py | 6 ++++ src/runtime/contrib/clml/clml_runtime.cc | 33 ++++++++++++++-------- tests/python/contrib/test_clml/test_ops.py | 5 +++- tests/scripts/task_build_adreno_bins.sh | 2 +- tests/scripts/task_config_build_adreno.sh | 4 ++- 6 files changed, 51 insertions(+), 15 deletions(-) diff --git a/cmake/modules/contrib/CLML.cmake b/cmake/modules/contrib/CLML.cmake index 30e60423b03b..2fde0de65b4b 100644 --- a/cmake/modules/contrib/CLML.cmake +++ b/cmake/modules/contrib/CLML.cmake @@ -22,7 +22,21 @@ if(USE_CLML) if(NOT USE_CLML_GRAPH_EXECUTOR) list(APPEND COMPILER_SRCS ${CLML_RUNTIME_MODULE}) endif() - message(STATUS "Build with CLML support...") + message(STATUS "Build with CLML support : " ${USE_CLML}) + if (NOT USE_CLML STREQUAL "ON") + set(CLML_VERSION_HEADER "${USE_CLML}/CL/cl_qcom_ml_ops.h") + if(EXISTS ${CLML_VERSION_HEADER}) + file(READ ${CLML_VERSION_HEADER} ver) + string(REGEX MATCH "CL_QCOM_ML_OPS_H_MAJOR_VERSION ([0-9]*)" _ ${ver}) + set(CLML_VERSION_MAJOR ${CMAKE_MATCH_1}) + else() + set(CLML_VERSION_MAJOR "2") + endif() + else() + set(CLML_VERSION_MAJOR "2") + endif() + add_definitions(-DTVM_CLML_VERSION=${CLML_VERSION_MAJOR}) + message(STATUS "CLML SDK Version :" ${CLML_VERSION_MAJOR}) endif() if(USE_CLML_GRAPH_EXECUTOR) diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index 3e3c2b13a34b..02e4f62bed24 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -28,6 +28,12 @@ from ..strategy.generic import is_depthwise_conv2d +def clml_sdk_version(): + """Utility function to get clml version version""" + + return tvm.support.libinfo().get("TVM_CLML_VERSION", 2) + + def is_clml_runtime_enabled(): """Check if the CLML graph runtime is present. diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index b78712e6564e..d03b8a318656 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -153,13 +153,24 @@ class CLMLRuntime : public JSONRuntimeBase { ICHECK(result == CL_SUCCESS) << "clQueryMLInterfaceVersionsQCOM:" << result; for (cl_uint i = 0; i < numVersions; ++i) { +#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 2 if (majorVersions[i] == 2) { - LOG(WARNING) << "CLML Version Selected:" << majorVersions[i] << " : " << majorVersions[i]; h_ClmlIntf = clGetMLInterfaceV2QCOM(0); - ICHECK(h_ClmlIntf != NULL) << "clGetMLInterfaceV2QCOM:" << result; + LOG(WARNING) << "CLML Target version:" << majorVersions[i]; break; } +#endif +#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 3 + if (majorVersions[i] == 3) { + h_ClmlIntf = clGetMLInterfaceV3QCOM(0); + LOG(WARNING) << "CLML Target version:" << majorVersions[i]; + break; + } +#endif } + ICHECK(h_ClmlIntf != NULL) + << "clGetMLInterfaceVxQCOM:" << result + << " Perhaps there is mispatch between CLML SDK version to target supported version"; char* tune_flag; if ((tune_flag = getenv("CLML_IS_TUNNING_RUN"))) this->is_tuning_run = std::stoi(tune_flag); @@ -523,7 +534,7 @@ class CLMLRuntime : public JSONRuntimeBase { } cl_ml_tensor_qcom DeviceMakeCLMLTensor( - void* pClmlIntf, cl_context context, tensor_dims_t dims, + cl_context context, tensor_dims_t dims, cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_channel_type dtype = CL_FLOAT) { cl_ml_tensor_qcom tensor; @@ -531,8 +542,7 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_tensor_desc_qcom desc = { dtype, layout, dims.n, dims.c, dims.h, dims.w, 0, CL_TENSOR_DIMENSIONS_4D_QCOM, { 0 }}; - CLMLInterfaceV2QCOM* clmlIntf = reinterpret_cast(pClmlIntf); - result = clmlIntf->clCreateMLTensorQCOM(workspace->context, NULL, &desc, &tensor); + result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->context, NULL, &desc, &tensor); ICHECK(tensor && result == CL_SUCCESS) << "clCreateMLTensorQCOM:" << result; (void)result; return tensor; @@ -544,9 +554,8 @@ class CLMLRuntime : public JSONRuntimeBase { cl_int result = CL_OUT_OF_HOST_MEMORY; cl_mem buffer = NULL; - CLMLInterfaceV2QCOM* clmlIntf = reinterpret_cast(pClmlIntf); result = - clmlIntf->clGetMLTensorMemorySizeQCOM(workspace->context, pTensorMemDesc->tensor, &size); + h_ClmlIntf->clGetMLTensorMemorySizeQCOM(workspace->context, pTensorMemDesc->tensor, &size); ICHECK(result == CL_SUCCESS) << "clGetMLTensorMemorySizeQCOM:" << result; buffer = clCreateBuffer(workspace->context, CL_MEM_READ_WRITE, size, NULL, &result); @@ -612,8 +621,7 @@ class CLMLRuntime : public JSONRuntimeBase { cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); auto tensor_dsc = std::make_shared(); - tensor_dsc->tensor = - DeviceMakeCLMLTensor(h_ClmlIntf, workspace->context, dims, layout, cl_dtype); + tensor_dsc->tensor = DeviceMakeCLMLTensor(workspace->context, dims, layout, cl_dtype); return tensor_dsc; } @@ -901,7 +909,6 @@ class CLMLRuntime : public JSONRuntimeBase { auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); - auto in_dims = get_tensor_dims(nodes_[node.GetInputs()[0].id_]); std::vector windows = node.GetAttr>("pool_size"); std::vector strides = node.GetAttr>("strides"); @@ -1103,7 +1110,6 @@ class CLMLRuntime : public JSONRuntimeBase { cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); int inputSize = input_.size(); - int axis = std::stoi(node.GetAttr>("axis")[0]); auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); cl_ml_tensor_qcom* concatInputs = new cl_ml_tensor_qcom[inputSize]; for (int i = 0; i < inputSize; i++) { @@ -1262,7 +1268,12 @@ class CLMLRuntime : public JSONRuntimeBase { CachedLayer layer_; // CLML Context +#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 2 CLMLInterfaceV2QCOM* h_ClmlIntf = NULL; +#endif +#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 3 + CLMLInterfaceV3QCOM* h_ClmlIntf = NULL; +#endif cl::OpenCLWorkspace* workspace = NULL; cl::OpenCLThreadEntry* tentry = NULL; cl_ml_tuningcache_qcom tuning_cache = NULL; diff --git a/tests/python/contrib/test_clml/test_ops.py b/tests/python/contrib/test_clml/test_ops.py index 1b2d586c8209..d489b1d4d507 100644 --- a/tests/python/contrib/test_clml/test_ops.py +++ b/tests/python/contrib/test_clml/test_ops.py @@ -260,7 +260,10 @@ def test_conv2d(device, dtype): @pytest.mark.parametrize("dtype", ["float16"]) @tvm.testing.requires_openclml -def _test_batchnorm(device, dtype): +def test_batchnorm(device, dtype): + if tvm.support.libinfo().get("TVM_CLML_VERSION", 2) < 3: + print("Skip due to unsupported CLML version") + return in_shape = (1, 8, 64, 64) channels = 8 diff --git a/tests/scripts/task_build_adreno_bins.sh b/tests/scripts/task_build_adreno_bins.sh index 6b43d7cbc421..9db3a547a752 100755 --- a/tests/scripts/task_build_adreno_bins.sh +++ b/tests/scripts/task_build_adreno_bins.sh @@ -29,7 +29,7 @@ cd ${output_directory} cp ../cmake/config.cmake . echo set\(USE_MICRO OFF\) >> config.cmake -echo set\(USE_CLML ON\) >> config.cmake +echo set\(USE_CLML "${ADRENO_OPENCL}"\) >> config.cmake echo set\(USE_CLML_GRAPH_EXECUTOR "${ADRENO_OPENCL}"\) >> config.cmake echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_CPP_RPC ON\) >> config.cmake diff --git a/tests/scripts/task_config_build_adreno.sh b/tests/scripts/task_config_build_adreno.sh index d45c5e8b7dcf..fa7a7c309a08 100755 --- a/tests/scripts/task_config_build_adreno.sh +++ b/tests/scripts/task_config_build_adreno.sh @@ -23,8 +23,10 @@ mkdir -p "$BUILD_DIR" cd "$BUILD_DIR" cp ../cmake/config.cmake . +[[ -z "${ADRENO_OPENCL}" ]] && CLML_PATH='ON' || CLML_PATH="${ADRENO_OPENCL}" + echo set\(USE_OPENCL ON\) >> config.cmake -echo set\(USE_CLML ON\) >> config.cmake +echo set\(USE_CLML ${CLML_PATH}\) >> config.cmake echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake From 9469b05bef255f3cba81f37903bd098d4275e6f4 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Mon, 2 Jan 2023 09:51:14 +0530 Subject: [PATCH 4/5] * review comments --- src/relay/backend/contrib/clml/codegen.cc | 6 +----- src/runtime/contrib/clml/clml_runtime.cc | 3 ++- tests/python/contrib/test_clml/test_ops.py | 1 - tests/scripts/task_config_build_adreno.sh | 2 -- 4 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/relay/backend/contrib/clml/codegen.cc b/src/relay/backend/contrib/clml/codegen.cc index 0e0d2f482b07..d8ca791ad8c4 100644 --- a/src/relay/backend/contrib/clml/codegen.cc +++ b/src/relay/backend/contrib/clml/codegen.cc @@ -328,11 +328,7 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer { const auto* dense = fn->body.as(); const CallNode* bias = nullptr; - if (backend::IsOp(dense, "add")) { - bias = dense; - dense = dense->args[0].as(); - } - if (backend::IsOp(dense, "nn.bias_add")) { + if (backend::IsOp(dense, "add") || backend::IsOp(dense, "nn.bias_add")) { bias = dense; dense = dense->args[0].as(); } diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index d03b8a318656..6396fce4858b 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -170,7 +170,8 @@ class CLMLRuntime : public JSONRuntimeBase { } ICHECK(h_ClmlIntf != NULL) << "clGetMLInterfaceVxQCOM:" << result - << " Perhaps there is mispatch between CLML SDK version to target supported version"; + << " Perhaps there is mispatch between CLML SDK version to target supported version:" + << majorVersions[numVersions - 1]; char* tune_flag; if ((tune_flag = getenv("CLML_IS_TUNNING_RUN"))) this->is_tuning_run = std::stoi(tune_flag); diff --git a/tests/python/contrib/test_clml/test_ops.py b/tests/python/contrib/test_clml/test_ops.py index d489b1d4d507..c4ec2603249b 100644 --- a/tests/python/contrib/test_clml/test_ops.py +++ b/tests/python/contrib/test_clml/test_ops.py @@ -120,7 +120,6 @@ def _get_conv_expected_codegen( weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w) else: weight_shape = (shape[1] // groups, channels, kernel_h, kernel_w) - # weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w) if is_depthwise: name = "nn.depthwise_conv2d" diff --git a/tests/scripts/task_config_build_adreno.sh b/tests/scripts/task_config_build_adreno.sh index fa7a7c309a08..b17e168312c4 100755 --- a/tests/scripts/task_config_build_adreno.sh +++ b/tests/scripts/task_config_build_adreno.sh @@ -23,8 +23,6 @@ mkdir -p "$BUILD_DIR" cd "$BUILD_DIR" cp ../cmake/config.cmake . -[[ -z "${ADRENO_OPENCL}" ]] && CLML_PATH='ON' || CLML_PATH="${ADRENO_OPENCL}" - echo set\(USE_OPENCL ON\) >> config.cmake echo set\(USE_CLML ${CLML_PATH}\) >> config.cmake echo set\(USE_RPC ON\) >> config.cmake From 6950d4bcdd046a5ca2fe884a1b0c380bd3e87dc1 Mon Sep 17 00:00:00 2001 From: Siva Rama Krishna Reddy B Date: Mon, 2 Jan 2023 14:46:45 +0530 Subject: [PATCH 5/5] * Make the adreno container compatible w/ and w/o CLML SDK availability --- tests/scripts/task_build_adreno_bins.sh | 4 ++++ tests/scripts/task_config_build_adreno.sh | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/scripts/task_build_adreno_bins.sh b/tests/scripts/task_build_adreno_bins.sh index 9db3a547a752..187ca7f815df 100755 --- a/tests/scripts/task_build_adreno_bins.sh +++ b/tests/scripts/task_build_adreno_bins.sh @@ -29,8 +29,12 @@ cd ${output_directory} cp ../cmake/config.cmake . echo set\(USE_MICRO OFF\) >> config.cmake +if [ -f "${ADRENO_OPENCL}/CL/cl_qcom_ml_ops.h" ] ; then echo set\(USE_CLML "${ADRENO_OPENCL}"\) >> config.cmake echo set\(USE_CLML_GRAPH_EXECUTOR "${ADRENO_OPENCL}"\) >> config.cmake +else +echo set\(USE_OPENCL ON\) >> config.cmake +fi echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_CPP_RPC ON\) >> config.cmake echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake diff --git a/tests/scripts/task_config_build_adreno.sh b/tests/scripts/task_config_build_adreno.sh index b17e168312c4..d378b5f842b5 100755 --- a/tests/scripts/task_config_build_adreno.sh +++ b/tests/scripts/task_config_build_adreno.sh @@ -24,7 +24,9 @@ cd "$BUILD_DIR" cp ../cmake/config.cmake . echo set\(USE_OPENCL ON\) >> config.cmake -echo set\(USE_CLML ${CLML_PATH}\) >> config.cmake +if [ -f "${ADRENO_OPENCL}/CL/cl_qcom_ml_ops.h" ] ; then +echo set\(USE_CLML ${ADRENO_OPENCL}\) >> config.cmake +fi echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake