Skip to content

Commit 8cf733a

Browse files
author
Sergey Shtin
committed
Restored previous state.
1 parent 08563ad commit 8cf733a

File tree

1 file changed

+15
-33
lines changed

1 file changed

+15
-33
lines changed

python/tvm/topi/x86/concat.py

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
from ..utils import get_const_int, const_vector
2323

2424

25-
def _concat(a_tuple, axis=0):
26-
"""Join a sequence of arrays along an existing axis.
25+
def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0):
26+
"""Join a sequence of arrays along an existing axis. Optimized for CPU exeution.
2727
2828
Parameters
2929
----------
30-
a_tuple : tuple of tvm.te.Tensor
30+
data : tuple of tvm.te.Tensor
3131
The arrays to concatenate
3232
3333
axis : int, optional
@@ -45,7 +45,7 @@ def gen_ir_1d(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf):
4545
out_buf = i_b.buffer_ptr(out_buf)
4646
outers = i_b.buffer_ptr(in_outers_tensor)
4747
cumsum = i_b.buffer_ptr(in_cumsum_tensor)
48-
for i in range(len(a_tuple)):
48+
for i in range(len(data)):
4949
with i_b.for_range(0, outers[i], name="j") as j:
5050
out_buf[cumsum[i] + j] = data_bufs1[i][j]
5151
return i_b.get()
@@ -60,39 +60,39 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer)
6060
if inner > 1:
6161
with i_b.for_range(0, inner, name="inn", kind="parallel") as inn:
6262
pos = inn * outer
63-
for i in range(len(a_tuple)):
63+
for i in range(len(data)):
6464
offset = inn * outers[i]
6565
with i_b.for_range(0, outers[i], name="j") as j:
6666
out_buf[pos + cumsum[i] + j] = data_bufs1[i][offset + j]
6767
else:
68-
for i in range(len(a_tuple)):
68+
for i in range(len(data)):
6969
with i_b.for_range(0, outers[i], name="j", kind="parallel") as j:
7070
out_buf[cumsum[i] + j] = data_bufs1[i][j]
7171
return i_b.get()
7272

7373
if axis < 0:
74-
axis += len(a_tuple[0].shape)
75-
concat_axis_sizes = [int(t.shape[axis]) for t in a_tuple]
74+
axis += len(data[0].shape)
75+
concat_axis_sizes = [int(t.shape[axis]) for t in data]
7676
join_size = int(np.sum(concat_axis_sizes))
77-
in_outers = [int(np.prod(i.shape[axis:])) for i in a_tuple]
77+
in_outers = [int(np.prod(i.shape[axis:])) for i in data]
7878
in_outers_cumsum = [0, *np.cumsum(in_outers, dtype="int64")[0:-1]]
79-
dtype = a_tuple[0].dtype
80-
out_shape = a_tuple[0].shape[:axis] + [join_size] + a_tuple[0].shape[axis + 1 :]
79+
dtype = data[0].dtype
80+
out_shape = data[0].shape[:axis] + [join_size] + data[0].shape[axis + 1 :]
8181
in_outers_tensor = const_vector(in_outers)
8282
in_cumsum_tensor = const_vector(in_outers_cumsum, name="cumsum")
8383
right_val = np.prod(out_shape[axis:])
8484
left_val = np.prod(out_shape[:axis])
8585

8686
if (
87-
len(a_tuple[0].shape) == 1
87+
len(data[0].shape) == 1
8888
or right_val == 1
89-
or (left_val == 1 and axis == len(a_tuple[0].shape) - 1)
89+
or (left_val == 1 and axis == len(data[0].shape) - 1)
9090
or (left_val == 1 and right_val == 1)
9191
):
9292
# badly parallelized case
9393
return te.extern(
9494
[out_shape],
95-
list(a_tuple) + [in_outers_tensor, in_cumsum_tensor],
95+
list(data) + [in_outers_tensor, in_cumsum_tensor],
9696
lambda ins, outs: gen_ir_1d(ins, ins[-2], ins[-1], outs[0]),
9797
dtype=dtype,
9898
name="concatenate_ext",
@@ -102,26 +102,8 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer)
102102
outer = get_const_int(int(right_val))
103103
return te.extern(
104104
[out_shape],
105-
list(a_tuple) + [in_outers_tensor, in_cumsum_tensor],
105+
list(data) + [in_outers_tensor, in_cumsum_tensor],
106106
lambda ins, outs: gen_ir(ins, ins[-2], ins[-1], outs[0], inner, outer),
107107
dtype=dtype,
108108
name="concatenate_ext",
109109
)
110-
111-
112-
def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0):
113-
"""Join a sequence of arrays along an existing axis. Optimized for CPU exeution.
114-
115-
Parameters
116-
----------
117-
data : tuple of tvm.te.Tensor
118-
The arrays to concatenate
119-
120-
axis : int, optional
121-
The axis along which the arrays will be joined. Default is 0.
122-
123-
Returns
124-
-------
125-
ret : tvm.te.Tensor
126-
"""
127-
return _concat(data, axis=axis)

0 commit comments

Comments
 (0)