Skip to content

Commit

Permalink
Updated all preprocessors to new array backend.
Browse files Browse the repository at this point in the history
- all preprocessors now support numpy and torch backend.
  • Loading branch information
robfiras committed Jan 19, 2024
1 parent f408f2d commit 5dc8422
Show file tree
Hide file tree
Showing 10 changed files with 268 additions and 92 deletions.
3 changes: 2 additions & 1 deletion mushroom_rl/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .array_backend import ArrayBackend
from .core import Core
from .dataset import Dataset, VectorizedDataset
from .environment import Environment, MDPInfo
Expand All @@ -11,5 +12,5 @@

import mushroom_rl.environments

__all__ = ['Core', 'Dataset', 'Environment', 'MDPInfo', 'Agent', 'AgentInfo', 'Serializable', 'Logger',
__all__ = ['ArrayBackend', 'Core', 'Dataset', 'Environment', 'MDPInfo', 'Agent', 'AgentInfo', 'Serializable', 'Logger',
'VectorCore', 'VectorizedEnvironment', 'MultiprocessEnvironment']
1 change: 0 additions & 1 deletion mushroom_rl/core/_impl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .numpy_dataset import NumpyDataset
from .torch_dataset import TorchDataset
from .list_dataset import ListDataset
from .array_backend import ArrayBackend, NumpyBackend, TorchBackend, ListBackend
from .core_logic import CoreLogic
from .vectorized_core_logic import VectorizedCoreLogic
2 changes: 1 addition & 1 deletion mushroom_rl/core/_impl/vectorized_core_logic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .array_backend import ArrayBackend
from mushroom_rl.core import ArrayBackend
from .core_logic import CoreLogic


Expand Down
7 changes: 3 additions & 4 deletions mushroom_rl/core/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from mushroom_rl.core.serialization import Serializable

from ._impl import *
from .array_backend import ArrayBackend


class AgentInfo(Serializable):
Expand Down Expand Up @@ -210,10 +209,10 @@ def core_preprocessors(self):
return self._core_preprocessors

def _convert_to_env_backend(self, array):
return self._env_backend.to_backend_array(self._agent_backend, array)
return self._env_backend.convert_to_backend(self._agent_backend, array)

def _convert_to_agent_backend(self, array):
return self._agent_backend.to_backend_array(self._env_backend, array)
return self._agent_backend.convert_to_backend(self._env_backend, array)

@property
def info(self):
Expand Down
Loading

0 comments on commit 5dc8422

Please sign in to comment.