Skip to content

Commit

Permalink
np -> torch
Browse files Browse the repository at this point in the history
  • Loading branch information
Greg-Farquhar committed Sep 18, 2018
1 parent 2f08595 commit 5bc920c
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions treeqn/utils/treeqn_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import torch
import torch.nn.functional as F
from treeqn.utils.pytorch_utils import cudify
Expand All @@ -15,20 +14,18 @@ def discount_with_dones(rewards, dones, gamma):
return discounted[::-1]

def make_seq_mask(mask):
mask = mask.numpy().astype(np.bool)
max_i = np.argmax(mask, axis=0)
if mask[max_i] == True:
mask[max_i:] = True
mask = ~np.expand_dims(mask, axis=-1) # tilde flips true and falses
return torch.from_numpy(mask.astype(np.float))
max_i = torch.max(mask, 0)[1]
if (mask[max_i] == 1).all():
mask[int(max_i):].fill_(1)
return (1 - mask).unsqueeze(1)

# some utilities for interpreting the trees we return
def build_sequences(sequences, masks, nenvs, nsteps, depth, return_mask=False, offset=0):
# sequences are bs x size, containing e.g. rewards, actions, state reps
# returns bs x depth x size processed sequences with a sliding window set by 'depth', padded with 0's
# if return_mask=True also returns a mask showing where the sequences were padded
# This can be used to produce targets for tree outputs, from the true observed sequences
tmp_masks = torch.from_numpy(masks.reshape(nenvs, nsteps).astype(np.int))
tmp_masks = torch.from_numpy(masks.reshape(nenvs, nsteps).astype(int))
tmp_masks = F.pad(tmp_masks, (0, 0, 0, depth+offset), mode="constant", value=0).data

sequences = [s.view(nenvs, nsteps, -1) for s in sequences]
Expand Down

0 comments on commit 5bc920c

Please sign in to comment.