Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions grain/_src/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ py_library(
"@abseil-py//absl/logging",
"//grain/_src/core:usage_logging",
"//grain/_src/python/dataset:stats",
"//grain/_src/python/dataset:base",
] + select({
"@platforms//os:windows": [],
"//conditions:default": ["@pypi//array_record:pkg"],
Expand Down
32 changes: 3 additions & 29 deletions grain/_src/python/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
import os
import threading
import time
import typing
from typing import Any, Generic, Optional, Protocol, SupportsIndex, TypeVar, Union
from typing import Any, Optional, SupportsIndex, TypeVar, Union

from absl import logging
from etils import epath

from grain._src.core import monitoring as grain_monitoring
from grain._src.core import usage_logging
from grain._src.python.dataset import base
from grain._src.python.dataset import stats as dataset_stats

from grain._src.core import monitoring # pylint: disable=g-bad-import-order
Expand Down Expand Up @@ -127,33 +127,7 @@ def paths(self) -> ArrayRecordDataSourcePaths:
return self._paths


@typing.runtime_checkable
class RandomAccessDataSource(Protocol, Generic[T]):
"""Interface for datasources where storage supports efficient random access.

Note that `__repr__` has to be additionally implemented to make checkpointing
work with this source.
"""

def __len__(self) -> int:
"""Returns the total number of records in the data source."""

def __getitem__(self, record_key: SupportsIndex) -> T:
"""Returns the value for the given record_key.

This method must be threadsafe. It's also expected to be deterministic.
When using multiprocessing (worker_count>0) PyGrain will pickle the data
source, which invokes __getstate__(), and send a copy to each worker
process, where __setstate__() is called. After that each worker process
has its own independent data source object.

Arguments:
record_key: This will be an integer in [0, len(self)-1].

Returns:
The corresponding record. File data sources often return the raw bytes but
records can be any Python object.
"""
RandomAccessDataSource = base.RandomAccessDataSource


class RangeDataSource:
Expand Down
26 changes: 23 additions & 3 deletions grain/_src/python/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,40 @@
import dataclasses
import enum
import typing
from typing import Generic, Protocol, TypeVar
from typing import Generic, Protocol, SupportsIndex, TypeVar


T = TypeVar("T")


@typing.runtime_checkable
class RandomAccessDataSource(Protocol[T]):
"""Interface for datasets where storage supports efficient random access."""
"""Interface for datasources where storage supports efficient random access.

Note that `__repr__` has to be additionally implemented to make checkpointing
work with this source.
"""

def __len__(self) -> int:
"""Returns the total number of records in the data source."""
...

def __getitem__(self, index: int) -> T:
def __getitem__(self, record_key: int | SupportsIndex) -> T:
"""Returns the value for the given record_key.

This method must be threadsafe. It's also expected to be deterministic.
When using multiprocessing (worker_count>0) PyGrain will pickle the data
source, which invokes __getstate__(), and send a copy to each worker
process, where __setstate__() is called. After that each worker process
has its own independent data source object.

Arguments:
record_key: This will be an integer in [0, len(self)-1].

Returns:
The corresponding record. File data sources often return the raw bytes but
records can be any Python object.
"""
...


Expand Down