diff --git a/examples/learning_rl/example_D4RL_RofuncRL.py b/examples/learning_rl/example_D4RL_RofuncRL.py new file mode 100644 index 000000000..7881cf24e --- /dev/null +++ b/examples/learning_rl/example_D4RL_RofuncRL.py @@ -0,0 +1,102 @@ +""" +D4RL (RofuncRL) +======================= + +D4RL tasks with RofuncRL offline RL algorithms (BC, DTrans, CQL, etc.) +""" + +import argparse + +import gymnasium as gym + +from rofunc.config.utils import omegaconf_to_dict, get_config +from rofunc.learning.RofuncRL.tasks import task_map +from rofunc.learning.RofuncRL.trainers import trainer_map +from rofunc.learning.pre_trained_models.download import model_zoo +from rofunc.learning.utils.download_datasets import download_d4rl_dataset +from rofunc.learning.utils.utils import set_seed + + +def train(custom_args): + # Config task and trainer parameters for Isaac Gym environments + args_overrides = ["task={}".format(custom_args.task), + "train={}{}RofuncRL".format(custom_args.task, custom_args.agent.upper()), + "sim_device={}".format(custom_args.sim_device), + "rl_device={}".format(custom_args.rl_device), + "graphics_device_id={}".format(custom_args.graphics_device_id), + "headless={}".format(custom_args.headless)] + cfg = get_config('./learning/rl', 'config', args=args_overrides) + + download_d4rl_dataset(save_dir='../data/D4RL') + + set_seed(cfg.train.Trainer.seed) + + # Instantiate the Isaac Gym environment + env = gym.make(f'{custom_args.task}-v3') + + # Instantiate the RL trainer + trainer = trainer_map[custom_args.agent](cfg=cfg.train, + env=env, + device=cfg.rl_device, + env_name=custom_args.task) + + # Start training + trainer.train() + + +def inference(custom_args): + # Config task and trainer parameters for Isaac Gym environments + args_overrides = ["task={}".format(custom_args.task), + "train={}{}RofuncRL".format(custom_args.task, custom_args.agent.upper()), + "sim_device={}".format(custom_args.sim_device), + "rl_device={}".format(custom_args.rl_device), + "graphics_device_id={}".format(custom_args.graphics_device_id), + "headless={}".format(False), + "num_envs={}".format(16)] + cfg = get_config('./learning/rl', 'config', args=args_overrides) + cfg_dict = omegaconf_to_dict(cfg.task) + + set_seed(cfg.train.Trainer.seed) + + # Instantiate the Isaac Gym environment + infer_env = task_map[custom_args.task](cfg=cfg_dict, + rl_device=cfg.rl_device, + sim_device=cfg.sim_device, + graphics_device_id=cfg.graphics_device_id, + headless=cfg.headless, + virtual_screen_capture=cfg.capture_video, # TODO: check + force_render=cfg.force_render) + + # Instantiate the RL trainer + trainer = trainer_map[custom_args.agent](cfg=cfg.train, + env=infer_env, + device=cfg.rl_device, + env_name=custom_args.task) + # load checkpoint + if custom_args.ckpt_path is None: + custom_args.ckpt_path = model_zoo(name="CURICabinetRofuncRLPPO_left_arm.pth") + trainer.agent.load_ckpt(custom_args.ckpt_path) + + # Start inference + trainer.inference() + + +if __name__ == '__main__': + gpu_id = 0 + + parser = argparse.ArgumentParser() + # Available tasks: Hopper, HalfCheetah, Walker2d, Reacher2d + parser.add_argument("--task", type=str, default="Hopper") + parser.add_argument("--agent", type=str, default="dtrans") # dtrans + parser.add_argument("--sim_device", type=str, default="cuda:{}".format(gpu_id)) + parser.add_argument("--rl_device", type=str, default="cuda:{}".format(gpu_id)) + parser.add_argument("--graphics_device_id", type=int, default=gpu_id) + parser.add_argument("--headless", type=str, default="True") + parser.add_argument("--inference", action="store_true", help="turn to inference mode while adding this argument") + parser.add_argument("--ckpt_path", type=str, default=None) + custom_args = parser.parse_args() + + if not custom_args.inference: + train(custom_args) + else: + inference(custom_args) diff --git a/rofunc/config/learning/rl/task/Hopper.yaml b/rofunc/config/learning/rl/task/Hopper.yaml new file mode 100644 index 000000000..994a216e1 --- /dev/null +++ b/rofunc/config/learning/rl/task/Hopper.yaml @@ -0,0 +1,2 @@ +name: Hopper + diff --git a/rofunc/config/learning/rl/train/BaseTaskDTRANSRofuncRL.yaml b/rofunc/config/learning/rl/train/BaseTaskDTRANSRofuncRL.yaml new file mode 100644 index 000000000..933fcc50f --- /dev/null +++ b/rofunc/config/learning/rl/train/BaseTaskDTRANSRofuncRL.yaml @@ -0,0 +1,90 @@ +# ========== Trainer parameters ========== +Trainer: + experiment_name: # Experiment name for logging. + experiment_directory: # Experiment directory for logging. + write_interval: 100 # TensorBoard write interval for logging. (timesteps) + checkpoint_interval: 1000 # Checkpoint interval for logging. (timesteps) + wandb: False # If true, log to Weights & Biases. + wandb_kwargs: # Weights & Biases kwargs. https://docs.wandb.ai/ref/python/init + project: # Weights & Biases project name. + rofunc_logger_kwargs: # Rofunc BeautyLogger kwargs. + verbose: True # If true, print to stdout. + maximum_steps: 100000 # The maximum number of steps to run for. + random_steps: 0 # The number of random exploration steps to take. + start_learning_steps: 0 # The number of steps to take before starting network updating. + seed: 42 # The random seed. + rollouts: 16 # The number of rollouts before updating. + eval_flag: False # If true, run evaluation. + eval_freq: 2500 # The frequency of evaluation. (timesteps) + eval_steps: 1000 # The number of steps to run for evaluation. + use_eval_thread: True # If true, use a separate thread for evaluation. + inference_steps: 1000 # The number of steps to run for inference. + max_episode_steps: 1000 # The maximum number of steps per episode. + + dataset_type: medium # medium, medium-replay, medium-expert, expert + mode: normal # normal for standard setting, delayed for sparse + dataset_root_path: /home/ubuntu/Github/Rofunc/examples/data/D4RL + env_targets: [ 3600, 1800 ] # evaluation conditioning targets + scale: 1000. # scale for reward and action + max_seq_length: 20 # Maximum length of the sequence for inputting to the GPT model. + + +# ========== Agent parameters ========== +Agent: + discount: 0.99 # The discount factor, gamma. + td_lambda: 0.95 # TD(lambda) coefficient (lam) for computing returns and advantages. + + learning_epochs: 8 # The number of epochs to train for per update. + batch_size: 1024 # Batch size for training. + + lr: 1e-4 # Learning rate for actor. + # lr_scheduler: # Learning rate scheduler type. + # lr_scheduler_kwargs: # Learning rate scheduler kwargs. + adam_eps: 1e-5 # Adam epsilon. + weight_decay: 1e-4 # Weight decay. + + # If true, use the Generalized Advantage Estimator (GAE) + # with a value function, see https://arxiv.org/pdf/1506.02438.pdf. + use_gae: True + + entropy_loss_scale: 0.01 # entropy loss scaling factor + value_loss_scale: 2.0 # value loss scaling factor + + grad_norm_clip: 1.0 # clipping coefficient for the norm of the gradients + ratio_clip: 0.2 # clipping coefficient for computing the clipped surrogate objective + value_clip: 0.2 # clipping coefficient for computing the value loss (if clip_predicted_values is True) + clip_predicted_values: True # clip predicted values during value loss computation + + kl_threshold: 0 # Initial coefficient for KL divergence. + +# state_preprocessor: # State preprocessor type. +# state_preprocessor_kwargs: # State preprocessor kwargs. +# value_preprocessor: # Value preprocessor type. +# value_preprocessor_kwargs: # Value preprocessor kwargs. +# rewards_shaper: # Rewards shaper type. + + +# ========= Model parameters ========== +Model: + use_init: True + use_action_clip: False # If true, clip actions to the action space range. + use_action_out_tanh: True # If true, apply tanh to the output of the actor. + action_clip: 1.0 # clipping coefficient for the norm of the actions + action_scale: 1.0 # scaling action range from [-1, 1] after tanh to [-action_scale, action_scale] + use_log_std_clip: True # If true, clip log standard deviations to the range [-20, 2]. + log_std_clip_max: 2.0 # clipping coefficient for the log standard deviations + log_std_clip_min: -20 # clipping coefficient for the log standard deviations + + actor: + n_layer: 3 + n_head: 1 + n_embd: 128 + dropout: 0.1 + activation_function: relu + max_episode_steps: ${train.Trainer.max_episode_steps} + + + + + + diff --git a/rofunc/config/utils.py b/rofunc/config/utils.py index c942d1470..d98a75a4e 100644 --- a/rofunc/config/utils.py +++ b/rofunc/config/utils.py @@ -31,7 +31,7 @@ def get_config(config_path=None, config_name=None, args=None, debug=False, absl_ :param config_name: name of the config file (without .yaml) :param args: custom args to rewrite some params in the config file :param debug: if True, print the config - :param absl_config_path: absolute path to the config file (for external user) + :param absl_config_path: absolute path to the folder contains config file (for external user) :return: """ # reset current hydra config if already parsed (but not passed in here) diff --git a/rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py b/rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py index 23f4b8c29..12a508b34 100644 --- a/rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py +++ b/rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py @@ -23,7 +23,6 @@ import rofunc as rf from rofunc.learning.RofuncRL.agents.base_agent import BaseAgent from rofunc.learning.RofuncRL.models.actor_models import ActorDTrans -from rofunc.learning.RofuncRL.utils.memory import Memory class DTransAgent(BaseAgent): @@ -37,7 +36,6 @@ def __init__(self, cfg: DictConfig, observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]], action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]], - memory: Optional[Union[Memory, Tuple[Memory]]] = None, device: Optional[Union[str, torch.device]] = None, experiment_dir: Optional[str] = None, rofunc_logger: Optional[rf.logger.BeautyLogger] = None): @@ -45,15 +43,14 @@ def __init__(self, :param cfg: Configurations :param observation_space: Observation space :param action_space: Action space - :param memory: Memory for storing transitions :param device: Device on which the torch tensor is allocated :param experiment_dir: Directory for storing experiment data :param rofunc_logger: Rofunc logger """ - super().__init__(cfg, observation_space, action_space, memory, device, experiment_dir, rofunc_logger) + super().__init__(cfg, observation_space, action_space, None, device, experiment_dir, rofunc_logger) - self.dtrans = ActorDTrans(cfg.Model, observation_space, action_space, device) + self.dtrans = ActorDTrans(cfg.Model, observation_space, action_space, self.se).to(self.device) self.models = {"dtrans": self.dtrans} # checkpoint models @@ -68,7 +65,8 @@ def __init__(self, self._lr = self.cfg.Agent.lr self._adam_eps = self.cfg.Agent.adam_eps self._weight_decay = self.cfg.Agent.weight_decay - self._max_length = self.cfg.Agent.max_length + self._max_seq_length = self.cfg.Trainer.max_seq_length + self._set_up() @@ -94,25 +92,25 @@ def act(self, states, actions, rewards, returns_to_go, timesteps): returns_to_go = returns_to_go.reshape(1, -1, 1) timesteps = timesteps.reshape(1, -1) - if self._max_length is not None: - states = states[:, -self._max_length:] - actions = actions[:, -self._max_length:] - returns_to_go = returns_to_go[:, -self._max_length:] - timesteps = timesteps[:, -self._max_length:] + if self._max_seq_length is not None: + states = states[:, -self._max_seq_length:] + actions = actions[:, -self._max_seq_length:] + returns_to_go = returns_to_go[:, -self._max_seq_length:] + timesteps = timesteps[:, -self._max_seq_length:] # pad all tokens to sequence length - attention_mask = torch.cat([torch.zeros(self._max_length - states.shape[1]), torch.ones(states.shape[1])]) + attention_mask = torch.cat([torch.zeros(self._max_seq_length - states.shape[1]), torch.ones(states.shape[1])]) attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1) states = torch.cat( - [torch.zeros((states.shape[0], self._max_length - states.shape[1], self.dtrans.state_dim), + [torch.zeros((states.shape[0], self._max_seq_length - states.shape[1], self.dtrans.state_dim), device=states.device), states], dim=1).to(dtype=torch.float32) actions = torch.cat( - [torch.zeros((actions.shape[0], self._max_length - actions.shape[1], self.dtrans.action_dim), + [torch.zeros((actions.shape[0], self._max_seq_length - actions.shape[1], self.dtrans.action_dim), device=actions.device), actions], dim=1).to(dtype=torch.float32) returns_to_go = torch.cat( - [torch.zeros((returns_to_go.shape[0], self._max_length - returns_to_go.shape[1], 1), + [torch.zeros((returns_to_go.shape[0], self._max_seq_length - returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go], dim=1).to(dtype=torch.float32) - timesteps = torch.cat([torch.zeros((timesteps.shape[0], self._max_length - timesteps.shape[1]), + timesteps = torch.cat([torch.zeros((timesteps.shape[0], self._max_seq_length - timesteps.shape[1]), device=timesteps.device), timesteps], dim=1).to(dtype=torch.long) else: attention_mask = None @@ -122,8 +120,8 @@ def act(self, states, actions, rewards, returns_to_go, timesteps): return action_preds[0, -1] - def update_net(self): - states, actions, rewards, dones, rtg, timesteps, attention_mask = self.get_batch(self.batch_size) + def update_net(self, batch): + states, actions, rewards, dones, rtg, timesteps, attention_mask = batch action_target = torch.clone(actions) state_preds, action_preds, reward_preds = self.dtrans.forward( @@ -139,12 +137,12 @@ def update_net(self): self.optimizer.zero_grad() loss.backward() - torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25) + torch.nn.utils.clip_grad_norm_(self.dtrans.parameters(), .25) self.optimizer.step() - with torch.no_grad(): - self.diagnostics['training/action_error'] = torch.mean( - (action_preds - action_target) ** 2).detach().cpu().item() + # with torch.no_grad(): + # self.diagnostics['training/action_error'] = torch.mean( + # (action_preds - action_target) ** 2).detach().cpu().item() # update learning rate if self._lr_scheduler is not None: diff --git a/rofunc/learning/RofuncRL/models/actor_models.py b/rofunc/learning/RofuncRL/models/actor_models.py index dcb03f714..622733edc 100644 --- a/rofunc/learning/RofuncRL/models/actor_models.py +++ b/rofunc/learning/RofuncRL/models/actor_models.py @@ -260,8 +260,8 @@ def __init__(self, cfg: DictConfig, self.cfg = cfg self.action_dim = get_space_dim(action_space) - self.gpt2_hidden_size = cfg.actor.gpt2_hidden_size - self.max_ep_len = cfg.actor.max_ep_len + self.n_embd = cfg.actor.n_embd + self.max_ep_len = cfg.actor.max_episode_steps # state encoder self.state_encoder = state_encoder @@ -272,22 +272,29 @@ def __init__(self, cfg: DictConfig, gpt_config = transformers.GPT2Config( vocab_size=1, # doesn't matter -- we don't use the vocab - n_embd=self.gpt2_hidden_size, + n_embd=self.n_embd, + n_layer=self.cfg.actor.n_layer, + n_head=self.cfg.actor.n_head, + n_inner=self.n_embd * 4, + activation_function=self.cfg.actor.activation_function, + resid_pdrop=self.cfg.actor.dropout, + attn_pdrop=self.cfg.actor.dropout, + n_positions=1024 ) - self.embed_timestep = nn.Embedding(self.max_ep_len, self.gpt2_hidden_size) - self.embed_return = torch.nn.Linear(1, self.gpt2_hidden_size) - self.embed_state = torch.nn.Linear(self.state_dim, self.gpt2_hidden_size) - self.embed_action = torch.nn.Linear(self.action_dim, self.gpt2_hidden_size) - self.embed_ln = nn.LayerNorm(self.gpt2_hidden_size) + self.embed_timestep = nn.Embedding(self.max_ep_len, self.n_embd) + self.embed_return = torch.nn.Linear(1, self.n_embd) + self.embed_state = torch.nn.Linear(self.state_dim, self.n_embd) + self.embed_action = torch.nn.Linear(self.action_dim, self.n_embd) + self.embed_ln = nn.LayerNorm(self.n_embd) self.backbone_net = transformers.GPT2Model(gpt_config) # note: we don't predict states or returns for the paper - self.predict_state = torch.nn.Linear(self.gpt2_hidden_size, self.state_dim) - self.predict_action = nn.Sequential(*([nn.Linear(self.gpt2_hidden_size, self.action_dim)] + + self.predict_state = torch.nn.Linear(self.n_embd, self.state_dim) + self.predict_action = nn.Sequential(*([nn.Linear(self.n_embd, self.action_dim)] + ([nn.Tanh()] if self.cfg.use_action_out_tanh else []))) - self.predict_return = torch.nn.Linear(self.gpt2_hidden_size, 1) + self.predict_return = torch.nn.Linear(self.n_embd, 1) def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None): batch_size, seq_length = states.shape[0], states.shape[1] @@ -313,7 +320,7 @@ def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_ # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) # which works nice in an autoregressive sense since states predict actions stacked_inputs = torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1 - ).permute(0, 2, 1, 3).reshape(batch_size, 3 * seq_length, self.gpt2_hidden_size) + ).permute(0, 2, 1, 3).reshape(batch_size, 3 * seq_length, self.n_embd) stacked_inputs = self.embed_ln(stacked_inputs) # to make the attention mask fit the stacked inputs, have to stack it as well @@ -327,7 +334,7 @@ def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_ # reshape x so that the second dimension corresponds to the original # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t - x = x.reshape(batch_size, seq_length, 3, self.gpt2_hidden_size).permute(0, 2, 1, 3) + x = x.reshape(batch_size, seq_length, 3, self.n_embd).permute(0, 2, 1, 3) # get predictions return_preds = self.predict_return(x[:, 2]) # predict next return given state and action diff --git a/rofunc/learning/RofuncRL/trainers/__init__.py b/rofunc/learning/RofuncRL/trainers/__init__.py index 837e6e1ed..c7198578e 100644 --- a/rofunc/learning/RofuncRL/trainers/__init__.py +++ b/rofunc/learning/RofuncRL/trainers/__init__.py @@ -4,6 +4,7 @@ from .a2c_trainer import A2CTrainer from .amp_trainer import AMPTrainer from .ase_trainer import ASETrainer +from .dtrans_trainer import DTransTrainer trainer_map = { "ppo": PPOTrainer, @@ -12,4 +13,5 @@ "a2c": A2CTrainer, "amp": AMPTrainer, "ase": ASETrainer, + "dtrans": DTransTrainer, } diff --git a/rofunc/learning/RofuncRL/trainers/base_trainer.py b/rofunc/learning/RofuncRL/trainers/base_trainer.py index 8765a1574..999512e9f 100644 --- a/rofunc/learning/RofuncRL/trainers/base_trainer.py +++ b/rofunc/learning/RofuncRL/trainers/base_trainer.py @@ -47,6 +47,7 @@ def __init__(self, self.agent = None self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device) + self.env_name = env_name '''Experiment log directory''' directory = self.cfg.Trainer.experiment_directory @@ -113,7 +114,9 @@ def __init__(self, f" action_space: {self.env.action_space.shape}\n " f" observation_space: {self.env.observation_space.shape}\n " f" num_envs: {self.env.num_envs}") - self.rofunc_logger.info(f"Task configurations:\n{self.env._env.cfg}") + + if hasattr(self.env._env, "cfg"): + self.rofunc_logger.info(f"Task configurations:\n{self.env._env.cfg}") '''Normalization''' self.state_norm = Normalization(shape=self.env.observation_space, device=device) diff --git a/rofunc/learning/RofuncRL/trainers/dtrans_trainer.py b/rofunc/learning/RofuncRL/trainers/dtrans_trainer.py index c8729acac..81691ee48 100644 --- a/rofunc/learning/RofuncRL/trainers/dtrans_trainer.py +++ b/rofunc/learning/RofuncRL/trainers/dtrans_trainer.py @@ -11,7 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import pickle +import random +import numpy as np +import torch import tqdm from rofunc.learning.RofuncRL.agents.offline.dtrans_agent import DTransAgent @@ -21,17 +26,128 @@ class DTransTrainer(BaseTrainer): def __init__(self, cfg, env, device, env_name): super().__init__(cfg, env, device, env_name) - self.agent = DTransAgent(cfg, self.env.observation_space, self.env.action_space, self.memory, - device, self.exp_dir, self.rofunc_logger) + self.agent = DTransAgent(cfg, self.env.observation_space, self.env.action_space, device, self.exp_dir, + self.rofunc_logger) self.setup_wandb() + self.pct_traj = 1 + self.dataset_type = self.cfg.Trainer.dataset_type + self.dataset_root_path = self.cfg.Trainer.dataset_root_path + self.mode = self.cfg.Trainer.mode + self.scale = self.cfg.Trainer.scale + self.max_episode_steps = self.cfg.Trainer.max_episode_steps + self.max_seq_length = self.cfg.Trainer.max_seq_length + + self.load_dataset() + + def load_dataset(self): + """Load dataset""" + dataset_path = os.path.join(self.dataset_root_path, f'{self.env_name.lower()}-{self.dataset_type}-v2.pkl') + with open(dataset_path, 'rb') as f: + self.trajectories = pickle.load(f) + + # save all path information into separate lists + states, traj_lens, returns = [], [], [] + for path in self.trajectories: + if self.mode == 'delayed': # delayed: all rewards moved to end of trajectory + path['rewards'][-1] = path['rewards'].sum() + path['rewards'][:-1] = 0. + states.append(path['observations']) + traj_lens.append(len(path['observations'])) + returns.append(path['rewards'].sum()) + traj_lens, returns = np.array(traj_lens), np.array(returns) + + # used for input normalization + states = np.concatenate(states, axis=0) + self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 + + num_timesteps = sum(traj_lens) + + num_timesteps = max(int(self.pct_traj * num_timesteps), 1) + sorted_inds = np.argsort(returns) # lowest to highest + self.num_trajectories = 1 + timesteps = traj_lens[sorted_inds[-1]] + ind = len(self.trajectories) - 2 + while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] <= num_timesteps: + timesteps += traj_lens[sorted_inds[ind]] + self.num_trajectories += 1 + ind -= 1 + self.sorted_inds = sorted_inds[-self.num_trajectories:] + + # used to reweight sampling so we sample according to timesteps instead of trajectories + self.p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds]) + + self.rofunc_logger.module(f'Starting new experiment: {self.env_name} {self.dataset_type}' + f' with {len(traj_lens)} trajectories and {num_timesteps} timesteps' + f' Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}' + f' Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}') + + def discount_cumsum(self, x, gamma): + tmp = np.zeros_like(x) + tmp[-1] = x[-1] + for t in reversed(range(x.shape[0] - 1)): + tmp[t] = x[t] + gamma * tmp[t + 1] + return tmp + + def get_batch(self, batch_size=256): + state_dim = self.agent.dtrans.state_dim + act_dim = self.agent.dtrans.action_dim + + batch_inds = np.random.choice( + np.arange(self.num_trajectories), + size=batch_size, + replace=True, + p=self.p_sample, # reweights so we sample according to timesteps + ) + + s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], [] + for i in range(batch_size): + traj = self.trajectories[int(self.sorted_inds[batch_inds[i]])] + si = random.randint(0, traj['rewards'].shape[0] - 1) + + # get sequences from dataset + s.append(traj['observations'][si:si + self.max_seq_length].reshape(1, -1, state_dim)) + a.append(traj['actions'][si:si + self.max_seq_length].reshape(1, -1, act_dim)) + r.append(traj['rewards'][si:si + self.max_seq_length].reshape(1, -1, 1)) + if 'terminals' in traj: + d.append(traj['terminals'][si:si + self.max_seq_length].reshape(1, -1)) + else: + d.append(traj['dones'][si:si + self.max_seq_length].reshape(1, -1)) + timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1)) + timesteps[-1][timesteps[-1] >= self.max_episode_steps] = self.max_episode_steps - 1 # padding cutoff + rtg.append(self.discount_cumsum(traj['rewards'][si:], gamma=1.)[:s[-1].shape[1] + 1].reshape(1, -1, 1)) + if rtg[-1].shape[1] <= s[-1].shape[1]: + rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1) + + # padding and state + reward normalization + tlen = s[-1].shape[1] + s[-1] = np.concatenate([np.zeros((1, self.max_seq_length - tlen, state_dim)), s[-1]], axis=1) + s[-1] = (s[-1] - self.state_mean) / self.state_std + a[-1] = np.concatenate([np.ones((1, self.max_seq_length - tlen, act_dim)) * -10., a[-1]], axis=1) + r[-1] = np.concatenate([np.zeros((1, self.max_seq_length - tlen, 1)), r[-1]], axis=1) + d[-1] = np.concatenate([np.ones((1, self.max_seq_length - tlen)) * 2, d[-1]], axis=1) + rtg[-1] = np.concatenate([np.zeros((1, self.max_seq_length - tlen, 1)), rtg[-1]], axis=1) / self.scale + timesteps[-1] = np.concatenate([np.zeros((1, self.max_seq_length - tlen)), timesteps[-1]], axis=1) + mask.append(np.concatenate([np.zeros((1, self.max_seq_length - tlen)), np.ones((1, tlen))], axis=1)) + + s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=self.device) + a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=self.device) + r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=self.device) + d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=self.device) + rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=self.device) + timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=self.device) + mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=self.device) + + return s, a, r, d, rtg, timesteps, mask + def train(self): """ Main training loop. """ with tqdm.trange(self.maximum_steps, ncols=80, colour='green') as self.t_bar: for _ in self.t_bar: - self.agent.update_net() + batch = self.get_batch() + self.agent.update_net(batch) # close the logger self.writer.close() diff --git a/rofunc/learning/utils/download_datasets.py b/rofunc/learning/utils/download_datasets.py index 3aff9104e..de0f24585 100644 --- a/rofunc/learning/utils/download_datasets.py +++ b/rofunc/learning/utils/download_datasets.py @@ -17,13 +17,15 @@ import pickle import gym +import d4rl # Import required to register environments, you may need to also import the submodule import numpy as np +import rofunc as rf from rofunc.utils.logger.beauty_logger import beauty_print -def download_d4rl_dataset(): - save_dir = os.path.join(os.getcwd(), '../../../examples/data/D4RL') +def download_d4rl_dataset(save_dir): + rf.oslab.create_dir(save_dir) for env_name in ['halfcheetah', 'hopper', 'walker2d']: for dataset_type in ['medium', 'medium-replay', 'expert']: @@ -64,14 +66,13 @@ def download_d4rl_dataset(): returns = np.array([np.sum(p['rewards']) for p in paths]) num_samples = np.sum([p['rewards'].shape[0] for p in paths]) print(f'Number of samples collected: {num_samples}') - print( - f'Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)}, min = {np.min(returns)}') + print(f'Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)},' + f' min = {np.min(returns)}') with open(f'{save_dir}/{name}.pkl', 'wb') as f: pickle.dump(paths, f) beauty_print('D4RL dataset downloaded', type='info') - # if __name__ == '__main__': # download_d4rl_dataset() diff --git a/setup.py b/setup.py index 3cd2e70a5..176fb9593 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ # Environment-specific dependencies. extras = { "baselines": ["skrl==0.10.2", "ray[rllib]==2.2.0", "stable-baselines3==1.8.0", "rl-games==1.6.0", - "mujoco_py==2.1.2.14", "gym[all]==0.26.2", "gymnasium[all]==0.28.1"], + "mujoco_py==2.1.2.14", "gym[all]==0.26.2", "gymnasium[all]==0.28.1", "mujoco-py<2.2,>=2.1"], } setup( @@ -23,7 +23,8 @@ packages=find_packages(exclude=["others"]), include_package_data=True, extras_require=extras, - install_requires=['setuptools==63.2.0', + install_requires=['cython==3.0.0a10', # for mujoco_py + 'setuptools==63.2.0', 'pandas', 'tqdm==4.65.0', 'pillow==9.5.0',