Skip to content

Commit

Permalink
[BugFix] MatNet #108
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Feb 26, 2024
1 parent 8c395af commit 20269a9
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions rl4co/models/zoo/matnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,34 @@

from rl4co.models.zoo.pomo.model import POMO
from rl4co.envs.common.base import RL4COEnvBase
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


class MatNet(POMO):
def __init__(
self,
env: RL4COEnvBase,
policy: Union[nn.Module, MatNetPolicy] = None,
optimizer_kwargs: dict = {"lr": 4 * 1e-4, "weight_decay": 1e-6},
lr_scheduler: str = "MultiStepLR",
lr_scheduler_kwargs: dict = {"milestones": [2001, 2101], "gamma": 0.1},
use_dihedral_8: bool = False,
num_starts: int = None,
train_data_size: int = 10_000,
batch_size: int = 200,
policy_params: dict = {},
model_params: dict = {},
**kwargs,
):
if policy is None:
policy = MatNetPolicy(env_name=env.name, **policy_params)

# Check if num_augment is not 0 or if diheral_8 is True
if kwargs.get("num_augment", 0) != 0:
log.error("MatNet does not use symmetric augmentation. Setting num_augment to 0.")
kwargs["num_augment"] = 0
if kwargs.get("use_dihedral_8", True):
log.error("MatNet does not use symmetric Dihedral Augmentation. Setting use_dihedral_8 to False.")
kwargs["use_dihedral_8"] = False

super(MatNet, self).__init__(
env=env,
policy=policy,
use_dihedral_8=use_dihedral_8,
num_starts=num_starts,
num_augment=0, # NOTE: for MatNet we don't use augmentation
**kwargs,
)
)

0 comments on commit 20269a9

Please sign in to comment.