2222from ..utils import get_const_int , const_vector
2323
2424
25- def _concat ( a_tuple , axis = 0 ):
26- """Join a sequence of arrays along an existing axis.
25+ def concatenate ( data : tvm . te . Tensor , axis : Optional [ int ] = 0 ):
26+ """Join a sequence of arrays along an existing axis. Optimized for CPU exeution.
2727
2828 Parameters
2929 ----------
30- a_tuple : tuple of tvm.te.Tensor
30+ data : tuple of tvm.te.Tensor
3131 The arrays to concatenate
3232
3333 axis : int, optional
@@ -45,7 +45,7 @@ def gen_ir_1d(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf):
4545 out_buf = i_b .buffer_ptr (out_buf )
4646 outers = i_b .buffer_ptr (in_outers_tensor )
4747 cumsum = i_b .buffer_ptr (in_cumsum_tensor )
48- for i in range (len (a_tuple )):
48+ for i in range (len (data )):
4949 with i_b .for_range (0 , outers [i ], name = "j" ) as j :
5050 out_buf [cumsum [i ] + j ] = data_bufs1 [i ][j ]
5151 return i_b .get ()
@@ -60,39 +60,39 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer)
6060 if inner > 1 :
6161 with i_b .for_range (0 , inner , name = "inn" , kind = "parallel" ) as inn :
6262 pos = inn * outer
63- for i in range (len (a_tuple )):
63+ for i in range (len (data )):
6464 offset = inn * outers [i ]
6565 with i_b .for_range (0 , outers [i ], name = "j" ) as j :
6666 out_buf [pos + cumsum [i ] + j ] = data_bufs1 [i ][offset + j ]
6767 else :
68- for i in range (len (a_tuple )):
68+ for i in range (len (data )):
6969 with i_b .for_range (0 , outers [i ], name = "j" , kind = "parallel" ) as j :
7070 out_buf [cumsum [i ] + j ] = data_bufs1 [i ][j ]
7171 return i_b .get ()
7272
7373 if axis < 0 :
74- axis += len (a_tuple [0 ].shape )
75- concat_axis_sizes = [int (t .shape [axis ]) for t in a_tuple ]
74+ axis += len (data [0 ].shape )
75+ concat_axis_sizes = [int (t .shape [axis ]) for t in data ]
7676 join_size = int (np .sum (concat_axis_sizes ))
77- in_outers = [int (np .prod (i .shape [axis :])) for i in a_tuple ]
77+ in_outers = [int (np .prod (i .shape [axis :])) for i in data ]
7878 in_outers_cumsum = [0 , * np .cumsum (in_outers , dtype = "int64" )[0 :- 1 ]]
79- dtype = a_tuple [0 ].dtype
80- out_shape = a_tuple [0 ].shape [:axis ] + [join_size ] + a_tuple [0 ].shape [axis + 1 :]
79+ dtype = data [0 ].dtype
80+ out_shape = data [0 ].shape [:axis ] + [join_size ] + data [0 ].shape [axis + 1 :]
8181 in_outers_tensor = const_vector (in_outers )
8282 in_cumsum_tensor = const_vector (in_outers_cumsum , name = "cumsum" )
8383 right_val = np .prod (out_shape [axis :])
8484 left_val = np .prod (out_shape [:axis ])
8585
8686 if (
87- len (a_tuple [0 ].shape ) == 1
87+ len (data [0 ].shape ) == 1
8888 or right_val == 1
89- or (left_val == 1 and axis == len (a_tuple [0 ].shape ) - 1 )
89+ or (left_val == 1 and axis == len (data [0 ].shape ) - 1 )
9090 or (left_val == 1 and right_val == 1 )
9191 ):
9292 # badly parallelized case
9393 return te .extern (
9494 [out_shape ],
95- list (a_tuple ) + [in_outers_tensor , in_cumsum_tensor ],
95+ list (data ) + [in_outers_tensor , in_cumsum_tensor ],
9696 lambda ins , outs : gen_ir_1d (ins , ins [- 2 ], ins [- 1 ], outs [0 ]),
9797 dtype = dtype ,
9898 name = "concatenate_ext" ,
@@ -102,26 +102,8 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer)
102102 outer = get_const_int (int (right_val ))
103103 return te .extern (
104104 [out_shape ],
105- list (a_tuple ) + [in_outers_tensor , in_cumsum_tensor ],
105+ list (data ) + [in_outers_tensor , in_cumsum_tensor ],
106106 lambda ins , outs : gen_ir (ins , ins [- 2 ], ins [- 1 ], outs [0 ], inner , outer ),
107107 dtype = dtype ,
108108 name = "concatenate_ext" ,
109109 )
110-
111-
112- def concatenate (data : tvm .te .Tensor , axis : Optional [int ] = 0 ):
113- """Join a sequence of arrays along an existing axis. Optimized for CPU exeution.
114-
115- Parameters
116- ----------
117- data : tuple of tvm.te.Tensor
118- The arrays to concatenate
119-
120- axis : int, optional
121- The axis along which the arrays will be joined. Default is 0.
122-
123- Returns
124- -------
125- ret : tvm.te.Tensor
126- """
127- return _concat (data , axis = axis )
0 commit comments