Skip to content

Commit

Permalink
🚀 [RofuncRL] RofuncDTrans example is available
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylark0924 committed Aug 28, 2023
1 parent 3c6292e commit 6c08f5c
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
6 changes: 4 additions & 2 deletions rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self,

self.track_losses = collections.deque(maxlen=100)
self.tracking_data = collections.defaultdict(list)
self.checkpoint_best_modules = {"timestep": 0, "loss": 2 ** 31, "saved": False, "modules": {}}

'''Get hyper-parameters from config'''
self._td_lambda = self.cfg.Agent.td_lambda
Expand All @@ -67,7 +68,6 @@ def __init__(self,
self._weight_decay = self.cfg.Agent.weight_decay
self._max_seq_length = self.cfg.Trainer.max_seq_length


self._set_up()

def _set_up(self):
Expand Down Expand Up @@ -99,7 +99,8 @@ def act(self, states, actions, rewards, returns_to_go, timesteps):
timesteps = timesteps[:, -self._max_seq_length:]

# pad all tokens to sequence length
attention_mask = torch.cat([torch.zeros(self._max_seq_length - states.shape[1]), torch.ones(states.shape[1])])
attention_mask = torch.cat(
[torch.zeros(self._max_seq_length - states.shape[1]), torch.ones(states.shape[1])])
attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
states = torch.cat(
[torch.zeros((states.shape[0], self._max_seq_length - states.shape[1], self.dtrans.state_dim),
Expand Down Expand Up @@ -150,3 +151,4 @@ def update_net(self, batch):

# record data
self.track_data("Loss", loss.item())
self.track_data("Action_error", torch.mean((action_preds - action_target) ** 2).item())
1 change: 0 additions & 1 deletion rofunc/learning/RofuncRL/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import copy
import datetime
import json
import multiprocessing
import os
import random
Expand Down
34 changes: 34 additions & 0 deletions rofunc/learning/RofuncRL/trainers/dtrans_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import pickle
import random
Expand Down Expand Up @@ -38,6 +39,8 @@ def __init__(self, cfg, env, device, env_name):
self.max_episode_steps = self.cfg.Trainer.max_episode_steps
self.max_seq_length = self.cfg.Trainer.max_seq_length

self.loss_mean = 0

self.load_dataset()

def load_dataset(self):
Expand Down Expand Up @@ -148,7 +151,38 @@ def train(self):
for _ in self.t_bar:
batch = self.get_batch()
self.agent.update_net(batch)
self.post_interaction()
self._step += 1

# close the logger
self.writer.close()
self.rofunc_logger.info('Training complete.')

def post_interaction(self):
# Update best models and tensorboard
if not self._step % self.write_interval and self.write_interval > 0:
# update best models
self.loss_mean = np.mean(self.agent.tracking_data.get("Loss", -1e4))
if self.loss_mean < self.agent.checkpoint_best_modules["loss"]:
self.agent.checkpoint_best_modules["timestep"] = self._step
self.agent.checkpoint_best_modules["loss"] = self.loss_mean
self.agent.checkpoint_best_modules["saved"] = False
self.agent.checkpoint_best_modules["modules"] = {k: copy.deepcopy(self.agent._get_internal_value(v)) for
k, v in self.agent.checkpoint_modules.items()}
self.agent.save_ckpt(os.path.join(self.agent.checkpoint_dir, "best_ckpt.pth"))

# Update tensorboard
self.write_tensorboard()

# Update tqdm bar message
if self.eval_flag:
post_str = f"Loss/Best/Eval: {self.loss_mean:.2f}/{self.agent.checkpoint_best_modules['loss']:.2f}/{self.eval_loss_mean:.2f}"
else:
post_str = f"Loss/Best: {self.loss_mean:.2f}/{self.agent.checkpoint_best_modules['loss']:.2f}"
self.t_bar.set_postfix_str(post_str)
self.rofunc_logger.info(f"Step: {self._step}, {post_str}", local_verbose=False)

# Save checkpoints
if not (self._step + 1) % self.agent.checkpoint_interval and \
self.agent.checkpoint_interval > 0 and self._step > 1:
self.agent.save_ckpt(os.path.join(self.agent.checkpoint_dir, f"ckpt_{self._step + 1}.pth"))

0 comments on commit 6c08f5c

Please sign in to comment.