diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py index d47665d6d509..f2cb62fc8396 100755 --- a/python/mxnet/module/executor_group.py +++ b/python/mxnet/module/executor_group.py @@ -308,8 +308,16 @@ def decide_slices(self, data_shapes): def _collect_arrays(self): """Collect internal arrays from executors.""" # convenient data structures - self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)] - for name, _ in self.data_shapes] + + # check if self.slices is populated, if not then that means that there is no batch size + if self.slices: + # based on batch size, slice up data for the given contexts (self.execs) + self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)] + for name, _ in self.data_shapes] + else: + # just use the context index as index into the data + self.data_arrays = [[(slice(i, i+1), e.arg_dict[name]) for i, e in enumerate(self.execs)] + for name, _ in self.data_shapes] self.state_arrays = [[e.arg_dict[name] for e in self.execs] for name in self.state_names]