Skip to content

Commit

Permalink
Merge pull request #118 from ai4co/debug-atsp
Browse files Browse the repository at this point in the history
[BugFix] fix the running bug of MatNet for ATSP
  • Loading branch information
cbhua committed Feb 26, 2024
2 parents 83b8550 + e7dae57 commit b3f1446
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 7 deletions.
1 change: 1 addition & 0 deletions rl4co/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
HeterogeneousAttentionModel,
HeterogeneousAttentionModelPolicy,
)
from rl4co.models.zoo.matnet import MatNet, MatNetPolicy
from rl4co.models.zoo.mdam import MDAM, MDAMPolicy
from rl4co.models.zoo.pomo import POMO, POMOPolicy
from rl4co.models.zoo.ppo import PPOModel, PPOPolicy
Expand Down
1 change: 1 addition & 0 deletions rl4co/models/zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
HeterogeneousAttentionModel,
HeterogeneousAttentionModelPolicy,
)
from rl4co.models.zoo.matnet import MatNet, MatNetPolicy
from rl4co.models.zoo.mdam import MDAM, MDAMPolicy
from rl4co.models.zoo.pomo import POMO, POMOPolicy
from rl4co.models.zoo.ppo import PPOModel, PPOPolicy
Expand Down
9 changes: 3 additions & 6 deletions rl4co/models/zoo/matnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,16 @@ def __init__(
batch_size: int = 200,
policy_params: dict = {},
model_params: dict = {},
**kwargs,
):
if policy is None:
policy = MatNetPolicy(env_name=env.name, **policy_params)

super(MatNet, self).__init__(
env=env,
policy=policy,
optimizer_kwargs=optimizer_kwargs,
lr_scheduler=lr_scheduler,
lr_scheduler_kwargs=lr_scheduler_kwargs,
use_dihedral_8=use_dihedral_8,
num_starts=num_starts,
train_data_size=train_data_size,
batch_size=batch_size,
**model_params,
num_augment=0, # NOTE: for MatNet we don't use augmentation
**kwargs,
)
2 changes: 1 addition & 1 deletion rl4co/utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def select_start_nodes(td, env, num_starts):
env: Environment may determine the node selection strategy
num_starts: Number of nodes to select. This may be passed when calling the policy directly. See :class:`rl4co.models.AutoregressiveDecoder`
"""
if env.name in ["tsp"]:
if env.name in ["tsp", "atsp"]:
selected = torch.arange(num_starts, device=td.device).repeat_interleave(
td.shape[0]
)
Expand Down

0 comments on commit b3f1446

Please sign in to comment.