diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index 00ef4162b9..2ab1cf1f4d 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -241,20 +241,24 @@ def forward(self, x: torch.Tensor, mode: str) -> Dict: assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) return getattr(self, mode)(x) - def compute_actor(self, x: torch.Tensor) -> Dict: + def compute_actor(self, x: Union[torch.Tensor, Dict]) -> Dict: """ Overview: VAC forward computation graph for actor part, input observation tensor to predict action logit. Arguments: - - x (:obj:`torch.Tensor`): The input observation tensor data. + - x (:obj:`Union[torch.Tensor, Dict]`): The input observation tensor data. If a dictionary is provided, \ + it should contain keys 'observation' and optionally 'action_mask'. Returns: - - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for actor, including ``logit``. + - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for actor, including ``logit`` \ + and optionally ``action_mask`` if the input is a dictionary. ReturnsKeys: - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \ the same dimension real-value ranged tensor of possible action choices, and for continuous action \ space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \ same as the number of continuous actions. Hybrid action space is a kind of combination of discrete \ and continuous action space, so the logit will be a dict with ``action_type`` and ``action_args``. + - action_mask (:obj:`Optional[torch.Tensor]`): The action mask tensor, included if the input is a \ + dictionary containing 'action_mask'. Shapes: - logit (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` @@ -264,13 +268,18 @@ def compute_actor(self, x: torch.Tensor) -> Dict: >>> actor_outputs = model(inputs,'compute_actor') >>> assert actor_outputs['logit'].shape == torch.Size([4, 64]) """ - if self.share_encoder: - x = self.encoder(x) + if isinstance(x, dict): + action_mask = x['action_mask'] + x = self.encoder(x['observation']) if self.share_encoder else self.actor_encoder(x['observation']) else: - x = self.actor_encoder(x) + action_mask = None + x = self.encoder(x) if self.share_encoder else self.actor_encoder(x) if self.action_space == 'discrete': - return self.actor_head(x) + result = {'logit': self.actor_head(x)['logit']} + if action_mask is not None: + result['action_mask'] = action_mask + return result elif self.action_space == 'continuous': x = self.actor_head(x) # mu, sigma return {'logit': x} @@ -279,12 +288,13 @@ def compute_actor(self, x: torch.Tensor) -> Dict: action_args = self.actor_head[1](x) return {'logit': {'action_type': action_type['logit'], 'action_args': action_args}} - def compute_critic(self, x: torch.Tensor) -> Dict: + def compute_critic(self, x: Union[torch.Tensor, Dict]) -> Dict: """ Overview: VAC forward computation graph for critic part, input observation tensor to predict state value. Arguments: - - x (:obj:`torch.Tensor`): The input observation tensor data. + - x (:obj:`Union[torch.Tensor, Dict]`): The input observation tensor data. If a dictionary is provided, \ + it should contain the key 'observation'. Returns: - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for critic, including ``value``. ReturnsKeys: @@ -298,23 +308,24 @@ def compute_critic(self, x: torch.Tensor) -> Dict: >>> critic_outputs = model(inputs,'compute_critic') >>> assert critic_outputs['value'].shape == torch.Size([4]) """ - if self.share_encoder: - x = self.encoder(x) + if isinstance(x, dict): + x = self.encoder(x['observation']) if self.share_encoder else self.critic_encoder(x['observation']) else: - x = self.critic_encoder(x) + x = self.encoder(x) if self.share_encoder else self.critic_encoder(x) x = self.critic_head(x) return {'value': x['pred']} - def compute_actor_critic(self, x: torch.Tensor) -> Dict: + def compute_actor_critic(self, x: Union[torch.Tensor, Dict]) -> Dict: """ Overview: VAC forward computation graph for both actor and critic part, input observation tensor to predict action \ logit and state value. Arguments: - - x (:obj:`torch.Tensor`): The input observation tensor data. + - x (:obj:`Union[torch.Tensor, Dict]`): The input observation tensor data. If a dictionary is provided, \ + it should contain keys 'observation' and optionally 'action_mask'. Returns: - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for both actor and critic, \ - including ``logit`` and ``value``. + including ``logit``, ``value``, and optionally ``action_mask`` if the input is a dictionary. ReturnsKeys: - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \ the same dimension real-value ranged tensor of possible action choices, and for continuous action \ @@ -322,6 +333,8 @@ def compute_actor_critic(self, x: torch.Tensor) -> Dict: same as the number of continuous actions. Hybrid action space is a kind of combination of discrete \ and continuous action space, so the logit will be a dict with ``action_type`` and ``action_args``. - value (:obj:`torch.Tensor`): The predicted state value tensor. + - action_mask (:obj:`torch.Tensor`, optional): The action mask tensor, included if the input is a \ + dictionary containing 'action_mask'. Shapes: - logit (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` - value (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch size, (B, 1) is squeezed to (B, ). @@ -338,17 +351,29 @@ def compute_actor_critic(self, x: torch.Tensor) -> Dict: ``compute_actor_critic`` interface aims to save computation when shares encoder and return the combination \ dict output. """ - if self.share_encoder: - actor_embedding = critic_embedding = self.encoder(x) + if isinstance(x, dict): + action_mask = x['action_mask'] + if self.share_encoder: + actor_embedding = critic_embedding = self.encoder(x['observation']) + else: + actor_embedding = self.actor_encoder(x['observation']) + critic_embedding = self.critic_encoder(x['observation']) else: - actor_embedding = self.actor_encoder(x) - critic_embedding = self.critic_encoder(x) + action_mask = None + if self.share_encoder: + actor_embedding = critic_embedding = self.encoder(x) + else: + actor_embedding = self.actor_encoder(x) + critic_embedding = self.critic_encoder(x) value = self.critic_head(critic_embedding)['pred'] if self.action_space == 'discrete': logit = self.actor_head(actor_embedding)['logit'] - return {'logit': logit, 'value': value} + result = {'logit': logit, 'value': value} + if action_mask is not None: + result['action_mask'] = action_mask + return result elif self.action_space == 'continuous': x = self.actor_head(actor_embedding) return {'logit': x, 'value': value} diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index 61d8c1e845..1c5f32d1db 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -7,8 +7,8 @@ import torch from ding.model import create_model -from ding.utils import import_module, allreduce, broadcast, get_rank, allreduce_async, synchronize, deep_merge_dicts, \ - POLICY_REGISTRY +from ding.utils import import_module, allreduce, allreduce_with_indicator, broadcast, get_rank, allreduce_async, \ + synchronize, deep_merge_dicts, POLICY_REGISTRY class Policy(ABC): @@ -415,20 +415,39 @@ def __repr__(self) -> str: def sync_gradients(self, model: torch.nn.Module) -> None: """ Overview: - Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training. + Synchronize (allreduce) gradients of model parameters in data-parallel multi-GPU training. + For parameters that did not participate in the forward/backward pass in some GPUs, + assign a zero gradient with an indicator of 0. This ensures that only GPUs which contributed + to the gradient computation are considered when averaging, thereby avoiding an incorrect + division by the total number of GPUs. Arguments: - model (:obj:`torch.nn.Module`): The model to synchronize gradients. .. note:: - This method is only used in multi-gpu training, and it should be called after ``backward`` method and \ - before ``step`` method. The user can also use ``bp_update_sync`` config to control whether to synchronize \ - gradients allreduce and optimizer updates. + This method is only used in multi-gpu training, and it should be called after the ``backward`` method and \ + before the ``step`` method. The user can also use the ``bp_update_sync`` config to control whether to \ + synchronize gradients allreduce and optimizer updates. """ - if self._bp_update_sync: for name, param in model.named_parameters(): if param.requires_grad: - allreduce(param.grad.data) + # Create an indicator tensor on the same device as the parameter (or its gradient) + if param.grad is not None: + # If the gradient exists, extract its data and set indicator to 1. + grad_tensor = param.grad.data + indicator = torch.tensor(1.0, device=grad_tensor.device) + else: + # If the parameter did not participate in the computation (grad is None), + # create a zero tensor for the gradient and set the indicator to 0. + grad_tensor = torch.zeros_like(param.data) + indicator = torch.tensor(0.0, device=grad_tensor.device) + + # Assign the zero gradient to param.grad to ensure that all GPUs + # participate in the subsequent allreduce call (avoiding deadlock). + param.grad = grad_tensor + + # Use the custom allreduce function to reduce the gradient using the indicator. + allreduce_with_indicator(param.grad, indicator) else: synchronize() diff --git a/ding/utils/__init__.py b/ding/utils/__init__.py index 68a92efcb5..3623cab824 100644 --- a/ding/utils/__init__.py +++ b/ding/utils/__init__.py @@ -39,5 +39,5 @@ allreduce, broadcast, DistContext, allreduce_async, synchronize else: from .pytorch_ddp_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \ - allreduce, broadcast, DDPContext, allreduce_async, synchronize, reduce_data, broadcast_object_list, \ - to_ddp_config, allreduce_data + allreduce, allreduce_with_indicator, broadcast, DDPContext, allreduce_async, synchronize, reduce_data, \ + broadcast_object_list, to_ddp_config, allreduce_data diff --git a/ding/utils/pytorch_ddp_dist_helper.py b/ding/utils/pytorch_ddp_dist_helper.py index 13d9e1e299..fcea6c81e8 100644 --- a/ding/utils/pytorch_ddp_dist_helper.py +++ b/ding/utils/pytorch_ddp_dist_helper.py @@ -5,6 +5,7 @@ import numpy as np import torch import torch.distributed as dist +import datetime from .default_helper import error_wrapper @@ -46,6 +47,27 @@ def allreduce(x: torch.Tensor) -> None: x.div_(get_world_size()) +def allreduce_with_indicator(grad: torch.Tensor, indicator: torch.Tensor) -> None: + """ + Overview: + Custom allreduce: Sum both the gradient and indicator tensors across all processes. + Then, if at least one process contributed (i.e., the summation of indicator > 0), + divide the gradient by the summed indicator. This ensures that if only a subset of + GPUs contributed a gradient, the averaging is performed based on the actual number + of contributors rather than the total number of GPUs. + Arguments: + - grad (torch.Tensor): Local gradient tensor to be reduced. + - indicator (torch.Tensor): A tensor flag (1 if the gradient is computed, 0 otherwise). + """ + # Allreduce (sum) the gradient and indicator + dist.all_reduce(grad) + dist.all_reduce(indicator) + + # Avoid division by zero. If indicator is close to 0 (extreme case), grad remains zeros. + if not torch.isclose(indicator, torch.tensor(0.0)): + grad.div_(indicator.item()) + + def allreduce_async(name: str, x: torch.Tensor) -> None: """ Overview: @@ -138,20 +160,25 @@ def wrapper(*args, **kwargs): return wrapper -def dist_init(backend: str = 'nccl', - addr: str = None, - port: str = None, - rank: int = None, - world_size: int = None) -> Tuple[int, int]: +def dist_init( + backend: str = 'nccl', + addr: str = None, + port: str = None, + rank: int = None, + world_size: int = None, + timeout: datetime.timedelta = datetime.timedelta(seconds=60000) +) -> Tuple[int, int]: """ Overview: - Initialize the distributed training setting + Initialize the distributed training setting. Arguments: - - backend (:obj:`str`): The backend of the distributed training, support ``['nccl', 'gloo']`` - - addr (:obj:`str`): The address of the master node - - port (:obj:`str`): The port of the master node - - rank (:obj:`int`): The rank of current process - - world_size (:obj:`int`): The total number of processes + - backend (:obj:`str`): The backend of the distributed training, supports ``['nccl', 'gloo']``. + - addr (:obj:`str`): The address of the master node. + - port (:obj:`str`): The port of the master node. + - rank (:obj:`int`): The rank of the current process. + - world_size (:obj:`int`): The total number of processes. + - timeout (:obj:`datetime.timedelta`): The timeout for operations executed against the process group. \ + Default is 60000 seconds. """ assert backend in ['nccl', 'gloo'], backend @@ -171,7 +198,7 @@ def dist_init(backend: str = 'nccl', else: world_size = int(ntasks) - dist.init_process_group(backend=backend, rank=rank, world_size=world_size) + dist.init_process_group(backend=backend, rank=rank, world_size=world_size, timeout=timeout) num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) diff --git a/ding/worker/learner/base_learner.py b/ding/worker/learner/base_learner.py index 0b57b06c76..a070e7a58d 100644 --- a/ding/worker/learner/base_learner.py +++ b/ding/worker/learner/base_learner.py @@ -35,6 +35,8 @@ def default_config(cls: type) -> EasyDict: train_iterations=int(1e9), dataloader=dict(num_workers=0, ), log_policy=True, + is_multitask_pipeline=False, + only_monitor_rank0=True, # --- Hooks --- hook=dict( load_ckpt_before_run='', @@ -59,7 +61,9 @@ def __init__( Overview: Initialization method, build common learner components according to cfg, such as hook, wrapper and so on. Arguments: - - cfg (:obj:`EasyDict`): Learner config, you can refer cls.config for details. + - cfg (:obj:`EasyDict`): Learner config, you can refer cls.config for details. It should include \ + `is_multitask_pipeline` to indicate if the pipeline is multitask, default is False, \ + and `only_monitor_rank0` to control whether only rank 0 needs monitor and tb_logger, default is True. - policy (:obj:`namedtuple`): A collection of policy function of learn mode. And policy can also be \ initialized when runtime. - tb_logger (:obj:`SummaryWriter`): Tensorboard summary writer. @@ -78,6 +82,12 @@ def __init__( self._instance_name = instance_name self._ckpt_name = None self._timer = EasyTimer() + self._is_multitask_pipeline = self._cfg.is_multitask_pipeline + self.only_monitor_rank0 = self._cfg.only_monitor_rank0 + + # Adjust only_monitor_rank0 based on is_multitask_pipeline + if self._is_multitask_pipeline: + self.only_monitor_rank0 = False # These 2 attributes are only used in parallel mode. self._end_flag = False @@ -92,8 +102,10 @@ def __init__( self._cfg.hook.log_reduce_after_iter = True # Logger (Monitor will be initialized in policy setter) - # Only rank == 0 learner needs monitor and tb_logger, others only need text_logger to display terminal output. - if self._rank == 0: + # In the multitask pipeline, each rank needs its own tb_logger. + # Otherwise, only rank == 0 learner needs monitor and tb_logger, + # others only need text_logger to display terminal output. + if self._rank == 0 or not self.only_monitor_rank0: if tb_logger is not None: self._logger, _ = build_logger( './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False @@ -108,6 +120,7 @@ def __init__( './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False ) self._tb_logger = None + self._log_buffer = { 'scalar': build_log_buffer(), 'scalars': build_log_buffer(), @@ -454,7 +467,7 @@ def policy(self, _policy: 'Policy') -> None: # noqa Policy variable monitor is set alongside with policy, because variables are determined by specific policy. """ self._policy = _policy - if self._rank == 0: + if self._rank == 0 or not self.only_monitor_rank0: self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10) if self._cfg.log_policy: self.info(self._policy.info()) diff --git a/ding/worker/learner/learner_hook.py b/ding/worker/learner/learner_hook.py index 6797af077b..2ddfe825b8 100644 --- a/ding/worker/learner/learner_hook.py +++ b/ding/worker/learner/learner_hook.py @@ -187,29 +187,33 @@ class LogShowHook(LearnerHook): def __init__(self, *args, ext_args: EasyDict = EasyDict(), **kwargs) -> None: """ Overview: - init LogShowHook + Init LogShowHook. Arguments: - - ext_args (:obj:`EasyDict`): extended_args, use ext_args.freq to set freq + - ext_args (:obj:`EasyDict`): Extended arguments, use ext_args.freq to set frequency and \ + ext_args.only_monitor_rank0 to control if only rank 0 should monitor, default is True. """ super().__init__(*args, **kwargs) if ext_args == {}: self._freq = 1 else: self._freq = ext_args.freq + self._only_monitor_rank0 = None def __call__(self, engine: 'BaseLearner') -> None: # noqa """ Overview: Show log, update record and tb_logger if rank is 0 and at interval iterations, - clear the log buffer for all learners regardless of rank + clear the log buffer for all learners regardless of rank. Arguments: - - engine (:obj:`BaseLearner`): the BaseLearner + - engine (:obj:`BaseLearner`): The BaseLearner. """ - # Only show log for rank 0 learner - if engine.rank != 0: + self._only_monitor_rank0 = engine.only_monitor_rank0 + # Only show log for rank 0 learner if _only_monitor_rank0 is True + if engine.rank != 0 and self._only_monitor_rank0: for k in engine.log_buffer: engine.log_buffer[k].clear() return + # For 'scalar' type variables: log_buffer -> tick_monitor -> monitor_time.step for k, v in engine.log_buffer['scalar'].items(): setattr(engine.monitor, k, v) @@ -243,7 +247,7 @@ def __call__(self, engine: 'BaseLearner') -> None: # noqa class LogReduceHook(LearnerHook): """ Overview: - Hook to reduce the distributed(multi-gpu) logs + Hook to reduce the distributed (multi-gpu) logs. Interfaces: __init__, __call__ Property: @@ -253,32 +257,45 @@ class LogReduceHook(LearnerHook): def __init__(self, *args, ext_args: EasyDict = EasyDict(), **kwargs) -> None: """ Overview: - init LogReduceHook + Initialize LogReduceHook. Arguments: - - ext_args (:obj:`EasyDict`): extended_args, use ext_args.freq to set log_reduce_freq + - ext_args (:obj:`EasyDict`): Extended arguments, use ext_args.freq to set log_reduce_freq. """ super().__init__(*args, **kwargs) def __call__(self, engine: 'BaseLearner') -> None: # noqa """ Overview: - reduce the logs from distributed(multi-gpu) learners + Reduce the logs from distributed (multi-gpu) learners. Arguments: - - engine (:obj:`BaseLearner`): the BaseLearner + - engine (:obj:`BaseLearner`): The BaseLearner. """ def aggregate(data): r""" Overview: - aggregate the information from all ranks(usually use sync allreduce) + Aggregate the information from all ranks (usually using sync allreduce). Arguments: - data (:obj:`dict`): Data that needs to be reduced. \ - Could be dict, torch.Tensor, numbers.Integral or numbers.Real. + Could be dict, torch.Tensor, numbers.Integral, or numbers.Real. Returns: - - new_data (:obj:`dict`): data after reduce + - new_data (:obj:`dict`): Data after reduction. """ + + def should_reduce(key): + # Check if the key starts with the "noreduce_" prefix. + # The "noreduce_" prefix is used in the unizero_multitask ddp pipeline + # to indicate data that should not be reduced. + return not key.startswith("noreduce_") + if isinstance(data, dict): - new_data = {k: aggregate(v) for k, v in data.items()} + new_data = {} + for k, v in data.items(): + if should_reduce(k): + new_data[k] = aggregate(v) # Perform allreduce on data that needs reduction. + else: + new_data[k] = v # Retain data that does not need reduction. + elif isinstance(data, list) or isinstance(data, tuple): new_data = [aggregate(t) for t in data] elif isinstance(data, torch.Tensor): @@ -299,7 +316,7 @@ def aggregate(data): new_data = new_data.cpu() new_data = new_data.item() else: - raise TypeError("invalid type in reduce: {}".format(type(data))) + raise TypeError("Invalid type in reduce: {}".format(type(data))) return new_data engine.log_buffer = aggregate(engine.log_buffer)