Skip to content

Commit

Permalink
initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
DenSumy committed Dec 28, 2023
1 parent 5d5fc27 commit a59d1a0
Show file tree
Hide file tree
Showing 23 changed files with 41 additions and 40 deletions.
2 changes: 1 addition & 1 deletion notebooks/train_and_export_onnx_example_continuous.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"import yaml\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import gym as gymnasium\n",
"import gymnasium as gym\n",
"from IPython import display\n",
"import numpy as np\n",
"import onnx\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/train_and_export_onnx_example_discrete.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"import yaml\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import gym as gymnasium\n",
"import gymnasium as gym\n",
"from IPython import display\n",
"import numpy as np\n",
"import onnx\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"import yaml\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import gym as gymnasium\n",
"import gymnasium as gym\n",
"from IPython import display\n",
"import numpy as np\n",
"import onnx\n",
Expand Down
2 changes: 1 addition & 1 deletion rl_games/algos_torch/players.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rl_games.algos_torch import torch_ext
from rl_games.algos_torch.running_mean_std import RunningMeanStd
from rl_games.common.tr_helpers import unsqueeze_obs
import gym as gymnasium
import gymnasium as gym
import torch
from torch import nn
import numpy as np
Expand Down
4 changes: 2 additions & 2 deletions rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from rl_games.common import schedulers
from rl_games.common import experience
from rl_games.common.a2c_common import print_statistics

from rl_games.common.env_configurations import patch_env_info
from rl_games.interfaces.base_algorithm import BaseAlgorithm
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
Expand Down Expand Up @@ -103,7 +103,7 @@ def base_init(self, base_name, config):
self.env_info = config.get('env_info')
if self.env_info is None:
self.vec_env = vecenv.create_vec_env(self.env_name, self.num_actors, **self.env_config)
self.env_info = self.vec_env.get_env_info()
self.env_info = patch_env_info(self.vec_env.get_env_info())

self._device = config.get('device', 'cuda:0')

Expand Down
5 changes: 3 additions & 2 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from rl_games.common.diagnostics import DefaultDiagnostics, PpoDiagnostics
from rl_games.algos_torch import model_builder
from rl_games.interfaces.base_algorithm import BaseAlgorithm
from rl_games.common.env_configurations import patch_env_info
import numpy as np
import time
import gym as gymnasium
import gymnasium as gym

from datetime import datetime
from tensorboardX import SummaryWriter
Expand Down Expand Up @@ -127,7 +128,7 @@ def __init__(self, base_name, params):
self.env_info = config.get('env_info')
if self.env_info is None:
self.vec_env = vecenv.create_vec_env(self.env_name, self.num_actors, **self.env_config)
self.env_info = self.vec_env.get_env_info()
self.env_info = patch_env_info(self.vec_env.get_env_info())
else:
self.vec_env = config.get('vec_env', None)

Expand Down
23 changes: 11 additions & 12 deletions rl_games/common/env_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from rl_games.envs.brax import create_brax_env
from rl_games.envs.envpool import create_envpool
from rl_games.envs.cule import create_cule
import gym as gymnasium
import gymnasium as gym
from gym.wrappers import FlattenObservation, FilterObservation
import numpy as np
import math
Expand Down Expand Up @@ -109,10 +109,10 @@ def create_dm_control_env(**kwargs):
return env

def create_super_mario_env(name='SuperMarioBros-v1'):
import gym as gymnasium
import gymnasium as gym
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT
import gym as gymnasium_super_mario_bros
import gymnasium as gym_super_mario_bros
env = gym_super_mario_bros.make(name)
env = JoypadSpace(env, SIMPLE_MOVEMENT)

Expand All @@ -121,11 +121,11 @@ def create_super_mario_env(name='SuperMarioBros-v1'):
return env

def create_super_mario_env_stage1(name='SuperMarioBrosRandomStage1-v1'):
import gym as gymnasium
import gymnasium as gym
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT

import gym as gymnasium_super_mario_bros
import gymnasium as gym_super_mario_bros
stage_names = [
'SuperMarioBros-1-1-v1',
'SuperMarioBros-1-2-v1',
Expand All @@ -143,13 +143,13 @@ def create_super_mario_env_stage1(name='SuperMarioBrosRandomStage1-v1'):
return env

def create_quadrupped_env():
import gym as gymnasium
import gymnasium as gym
import roboschool
import quadruppedEnv
return wrappers.FrameStack(wrappers.MaxAndSkipEnv(gym.make('QuadruppedWalk-v1'), 4, False), 2, True)

def create_roboschool_env(name):
import gym as gymnasium
import gymnasium as gym
import roboschool
return gym.make(name)

Expand Down Expand Up @@ -203,8 +203,8 @@ def create_test_env(name, **kwargs):
return env

def create_minigrid_env(name, **kwargs):
import gym as gymnasium_minigrid
import gym as gymnasium_minigrid.wrappers
import gym_minigrid
import gym_minigrid.wrappers


state_bonus = kwargs.pop('state_bonus', False)
Expand Down Expand Up @@ -442,7 +442,7 @@ def get_env_info(env):
if hasattr(env, "value_size"):
result_shapes['value_size'] = env.value_size
print(result_shapes)
return result_shapes
return patch_env_info(result_shapes)

def get_obs_and_action_spaces_from_config(config):
env_config = config.get('env_config', {})
Expand All @@ -453,9 +453,8 @@ def get_obs_and_action_spaces_from_config(config):


def patch_env_info(env_info):
import gymnas
env_info['observation_space'] = gymnasium_utils.convert_space(env_info['observation_space'] )
env_info['action_space'] = gymnasium_utils.convert_space(env_info['observation_space'] )
env_info['action_space'] = gymnasium_utils.convert_space(env_info['action_space'] )
if 'state_space' in env_info:
env_info['state_space'] = gymnasium_utils.convert_space(env_info['state_space'] )
return env_info
Expand Down
2 changes: 1 addition & 1 deletion rl_games/common/experience.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import random
import gym as gymnasium
import gymnasium as gym
import torch
from rl_games.common.segment_tree import SumSegmentTree, MinSegmentTree
import torch
Expand Down
8 changes: 4 additions & 4 deletions rl_games/common/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
import threading
import time
import gym as gymnasium
import gymnasium as gym
import numpy as np
import torch
import copy
Expand All @@ -11,7 +11,7 @@
from rl_games.common import vecenv
from rl_games.common import env_configurations
from rl_games.algos_torch import model_builder

from rl_games.common.env_configurations import patch_env_info

class BasePlayer(object):

Expand All @@ -32,11 +32,11 @@ def __init__(self, params):
print('[BasePlayer] Creating vecenv: ', self.env_name)
self.env = vecenv.create_vec_env(
self.env_name, self.config['num_actors'], **self.env_config)
self.env_info = self.env.get_env_info()
self.env_info = patch_env_info(self.env.get_env_info())
else:
print('[BasePlayer] Creating regular env: ', self.env_name)
self.env = self.create_env()
self.env_info = env_configurations.get_env_info(self.env)
self.env_info = patch_env_info(env_configurations.get_env_info(self.env))
else:
self.env = config.get('vec_env')

Expand Down
2 changes: 1 addition & 1 deletion rl_games/common/vecenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rl_games.common.env_configurations import configurations
from rl_games.common.tr_helpers import dicts_to_dict_with_arrays
import numpy as np
import gym as gymnasium
import gymnasium as gym
import random
from time import sleep
import torch
Expand Down
2 changes: 1 addition & 1 deletion rl_games/common/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
os.environ.setdefault('PATH', '')
from collections import deque

import gym as gymnasium
import gymnasium as gym
from gym import spaces
from copy import copy

Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/brax.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from rl_games.common.ivecenv import IVecEnv
import gym as gymnasium
import gymnasium as gym
import numpy as np
import torch.utils.dlpack as tpack

Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/cule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from rl_games.common.ivecenv import IVecEnv
import gym as gymnasium
import gymnasium as gym
import torch
import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/diambra/diambra.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gym as gymnasium
import gymnasium as gym
import numpy as np
import os
import random
Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/envpool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from rl_games.common.ivecenv import IVecEnv
import gym as gymnasium
import gymnasium as gym
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/multiwalker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gym as gymnasium
import gymnasium as gym
import numpy as np
from pettingzoo.sisl import multiwalker_v6
import yaml
Expand Down
5 changes: 3 additions & 2 deletions rl_games/envs/slimevolley_selfplay.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import gym as gymnasium
import gymnasium as gym
import numpy as np
import slimevolleygym
import yaml
from rl_games.torch_runner import Runner
from rl_games.common.env_configurations import patch_env_info
import os

class SlimeVolleySelfplay(gym.Env):
Expand Down Expand Up @@ -32,7 +33,7 @@ def create_agent(self, config='rl_games/configs/ma/ppo_slime_self_play.yaml'):
config = yaml.safe_load(stream)
runner = Runner()
from rl_games.common.env_configurations import get_env_info
config['params']['config']['env_info'] = get_env_info(self)
config['params']['config']['env_info'] = patch_env_info(get_env_info(self))
runner.load(config)
config = runner.get_prebuilt_config()

Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/smac_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gym as gymnasium
import gymnasium as gym
import numpy as np
from smac.env import StarCraft2Env
from smac.env import MultiAgentEnv
Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/smac_v2_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gym as gymnasium
import gymnasium as gym
import numpy as np
import yaml
from smacv2.env import StarCraft2Env
Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/test/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gym as gymnasium
import gymnasium as gym

gym.envs.register(
id='TestRnnEnv-v0',
Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/test/example_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gym as gymnasium
import gymnasium as gym
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/test/rnn_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gym as gymnasium
import gymnasium as gym
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion rl_games/envs/test/test_asymmetric_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gym as gymnasium
import gymnasium as gym
import numpy as np
from rl_games.common.wrappers import MaskVelocityWrapper

Expand Down

0 comments on commit a59d1a0

Please sign in to comment.