Skip to content

Commit

Permalink
add logic for no batch size while getting data arrays from executors (a…
Browse files Browse the repository at this point in the history
…pache#17772) (apache#18075)

Co-authored-by: Ubuntu <[email protected]>

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
mseth10 and Ubuntu authored Apr 26, 2020
1 parent 0e7dd91 commit 63e2b19
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions python/mxnet/module/executor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,16 @@ def decide_slices(self, data_shapes):
def _collect_arrays(self):
"""Collect internal arrays from executors."""
# convenient data structures
self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)]
for name, _ in self.data_shapes]

# check if self.slices is populated, if not then that means that there is no batch size
if self.slices:
# based on batch size, slice up data for the given contexts (self.execs)
self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)]
for name, _ in self.data_shapes]
else:
# just use the context index as index into the data
self.data_arrays = [[(slice(i, i+1), e.arg_dict[name]) for i, e in enumerate(self.execs)]
for name, _ in self.data_shapes]

self.state_arrays = [[e.arg_dict[name] for e in self.execs]
for name in self.state_names]
Expand Down

0 comments on commit 63e2b19

Please sign in to comment.