Skip to content

Commit b51a810

Browse files
committed
Add DeviceType for NumpyToTensor
Signed-off-by: Daniel Sperber <[email protected]>
1 parent 72a1329 commit b51a810

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

rllib/connectors/common/numpy_to_tensor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Any, Dict, List, Optional, TYPE_CHECKING
22

33
import gymnasium as gym
44

@@ -12,6 +12,9 @@
1212
from ray.rllib.utils.typing import EpisodeType
1313
from ray.util.annotations import PublicAPI
1414

15+
if TYPE_CHECKING:
16+
from ray.rllib.utils.typing import DeviceType
17+
1518

1619
@PublicAPI(stability="alpha")
1720
class NumpyToTensor(ConnectorV2):
@@ -59,7 +62,7 @@ def __init__(
5962
input_action_space: Optional[gym.Space] = None,
6063
*,
6164
pin_memory: bool = False,
62-
device: Optional[str] = None,
65+
device: Optional["DeviceType"] = None,
6366
**kwargs,
6467
):
6568
"""Initializes a NumpyToTensor instance.

0 commit comments

Comments
 (0)