Skip to content
Merged
11 changes: 10 additions & 1 deletion python/paddle/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Callable, Generator, TypeVar

_T = TypeVar('_T')
__all__ = []


def batch(reader, batch_size, drop_last=False):
def batch(
reader: Callable[[], Generator[_T, None, None]],
batch_size: int,
drop_last: bool = False,
) -> Callable[[], Generator[list[_T], None, None]]:
"""
This operator creates a batched reader which combines the data from the
input reader to batched data.
Expand Down