Skip to content

Commit

Permalink
Merge pull request #325 from DeNA/develop
Browse files Browse the repository at this point in the history
Merge develop branch into master, August 2022.
  • Loading branch information
ikki407 authored Nov 2, 2022
2 parents f66eedd + 896a351 commit 8fb63a3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
22 changes: 12 additions & 10 deletions handyrl/envs/geister.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,19 @@ class GeisterNet(nn.Module):
def __init__(self):
super().__init__()

layers, filters, p_filters = 3, 32, 8
layers, filters = 3, 32
p_filters, v_filters = 8, 2
input_channels = 7 + 18 # board channels + scalar inputs
self.input_size = (input_channels, 6, 6)

self.conv1 = nn.Conv2d(input_channels, filters, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(filters)
self.body = DRC(layers, filters, filters)

self.head_p_move = Conv2dHead((filters * 2, 6, 6), p_filters, 4)
self.head_p_move = Conv2dHead((filters, 6, 6), p_filters, 4)
self.head_p_set = nn.Linear(1, 70, bias=True)
self.head_v = ScalarHead((filters * 2, 6, 6), 1, 1)
self.head_r = ScalarHead((filters * 2, 6, 6), 1, 1)
self.head_v = ScalarHead((filters, 6, 6), v_filters, 1)
self.head_r = ScalarHead((filters, 6, 6), v_filters, 1)

def init_hidden(self, batch_size=[]):
return self.body.init_hidden(self.input_size[1:], batch_size)
Expand All @@ -154,7 +155,6 @@ def forward(self, x, hidden):

h_e = F.relu(self.bn1(self.conv1(h)))
h, hidden = self.body(h_e, hidden, num_repeats=3)
h = torch.cat([h_e, h], -3)

h_p_move = self.head_p_move(h)
turn_color = s[:, :1]
Expand Down Expand Up @@ -189,10 +189,11 @@ class Environment(BaseEnvironment):

def __init__(self, args=None):
super().__init__()
self.args = args if args is not None else {}
self.reset()

def reset(self, args={}):
self.args = args
def reset(self, args=None):
self.game_args = args if args is not None else {}
self.board = -np.ones((6, 6), dtype=np.int32) # (x, y) -1 is empty
self.color = self.BLACK
self.turn_count = -2 # before setting original positions
Expand Down Expand Up @@ -343,8 +344,9 @@ def _piece(p):
s = ' ' + ' '.join(self.Y) + '\n'
for i in range(6):
s += self.X[i] + ' ' + ' '.join([self.P[_piece(self.board[i, j])] for j in range(6)]) + '\n'
s += 'color = ' + self.C[self.color] + '\n'
s += 'record = ' + self.record_string()
s += 'remained = B:%d R:%d b:%d r:%d' % tuple(self.piece_cnt) + '\n'
s += 'turn = ' + str(self.turn_count).ljust(3) + ' color = ' + self.C[self.color]
# s += 'record = ' + self.record_string()
return s

def _set(self, layout):
Expand Down Expand Up @@ -409,7 +411,7 @@ def diff_info(self, player):

def update(self, info, reset):
if reset:
self.args = {**self.args, **info}
self.game_args = {**self.game_args, **info}
self.reset(info)
elif 'set' in info:
self._set(info['set'])
Expand Down
2 changes: 1 addition & 1 deletion handyrl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def select_episode(self):
while True:
ep_count = min(len(self.episodes), self.args['maximum_episodes'])
ep_idx = random.randrange(ep_count)
accept_rate = 1 - (ep_count - 1 - ep_idx) / self.args['maximum_episodes']
accept_rate = 1 - (ep_count - 1 - ep_idx) / ep_count
if random.random() < accept_rate:
break
ep = self.episodes[ep_idx]
Expand Down

0 comments on commit 8fb63a3

Please sign in to comment.