Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 34 additions & 17 deletions ding/model/template/vac.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,19 @@ def compute_actor(self, x: torch.Tensor) -> Dict:
Overview:
VAC forward computation graph for actor part, input observation tensor to predict action logit.
Arguments:
- x (:obj:`torch.Tensor`): The input observation tensor data.
- x (:obj:`torch.Tensor` or :obj:`Dict`): The input observation tensor data. If a dictionary is provided, \
it should contain keys 'observation' and optionally 'action_mask'.
Returns:
- outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for actor, including ``logit``.
- outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for actor, including ``logit`` \
and optionally ``action_mask`` if the input is a dictionary.
ReturnsKeys:
- logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \
the same dimension real-value ranged tensor of possible action choices, and for continuous action \
space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \
same as the number of continuous actions. Hybrid action space is a kind of combination of discrete \
and continuous action space, so the logit will be a dict with ``action_type`` and ``action_args``.
- action_mask (:obj:`torch.Tensor`, optional): The action mask tensor, included if the input is a \
dictionary containing 'action_mask'.
Shapes:
- logit (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape``

Expand All @@ -264,13 +268,17 @@ def compute_actor(self, x: torch.Tensor) -> Dict:
>>> actor_outputs = model(inputs,'compute_actor')
>>> assert actor_outputs['logit'].shape == torch.Size([4, 64])
"""
if self.share_encoder:
x = self.encoder(x)
if isinstance(x, dict):
action_mask = x.get('action_mask', None)
x = self.encoder(x['observation']) if self.share_encoder else self.actor_encoder(x['observation'])
else:
x = self.actor_encoder(x)
x = self.encoder(x) if self.share_encoder else self.actor_encoder(x)

if self.action_space == 'discrete':
return self.actor_head(x)
result = {'logit': self.actor_head(x)['logit']}
if action_mask is not None:
result['action_mask'] = action_mask
return result
elif self.action_space == 'continuous':
x = self.actor_head(x) # mu, sigma
return {'logit': x}
Expand All @@ -284,7 +292,8 @@ def compute_critic(self, x: torch.Tensor) -> Dict:
Overview:
VAC forward computation graph for critic part, input observation tensor to predict state value.
Arguments:
- x (:obj:`torch.Tensor`): The input observation tensor data.
- x (:obj:`torch.Tensor` or :obj:`Dict`): The input observation tensor data. If a dictionary is provided, \
it should contain the key 'observation'.
Returns:
- outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for critic, including ``value``.
ReturnsKeys:
Expand All @@ -298,10 +307,10 @@ def compute_critic(self, x: torch.Tensor) -> Dict:
>>> critic_outputs = model(inputs,'compute_critic')
>>> assert critic_outputs['value'].shape == torch.Size([4])
"""
if self.share_encoder:
x = self.encoder(x)
if isinstance(x, dict):
x = self.encoder(x['observation']) if self.share_encoder else self.critic_encoder(x['observation'])
else:
x = self.critic_encoder(x)
x = self.encoder(x) if self.share_encoder else self.critic_encoder(x)
x = self.critic_head(x)
return {'value': x['pred']}

Expand All @@ -311,17 +320,20 @@ def compute_actor_critic(self, x: torch.Tensor) -> Dict:
VAC forward computation graph for both actor and critic part, input observation tensor to predict action \
logit and state value.
Arguments:
- x (:obj:`torch.Tensor`): The input observation tensor data.
- x (:obj:`torch.Tensor` or :obj:`Dict`): The input observation tensor data. If a dictionary is provided, \
it should contain keys 'observation' and optionally 'action_mask'.
Returns:
- outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for both actor and critic, \
including ``logit`` and ``value``.
including ``logit``, ``value``, and optionally ``action_mask`` if the input is a dictionary.
ReturnsKeys:
- logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \
the same dimension real-value ranged tensor of possible action choices, and for continuous action \
space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \
same as the number of continuous actions. Hybrid action space is a kind of combination of discrete \
and continuous action space, so the logit will be a dict with ``action_type`` and ``action_args``.
- value (:obj:`torch.Tensor`): The predicted state value tensor.
- action_mask (:obj:`torch.Tensor`, optional): The action mask tensor, included if the input is a \
dictionary containing 'action_mask'.
Shapes:
- logit (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape``
- value (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch size, (B, 1) is squeezed to (B, ).
Expand All @@ -338,17 +350,22 @@ def compute_actor_critic(self, x: torch.Tensor) -> Dict:
``compute_actor_critic`` interface aims to save computation when shares encoder and return the combination \
dict output.
"""
if self.share_encoder:
actor_embedding = critic_embedding = self.encoder(x)
if isinstance(x, dict):
action_mask = x.get('action_mask', None)
actor_embedding = critic_embedding = self.encoder(
x['observation']
) if self.share_encoder else self.actor_encoder(x['observation'])
else:
actor_embedding = self.actor_encoder(x)
critic_embedding = self.critic_encoder(x)
actor_embedding = critic_embedding = self.encoder(x) if self.share_encoder else self.actor_encoder(x)

value = self.critic_head(critic_embedding)['pred']

if self.action_space == 'discrete':
logit = self.actor_head(actor_embedding)['logit']
return {'logit': logit, 'value': value}
result = {'logit': logit, 'value': value}
if action_mask is not None:
result['action_mask'] = action_mask
return result
elif self.action_space == 'continuous':
x = self.actor_head(actor_embedding)
return {'logit': x, 'value': value}
Expand Down
18 changes: 10 additions & 8 deletions ding/policy/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,6 @@ def default_config(cls: type) -> EasyDict:
traj_len_inf=False,
# neural network model config
model=dict(),
# If resume_training is True, the environment step count (collector.envstep) and training iteration (train_iter)
# will be loaded from the pretrained checkpoint, allowing training to resume seamlessly
# from where the ckpt left off.
learn=dict(resume_training=False),
)

def __init__(
Expand Down Expand Up @@ -420,15 +416,21 @@ def sync_gradients(self, model: torch.nn.Module) -> None:
- model (:obj:`torch.nn.Module`): The model to synchronize gradients.

.. note::
This method is only used in multi-gpu training, and it should be called after ``backward`` method and \
before ``step`` method. The user can also use ``bp_update_sync`` config to control whether to synchronize \
gradients allreduce and optimizer updates.
This method is only used in multi-gpu training, and it should be called after the ``backward`` method and \
before the ``step`` method. The user can also use the ``bp_update_sync`` config to control whether to \
synchronize gradients allreduce and optimizer updates.
"""

if self._bp_update_sync:
for name, param in model.named_parameters():
if param.requires_grad:
allreduce(param.grad.data)
if param.grad is not None:
allreduce(param.grad.data)
else:
# If the gradient is None, create a zero tensor
# with the same size as param.grad and perform allreduce
zero_grad = torch.zeros_like(param.data)
allreduce(zero_grad)
else:
synchronize()

Expand Down
30 changes: 18 additions & 12 deletions ding/utils/pytorch_ddp_dist_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import torch
import torch.distributed as dist
import datetime

from .default_helper import error_wrapper

Expand Down Expand Up @@ -138,20 +139,25 @@ def wrapper(*args, **kwargs):
return wrapper


def dist_init(backend: str = 'nccl',
addr: str = None,
port: str = None,
rank: int = None,
world_size: int = None) -> Tuple[int, int]:
def dist_init(
backend: str = 'nccl',
addr: str = None,
port: str = None,
rank: int = None,
world_size: int = None,
timeout: datetime.timedelta = datetime.timedelta(seconds=60000)
) -> Tuple[int, int]:
"""
Overview:
Initialize the distributed training setting
Initialize the distributed training setting.
Arguments:
- backend (:obj:`str`): The backend of the distributed training, support ``['nccl', 'gloo']``
- addr (:obj:`str`): The address of the master node
- port (:obj:`str`): The port of the master node
- rank (:obj:`int`): The rank of current process
- world_size (:obj:`int`): The total number of processes
- backend (:obj:`str`): The backend of the distributed training, supports ``['nccl', 'gloo']``.
- addr (:obj:`str`): The address of the master node.
- port (:obj:`str`): The port of the master node.
- rank (:obj:`int`): The rank of the current process.
- world_size (:obj:`int`): The total number of processes.
- timeout (:obj:`datetime.timedelta`): The timeout for operations executed against the process group. \
Default is 60000 seconds.
"""

assert backend in ['nccl', 'gloo'], backend
Expand All @@ -171,7 +177,7 @@ def dist_init(backend: str = 'nccl',
else:
world_size = int(ntasks)

dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
dist.init_process_group(backend=backend, rank=rank, world_size=world_size, timeout=timeout)

num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
Expand Down
21 changes: 15 additions & 6 deletions ding/worker/learner/base_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def __init__(
Overview:
Initialization method, build common learner components according to cfg, such as hook, wrapper and so on.
Arguments:
- cfg (:obj:`EasyDict`): Learner config, you can refer cls.config for details.
- cfg (:obj:`EasyDict`): Learner config, you can refer cls.config for details. It should include \
`is_unizero_multitask_pipeline` to indicate if the pipeline is unizero multitask, default is False, \
and `only_monitor_rank0` to control whether only rank 0 needs monitor and tb_logger, default is True.
- policy (:obj:`namedtuple`): A collection of policy function of learn mode. And policy can also be \
initialized when runtime.
- tb_logger (:obj:`SummaryWriter`): Tensorboard summary writer.
Expand All @@ -78,6 +80,12 @@ def __init__(
self._instance_name = instance_name
self._ckpt_name = None
self._timer = EasyTimer()
self._is_unizero_multitask_pipeline = self._cfg.get('is_unizero_multitask_pipeline', False)
self._only_monitor_rank0 = self._cfg.get('only_monitor_rank0', True)

# Adjust only_monitor_rank0 based on is_unizero_multitask_pipeline
if self._is_unizero_multitask_pipeline:
self._only_monitor_rank0 = False

# These 2 attributes are only used in parallel mode.
self._end_flag = False
Expand All @@ -92,8 +100,10 @@ def __init__(
self._cfg.hook.log_reduce_after_iter = True

# Logger (Monitor will be initialized in policy setter)
# Only rank == 0 learner needs monitor and tb_logger, others only need text_logger to display terminal output.
if self._rank == 0:
# In the unizero multitask pipeline, each rank needs its own tb_logger.
# Otherwise, only rank == 0 learner needs monitor and tb_logger,
# others only need text_logger to display terminal output.
if self._rank == 0 or self._is_unizero_multitask_pipeline:
if tb_logger is not None:
self._logger, _ = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
Expand All @@ -108,6 +118,7 @@ def __init__(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
)
self._tb_logger = None

self._log_buffer = {
'scalar': build_log_buffer(),
'scalars': build_log_buffer(),
Expand All @@ -122,8 +133,6 @@ def __init__(
self._hooks = {'before_run': [], 'before_iter': [], 'after_iter': [], 'after_run': []}
# Last iteration. Used to record current iter.
self._last_iter = CountVar(init_val=0)
# Collector envstep. Used to record current envstep.
self._collector_envstep = 0

# Setup time wrapper and hook.
self._setup_wrapper()
Expand Down Expand Up @@ -454,7 +463,7 @@ def policy(self, _policy: 'Policy') -> None: # noqa
Policy variable monitor is set alongside with policy, because variables are determined by specific policy.
"""
self._policy = _policy
if self._rank == 0:
if self._rank == 0 or not self._only_monitor_rank0:
self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10)
if self._cfg.log_policy:
self.info(self._policy.info())
Expand Down
Loading
Loading