diff --git a/sota-implementations/MCTS/AlphaZero/mcts_node.py b/sota-implementations/MCTS/AlphaZero/mcts_node.py deleted file mode 100644 index 7dd07538760..00000000000 --- a/sota-implementations/MCTS/AlphaZero/mcts_node.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -from __future__ import annotations - -import torch -from tensordict import tensorclass, TensorDict - - -@tensorclass(autocast=True) -class MCTSNode: - prior_action: torch.Tensor - parent: MCTSNode | None - children_values: torch.Tensor - children_priors: torch.Tensor - children_visits: torch.Tensor - score: torch.Tensor - children: MCTSNode - children_ids: torch.Tensor - state: TensorDict - terminated: torch.Tensor - - def __init__( - self, - action: torch.Tensor, - parent: MCTSNode | None, - ): - self.prior_action = action - self.parent = parent - self.children_ids = torch.tensor([], dtype=torch.int32) - - @property - def visits(self) -> torch.Tensor: - assert self.parent != None - return self.parent.children_visits[self.prior_action] - - @visits.setter - def visits(self, x) -> None: - assert self.parent != None - self.parent.children_visits[self.prior_action] = x - - @property - def value(self) -> torch.Tensor: - assert self.parent != None - return self.parent.children_values[self.prior_action] - - @value.setter - def value(self, x) -> None: - assert self.parent != None - self.parent.children_values[self.prior_action] = x - - @property - def expanded(self) -> bool: - return self.children_ids.numel() > 0 - - def get_child(self, action: torch.Tensor) -> MCTSNode: - idx = (self.children_ids == action).all(-1) - return self.children[idx] # type: ignore - - @classmethod - def root(cls) -> MCTSNode: - return cls(torch.Tensor(-1), None) diff --git a/sota-implementations/MCTS/AlphaZero/mcts_policy.py b/sota-implementations/MCTS/AlphaZero/mcts_policy.py index 346160061a1..5d80729f134 100644 --- a/sota-implementations/MCTS/AlphaZero/mcts_policy.py +++ b/sota-implementations/MCTS/AlphaZero/mcts_policy.py @@ -19,7 +19,7 @@ from torchrl.objectives.value.functional import reward2go -from .mcts_node import MCTSNode +from torchrl.data import MCTSNode, MCTSChildren @dataclass @@ -64,17 +64,17 @@ def forward(self, node: MCTSNode) -> TensorDictBase: if exploration_type() == ExplorationType.RANDOM or exploration_type() is None: tensordict[self.action_key] = self.explore_action(node) - elif exploration_type() == ExplorationType.MODE: + elif exploration_type() in (ExplorationType.MODE, ExplorationType.DETERMINISTIC, ExplorationType.MEAN): tensordict[self.action_key] = self.get_greedy_action(node) return tensordict def get_greedy_action(self, node: MCTSNode) -> torch.Tensor: - action = torch.argmax(node.children_visits) + action = torch.argmax(node.children.visits) return action def explore_action(self, node: MCTSNode) -> torch.Tensor: - action_scores = node.scores + action_scores = node.score max_value = torch.max(action_scores) action = torch.argmax( @@ -156,9 +156,6 @@ class ExpansionStrategy: This policy will use to initialize a node when it gets expanded at the first time. """ - def __init__(self): - super().__init__() - def forward(self, node: MCTSNode) -> MCTSNode: """The node to be expanded. @@ -179,7 +176,7 @@ def forward(self, node: MCTSNode) -> MCTSNode: @abstractmethod def expand(self, node: MCTSNode) -> None: - pass + ... def set_node(self, node: MCTSNode) -> None: self.node = node @@ -189,7 +186,7 @@ class BatchedRootExpansionStrategy(ExpansionStrategy): def __init__( self, policy_module: TensorDictModule, - module_action_value_key: str = "action_value", + module_action_value_key: NestedKey = "action_value", ): super().__init__() assert module_action_value_key in policy_module.out_keys @@ -200,9 +197,7 @@ def expand(self, node: MCTSNode) -> None: policy_netword_td = node.state.select(*self.policy_module.in_keys) policy_netword_td = self.policy_module(policy_netword_td) p_sa = policy_netword_td[self.action_value_key] - node.children_priors = p_sa # prior_action_value - node.children_values = torch.zeros_like(p_sa) # action_value - node.children_visits = torch.zeros_like(p_sa) # action_count + node.children = MCTSChildren.init_from_prob(p_sa) # setattr(node, "truncated", torch.ones(1, dtype=torch.bool)) @@ -210,7 +205,7 @@ class AlphaZeroExpansionStrategy(ExpansionStrategy): def __init__( self, policy_module: TensorDictModule, - module_action_value_key: str = "action_value", + module_action_value_key: NestedKey = "action_value", ): super().__init__() assert module_action_value_key in policy_module.out_keys @@ -221,9 +216,9 @@ def expand(self, node: MCTSNode) -> None: policy_netword_td = node.state.select(*self.policy_module.in_keys) policy_netword_td = self.policy_module(policy_netword_td) p_sa = policy_netword_td[self.action_value_key] - node.children_priors = p_sa # prior_action_value - node.children_values = torch.zeros_like(p_sa) # action_value - node.children_visits = torch.zeros_like(p_sa) # action_count + node.children.priors = p_sa # prior_action_value + node.children.vals = torch.zeros_like(p_sa) # action_value + node.children.visits = torch.zeros_like(p_sa) # action_count # setattr(node, "truncated", torch.ones(1, dtype=torch.bool)) @@ -244,15 +239,15 @@ def __init__( self.node: MCTSNode def forward(self, node: MCTSNode) -> MCTSNode: - n = torch.sum(node.children_visits, dim=-1) + 1 + n = torch.sum(node.children.visits, dim=-1) + 1 u_sa = ( self.cpuct - * node.children_priors + * node.children.priors * torch.sqrt(n) - / (1 + node.children_visits) + / (1 + node.children.visits) ) - optimism_estimation = node.children_values + u_sa + optimism_estimation = node.children.vals + u_sa node.scores = optimism_estimation return node @@ -270,17 +265,17 @@ def __init__( self.epsilon = epsilon def forward(self, node: MCTSNode) -> MCTSNode: - if node.children_priors.device.type == "mps": - device = node.children_priors.device + if node.children.priors.device.type == "mps": + device = node.children.priors.device noise = _Dirichlet.apply( - self.alpha * torch.ones_like(node.children_priors).cpu() + self.alpha * torch.ones_like(node.children.priors).cpu() ) noise = noise.to(device) # type: ignore else: - noise = _Dirichlet.apply(self.alpha * torch.ones_like(node.children_priors)) + noise = _Dirichlet.apply(self.alpha * torch.ones_like(node.children.priors)) - noisy_priors = (1 - self.epsilon) * node.children_priors + self.epsilon * noise # type: ignore - node.children_priors = noisy_priors + noisy_priors = (1 - self.epsilon) * node.children.priors + self.epsilon * noise # type: ignore + node.children.priors = noisy_priors return node @@ -293,6 +288,8 @@ class MCTSPolicy(TensorDictModuleBase): exploration_strategy: a policy to exploration vs exploitation """ + node: MCTSNode + def __init__( self, expansion_strategy: AlphaZeroExpansionStrategy, @@ -313,10 +310,11 @@ def __init__( self.expansion_strategy = expansion_strategy self.selection_strategy = selection_strategy self.exploration_strategy = exploration_strategy - self.node: MCTSNode self.batch_size = batch_size def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if not hasattr(self, "node"): + raise RuntimeError("the MCTS policy has not been initialized. Please provide a node through policy.set_node().") if not self.node.expanded: self.node.state = tensordict # type: ignore self.expansion_strategy.forward(self.node) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 3749e6e8cbc..6fd9f3815ca 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -74,3 +74,4 @@ UnboundedDiscreteTensorSpec, ) from .utils import check_no_exclusive_keys, consolidate_spec, contains_lazy_spec +from .mcts import MCTSNode, MCTSChildren diff --git a/torchrl/data/mcts/__init__.py b/torchrl/data/mcts/__init__.py new file mode 100644 index 00000000000..65c6bc3a476 --- /dev/null +++ b/torchrl/data/mcts/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .nodes import MCTSNode, MCTSChildren diff --git a/torchrl/data/mcts/nodes.py b/torchrl/data/mcts/nodes.py new file mode 100644 index 00000000000..8c1128d0294 --- /dev/null +++ b/torchrl/data/mcts/nodes.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import torch +from tensordict import tensorclass, TensorDict + +@tensorclass(autocast=True) +class MCTSChildren: + vals: torch.Tensor + priors: torch.Tensor + visits: torch.Tensor + ids: torch.Tensor | None = None + nodes: MCTSNode | None = None + + @classmethod + def init_from_prob(cls, probs): + vals = torch.zeros_like(probs) + visits = torch.zeros_like(probs) + return cls(vals=vals, priors=probs, visits=visits) + + +@tensorclass(autocast=True) +class MCTSNode: + prior_action: torch.Tensor + _children: MCTSChildren | None = None + score: torch.Tensor | None = None + state: TensorDict | None = None + terminated: torch.Tensor | None = None + parent: MCTSNode | None = None + + @classmethod + def from_action( + cls, + action: torch.Tensor, + parent: MCTSNode | None, + ): + return cls(prior_action=action, parent=parent) + + @property + def children(self) -> MCTSChildren: + children = self._children + if children is None: + return MCTSChildren(*[torch.zeros((), device=self.device) for _ in range(4)]) + return children + + @children.setter + def children(self, value): + self._children = value + + @property + def visits(self) -> torch.Tensor: + assert self.parent is not None + return self.parent.children.visits[self.prior_action] + + @visits.setter + def visits(self, x) -> None: + assert self.parent is not None + self.parent.children.visits[self.prior_action] = x + + @property + def value(self) -> torch.Tensor: + assert self.parent is not None + return self.parent.children.vals[self.prior_action] + + @value.setter + def value(self, x) -> None: + assert self.parent is not None + self.parent.children.vals[self.prior_action] = x + + @property + def expanded(self) -> bool: + return self.children.ids.numel() > 0 + + def get_child(self, action: torch.Tensor) -> MCTSNode: + idx = (self.children.ids == action).all(-1) + return self.children.nodes[idx] # type: ignore + + @classmethod + def root(cls) -> MCTSNode: + return cls(torch.Tensor(-1), None) + + @classmethod + def dummy(cls): + """Creates a 'dummy' MCTSNode that can be used to explore TorchRL's MCTS API.""" + children_values = stuff + children_priors = stuff + children_visits = stuff + children_ids = stuff + children_nodes = stuff + children = MCTSChildren( + values = children_values, + priors = children_priors, + visits = children_visits, + ids = children_ids, + nodes = children_nodes, + ) + prior_action = stuff + score = stuff + state = stuff + terminated = stuff + parent = None + return cls( + prior_action=prior_action, + children=children, + score=score, + state=state, + terminated=terminated, + parent=parent, + ) diff --git a/torchrl/modules/tensordict_module/mcts.py b/torchrl/modules/tensordict_module/mcts.py new file mode 100644 index 00000000000..56a73de5a6d --- /dev/null +++ b/torchrl/modules/tensordict_module/mcts.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations +