Skip to content

Commit

Permalink
add new model
Browse files Browse the repository at this point in the history
  • Loading branch information
BartekCupial committed Mar 2, 2024
1 parent 23df343 commit bb184f8
Show file tree
Hide file tree
Showing 5 changed files with 489 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sf_examples/nethack/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from sf_examples.nethack.models.chaotic_dwarf import ChaoticDwarvenGPT5
from sf_examples.nethack.models.scaled import ScaledNet

MODELS = [
ChaoticDwarvenGPT5,
ScaledNet,
]
MODELS_LOOKUP = {c.__name__: c for c in MODELS}
66 changes: 66 additions & 0 deletions sf_examples/nethack/models/crop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import logging

try:
import torch
from torch import nn
from torch.nn import functional as F
except ImportError:
logging.exception("PyTorch not found. Please install the agent dependencies with " '`pip install "nle[agent]"`')


def _step_to_range(delta, num_steps):
"""Range of `num_steps` integers with distance `delta` centered around zero."""
return delta * torch.arange(-num_steps // 2, num_steps // 2)


class Crop(nn.Module):
"""Helper class for NetHackNet below."""

def __init__(self, height, width, height_target, width_target):
super(Crop, self).__init__()
self.width = width
self.height = height
self.width_target = width_target
self.height_target = height_target
width_grid = _step_to_range(2 / (self.width - 1), self.width_target)[None, :].expand(self.height_target, -1)
height_grid = _step_to_range(2 / (self.height - 1), height_target)[:, None].expand(-1, self.width_target)

# "clone" necessary, https://github.com/pytorch/pytorch/issues/34880
self.register_buffer("width_grid", width_grid.clone())
self.register_buffer("height_grid", height_grid.clone())

def forward(self, inputs, coordinates):
"""Calculates centered crop around given x,y coordinates.
Args:
inputs [B x H x W]
coordinates [B x 2] x,y coordinates
Returns:
[B x H' x W'] inputs cropped and centered around x,y coordinates.
"""
assert inputs.shape[1] == self.height
assert inputs.shape[2] == self.width

inputs = inputs[:, None, :, :].float()

x = coordinates[:, 0]
y = coordinates[:, 1]

# NOTE: Need to do -self.width/2 + 1/2 here to cancel things out correctly
# with the width_grid below for both even and odd input dimensions.
x_shift = 2 / (self.width - 1) * (x.float() - self.width / 2 + 1 / 2)
y_shift = 2 / (self.height - 1) * (y.float() - self.height / 2 + 1 / 2)

grid = torch.stack(
[
self.width_grid[None, :, :] + x_shift[:, None, None],
self.height_grid[None, :, :] + y_shift[:, None, None],
],
dim=3,
)

# NOTE: Location x, y in grid tells you the shift from the cursor
# coordinates. The reason we do all this 2/(self.width - 1) stuff is because
# the inverse of this happens in the below F.grid_sample function. The F.grid_sample
# implementation is here: https://github.com/pytorch/pytorch/blob/f064c5aa33483061a48994608d890b968ae53fb5/aten/src/THNN/generic/SpatialGridSamplerBilinear.c#L41

return torch.round(F.grid_sample(inputs, grid, align_corners=True)).squeeze(1).long()
Loading

0 comments on commit bb184f8

Please sign in to comment.