Skip to content

Commit

Permalink
Merge branch 'isaac' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
boris-il-forte committed Dec 29, 2023
2 parents ad4be8d + 724f017 commit c271d77
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 0 deletions.
144 changes: 144 additions & 0 deletions examples/isaac_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import hydra
from omegaconf import DictConfig
from omniisaacgymenvs.utils.hydra_cfg.reformat import omegaconf_to_dict
from omniisaacgymenvs.utils.hydra_cfg.hydra_utils import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
from tqdm import trange

from mushroom_rl.core import VectorCore, Logger
from mushroom_rl.algorithms.actor_critic import TRPO, PPO

from mushroom_rl.policy import GaussianTorchPolicy
from mushroom_rl.environments import IsaacEnv
from mushroom_rl.utils import TorchUtils


class Network(nn.Module):
def __init__(self, input_shape, output_shape, n_features, **kwargs):
super(Network, self).__init__()

n_input = input_shape[-1]
n_output = output_shape[0]

self._h1 = nn.Linear(n_input, n_features)
self._h2 = nn.Linear(n_features, n_features)
self._h3 = nn.Linear(n_features, n_output)

nn.init.xavier_uniform_(self._h1.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h2.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h3.weight,
gain=nn.init.calculate_gain('linear'))

def forward(self, state, **kwargs):
features1 = F.relu(self._h1(torch.squeeze(state, 1).float()))
features2 = F.relu(self._h2(features1))
a = self._h3(features2)

return a


def experiment(cfg_dict, headless, alg, n_epochs, n_steps, n_steps_per_fit, n_episodes_test,
alg_params, policy_params):

logger = Logger(alg.__name__, results_dir=None)
logger.strong_line()
logger.info('Experiment Algorithm: ' + alg.__name__)

mdp = IsaacEnv(cfg_dict, headless=headless)


critic_params = dict(network=Network,
optimizer={'class': optim.Adam,
'params': {'lr': 3e-4}},
loss=F.mse_loss,
n_features=32,
batch_size=64,
use_cuda=True,
input_shape=mdp.info.observation_space.shape,
output_shape=(1,))

policy = GaussianTorchPolicy(Network,
mdp.info.observation_space.shape,
mdp.info.action_space.shape,
**policy_params)

alg_params['critic_params'] = critic_params

agent = alg(mdp.info, policy, **alg_params)
#agent.set_logger(logger)

core = VectorCore(agent, mdp)

dataset = core.evaluate(n_episodes=n_episodes_test, render=False)

J = torch.mean(torch.stack(dataset.discounted_return))
R = torch.mean(torch.stack(dataset.undiscounted_return))
E = agent.policy.entropy()

logger.epoch_info(0, J=J, R=R, entropy=E)

for it in trange(n_epochs, leave=False):
core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit)
dataset = core.evaluate(n_episodes=n_episodes_test, render=False)

J = torch.mean(torch.stack(dataset.discounted_return))
R = torch.mean(torch.stack(dataset.undiscounted_return))
E = agent.policy.entropy()

logger.epoch_info(it+1, J=J, R=R, entropy=E)

logger.info('Press a button to visualize')
input()
core.evaluate(n_episodes=5, render=True)


@hydra.main(config_name="config", config_path="./cfg")
def parse_hydra_configs(cfg: DictConfig):
TorchUtils.set_default_device('cuda')
headless = cfg.headless
cfg_dict = omegaconf_to_dict(cfg)

max_kl = .015

policy_params = dict(
std_0=1.,
n_features=32,
use_cuda=True

)

ppo_params = dict(actor_optimizer={'class': optim.Adam,
'params': {'lr': 3e-4}},
n_epochs_policy=4,
batch_size=64,
eps_ppo=.2,
lam=.95)

trpo_params = dict(ent_coeff=0.0,
max_kl=.01,
lam=.95,
n_epochs_line_search=10,
n_epochs_cg=100,
cg_damping=1e-2,
cg_residual_tol=1e-10)

algs_params = [
(PPO, 'ppo', ppo_params),
(TRPO, 'trpo', trpo_params)
]

for alg, alg_name, alg_params in algs_params:
experiment(cfg_dict=cfg_dict, headless=headless, alg=alg, n_epochs=40, n_steps=30000, n_steps_per_fit=3000,
n_episodes_test=512, alg_params=alg_params, policy_params=policy_params)


if __name__ == '__main__':
parse_hydra_configs()
5 changes: 5 additions & 0 deletions mushroom_rl/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
except ImportError:
pass

try:
IsaacEnv = None
from .isaac_env import IsaacEnv
except ImportError:
pass

try:
PyBullet = None
Expand Down
134 changes: 134 additions & 0 deletions mushroom_rl/environments/isaac_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import torch
from gym import spaces as gym_spaces

from omni.isaac.kit import SimulationApp
from omniisaacgymenvs.utils.task_util import initialize_task

from mushroom_rl.core import VectorizedEnvironment, MDPInfo
from mushroom_rl.utils.viewer import ImageViewer
from mushroom_rl.utils.isaac_utils import convert_task_observation
from mushroom_rl.rl_utils.spaces import *

# import carb


class IsaacEnv(VectorizedEnvironment):
"""
Interface for OmniIsaacGymEnvs environments. It makes it possible to use every
OmniIsaacGymEnvs environment just providing the task.
"""

def __init__(self, cfg=None, headless=False, backend='torch'):
""" Initializes RL and task parameters.
Args:
cfg (dict): dictionary containing the parameters required to build the task;
headless (bool): Whether to run training headless;
backend (str, 'torch'): The backend to be used by the environment.
"""
RENDER_WIDTH = 1280 # 1600
RENDER_HEIGHT = 720 # 900
RENDER_DT = 1.0 / 60.0 # 60 Hz

self._simulation_app = SimulationApp({"headless": headless,
"window_width": 1920,
"window_height": 1080,
"width": RENDER_WIDTH,
"height": RENDER_HEIGHT})

# TODO check if the next line is needed
#carb.settings.get_settings().set("/persistent/omnihydra/useSceneGraphInstancing", True)

self._render = not headless

self._viewer = ImageViewer([RENDER_WIDTH, RENDER_HEIGHT], RENDER_DT)

initialize_task(cfg, self)
action_space = self._convert_gym_space(self._task.action_space)
observation_space = self._convert_gym_space(self._task.observation_space)

# Create MDP info for mushroom
# default episod lenght
max_e_lenght = 1000
if hasattr(self._task, '_max_episode_length'):
max_e_lenght = self._task._max_episode_length
mdp_info = MDPInfo(observation_space, action_space, 0.99,
max_e_lenght, dt=RENDER_DT, backend=backend)

super().__init__(mdp_info, self._task.num_envs)

def set_task(self, task, backend="torch", sim_params=None, init_sim=True, rendering_dt = True, **kwargs):
from omni.isaac.core.world import World
RENDER_DT = 1.0 / 60.0 # 60 Hz

self._device = "cpu"
if sim_params and "use_gpu_pipeline" in sim_params:
if sim_params["use_gpu_pipeline"]:
self._device = sim_params["sim_device"]

self._world = World(
stage_units_in_meters=1.0,
rendering_dt=RENDER_DT,
backend=backend,
sim_params=sim_params,
device=self._device
)

self._task = task
self._world.add_task(task)
self._world.reset()

def seed(self, seed=-1):
from omni.isaac.core.utils.torch.maths import set_seed
return set_seed(seed)

def reset_all(self, env_mask, state=None):
idxs = torch.argwhere(env_mask).squeeze() # .cpu().numpy() # takes torch datatype
if idxs.dim() > 0: # only resets task for tensor with actual dimension
self._task.reset_idx(idxs)
# self._world.step(render=self._render) # TODO Check if we can do otherwise
task_obs = self._task.get_observations()
observation = convert_task_observation(task_obs)
return observation, [{}]*self._n_envs

def step_all(self, env_mask, action):
self._task.pre_physics_step(action)

# allow users to specify the control frequency through config
for _ in range(self._task.control_frequency_inv):
self._world.step(render=self._render)

observation, reward, done, info = self._task.post_physics_step()
# converts task obs from dictionary to tensor
observation = convert_task_observation(observation)

env_mask_cuda = torch.as_tensor(env_mask).cuda()

return observation, reward, torch.logical_and(done, env_mask_cuda), [info]*self._n_envs

def render_all(self, env_mask, record=False):
self._world.render()
task_render = self._task.get_render()

self._viewer.display(task_render)

if record:
return task_render

def stop(self):
self._world.reset()

def __del__(self):
self._simulation_app.close()

@staticmethod
def _convert_gym_space(space):
# import pdb; pdb.set_trace()
if isinstance(space, gym_spaces.Discrete):
return Discrete(space.n)
elif isinstance(space, gym_spaces.Box):
return Box(low=space.low, high=space.high, shape=space.shape)
else:
raise ValueError
10 changes: 10 additions & 0 deletions mushroom_rl/utils/isaac_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch


def convert_task_observation(observation):
obs_t = observation
for _ in range(5):
if torch.is_tensor(obs_t):
break
obs_t = obs_t[list(obs_t.keys())[0]]
return obs_t

0 comments on commit c271d77

Please sign in to comment.