Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
aeb5387
Overhaul typing submodule
Daraan Aug 6, 2025
e97c7b0
Use better torch DeviceType in data
Daraan Aug 6, 2025
81d7271
Update Runner interface with more precise Device typing
Daraan Aug 6, 2025
1b89882
fixup old style
Daraan Aug 6, 2025
f27bf95
Update DeviceType doc
Daraan Aug 6, 2025
1e0822d
Remove duplicate comment
Daraan Aug 6, 2025
e981c6f
Fix missing guard
Daraan Aug 6, 2025
012bedd
Add more string guards on older for older versions
Daraan Aug 6, 2025
e832e0d
Add DeviceType for NumpyToTensor
Daraan Aug 6, 2025
4b6a378
Merge branch 'master' into fix-tensor-type
Daraan Aug 21, 2025
dd828ac
Merge branch 'master' into fix-tensor-type
Daraan Aug 22, 2025
8e2ed7c
Merge branch 'master' into fix-tensor-type
kamil-kaczmarek Sep 18, 2025
7fc842a
Merge branch 'master' into fix-tensor-type
kamil-kaczmarek Sep 18, 2025
5a2be91
Merge branch 'master' into fix-tensor-type
Daraan Sep 18, 2025
58a79c1
typing: torch devices as internal alias
Daraan Sep 18, 2025
95b7396
Merge remote-tracking branch 'origin/fix-tensor-type' into fix-tensor…
Daraan Sep 18, 2025
336e51b
Merge branch 'master' into fix-tensor-type
kamil-kaczmarek Sep 18, 2025
676ea71
revert changes to data
Daraan Sep 19, 2025
9dbe2c8
Move tensor types to DeviceType
Daraan Sep 19, 2025
33fd7ab
Merge branch 'master' into fix-tensor-type
Daraan Sep 19, 2025
7bfd886
Merge branch 'master' into fix-tensor-type
kamil-kaczmarek Sep 24, 2025
5050a4c
Merge branch 'master' into fix-tensor-type
kamil-kaczmarek Sep 24, 2025
522bab8
Merge branch 'master' into fix-tensor-type
kamil-kaczmarek Sep 25, 2025
4bf43eb
Merge branch 'master' into fix-tensor-type
Daraan Sep 27, 2025
ae2e422
Merge branch 'master' into fix-tensor-type
Daraan Sep 28, 2025
b9bd7d9
Merge branch 'master' into fix-tensor-type
kamil-kaczmarek Oct 4, 2025
f5d96f5
sort imports
Daraan Oct 4, 2025
a99d49f
Merge branch 'master' into fix-tensor-type
Daraan Oct 20, 2025
46731ec
Merge branch 'master' into fix-tensor-type
kamil-kaczmarek Oct 21, 2025
3af8145
Merge branch 'master' into fix-tensor-type
Daraan Oct 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
import torch
import torch.utils.data
from tensorflow_metadata.proto.v0 import schema_pb2
from torch._prims_common import DeviceLikeType

from ray.data._internal.execution.interfaces import Executor, NodeIdStr
from ray.data.grouped_data import GroupedData
Expand Down Expand Up @@ -5039,7 +5040,7 @@ def iter_torch_batches(
prefetch_batches: int = 1,
batch_size: Optional[int] = 256,
dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
device: str = "auto",
device: Union["DeviceLikeType", Literal["auto"]] = "auto",
collate_fn: Optional[Callable[[Dict[str, np.ndarray]], CollatedData]] = None,
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
Expand Down
4 changes: 3 additions & 1 deletion python/ray/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Iterable,
Iterator,
List,
Literal,
Optional,
Tuple,
TypeVar,
Expand Down Expand Up @@ -40,6 +41,7 @@
if TYPE_CHECKING:
import tensorflow as tf
import torch
from torch._prims_common import DeviceLikeType

from ray.data.dataset import (
CollatedData,
Expand Down Expand Up @@ -272,7 +274,7 @@ def iter_torch_batches(
prefetch_batches: int = 1,
batch_size: Optional[int] = 256,
dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
device: str = "auto",
device: Union["DeviceLikeType", Literal["auto"]] = "auto",
collate_fn: Optional[
Union[Callable[[Dict[str, np.ndarray]], "CollatedData"], CollateFn]
] = None,
Expand Down
7 changes: 5 additions & 2 deletions rllib/connectors/common/numpy_to_tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TYPE_CHECKING

import gymnasium as gym

Expand All @@ -12,6 +12,9 @@
from ray.rllib.utils.typing import EpisodeType
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.rllib.utils.typing import DeviceType


@PublicAPI(stability="alpha")
class NumpyToTensor(ConnectorV2):
Expand Down Expand Up @@ -59,7 +62,7 @@ def __init__(
input_action_space: Optional[gym.Space] = None,
*,
pin_memory: bool = False,
device: Optional[str] = None,
device: Optional["DeviceType"] = None,
**kwargs,
):
"""Initializes a NumpyToTensor instance.
Expand Down
11 changes: 6 additions & 5 deletions rllib/core/learner/differentiable_learner.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import abc
import logging
import numpy
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
Optional,
Tuple,
TYPE_CHECKING,
Union,
)

import numpy

from ray.rllib.connectors.learner.learner_connector_pipeline import (
LearnerConnectorPipeline,
)
Expand All @@ -22,19 +23,19 @@
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.utils import unflatten_dict
from ray.rllib.utils.annotations import (
override,
OverrideToImplementCustomLogic,
OverrideToImplementCustomLogic_CallToSuperRecommended,
override,
)
from ray.rllib.utils.checkpoints import Checkpointable
from ray.rllib.utils.metrics import (
DATASET_NUM_ITERS_TRAINED,
DATASET_NUM_ITERS_TRAINED_LIFETIME,
MODULE_TRAIN_BATCH_SIZE_MEAN,
NUM_ENV_STEPS_TRAINED,
NUM_ENV_STEPS_TRAINED_LIFETIME,
NUM_MODULE_STEPS_TRAINED,
NUM_MODULE_STEPS_TRAINED_LIFETIME,
MODULE_TRAIN_BATCH_SIZE_MEAN,
WEIGHTS_SEQ_NO,
)
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
Expand Down Expand Up @@ -124,7 +125,7 @@ def build(self, device: Optional[DeviceType] = None) -> None:
if self._is_built:
logger.debug("DifferentiableLearner already built. Skipping built.")

# If a dvice was passed, set the `DifferentiableLearner`'s device.
# If a device was passed, set the `DifferentiableLearner`'s device.
if device:
self._device = device

Expand Down
10 changes: 6 additions & 4 deletions rllib/offline/offline_evaluation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,11 @@ def set_device(self):
try:
self.__device = get_device(
self.config,
0
if not self.worker_index
else self.config.num_gpus_per_offline_eval_runner,
(
0
if not self.worker_index
else self.config.num_gpus_per_offline_eval_runner
),
)
except NotImplementedError:
self.__device = None
Expand Down Expand Up @@ -456,7 +458,7 @@ def _batch_iterator(self) -> MiniBatchRayDataIterator:
return self.__batch_iterator

@property
def _device(self) -> DeviceType:
def _device(self) -> Union[DeviceType, None]:
return self.__device

@property
Expand Down
30 changes: 18 additions & 12 deletions rllib/offline/offline_policy_evaluation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,11 @@ def __call__(self, batch: Dict[str, numpy.ndarray]) -> Dict[str, numpy.ndarray]:
# TODO (simon): Refactor into a single code block for both cases.
episodes = self.episode_buffer.sample(
num_items=self.config.train_batch_size_per_learner,
batch_length_T=self.config.model_config.get("max_seq_len", 0)
if self._module.is_stateful()
else None,
batch_length_T=(
self.config.model_config.get("max_seq_len", 0)
if self._module.is_stateful()
else None
),
n_step=self.config.get("n_step", 1) or 1,
# TODO (simon): This can be removed as soon as DreamerV3 has been
# cleaned up, i.e. can use episode samples for training.
Expand All @@ -131,9 +133,11 @@ def __call__(self, batch: Dict[str, numpy.ndarray]) -> Dict[str, numpy.ndarray]:
# Sample steps from the buffer.
episodes = self.episode_buffer.sample(
num_items=self.config.train_batch_size_per_learner,
batch_length_T=self.config.model_config.get("max_seq_len", 0)
if self._module.is_stateful()
else None,
batch_length_T=(
self.config.model_config.get("max_seq_len", 0)
if self._module.is_stateful()
else None
),
n_step=self.config.get("n_step", 1) or 1,
# TODO (simon): This can be removed as soon as DreamerV3 has been
# cleaned up, i.e. can use episode samples for training.
Expand Down Expand Up @@ -241,14 +245,14 @@ def _create_batch_iterator(self, **kwargs) -> Iterable:
# Define the collate function that converts the flattened dictionary
# to a `MultiAgentBatch` with Tensors.
def _collate_fn(
_batch: Dict[str, numpy.ndarray]
_batch: Dict[str, numpy.ndarray],
) -> Dict[EpisodeID, Dict[str, numpy.ndarray]]:

return _batch["episodes"]

# Define the finalize function that makes the host-to-device transfer.
def _finalize_fn(
_batch: Dict[EpisodeID, Dict[str, numpy.ndarray]]
_batch: Dict[EpisodeID, Dict[str, numpy.ndarray]],
) -> Dict[EpisodeID, Dict[str, TensorType]]:

return [
Expand Down Expand Up @@ -556,9 +560,11 @@ def set_device(self):
try:
self.__device = get_device(
self.config,
0
if not self.worker_index
else self.config.num_gpus_per_offline_eval_runner,
(
0
if not self.worker_index
else self.config.num_gpus_per_offline_eval_runner
),
)
except NotImplementedError:
self.__device = None
Expand Down Expand Up @@ -613,7 +619,7 @@ def _batch_iterator(self) -> MiniBatchRayDataIterator:
return self.__batch_iterator

@property
def _device(self) -> DeviceType:
def _device(self) -> Union[DeviceType, None]:
return self.__device

@property
Expand Down
7 changes: 6 additions & 1 deletion rllib/utils/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils import force_list, merge_dicts

from typing import Optional, TYPE_CHECKING

if TYPE_CHECKING:
from ray.rllib.utils.typing import FromConfigSpec


@DeveloperAPI
def from_config(cls, config=None, **kwargs):
def from_config(cls, config: Optional["FromConfigSpec"] = None, **kwargs):
"""Uses the given config to create an object.

If `config` is a dict, an optional "type" key can be used as a
Expand Down
6 changes: 3 additions & 3 deletions rllib/utils/runners/runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import logging

from typing import Any, TYPE_CHECKING
from typing import Any, Union, TYPE_CHECKING

from ray.rllib.utils.actor_manager import FaultAwareApply
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
Expand Down Expand Up @@ -87,8 +87,8 @@ def stop(self) -> None:

@property
@abc.abstractmethod
def _device(self) -> DeviceType:
"""Returns the device of this `Runner`."""
def _device(self) -> Union[DeviceType, None]:
"""Returns the device of this `Runner`. None if framework is not supported."""
pass

@abc.abstractmethod
Expand Down
Loading