Skip to content

Commit

Permalink
added data generation
Browse files Browse the repository at this point in the history
  • Loading branch information
DenSumy committed Oct 8, 2024
1 parent a79fcc2 commit 5c04699
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
5 changes: 4 additions & 1 deletion rl_games/common/env_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import numpy as np
import math


try:
import bark_ml.environments.gym
except:
pass

class HCRewardEnv(gym.RewardWrapper):
def __init__(self, env):
Expand Down
15 changes: 15 additions & 0 deletions rl_games/common/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from rl_games.common import env_configurations
from rl_games.algos_torch import model_builder

import pandas as pd

class BasePlayer(object):

Expand Down Expand Up @@ -271,6 +272,9 @@ def init_rnn(self):
)[2]), dtype=torch.float32).to(self.device) for s in rnn_states]

def run(self):
# create pandas dataframe with fields: game_index, observation, action, reward and done
df = pd.DataFrame(columns=['game_index', 'observation', 'action', 'reward', 'done'])

n_games = self.games_num
render = self.render_env
n_game_life = self.n_game_life
Expand Down Expand Up @@ -313,6 +317,8 @@ def run(self):

print_game_res = False

game_indices = torch.arange(0, batch_size).to(self.device)
cur_games = batch_size
for n in range(self.max_steps):
if self.evaluation and n % self.update_checkpoint_freq == 0:
self.maybe_load_new_checkpoint()
Expand All @@ -324,7 +330,11 @@ def run(self):
else:
action = self.get_action(obses, is_deterministic)

prev_obses = obses
obses, r, done, info = self.env_step(self.env, action)

for i in range(batch_size):
df.loc[len(df)] = [game_indices[i].cpu().numpy().item(), prev_obses[i].cpu().numpy(), action[i].cpu().numpy(), r[i].cpu().numpy().item(), done[i].cpu().numpy().item()]
cr += r
steps += 1

Expand All @@ -337,6 +347,9 @@ def run(self):
done_count = len(done_indices)
games_played += done_count

for bid in done_indices:
game_indices[bid] = cur_games
cur_games += 1
if done_count > 0:
if self.is_rnn:
for s in self.states:
Expand Down Expand Up @@ -379,6 +392,8 @@ def run(self):
else:
print('av reward:', sum_rewards / games_played * n_game_life,
'av steps:', sum_steps / games_played * n_game_life)

df.to_parquet('game_data.parquet')

def get_batch_size(self, obses, batch_size):
obs_shape = self.obs_shape
Expand Down
5 changes: 4 additions & 1 deletion rl_games/configs/mujoco/ant_envpool.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,7 @@ params:
#flat_observation: True

player:
render: False
render: False
num_actors: 64
games_num: 1000
use_vecenv: True

0 comments on commit 5c04699

Please sign in to comment.