Skip to content

Commit e84f163

Browse files
authored
[TE] Optimized version of concatenation layer (#11341)
* [TE] Optimized version of concatenation layer 1. Concat implemented using extern_op 2. New tests added. 3. Workaround to allow inline extern_op-s with other layers. * *test fix * test_any.py fix. * test_forward.py from tensorflow fix. * lint fix. * Fixes after code review. * New comment added. * Lint fix. * Another lint fix. * Comments added. * rebase issue fix. * Restored previous state. * Update after code review. * After code review changes. * lint review. * Change strategy for cuda to fix tests. * Rebase to main * Comments changes after review. * Some more comments fixes. * One more error fix in comments. * restart build
1 parent a1d95ec commit e84f163

File tree

11 files changed

+359
-30
lines changed

11 files changed

+359
-30
lines changed

python/tvm/relay/op/_transform.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,12 @@
6868

6969

7070
# concatenate
71-
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
71+
@_reg.register_compute("concatenate")
72+
def compute_concat(attrs, inputs, output_type):
73+
return [topi.concatenate(inputs, attrs.axis)]
74+
75+
76+
_reg.register_strategy("concatenate", strategy.concatenate_strategy)
7277

7378
# sliding_window
7479
@_reg.register_compute("sliding_window")

python/tvm/relay/op/strategy/cuda.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,15 @@ def schedule_reduce_cuda(attrs, outs, target):
4242
return topi.cuda.schedule_reduce(outs)
4343

4444

45-
@schedule_concatenate.register(["cuda", "gpu"])
46-
def schedule_concatenate_cuda(attrs, outs, target):
47-
"""schedule concatenate for cuda"""
48-
with target:
49-
return topi.cuda.schedule_injective(outs)
45+
@concatenate_strategy.register(["cuda", "gpu"])
46+
def concatenate_strategy_cuda(attrs, inputs, out_type, target):
47+
strategy = _op.OpStrategy()
48+
strategy.add_implementation(
49+
wrap_compute_concat(topi.transform.concatenate),
50+
wrap_topi_schedule(topi.cuda.schedule_injective),
51+
name="concatenate.cuda",
52+
)
53+
return strategy
5054

5155

5256
@schedule_pool.register(["cuda", "gpu"])

python/tvm/relay/op/strategy/generic.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,6 +1781,15 @@ def _compute_scanop(attrs, inputs, _):
17811781
return _compute_scanop
17821782

17831783

1784+
def wrap_compute_concat(topi_compute):
1785+
"""Wrap concatenate topi compute"""
1786+
1787+
def _compute_concat(attrs, inputs, _):
1788+
return [topi_compute(inputs, attrs.axis)]
1789+
1790+
return _compute_concat
1791+
1792+
17841793
@override_native_generic_func("cumsum_strategy")
17851794
def cumsum_strategy(attrs, inputs, out_type, target):
17861795
"""cumsum generic strategy"""
@@ -1793,6 +1802,18 @@ def cumsum_strategy(attrs, inputs, out_type, target):
17931802
return strategy
17941803

17951804

1805+
@override_native_generic_func("concat_strategy")
1806+
def concatenate_strategy(attrs, inputs, out_type, target):
1807+
"""concatenate generic strategy"""
1808+
strategy = _op.OpStrategy()
1809+
strategy.add_implementation(
1810+
wrap_compute_concat(topi.concatenate),
1811+
wrap_topi_schedule(topi.generic.schedule_injective),
1812+
name="concatenate",
1813+
)
1814+
return strategy
1815+
1816+
17961817
@override_native_generic_func("cumprod_strategy")
17971818
def cumprod_strategy(attrs, inputs, out_type, target):
17981819
"""cumprod generic strategy"""

python/tvm/relay/op/strategy/x86.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020

2121
import re
22-
from tvm import topi
22+
from tvm import topi, tir
2323
from tvm.topi.x86.utils import target_has_vnni
2424
from tvm.auto_scheduler import is_auto_scheduler_enabled
2525
from tvm.te import SpecializedCondition
@@ -48,13 +48,6 @@ def schedule_reduce_cpu(attrs, outs, target):
4848
return topi.x86.schedule_reduce(outs)
4949

5050

51-
@schedule_concatenate.register("cpu")
52-
def schedule_concatenate_cpu(attrs, outs, target):
53-
"""schedule concatenate op for x86"""
54-
with target:
55-
return topi.x86.schedule_concatenate(outs)
56-
57-
5851
@schedule_pool.register("cpu")
5952
def schedule_pool_cpu(attrs, outs, target):
6053
"""schedule pooling ops for x86"""
@@ -741,3 +734,34 @@ def conv2d_winograd_without_weight_transfrom_strategy_cpu(attrs, inputs, out_typ
741734
"Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout)
742735
)
743736
return strategy
737+
738+
739+
@concatenate_strategy.register(["cpu"])
740+
def concatenate_strategy_cpu(attrs, inputs, out_type, target):
741+
"""concatenate x86 strategy"""
742+
strategy = _op.OpStrategy()
743+
use_only_old_concat = False
744+
for inpt in inputs:
745+
shape = inpt.shape
746+
for i in shape:
747+
if not isinstance(i, tir.expr.IntImm):
748+
use_only_old_concat = True
749+
break
750+
if use_only_old_concat:
751+
strategy.add_implementation(
752+
wrap_compute_concat(topi.transform.concatenate),
753+
wrap_topi_schedule(topi.x86.injective.schedule_concatenate),
754+
name="concatenate.generic",
755+
)
756+
else:
757+
strategy.add_implementation(
758+
wrap_compute_concat(topi.x86.concatenate),
759+
wrap_topi_schedule(topi.x86.schedule_concatenate_cpu),
760+
name="concatenate.cpu",
761+
)
762+
strategy.add_implementation(
763+
wrap_compute_concat(topi.transform.concatenate),
764+
wrap_topi_schedule(topi.x86.injective.schedule_concatenate),
765+
name="concatenate.generic",
766+
)
767+
return strategy

python/tvm/topi/x86/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,4 @@
4343
from .scatter import *
4444
from .group_conv2d import *
4545
from .math_alter_op import *
46+
from .concat import *

python/tvm/topi/x86/concat.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"concatenate related operators"
18+
from typing import Optional
19+
import tvm
20+
from tvm import te
21+
import numpy as np
22+
from ..utils import get_const_int, const_vector
23+
24+
25+
def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0):
26+
"""Join a sequence of arrays along an existing axis. Optimized for CPU exeution.
27+
28+
Parameters
29+
----------
30+
data : tuple of tvm.te.Tensor
31+
The arrays to concatenate
32+
33+
axis : int, optional
34+
The axis along which the arrays will be joined. Default is 0.
35+
36+
Returns
37+
-------
38+
ret : tvm.te.Tensor
39+
"""
40+
41+
def gen_ir_1d(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf):
42+
"""Custom conactenation execution."""
43+
i_b = tvm.tir.ir_builder.create()
44+
data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs]
45+
out_buf = i_b.buffer_ptr(out_buf)
46+
outers = i_b.buffer_ptr(in_outers_tensor)
47+
cumsum = i_b.buffer_ptr(in_cumsum_tensor)
48+
for i in range(len(data)):
49+
with i_b.for_range(0, outers[i], name="j") as j:
50+
out_buf[cumsum[i] + j] = data_bufs1[i][j]
51+
return i_b.get()
52+
53+
def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer):
54+
"""Common case of conactenation execution."""
55+
i_b = tvm.tir.ir_builder.create()
56+
data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs]
57+
out_buf = i_b.buffer_ptr(out_buf)
58+
outers = i_b.buffer_ptr(in_outers_tensor)
59+
cumsum = i_b.buffer_ptr(in_cumsum_tensor)
60+
if inner > 1:
61+
with i_b.for_range(0, inner, name="inn", kind="parallel") as inn:
62+
pos = inn * outer
63+
for i in range(len(data)):
64+
offset = inn * outers[i]
65+
with i_b.for_range(0, outers[i], name="j") as j:
66+
out_buf[pos + cumsum[i] + j] = data_bufs1[i][offset + j]
67+
else:
68+
for i in range(len(data)):
69+
with i_b.for_range(0, outers[i], name="j", kind="parallel") as j:
70+
out_buf[cumsum[i] + j] = data_bufs1[i][j]
71+
return i_b.get()
72+
73+
if axis < 0:
74+
axis += len(data[0].shape)
75+
concat_axis_sizes = [int(t.shape[axis]) for t in data]
76+
join_size = int(np.sum(concat_axis_sizes))
77+
in_outers = [int(np.prod(i.shape[axis:])) for i in data]
78+
in_outers_cumsum = [0, *np.cumsum(in_outers, dtype="int64")[0:-1]]
79+
dtype = data[0].dtype
80+
out_shape = data[0].shape[:axis] + [join_size] + data[0].shape[axis + 1 :]
81+
in_outers_tensor = const_vector(in_outers)
82+
in_cumsum_tensor = const_vector(in_outers_cumsum, name="cumsum")
83+
right_val = np.prod(out_shape[axis:])
84+
left_val = np.prod(out_shape[:axis])
85+
86+
if (
87+
len(data[0].shape) == 1
88+
or right_val == 1
89+
or (left_val == 1 and axis == len(data[0].shape) - 1)
90+
or (left_val == 1 and right_val == 1)
91+
):
92+
# badly parallelized case
93+
return te.extern(
94+
[out_shape],
95+
list(data) + [in_outers_tensor, in_cumsum_tensor],
96+
lambda ins, outs: gen_ir_1d(ins, ins[-2], ins[-1], outs[0]),
97+
dtype=dtype,
98+
name="concatenate_ext",
99+
)
100+
101+
inner = get_const_int(int(left_val))
102+
outer = get_const_int(int(right_val))
103+
return te.extern(
104+
[out_shape],
105+
list(data) + [in_outers_tensor, in_cumsum_tensor],
106+
lambda ins, outs: gen_ir(ins, ins[-2], ins[-1], outs[0], inner, outer),
107+
dtype=dtype,
108+
name="concatenate_ext",
109+
)

python/tvm/topi/x86/injective.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,22 @@
1717
# pylint: disable=invalid-name
1818
"""x86 declaration and schedules."""
1919
from tvm import te
20+
from tvm.topi import tag
2021
from tvm.tir import IntImm
22+
from tvm.topi.generic.injective import (
23+
schedule_injective_from_existing as schedule_injective_for_concat,
24+
)
2125
from ..utils import is_empty_shape
2226

2327

2428
def schedule_injective_from_existing(sch, out):
2529
"""Schedule for injective op from existing schedule.
26-
2730
Parameters
2831
----------
2932
sch: Schedule
3033
The schedule to update.
3134
out: Tensor
3235
The tensor representing the injective op.
33-
3436
Returns
3537
-------
3638
sch: Schedule
@@ -61,13 +63,11 @@ def schedule_injective_from_existing(sch, out):
6163

6264
def schedule_injective(outs):
6365
"""X86 schedule for injective op.
64-
6566
Parameters
6667
----------
6768
outs: Array of Tensor
6869
The computation graph description of injective in the format
6970
of an array of tensors.
70-
7171
Returns
7272
-------
7373
sch: Schedule
@@ -85,13 +85,11 @@ def schedule_injective(outs):
8585

8686
def schedule_concatenate(outs):
8787
"""X86 schedule for concatenate op.
88-
8988
Parameters
9089
----------
9190
outs: Array of Tensor
9291
The computation graph description of injective in the format
9392
of an array of tensors.
94-
9593
Returns
9694
-------
9795
sch: Schedule
@@ -132,5 +130,37 @@ def vectorize(sch, tensor, vectorize_limit):
132130
return s
133131

134132

133+
def schedule_concatenate_cpu(outs):
134+
"""X86 schedule for concatenate op.
135+
Parameters
136+
----------
137+
outs: Array of Tensor
138+
The computation graph description in the format
139+
of an array of tensors.
140+
Returns
141+
-------
142+
sch: Schedule
143+
The computation schedule for the op.
144+
"""
145+
146+
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
147+
s = te.create_schedule([x.op for x in outs])
148+
scheduled_ops = []
149+
150+
def traverse(op):
151+
if tag.is_injective(op.tag):
152+
schedule_injective_for_concat(s, op.output(0))
153+
154+
for tensor in op.input_tensors:
155+
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
156+
traverse(tensor.op)
157+
scheduled_ops.append(op)
158+
159+
for out in outs:
160+
traverse(out.op)
161+
162+
return s
163+
164+
135165
schedule_elemwise = schedule_injective
136166
schedule_broadcast = schedule_injective

src/relay/op/tensor/transform.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,6 @@ RELAY_REGISTER_OP("concatenate")
346346
.set_support_level(1)
347347
.add_type_rel("Concatenate", ConcatenateRel<ConcatenateAttrs>)
348348
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConcatenateLayout)
349-
.set_attr<FTVMCompute>("FTVMCompute", ConcatenateCompute)
350349
.set_attr<TOpPattern>("TOpPattern", kInjective);
351350

352351
TVM_REGISTER_NODE_TYPE(StackAttrs);

src/te/schedule/schedule_dataflow_rewrite.cc

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,29 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) {
511511
std::vector<bool> changed(sch->stages.size(), false);
512512
std::vector<Stmt> new_hybrid_body(sch->stages.size());
513513
std::vector<bool> hybrid_changed(sch->stages.size(), false);
514+
// (sshtin): this workaround allows to inline extern ops into their consumer.
515+
// All inputs for extern op should not be inlined because inlining may happen
516+
// before TE generation for particular extern op. That may lead to
517+
// crash during lowering or building stages.
518+
// The problem description:
519+
// In case of operations fusing, arguments inlining
520+
// prevents creation of ProducerNode for extern operation.
521+
// Instead of the creation it is supposed to use operation argument as inlined buffer
522+
// but extern_op TIR generation can be peformed after inlining procedure so
523+
// newly generated TIR does not have reference to input data at all.
524+
std::unordered_map<Operation, Operation> ext_ops;
525+
for (size_t i = 0; i < sch->stages.size(); i++) {
526+
Stage stage = sch->stages[i];
527+
auto ext_op = stage->op.as<ExternOpNode>();
528+
if (ext_op) {
529+
auto inps = ext_op->InputTensors();
530+
for (size_t ii = 0; ii < inps.size(); ++ii) {
531+
if (ext_ops.find(inps[ii]->op) == ext_ops.end()) {
532+
ext_ops[inps[ii]->op] = stage->op;
533+
}
534+
}
535+
}
536+
}
514537
// inline all the ops
515538
for (size_t i = sch->stages.size(); i != 0; --i) {
516539
Stage stage = sch->stages[i - 1];
@@ -525,8 +548,13 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) {
525548
for (auto iv : compute->axis) {
526549
args.push_back(iv->var);
527550
}
551+
if (ext_ops.find(stage->op) != ext_ops.end()) {
552+
// sshtin: The extern op can try to get access to the input tensors as a raw data,
553+
// that can lead to error in IR builder.
554+
stage->attach_type = kGroupRoot;
555+
continue;
556+
}
528557
ICHECK_EQ(compute->body.size(), 1U) << "can only inline compute op with 1 output";
529-
530558
if (feature_extraction_mode && compute->attrs.count("const_matrix")) {
531559
// Use constant value to replace access of const matrices.
532560
// This produces wrong IR but is good enough for feature extraction purposes.

0 commit comments

Comments
 (0)