Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add logic for no batch size while getting data arrays from executors (#…
Browse files Browse the repository at this point in the history
…17772) (#18122)

Co-authored-by: Ubuntu <[email protected]>

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
mseth10 and Ubuntu committed Apr 23, 2020
1 parent 4392b4c commit 8695537
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions python/mxnet/module/executor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,16 @@ def decide_slices(self, data_shapes):
def _collect_arrays(self):
"""Collect internal arrays from executors."""
# convenient data structures
self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)]
for name, _ in self.data_shapes]

# check if self.slices is populated, if not then that means that there is no batch size
if self.slices:
# based on batch size, slice up data for the given contexts (self.execs)
self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)]
for name, _ in self.data_shapes]
else:
# just use the context index as index into the data
self.data_arrays = [[(slice(i, i+1), e.arg_dict[name]) for i, e in enumerate(self.execs)]
for name, _ in self.data_shapes]

self.state_arrays = [[e.arg_dict[name] for e in self.execs]
for name in self.state_names]
Expand Down

0 comments on commit 8695537

Please sign in to comment.