From 98b595d7f7d7777fc8fad25f77cab07bf27c0fc6 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Tue, 18 Apr 2023 19:42:59 +0900 Subject: [PATCH] fix: fill 0 for reward, return, value in make_batch() --- handyrl/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/handyrl/train.py b/handyrl/train.py index 5f8c43a..7dad705 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -77,9 +77,9 @@ def replace_none(a, b): obs = bimap_r(obs_zeros, obs, lambda _, o: np.array(o)) # datum that is not changed by training configuration - v = np.array([[replace_none(m['value'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) - rew = np.array([[replace_none(m['reward'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) - ret = np.array([[replace_none(m['return'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) + v = np.array([[replace_none(m['value'][player], 0) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) + rew = np.array([[replace_none(m['reward'][player], 0) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) + ret = np.array([[replace_none(m['return'][player], 0) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) oc = np.array([ep['outcome'][player] for player in players], dtype=np.float32).reshape(1, len(players), -1) emask = np.ones((len(moments), 1, 1), dtype=np.float32) # episode mask