Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions python/ray/air/_internal/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union

import numpy as np
import pandas as pd
Expand All @@ -19,6 +19,9 @@
_is_tensor_sequence_mapping,
)

if TYPE_CHECKING:
from ray.data.dataset import TorchDeviceType

# Default non-blocking transfer for tensors.
DEFAULT_TENSOR_NON_BLOCKING_TRANSFER = env_bool(
"RAY_AIR_DEFAULT_TENSOR_NON_BLOCKING_TRANSFER",
Expand Down Expand Up @@ -148,7 +151,7 @@ def get_tensor_for_columns(columns, dtype):
def convert_ndarray_to_torch_tensor(
ndarray: np.ndarray,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[str, "torch.device"]] = None,
device: Optional[TorchDeviceType] = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""Convert a NumPy ndarray to a Torch Tensor.
Expand Down Expand Up @@ -197,7 +200,7 @@ def convert_ndarray_to_torch_tensor(
def convert_ndarray_batch_to_torch_tensor_batch(
ndarrays: Union[np.ndarray, Dict[str, np.ndarray]],
dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
device: Optional[Union[str, "torch.device"]] = None,
device: Optional[TorchDeviceType] = None,
pin_memory: bool = False,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Convert a NumPy ndarray batch to a Torch Tensor batch.
Expand Down Expand Up @@ -333,7 +336,7 @@ def consume_prefix_in_state_dict_if_present_not_in_place(
def convert_ndarray_list_to_torch_tensor_list(
ndarrays: Dict[str, List[np.ndarray]],
dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
device: Optional[Union[str, "torch.device"]] = None,
device: Optional[TorchDeviceType] = None,
pin_memory: bool = False,
) -> Dict[str, List[torch.Tensor]]:
"""Convert a dict mapping column names to lists of ndarrays to Torch Tensors.
Expand Down Expand Up @@ -411,7 +414,7 @@ def arrow_batch_to_tensors(
@torch.no_grad()
def concat_tensors_to_device(
tensor_sequence: Sequence[torch.Tensor],
device: Optional[Union[str, "torch.device"]] = None,
device: Optional[TorchDeviceType] = None,
non_blocking: bool = DEFAULT_TENSOR_NON_BLOCKING_TRANSFER,
) -> torch.Tensor:
"""Stack sequence of tensors into a contiguous GPU tensor.
Expand Down Expand Up @@ -493,7 +496,7 @@ def _get_type_str(batch: Any) -> str:
@torch.no_grad()
def move_tensors_to_device(
batch: TensorBatchType,
device: Optional[Union[str, "torch.device"]] = None,
device: Optional[TorchDeviceType] = None,
non_blocking: bool = DEFAULT_TENSOR_NON_BLOCKING_TRANSFER,
) -> TensorBatchReturnType:
"""Move tensors to the specified device.
Expand Down