From 4ef29f752dcb7788789318f7b80759a93864203a Mon Sep 17 00:00:00 2001 From: Junjia Liu Date: Thu, 31 Aug 2023 00:37:58 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=80=20[RofuncRL]=20Update=20RofuncDTra?= =?UTF-8?q?ns?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/learning_rl/example_D4RL_RofuncRL.py | 4 ++-- rofunc/learning/utils/download_datasets.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/learning_rl/example_D4RL_RofuncRL.py b/examples/learning_rl/example_D4RL_RofuncRL.py index 26167886d..34aa782f8 100644 --- a/examples/learning_rl/example_D4RL_RofuncRL.py +++ b/examples/learning_rl/example_D4RL_RofuncRL.py @@ -82,11 +82,11 @@ def inference(custom_args): if __name__ == '__main__': - gpu_id = 0 + gpu_id = 3 parser = argparse.ArgumentParser() # Available tasks: Hopper, HalfCheetah, Walker2d, Reacher2d - parser.add_argument("--task", type=str, default="Walker2d") + parser.add_argument("--task", type=str, default="Hopper") parser.add_argument("--agent", type=str, default="dtrans") # dtrans parser.add_argument("--sim_device", type=str, default="cuda:{}".format(gpu_id)) parser.add_argument("--rl_device", type=str, default="cuda:{}".format(gpu_id)) diff --git a/rofunc/learning/utils/download_datasets.py b/rofunc/learning/utils/download_datasets.py index de0f24585..eb86de87c 100644 --- a/rofunc/learning/utils/download_datasets.py +++ b/rofunc/learning/utils/download_datasets.py @@ -17,7 +17,6 @@ import pickle import gym -import d4rl # Import required to register environments, you may need to also import the submodule import numpy as np import rofunc as rf @@ -34,6 +33,8 @@ def download_d4rl_dataset(save_dir): if os.path.exists(f'{save_dir}/{name}.pkl'): continue + import d4rl # Import required to register environments, you may need to also import the submodule + env = gym.make(name) dataset = env.get_dataset()