Skip to content

Commit

Permalink
🚀 [RofuncRL] Minor update in RofuncDTrans
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylark0924 committed Aug 29, 2023
1 parent 6c08f5c commit a5d0a8f
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions rofunc/learning/RofuncRL/trainers/dtrans_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
from rofunc.learning.RofuncRL.trainers.base_trainer import BaseTrainer


def discount_cumsum(x, gamma):
tmp = np.zeros_like(x)
tmp[-1] = x[-1]
for t in reversed(range(x.shape[0] - 1)):
tmp[t] = x[t] + gamma * tmp[t + 1]
return tmp


class DTransTrainer(BaseTrainer):
def __init__(self, cfg, env, device, env_name):
super().__init__(cfg, env, device, env_name)
Expand All @@ -41,10 +49,16 @@ def __init__(self, cfg, env, device, env_name):

self.loss_mean = 0

# list of dict, each dict contains a traj with
# ['observations', 'next_observations', 'actions', 'rewards', 'terminals']
self.trajectories = None

self.load_dataset()

def load_dataset(self):
"""Load dataset"""
"""
Load dataset from pickle file and preprocess it.
"""
dataset_path = os.path.join(self.dataset_root_path, f'{self.env_name.lower()}-{self.dataset_type}-v2.pkl')
with open(dataset_path, 'rb') as f:
self.trajectories = pickle.load(f)
Expand All @@ -65,9 +79,9 @@ def load_dataset(self):
self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

num_timesteps = sum(traj_lens)

num_timesteps = max(int(self.pct_traj * num_timesteps), 1)
sorted_inds = np.argsort(returns) # lowest to highest

self.num_trajectories = 1
timesteps = traj_lens[sorted_inds[-1]]
ind = len(self.trajectories) - 2
Expand All @@ -77,21 +91,14 @@ def load_dataset(self):
ind -= 1
self.sorted_inds = sorted_inds[-self.num_trajectories:]

# used to reweight sampling so we sample according to timesteps instead of trajectories
# used to re-weight sampling, so we sample according to timesteps instead of trajectories
self.p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds])

self.rofunc_logger.module(f'Starting new experiment: {self.env_name} {self.dataset_type}'
f' with {len(traj_lens)} trajectories and {num_timesteps} timesteps'
f' Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}'
f' Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')

def discount_cumsum(self, x, gamma):
tmp = np.zeros_like(x)
tmp[-1] = x[-1]
for t in reversed(range(x.shape[0] - 1)):
tmp[t] = x[t] + gamma * tmp[t + 1]
return tmp

def get_batch(self, batch_size=256):
state_dim = self.agent.dtrans.state_dim
act_dim = self.agent.dtrans.action_dim
Expand All @@ -100,7 +107,7 @@ def get_batch(self, batch_size=256):
np.arange(self.num_trajectories),
size=batch_size,
replace=True,
p=self.p_sample, # reweights so we sample according to timesteps
p=self.p_sample, # re-weights so we sample according to timesteps
)

s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []
Expand All @@ -118,7 +125,7 @@ def get_batch(self, batch_size=256):
d.append(traj['dones'][si:si + self.max_seq_length].reshape(1, -1))
timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
timesteps[-1][timesteps[-1] >= self.max_episode_steps] = self.max_episode_steps - 1 # padding cutoff
rtg.append(self.discount_cumsum(traj['rewards'][si:], gamma=1.)[:s[-1].shape[1] + 1].reshape(1, -1, 1))
rtg.append(discount_cumsum(traj['rewards'][si:], gamma=1.)[:s[-1].shape[1] + 1].reshape(1, -1, 1))
if rtg[-1].shape[1] <= s[-1].shape[1]:
rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

Expand Down

0 comments on commit a5d0a8f

Please sign in to comment.