From 5bc920c9f89f40111d7ae29193e68ca43d9e2bbe Mon Sep 17 00:00:00 2001 From: greg Date: Tue, 18 Sep 2018 00:30:00 -0400 Subject: [PATCH] np -> torch --- treeqn/utils/treeqn_utils.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/treeqn/utils/treeqn_utils.py b/treeqn/utils/treeqn_utils.py index 4a8166d..a7e87af 100644 --- a/treeqn/utils/treeqn_utils.py +++ b/treeqn/utils/treeqn_utils.py @@ -1,4 +1,3 @@ -import numpy as np import torch import torch.nn.functional as F from treeqn.utils.pytorch_utils import cudify @@ -15,12 +14,10 @@ def discount_with_dones(rewards, dones, gamma): return discounted[::-1] def make_seq_mask(mask): - mask = mask.numpy().astype(np.bool) - max_i = np.argmax(mask, axis=0) - if mask[max_i] == True: - mask[max_i:] = True - mask = ~np.expand_dims(mask, axis=-1) # tilde flips true and falses - return torch.from_numpy(mask.astype(np.float)) + max_i = torch.max(mask, 0)[1] + if (mask[max_i] == 1).all(): + mask[int(max_i):].fill_(1) + return (1 - mask).unsqueeze(1) # some utilities for interpreting the trees we return def build_sequences(sequences, masks, nenvs, nsteps, depth, return_mask=False, offset=0): @@ -28,7 +25,7 @@ def build_sequences(sequences, masks, nenvs, nsteps, depth, return_mask=False, o # returns bs x depth x size processed sequences with a sliding window set by 'depth', padded with 0's # if return_mask=True also returns a mask showing where the sequences were padded # This can be used to produce targets for tree outputs, from the true observed sequences - tmp_masks = torch.from_numpy(masks.reshape(nenvs, nsteps).astype(np.int)) + tmp_masks = torch.from_numpy(masks.reshape(nenvs, nsteps).astype(int)) tmp_masks = F.pad(tmp_masks, (0, 0, 0, depth+offset), mode="constant", value=0).data sequences = [s.view(nenvs, nsteps, -1) for s in sequences]