Skip to content

Commit a5e883e

Browse files
authored
[RUNTIME][CLML] Fix for Softmax op for 4D tensors (#16328)
Fixed the softmax layer for 4D tensors to support for NCHW and NHWC layout types. Enabled relevant test cases for softmax layer
1 parent 7ef521f commit a5e883e

File tree

3 files changed

+98
-53
lines changed

3 files changed

+98
-53
lines changed

python/tvm/relay/op/contrib/clml.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,8 @@ def check_pad_op(extract):
437437

438438
def check_softmax_op(extract):
439439
call = extract
440-
if len(call.args[0].checked_type.shape) > 2:
440+
# supports 2D and 4D tensors
441+
if len(call.args[0].checked_type.shape) not in [2, 4]:
441442
return False
442443
return True
443444

src/runtime/contrib/clml/clml_runtime.cc

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ class CLMLRuntime : public JSONRuntimeBase {
511511

512512
/*!
513513
* \brief Create an CLML tensor from JSON node entry. Lookup storage map before creation.
514+
* Update input placeholder for NHWC layout
514515
*
515516
* \param nid The node index of graph JSON.
516517
* \param shape shape information of tensor
@@ -528,15 +529,22 @@ class CLMLRuntime : public JSONRuntimeBase {
528529
uint32_t eid = EntryID(nid, 0);
529530
node_data = data_entry_[eid]->data;
530531
}
532+
531533
auto clml_tensor = MakeCLMLTensorFromJSONNode(node, layout, dtype, node_data, shape);
534+
532535
this->layer_.storage_map.insert({nid, std::make_pair(clml_tensor, node)});
533536

534537
if ("input" == node.GetOpType()) {
535538
this->layer_.inputs.insert({nid, this->layer_.storage_map[nid].first});
536539
// Input copy placeholder Tensor
537-
this->layer_.in_placeholder.insert(
538-
{nid, MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM, dtype, node_data,
539-
shape)});
540+
if (layout == CL_TENSOR_LAYOUT_OPTIMAL_QCOM) {
541+
this->layer_.in_placeholder.insert(
542+
{nid, MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM, dtype, node_data,
543+
shape)});
544+
} else {
545+
this->layer_.in_placeholder.insert(
546+
{nid, MakeCLMLTensorFromJSONNode(node, layout, dtype, node_data, shape)});
547+
}
540548
}
541549

542550
return clml_tensor;
@@ -559,6 +567,7 @@ class CLMLRuntime : public JSONRuntimeBase {
559567
const auto& node = nodes_[nid];
560568
if ("nn.dense" == node.GetOpName()) CreateDenseLayerTensor(&layer_, node, nid);
561569
if ("nn.batch_matmul" == node.GetOpName()) CreateBatchMatmulLayerTensor(&layer_, node, nid);
570+
if ("nn.softmax" == node.GetOpName()) CreateSoftmaxLayerTensor(&layer_, node, nid);
562571
}
563572

564573
for (nid = 0; nid < nodes_.size(); ++nid) {
@@ -1092,6 +1101,37 @@ class CLMLRuntime : public JSONRuntimeBase {
10921101
return;
10931102
}
10941103

1104+
/*!
1105+
* \brief Create a Softmax layer Tensors with supported layout.
1106+
* \param layer The CLML layer to build. Containing inputs, outputs and the CLML function.
1107+
* \param node The JSON representation of the operator.
1108+
* \param nid The node index of JSON graph node, which points to this operator.
1109+
*/
1110+
1111+
void CreateSoftmaxLayerTensor(CachedLayer* layer, const JSONGraphNode& node, size_t nid) {
1112+
cl_ml_tensor_layout_qcom layout;
1113+
cl_int result = 0;
1114+
cl_ml_op_qcom op = nullptr;
1115+
DLDataType tvm_dtype = node.GetOpDataType()[0];
1116+
cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
1117+
auto out_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]);
1118+
int axis = std::stoi(node.GetAttr<std::vector<std::string>>("axis")[0]);
1119+
// enabling NHWC layout && NCHW layout for 4D, basis the axis value
1120+
if (out_dims.h >= 1 && out_dims.w >= 1) {
1121+
if (axis == 3 || axis == -1) {
1122+
layout = CL_TENSOR_LAYOUT_NHWC_QCOM;
1123+
} else {
1124+
layout = CL_TENSOR_LAYOUT_NCHW_QCOM;
1125+
}
1126+
} else { // default layout for 2D
1127+
layout = CL_TENSOR_LAYOUT_OPTIMAL_QCOM;
1128+
}
1129+
auto output = MakeCLMLTensorFromJSONEntry(nid, {}, layout, cl_dtype);
1130+
auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {}, layout, cl_dtype);
1131+
1132+
return;
1133+
}
1134+
10951135
/*!
10961136
* \brief Create a SoftMax layer.
10971137
*
@@ -1100,24 +1140,20 @@ class CLMLRuntime : public JSONRuntimeBase {
11001140
* \param nid The node index of JSON graph node, which points to this operator.
11011141
*/
11021142
void CreateSoftMaxLayer(CachedLayer* layer, const JSONGraphNode& node, size_t nid) {
1143+
cl_ml_tensor_layout_qcom layout;
1144+
cl_softmax_mode_qcom mode = CL_SOFTMAX_MODE_SPATIAL_QCOM;
11031145
cl_int result = 0;
11041146
cl_ml_op_qcom op = nullptr;
11051147
DLDataType tvm_dtype = node.GetOpDataType()[0];
11061148
cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
11071149
cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, cl_dtype);
1108-
auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {},
1109-
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
1110-
auto out_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]);
1111-
auto output = MakeCLMLTensorFromJSONEntry(nid, {out_dims.n, out_dims.c, 1, 1},
1112-
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
1113-
1114-
cl_ml_op_softmax_desc_qcom softmax_desc = {CL_SOFTMAX_ALGORITHM_ACCURATE_QCOM,
1115-
CL_SOFTMAX_MODE_INSTANCE_QCOM, cl_arithmetic_mode};
1116-
1150+
auto output = MakeCLMLTensorFromJSONEntry(nid, {}, layout, cl_dtype);
1151+
auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {}, layout, cl_dtype);
1152+
cl_ml_op_softmax_desc_qcom softmax_desc = {CL_SOFTMAX_ALGORITHM_ACCURATE_QCOM, mode,
1153+
cl_arithmetic_mode};
11171154
result = CLML_INTF->clCreateMLOpSoftmaxQCOM(CLML_CTX, nullptr, &softmax_desc, input->tensor,
11181155
output->tensor, &op, layer_.tuning_cache);
11191156
ICHECK(op && result == CL_SUCCESS) << "SoftMax Error:" << result;
1120-
11211157
layer->function.push_back(op);
11221158
return;
11231159
}

tests/python/contrib/test_clml/test_ops.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,9 @@ def test_conv2d(remote, dtype, target, trials, executor_type):
280280
has_activation=composite[2],
281281
)
282282
outputs = _build_and_run_network(remote, func, params, inputs, target, executor_type)
283-
out_rtol = 1e-1 if dtype == "float16" else 1e-5
283+
out_tol = 1e-1 if dtype == "float16" else 1e-5
284284
tvm.testing.assert_allclose(
285-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
285+
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
286286
)
287287
args = (shape, kernel_h, kernel_w, pad, stride, dilation, groups, dtype, out_channels)
288288
exp_codegen = _get_conv_expected_codegen(
@@ -373,9 +373,9 @@ def test_conv2d_transpose(remote, dtype, target, trials, executor_type):
373373
func = relay.Function([x, w], y)
374374
mod = IRModule.from_expr(func)
375375
outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type)
376-
out_rtol = 1e-1 if dtype == "float16" else 1e-5
376+
out_tol = 1e-1 if dtype == "float16" else 1e-5
377377
tvm.testing.assert_allclose(
378-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
378+
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
379379
)
380380
args = (
381381
dshape,
@@ -425,9 +425,9 @@ def test_batchnorm(remote, dtype, target, trials, executor_type):
425425
"a": input_arr,
426426
}
427427
outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type)
428-
out_rtol = 1e-3 if dtype == "float16" else 1e-5
428+
out_tol = 1e-3 if dtype == "float16" else 1e-5
429429
tvm.testing.assert_allclose(
430-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
430+
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
431431
)
432432
exp_codegen = [
433433
{
@@ -485,9 +485,9 @@ def test_concat(remote, dtype, target, trials, executor_type):
485485
func = relay.concatenate((a, b), axis=1)
486486

487487
outputs = _build_and_run_network(remote, func, params, inputs, target, executor_type)
488-
out_rtol = 1e-2 if dtype == "float16" else 1e-5
488+
out_tol = 1e-2 if dtype == "float16" else 1e-5
489489
tvm.testing.assert_allclose(
490-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
490+
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
491491
)
492492

493493
exp_codegen = [
@@ -601,9 +601,9 @@ def test_pool(remote, dtype, target, trials, executor_type):
601601
func = relay.nn.avg_pool2d(a, pool_size=pool_size, strides=stride, padding=padding)
602602

603603
outputs = _build_and_run_network(remote, func, params, inputs, target, executor_type)
604-
out_rtol = 1e-2 if dtype == "float16" else 1e-5
604+
out_tol = 1e-2 if dtype == "float16" else 1e-5
605605
tvm.testing.assert_allclose(
606-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
606+
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
607607
)
608608
args = (input_shape, pool_size, stride, padding, pooling_type, dtype)
609609
exp_codegen = _get_pool_expected_codegen(*args)
@@ -690,9 +690,9 @@ def _get_model(x_shape, k_shape, has_bias=False):
690690
def _verify(out, params, inputs, exp_codegen):
691691
mod = IRModule.from_expr(out)
692692
outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type)
693-
out_rtol = 1e-1 if dtype == "float16" else 1e-5
693+
out_tol = 1e-1 if dtype == "float16" else 1e-5
694694
tvm.testing.assert_allclose(
695-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
695+
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
696696
)
697697
verify_codegen(remote, mod, params, exp_codegen, target)
698698

@@ -718,9 +718,9 @@ def _get_model(a_shape, b_shape, op_func):
718718
def _verify(out, params, inputs):
719719
mod = IRModule.from_expr(out)
720720
outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type)
721-
out_rtol = 1e-2 if dtype == "float16" else 1e-5
721+
out_tol = 1e-2 if dtype == "float16" else 1e-5
722722
tvm.testing.assert_allclose(
723-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
723+
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
724724
)
725725
exp_codegen = [
726726
{
@@ -776,9 +776,9 @@ def _get_model(a_shape, op):
776776
def _verify(out, params, inputs):
777777
mod = IRModule.from_expr(out)
778778
outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type)
779-
out_rtol = 1e-2 if dtype == "float16" else 1e-5
779+
out_tol = 1e-2 if dtype == "float16" else 1e-5
780780
tvm.testing.assert_allclose(
781-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
781+
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
782782
)
783783

784784
exp_codegen = [
@@ -823,12 +823,11 @@ def _get_model(a_shape, block_size):
823823
def _verify(out, params, inputs):
824824
mod = IRModule.from_expr(out)
825825
outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type)
826-
out_rtol = 1e-2 if dtype == "float16" else 1e-5
826+
out_tol = 1e-2 if dtype == "float16" else 1e-5
827827
tvm.testing.assert_allclose(
828-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
828+
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
829829
)
830830

831-
# Check to make sure these ops are offloaded to CLML instead of TVM.
832831
exp_codegen = [
833832
{
834833
"attrs": {
@@ -877,12 +876,11 @@ def _get_model(a_shape, scale, align_corners):
877876
def _verify(out, params, inputs):
878877
mod = IRModule.from_expr(out)
879878
outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type)
880-
out_rtol = 1e-2 if dtype == "float16" else 1e-5
879+
out_tol = 1e-2 if dtype == "float16" else 1e-5
881880
tvm.testing.assert_allclose(
882-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
881+
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
883882
)
884883

885-
# Check to make sure these ops are offloaded to CLML instead of TVM.
886884
exp_codegen = [
887885
{
888886
"attrs": {
@@ -944,12 +942,11 @@ def _get_model(a_shape, b_shape, a_transpose, b_transpose):
944942
def _verify(out, params, inputs):
945943
mod = IRModule.from_expr(out)
946944
outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type)
947-
out_rtol = 1e-1 if dtype == "float16" else 1e-5
945+
out_tol = 1e-1 if dtype == "float16" else 1e-5
948946
tvm.testing.assert_allclose(
949-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
947+
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
950948
)
951949

952-
# Check to make sure these ops are offloaded to CLML instead of TVM.
953950
exp_codegen = [
954951
{
955952
"attrs": {
@@ -1026,20 +1023,30 @@ def _get_model(a_shape, axis):
10261023
params = {}
10271024
return out, params, inputs, axis
10281025

1029-
def _verify(out, params, inputs, axis):
1026+
def _verify(out, params, inputs, axis, out_tol):
10301027
mod = IRModule.from_expr(out)
10311028
outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type)
1032-
out_rtol = 1e-1 if dtype == "float16" else 1e-5
10331029
tvm.testing.assert_allclose(
1034-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
1030+
outputs[0].asnumpy(), outputs[1].numpy(), rtol=out_tol, atol=out_tol
10351031
)
10361032
args = (inputs, dtype, outputs[0].shape, axis)
10371033
exp_codegen = _get_softmax_exp_codegen(*args)
10381034
verify_codegen(remote, mod, params, exp_codegen, target)
10391035

1040-
_verify(*(_get_model((1, 5), 1)))
1041-
_verify(*(_get_model((1, 1000), 1)))
1042-
_verify(*(_get_model((1, 3), 1)))
1036+
# 2D Tensor TEST CASES
1037+
_verify(*(_get_model((1, 5), 1)), 1e-3)
1038+
_verify(*(_get_model((1, 16), 1)), 1e-3)
1039+
_verify(*(_get_model((1, 1000), -1)), 1e-3)
1040+
1041+
# 4D Tensor TEST CASES layout = NCHW
1042+
_verify(*(_get_model((1, 100, 64, 100), 1)), 1e-3)
1043+
_verify(*(_get_model((1, 64, 64, 64), 1)), 1e-3)
1044+
_verify(*(_get_model((1, 5, 3, 4), 1)), 1e-3)
1045+
1046+
# 4D Tensor TEST CASES layout = NHWC
1047+
_verify(*(_get_model((1, 64, 100, 100), 3)), 1e-1)
1048+
_verify(*(_get_model((1, 100, 100, 100), 3)), 1e-1)
1049+
_verify(*(_get_model((1, 64, 5, 32), -1)), 1e-1)
10431050

10441051

10451052
@pytest.mark.parametrize("dtype", ["float32", "float16"])
@@ -1066,9 +1073,9 @@ def _verify(in_shape, scale_h, scale_w):
10661073
)
10671074
mod = IRModule.from_expr(func)
10681075
outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type)
1069-
out_rtol = 1e-2 if dtype == "float16" else 1e-5
1076+
out_tol = 1e-2 if dtype == "float16" else 1e-5
10701077
tvm.testing.assert_allclose(
1071-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
1078+
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
10721079
)
10731080
exp_codegen = [
10741081
{
@@ -1124,9 +1131,9 @@ def _verify(shape, newshape):
11241131
params = {}
11251132
mod = IRModule.from_expr(out)
11261133
outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type)
1127-
out_rtol = 1e-3 if dtype == "float16" else 1e-5
1134+
out_tol = 1e-3 if dtype == "float16" else 1e-5
11281135
tvm.testing.assert_allclose(
1129-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
1136+
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
11301137
)
11311138
exp_codegen = [
11321139
{
@@ -1223,9 +1230,9 @@ def test_pool_global(remote, dtype, target, executor_type, trials):
12231230
func = relay.nn.global_avg_pool2d(a)
12241231
mod = IRModule.from_expr(func)
12251232
outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type)
1226-
out_rtol = 1e-3 if dtype == "float16" else 1e-5
1233+
out_tol = 1e-3 if dtype == "float16" else 1e-5
12271234
tvm.testing.assert_allclose(
1228-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
1235+
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
12291236
)
12301237
args = (input_shape, pooling_type, dtype, outputs[0].shape)
12311238
exp_codegen = _get_pool_global_expected_codegen(*args)
@@ -1241,6 +1248,7 @@ def _get_model(a_shape):
12411248
# Defined the test case with unary operator
12421249
# Single batch_flatten op is failing in native OpenCL
12431250
# Empty TVM mod in VM doesn't pick appropriate cross compiler
1251+
np.random.seed(0)
12441252
out = relay.nn.relu(a)
12451253
out = relay.nn.batch_flatten(out)
12461254
inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, a_shape).astype(dtype))}
@@ -1250,9 +1258,9 @@ def _get_model(a_shape):
12501258
def _verify(out, params, inputs):
12511259
mod = IRModule.from_expr(out)
12521260
outputs = _build_and_run_network(remote, mod, params, inputs, target, executor_type)
1253-
out_rtol = 1e-3 if dtype == "float16" else 1e-5
1261+
out_tol = 1e-3 if dtype == "float16" else 1e-5
12541262
tvm.testing.assert_allclose(
1255-
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, atol=out_rtol
1263+
outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
12561264
)
12571265
exp_codegen = [
12581266
{

0 commit comments

Comments
 (0)