Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 24, 2024
1 parent 30ec91c commit 3f4c392
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 90 deletions.
63 changes: 0 additions & 63 deletions sota-implementations/MCTS/AlphaZero/mcts_node.py

This file was deleted.

52 changes: 25 additions & 27 deletions sota-implementations/MCTS/AlphaZero/mcts_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from torchrl.objectives.value.functional import reward2go

from .mcts_node import MCTSNode
from torchrl.data import MCTSNode, MCTSChildren


@dataclass
Expand Down Expand Up @@ -64,17 +64,17 @@ def forward(self, node: MCTSNode) -> TensorDictBase:

if exploration_type() == ExplorationType.RANDOM or exploration_type() is None:
tensordict[self.action_key] = self.explore_action(node)
elif exploration_type() == ExplorationType.MODE:
elif exploration_type() in (ExplorationType.MODE, ExplorationType.DETERMINISTIC, ExplorationType.MEAN):
tensordict[self.action_key] = self.get_greedy_action(node)

return tensordict

def get_greedy_action(self, node: MCTSNode) -> torch.Tensor:
action = torch.argmax(node.children_visits)
action = torch.argmax(node.children.visits)
return action

def explore_action(self, node: MCTSNode) -> torch.Tensor:
action_scores = node.scores
action_scores = node.score

max_value = torch.max(action_scores)
action = torch.argmax(
Expand Down Expand Up @@ -156,9 +156,6 @@ class ExpansionStrategy:
This policy will use to initialize a node when it gets expanded at the first time.
"""

def __init__(self):
super().__init__()

def forward(self, node: MCTSNode) -> MCTSNode:
"""The node to be expanded.
Expand All @@ -179,7 +176,7 @@ def forward(self, node: MCTSNode) -> MCTSNode:

@abstractmethod
def expand(self, node: MCTSNode) -> None:
pass
...

def set_node(self, node: MCTSNode) -> None:
self.node = node
Expand All @@ -189,7 +186,7 @@ class BatchedRootExpansionStrategy(ExpansionStrategy):
def __init__(
self,
policy_module: TensorDictModule,
module_action_value_key: str = "action_value",
module_action_value_key: NestedKey = "action_value",
):
super().__init__()
assert module_action_value_key in policy_module.out_keys
Expand All @@ -200,17 +197,15 @@ def expand(self, node: MCTSNode) -> None:
policy_netword_td = node.state.select(*self.policy_module.in_keys)
policy_netword_td = self.policy_module(policy_netword_td)
p_sa = policy_netword_td[self.action_value_key]
node.children_priors = p_sa # prior_action_value
node.children_values = torch.zeros_like(p_sa) # action_value
node.children_visits = torch.zeros_like(p_sa) # action_count
node.children = MCTSChildren.init_from_prob(p_sa)
# setattr(node, "truncated", torch.ones(1, dtype=torch.bool))


class AlphaZeroExpansionStrategy(ExpansionStrategy):
def __init__(
self,
policy_module: TensorDictModule,
module_action_value_key: str = "action_value",
module_action_value_key: NestedKey = "action_value",
):
super().__init__()
assert module_action_value_key in policy_module.out_keys
Expand All @@ -221,9 +216,9 @@ def expand(self, node: MCTSNode) -> None:
policy_netword_td = node.state.select(*self.policy_module.in_keys)
policy_netword_td = self.policy_module(policy_netword_td)
p_sa = policy_netword_td[self.action_value_key]
node.children_priors = p_sa # prior_action_value
node.children_values = torch.zeros_like(p_sa) # action_value
node.children_visits = torch.zeros_like(p_sa) # action_count
node.children.priors = p_sa # prior_action_value
node.children.vals = torch.zeros_like(p_sa) # action_value
node.children.visits = torch.zeros_like(p_sa) # action_count
# setattr(node, "truncated", torch.ones(1, dtype=torch.bool))


Expand All @@ -244,15 +239,15 @@ def __init__(
self.node: MCTSNode

def forward(self, node: MCTSNode) -> MCTSNode:
n = torch.sum(node.children_visits, dim=-1) + 1
n = torch.sum(node.children.visits, dim=-1) + 1
u_sa = (
self.cpuct
* node.children_priors
* node.children.priors
* torch.sqrt(n)
/ (1 + node.children_visits)
/ (1 + node.children.visits)
)

optimism_estimation = node.children_values + u_sa
optimism_estimation = node.children.vals + u_sa
node.scores = optimism_estimation

return node
Expand All @@ -270,17 +265,17 @@ def __init__(
self.epsilon = epsilon

def forward(self, node: MCTSNode) -> MCTSNode:
if node.children_priors.device.type == "mps":
device = node.children_priors.device
if node.children.priors.device.type == "mps":
device = node.children.priors.device
noise = _Dirichlet.apply(
self.alpha * torch.ones_like(node.children_priors).cpu()
self.alpha * torch.ones_like(node.children.priors).cpu()
)
noise = noise.to(device) # type: ignore
else:
noise = _Dirichlet.apply(self.alpha * torch.ones_like(node.children_priors))
noise = _Dirichlet.apply(self.alpha * torch.ones_like(node.children.priors))

noisy_priors = (1 - self.epsilon) * node.children_priors + self.epsilon * noise # type: ignore
node.children_priors = noisy_priors
noisy_priors = (1 - self.epsilon) * node.children.priors + self.epsilon * noise # type: ignore
node.children.priors = noisy_priors
return node


Expand All @@ -293,6 +288,8 @@ class MCTSPolicy(TensorDictModuleBase):
exploration_strategy: a policy to exploration vs exploitation
"""

node: MCTSNode

def __init__(
self,
expansion_strategy: AlphaZeroExpansionStrategy,
Expand All @@ -313,10 +310,11 @@ def __init__(
self.expansion_strategy = expansion_strategy
self.selection_strategy = selection_strategy
self.exploration_strategy = exploration_strategy
self.node: MCTSNode
self.batch_size = batch_size

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if not hasattr(self, "node"):
raise RuntimeError("the MCTS policy has not been initialized. Please provide a node through policy.set_node().")
if not self.node.expanded:
self.node.state = tensordict # type: ignore
self.expansion_strategy.forward(self.node)
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,4 @@
UnboundedDiscreteTensorSpec,
)
from .utils import check_no_exclusive_keys, consolidate_spec, contains_lazy_spec
from .mcts import MCTSNode, MCTSChildren
6 changes: 6 additions & 0 deletions torchrl/data/mcts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .nodes import MCTSNode, MCTSChildren
112 changes: 112 additions & 0 deletions torchrl/data/mcts/nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import torch
from tensordict import tensorclass, TensorDict

@tensorclass(autocast=True)
class MCTSChildren:
vals: torch.Tensor
priors: torch.Tensor
visits: torch.Tensor
ids: torch.Tensor | None = None
nodes: MCTSNode | None = None

@classmethod
def init_from_prob(cls, probs):
vals = torch.zeros_like(probs)
visits = torch.zeros_like(probs)
return cls(vals=vals, priors=probs, visits=visits)


@tensorclass(autocast=True)
class MCTSNode:
prior_action: torch.Tensor
_children: MCTSChildren | None = None
score: torch.Tensor | None = None
state: TensorDict | None = None
terminated: torch.Tensor | None = None
parent: MCTSNode | None = None

@classmethod
def from_action(
cls,
action: torch.Tensor,
parent: MCTSNode | None,
):
return cls(prior_action=action, parent=parent)

@property
def children(self) -> MCTSChildren:
children = self._children
if children is None:
return MCTSChildren(*[torch.zeros((), device=self.device) for _ in range(4)])
return children

@children.setter
def children(self, value):
self._children = value

@property
def visits(self) -> torch.Tensor:
assert self.parent is not None
return self.parent.children.visits[self.prior_action]

@visits.setter
def visits(self, x) -> None:
assert self.parent is not None
self.parent.children.visits[self.prior_action] = x

@property
def value(self) -> torch.Tensor:
assert self.parent is not None
return self.parent.children.vals[self.prior_action]

@value.setter
def value(self, x) -> None:
assert self.parent is not None
self.parent.children.vals[self.prior_action] = x

@property
def expanded(self) -> bool:
return self.children.ids.numel() > 0

def get_child(self, action: torch.Tensor) -> MCTSNode:
idx = (self.children.ids == action).all(-1)
return self.children.nodes[idx] # type: ignore

@classmethod
def root(cls) -> MCTSNode:
return cls(torch.Tensor(-1), None)

@classmethod
def dummy(cls):
"""Creates a 'dummy' MCTSNode that can be used to explore TorchRL's MCTS API."""
children_values = stuff
children_priors = stuff
children_visits = stuff
children_ids = stuff
children_nodes = stuff
children = MCTSChildren(
values = children_values,
priors = children_priors,
visits = children_visits,
ids = children_ids,
nodes = children_nodes,
)
prior_action = stuff
score = stuff
state = stuff
terminated = stuff
parent = None
return cls(
prior_action=prior_action,
children=children,
score=score,
state=state,
terminated=terminated,
parent=parent,
)
7 changes: 7 additions & 0 deletions torchrl/modules/tensordict_module/mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

0 comments on commit 3f4c392

Please sign in to comment.