From ef2eadd144007dd3dc48af7a55e09290e815f981 Mon Sep 17 00:00:00 2001 From: Jim Lin Date: Thu, 28 Aug 2025 16:07:00 -0700 Subject: [PATCH] Consolidates `base.RandomAccessDataSource` and `data_sources.RandomAccessDataSource` to avoid confusion PiperOrigin-RevId: 800651478 --- grain/_src/python/BUILD | 1 + grain/_src/python/data_sources.py | 32 +++---------------------------- grain/_src/python/dataset/base.py | 26 ++++++++++++++++++++++--- 3 files changed, 27 insertions(+), 32 deletions(-) diff --git a/grain/_src/python/BUILD b/grain/_src/python/BUILD index 3b7228f4..b946d805 100644 --- a/grain/_src/python/BUILD +++ b/grain/_src/python/BUILD @@ -14,6 +14,7 @@ py_library( "@abseil-py//absl/logging", "//grain/_src/core:usage_logging", "//grain/_src/python/dataset:stats", + "//grain/_src/python/dataset:base", ] + select({ "@platforms//os:windows": [], "//conditions:default": ["@pypi//array_record:pkg"], diff --git a/grain/_src/python/data_sources.py b/grain/_src/python/data_sources.py index 18ed54e8..bc81ac12 100644 --- a/grain/_src/python/data_sources.py +++ b/grain/_src/python/data_sources.py @@ -29,14 +29,14 @@ import os import threading import time -import typing -from typing import Any, Generic, Optional, Protocol, SupportsIndex, TypeVar, Union +from typing import Any, Optional, SupportsIndex, TypeVar, Union from absl import logging from etils import epath from grain._src.core import monitoring as grain_monitoring from grain._src.core import usage_logging +from grain._src.python.dataset import base from grain._src.python.dataset import stats as dataset_stats from grain._src.core import monitoring # pylint: disable=g-bad-import-order @@ -127,33 +127,7 @@ def paths(self) -> ArrayRecordDataSourcePaths: return self._paths -@typing.runtime_checkable -class RandomAccessDataSource(Protocol, Generic[T]): - """Interface for datasources where storage supports efficient random access. - - Note that `__repr__` has to be additionally implemented to make checkpointing - work with this source. - """ - - def __len__(self) -> int: - """Returns the total number of records in the data source.""" - - def __getitem__(self, record_key: SupportsIndex) -> T: - """Returns the value for the given record_key. - - This method must be threadsafe. It's also expected to be deterministic. - When using multiprocessing (worker_count>0) PyGrain will pickle the data - source, which invokes __getstate__(), and send a copy to each worker - process, where __setstate__() is called. After that each worker process - has its own independent data source object. - - Arguments: - record_key: This will be an integer in [0, len(self)-1]. - - Returns: - The corresponding record. File data sources often return the raw bytes but - records can be any Python object. - """ +RandomAccessDataSource = base.RandomAccessDataSource class RangeDataSource: diff --git a/grain/_src/python/dataset/base.py b/grain/_src/python/dataset/base.py index e6e9c219..32600a39 100644 --- a/grain/_src/python/dataset/base.py +++ b/grain/_src/python/dataset/base.py @@ -22,7 +22,7 @@ import dataclasses import enum import typing -from typing import Generic, Protocol, TypeVar +from typing import Generic, Protocol, SupportsIndex, TypeVar T = TypeVar("T") @@ -30,12 +30,32 @@ @typing.runtime_checkable class RandomAccessDataSource(Protocol[T]): - """Interface for datasets where storage supports efficient random access.""" + """Interface for datasources where storage supports efficient random access. + + Note that `__repr__` has to be additionally implemented to make checkpointing + work with this source. + """ def __len__(self) -> int: + """Returns the total number of records in the data source.""" ... - def __getitem__(self, index: int) -> T: + def __getitem__(self, record_key: int | SupportsIndex) -> T: + """Returns the value for the given record_key. + + This method must be threadsafe. It's also expected to be deterministic. + When using multiprocessing (worker_count>0) PyGrain will pickle the data + source, which invokes __getstate__(), and send a copy to each worker + process, where __setstate__() is called. After that each worker process + has its own independent data source object. + + Arguments: + record_key: This will be an integer in [0, len(self)-1]. + + Returns: + The corresponding record. File data sources often return the raw bytes but + records can be any Python object. + """ ...