Skip to content

Commit 35522cb

Browse files
[data] Small (typing) overhaul of valid torch device types (#56743)
## Why are these changes needed? Follow up of RLlib PR: #55291 Torch `device` allows input types of `str | int | torch.device`, this PR unifies the type in a type variable and allows for the `int` type as well. Upstream air PR: #56745 ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [x] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ x I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] This PR is not tested :( --------- Signed-off-by: Daniel Sperber <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent cf2331b commit 35522cb

File tree

3 files changed

+15
-7
lines changed

3 files changed

+15
-7
lines changed

python/ray/data/collate_fn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@
1414

1515
import numpy as np
1616

17-
from ray.data.block import DataBatch
1817
from ray.util.annotations import DeveloperAPI
1918

2019
if TYPE_CHECKING:
2120
import pandas
2221
import pyarrow
2322
import torch
2423

25-
from ray.data.dataset import CollatedData
24+
from ray.data.block import DataBatch
25+
from ray.data.dataset import CollatedData, TorchDeviceType
2626

2727

28-
DataBatchType = TypeVar("DataBatchType", bound=DataBatch)
28+
DataBatchType = TypeVar("DataBatchType", bound="DataBatch")
2929

3030
TensorSequenceType = Union[
3131
List["torch.Tensor"],
@@ -226,7 +226,7 @@ class DefaultCollateFn(ArrowBatchCollateFn):
226226
def __init__(
227227
self,
228228
dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
229-
device: Optional[Union[str, "torch.device"]] = None,
229+
device: Optional["TorchDeviceType"] = None,
230230
pin_memory: bool = False,
231231
):
232232
"""Initialize the collate function.
@@ -242,7 +242,7 @@ def __init__(
242242

243243
super().__init__()
244244
self.dtypes = dtypes
245-
if isinstance(device, str):
245+
if isinstance(device, (str, int)):
246246
self.device = torch.device(device)
247247
else:
248248
self.device = device

python/ray/data/dataset.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@
152152
CollatedData = TypeVar("CollatedData")
153153
TorchBatchType = Union[Dict[str, "torch.Tensor"], CollatedData]
154154

155+
TorchDeviceType = Union[str, "torch.device", int]
156+
"""
157+
A device identifier, which can be a string (e.g. 'cpu', 'cuda:0'),
158+
a torch.device object, or an integer (e.g. 0 for 'cuda:0').
159+
"""
160+
155161
BT_API_GROUP = "Basic Transformations"
156162
SSR_API_GROUP = "Sorting, Shuffling and Repartitioning"
157163
SMJ_API_GROUP = "Splitting, Merging, Joining datasets"
@@ -5318,7 +5324,7 @@ def iter_torch_batches(
53185324
prefetch_batches: int = 1,
53195325
batch_size: Optional[int] = 256,
53205326
dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
5321-
device: str = "auto",
5327+
device: Union[TorchDeviceType, Literal["auto"]] = "auto",
53225328
collate_fn: Optional[Callable[[Dict[str, np.ndarray]], CollatedData]] = None,
53235329
drop_last: bool = False,
53245330
local_shuffle_buffer_size: Optional[int] = None,

python/ray/data/iterator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Iterable,
1010
Iterator,
1111
List,
12+
Literal,
1213
Optional,
1314
Tuple,
1415
TypeVar,
@@ -47,6 +48,7 @@
4748
Schema,
4849
TensorFlowTensorBatchType,
4950
TorchBatchType,
51+
TorchDeviceType,
5052
)
5153

5254

@@ -272,7 +274,7 @@ def iter_torch_batches(
272274
prefetch_batches: int = 1,
273275
batch_size: Optional[int] = 256,
274276
dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
275-
device: str = "auto",
277+
device: Union["TorchDeviceType", Literal["auto"]] = "auto",
276278
collate_fn: Optional[
277279
Union[Callable[[Dict[str, np.ndarray]], "CollatedData"], CollateFn]
278280
] = None,

0 commit comments

Comments
 (0)