Skip to content
9 changes: 7 additions & 2 deletions ding/model/template/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
norm_type: Optional[str] = None,
dropout: Optional[float] = None,
init_bias: Optional[float] = None,
noise: bool = False,
) -> None:
"""
Overview:
Expand All @@ -57,6 +58,8 @@ def __init__(
- dropout (:obj:`Optional[float]`): The dropout rate of the dropout layer. \
if ``None`` then default disable dropout layer.
- init_bias (:obj:`Optional[float]`): The initial value of the last layer bias in the head network. \
- noise (:obj:`bool`): Whether to use ``NoiseLinearLayer`` as ``layer_fn`` to boost exploration in \
Q networks' MLP. Default to ``False``.
"""
super(DQN, self).__init__()
# Squeeze data from tuple, list or dict to single object. For example, from (4, ) to 4
Expand Down Expand Up @@ -90,7 +93,8 @@ def __init__(
layer_num=head_layer_num,
activation=activation,
norm_type=norm_type,
dropout=dropout
dropout=dropout,
noise=noise,
)
else:
self.head = head_cls(
Expand All @@ -99,7 +103,8 @@ def __init__(
head_layer_num,
activation=activation,
norm_type=norm_type,
dropout=dropout
dropout=dropout,
noise=noise,
)
if init_bias is not None and head_cls == DuelingHead:
# Zero the last layer bias of advantage head
Expand Down
19 changes: 19 additions & 0 deletions ding/policy/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
from typing import List, Any, Dict, Callable
import torch
import torch.nn as nn
import numpy as np
import treetensor.torch as ttorch
from ding.utils.data import default_collate
from ding.torch_utils import to_tensor, to_ndarray, unsqueeze, squeeze
from ding.torch_utils import NoiseLinearLayer


def set_noise_mode(module: nn.Module, noise_enabled: bool):
"""
Overview:
Recursively set the 'enable_noise' attribute for all NoiseLinearLayer modules within the given module.
This function is typically used in algorithms such as NoisyNet and Rainbow.
During training, 'enable_noise' should be set to True to enable noise for exploration.
During inference or evaluation, it should be set to False to disable noise for deterministic behavior.

Arguments:
- module (:obj:`nn.Module`): The root module to search for NoiseLinearLayer instances.
- noise_enabled (:obj:`bool`): Whether to enable or disable noise.
"""
for m in module.modules():
if isinstance(m, NoiseLinearLayer):
m.enable_noise = noise_enabled


def default_preprocess_learn(
Expand Down
44 changes: 43 additions & 1 deletion ding/policy/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ding.utils.data import default_collate, default_decollate

from .base_policy import Policy
from .common_utils import default_preprocess_learn
from .common_utils import default_preprocess_learn, set_noise_mode


@POLICY_REGISTRY.register('dqn')
Expand Down Expand Up @@ -97,6 +97,8 @@ class DQNPolicy(Policy):
discount_factor=0.97,
# (int) The number of steps for calculating target q_value.
nstep=1,
# (bool) Whether to use NoisyNet for exploration in both learning and collecting. Default is False.
noisy_net=False,
model=dict(
# (list(int)) Sequence of ``hidden_size`` of subsequent conv layers and the final dense layer.
encoder_hidden_size_list=[128, 128, 64],
Expand Down Expand Up @@ -248,6 +250,21 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
.. note::
For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``.
"""
# Set noise mode for NoisyNet for exploration in learning if enabled in config
# We need to reset set_noise_mode every _forward_xxx because the model is reused across different
# phases (learn/collect/eval).
if self._cfg.noisy_net:
set_noise_mode(self._learn_model, True)
set_noise_mode(self._target_model, True)

# A noisy network agent samples a new set of parameters after every step of optimisation.
# Between optimisation steps, the agent acts according to a fixed set of parameters (weights and biases).
# This ensures that the agent always acts according to parameters that are drawn from
# the current noise distribution.
if self._cfg.noisy_net:
self._reset_noise(self._learn_model)
self._reset_noise(self._target_model)

# Data preprocessing operations, such as stack data, cpu to cuda device
data = default_preprocess_learn(
data,
Expand Down Expand Up @@ -380,10 +397,17 @@ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
.. note::
For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``.
"""
# Set noise mode for NoisyNet for exploration in collecting if enabled in config.
# We need to reset set_noise_mode every _forward_xxx because the model is reused across different
# phases (learn/collect/eval).
if self._cfg.noisy_net:
set_noise_mode(self._collect_model, True)

data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)

self._collect_model.eval()
with torch.no_grad():
output = self._collect_model.forward(data, eps=eps)
Expand Down Expand Up @@ -472,10 +496,16 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
.. note::
For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``.
"""
# We need to reset set_noise_mode every _forward_xxx because the model is reused across different
# phases (learn/collect/eval).
# Ensure that in evaluation mode noise is disabled.
set_noise_mode(self._eval_model, False)

data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)

self._eval_model.eval()
with torch.no_grad():
output = self._eval_model.forward(data)
Expand Down Expand Up @@ -533,6 +563,18 @@ def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = F
)
return {'priority': td_error_per_sample.abs().tolist()}

def _reset_noise(self, model: torch.nn.Module):
r"""
Overview:
Reset the noise of model.

Arguments:
- model (:obj:`torch.nn.Module`): the model to reset, must contain reset_noise method
"""
for m in model.modules():
if hasattr(m, 'reset_noise'):
m.reset_noise()


@POLICY_REGISTRY.register('dqn_stdim')
class DQNSTDIMPolicy(DQNPolicy):
Expand Down
16 changes: 13 additions & 3 deletions ding/policy/rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from .dqn import DQNPolicy
from .common_utils import default_preprocess_learn
from .common_utils import default_preprocess_learn, set_noise_mode


@POLICY_REGISTRY.register('rainbow')
Expand Down Expand Up @@ -86,8 +86,9 @@ class RainbowDQNPolicy(DQNPolicy):
discount_factor=0.99,
# (int) N-step reward for target q_value estimation
nstep=3,
# (bool) Whether to use NoisyNet for exploration in both learning and collecting. Default is True.
noisy_net=True,
learn=dict(

# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
Expand Down Expand Up @@ -201,6 +202,11 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
# ====================
self._learn_model.train()
self._target_model.train()

# Set noise mode for NoisyNet for exploration in learning if enabled in config
set_noise_mode(self._learn_model, True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use self._cfg.noisy_net to control this logic

set_noise_mode(self._target_model, True)

# reset noise of noisenet for both main model and target model
self._reset_noise(self._learn_model)
self._reset_noise(self._target_model)
Expand Down Expand Up @@ -262,12 +268,16 @@ def _forward_collect(self, data: dict, eps: float) -> dict:
ReturnsKeys
- necessary: ``action``
"""
# Set noise mode for NoisyNet for exploration in collecting if enabled in config
# We need to reset set_noise_mode every _forward_xxx because the model is reused across
# different phases (learn/collect/eval).
set_noise_mode(self._collect_model, True)

data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
self._collect_model.eval()
self._reset_noise(self._collect_model)
with torch.no_grad():
output = self._collect_model.forward(data, eps=eps)
if self._cuda:
Expand Down
9 changes: 7 additions & 2 deletions ding/torch_utils/network/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,10 @@ class NoiseLinearLayer(nn.Module):
def __init__(self, in_channels: int, out_channels: int, sigma0: int = 0.4) -> None:
"""
Overview:
Initialize the NoiseLinearLayer class.
Initialize the NoiseLinearLayer class. The 'enable_noise' attribute enables external control over whether \
noise is applied.
- If enable_noise is True, the layer adds noise even if the module is in evaluation mode.
- If enable_noise is False, no noise is added regardless of self.training.
Arguments:
- in_channels (:obj:`int`): Number of channels in the input tensor.
- out_channels (:obj:`int`): Number of channels in the output tensor.
Expand All @@ -654,6 +657,7 @@ def __init__(self, in_channels: int, out_channels: int, sigma0: int = 0.4) -> No
self.register_buffer("weight_eps", torch.empty(out_channels, in_channels))
self.register_buffer("bias_eps", torch.empty(out_channels))
self.sigma0 = sigma0
self.enable_noise = False
self.reset_parameters()
self.reset_noise()

Expand Down Expand Up @@ -703,7 +707,8 @@ def forward(self, x: torch.Tensor):
Returns:
- output (:obj:`torch.Tensor`): The output tensor with noise.
"""
if self.training:
# Determine whether to add noise:
if self.enable_noise:
return F.linear(
x,
self.weight_mu + self.weight_sigma * self.weight_eps,
Expand Down
26 changes: 25 additions & 1 deletion ding/torch_utils/network/tests/test_nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ding.torch_utils import build_activation
from ding.torch_utils.network.nn_module import MLP, conv1d_block, conv2d_block, fc_block, deconv2d_block, \
ChannelShuffle, one_hot, NearestUpsample, BilinearUpsample, binary_encode, weight_init_, NaiveFlatten, \
normed_linear, normed_conv2d
normed_linear, normed_conv2d, NoiseLinearLayer

batch_size = 2
in_channels = 2
Expand Down Expand Up @@ -238,3 +238,27 @@ def test_flatten(self):
model3 = NaiveFlatten(1, 3)
output3 = model2(inputs)
assert output1.shape == (4, 3 * 8 * 8)

def test_noise_linear_layer(self):
input = torch.rand(batch_size, in_channels).requires_grad_(True)
layer = NoiseLinearLayer(in_channels, out_channels, sigma0=0.5)
# No noise by default
output = self.run_model(input, layer)
assert output.shape == (batch_size, out_channels)
# Enable noise
layer.enable_noise = True
layer.reset_noise()
output_noise = self.run_model(input, layer)
assert output_noise.shape == (batch_size, out_channels)
# Check that outputs are different after resetting noise
with torch.no_grad():
layer.reset_noise()
out1 = layer(input)
layer.reset_noise()
out2 = layer(input)
# The outputs should be different (very likely)
assert not torch.allclose(out1, out2)
# Check reset_parameters
layer.reset_parameters()
assert layer.weight_mu.shape == (out_channels, in_channels)
assert layer.bias_mu.shape == (out_channels, )
60 changes: 60 additions & 0 deletions dizoo/atari/config/serial/demon_attack/demon_attack_dqn_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from easydict import EasyDict

demon_attack_dqn_config = dict(
exp_name='DemonAttack_dqn_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=1e6,
env_id='DemonAttackNoFrameskip-v4',
frame_stack=4,
),
policy=dict(
cuda=True,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
noise=True,
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
),
noisy_net=True,
collect=dict(n_sample=96),
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=250000,
),
replay_buffer=dict(replay_buffer_size=100000, ),
),
),
)
demon_attack_dqn_config = EasyDict(demon_attack_dqn_config)
main_config = demon_attack_dqn_config
demon_attack_dqn_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
)
demon_attack_dqn_create_config = EasyDict(demon_attack_dqn_create_config)
create_config = demon_attack_dqn_create_config

if __name__ == '__main__':
# or you can enter `ding -m serial -c demon_attack_dqn_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(10e6))
Loading