@@ -140,14 +140,11 @@ def compute(shape, fcompute, name="compute"):
140140 return op_node .output (0 )
141141
142142
143- def scan (axis , init , update , state_placeholder , name = "scan" ):
143+ def scan (init , update , state_placeholder , name = "scan" ):
144144 """Construct new tensors by scanning over axis.
145145
146146 Parameters
147147 ----------
148- axis: IterVar
149- The scanning axis.
150-
151148 init: Tensor or list of Tensor
152149 The initial condition of first init.shape[0] timestamps
153150
@@ -170,12 +167,11 @@ def scan(axis, init, update, state_placeholder, name="scan"):
170167 # The following code is equivalent to numpy.cumsum
171168 m = tvm.Var("m")
172169 n = tvm.Var("n")
173- t = tvm.IterVar((1, m), name="t")
174170 X = tvm.placeholder((m, n), name="X")
175171 s_state = tvm.placeholder((m, n))
176172 s_init = tvm.compute((1, n), lambda _, i: X[0, i])
177- s_update = tvm.compute((n, ), lambda i: s_state[t-1, i] + X[t, i])
178- res = tvm.scan(t, s_init, s_update, s_state)
173+ s_update = tvm.compute((m, n ), lambda t, i: s_state[t-1, i] + X[t, i])
174+ res = tvm.scan(s_init, s_update, s_state)
179175 """
180176 if isinstance (init , _tensor .Tensor ):
181177 init = [init ]
@@ -185,6 +181,7 @@ def scan(axis, init, update, state_placeholder, name="scan"):
185181 state_placeholder = [state_placeholder ]
186182 if len (init ) != len (update ) or len (init ) != len (state_placeholder ):
187183 raise ValueError ("init, update, state_placeholder must have same length" )
184+ axis = IterVar ((init [0 ].shape [0 ], update [0 ].shape [0 ]), "%s.idx" % name )
188185 op = _api_internal ._ScanOp (name , axis , init , update , state_placeholder )
189186 res = [op .output (i ) for i in range (len (update ))]
190187 return (res [0 ] if len (res ) == 1 else res )
0 commit comments