Skip to content

Commit

Permalink
Added num_groups option to config used by Horovod distributed optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
rickybalin committed Apr 25, 2023
1 parent 0f1d651 commit 0645a4c
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/config/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ class Tensorflow(Framework):
name: str = "tensorflow"
inter_op_parallelism_threads: int = 2
intra_op_parallelism_threads: int = 24
num_groups: int = 0

@dataclass
class Torch(Framework):
name: str = "torch"
sparse: bool = False
distributed_mode: DistributedMode = DistributedMode.DDP
num_groups: int = 0

cs = ConfigStore.instance()
cs.store(group="framework", name="tensorflow", node=Tensorflow)
Expand Down
2 changes: 1 addition & 1 deletion src/utils/tensorflow2/distributed_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def init_optimizer(self):

# Wrap the optimizer it in horovod:
# self._opt = hvd.DistributedOptimizer(self._opt)
self.tape = hvd.DistributedGradientTape(self.tape, num_groups=1)
self.tape = hvd.DistributedGradientTape(self.tape, num_groups=self.args.framework.num_groups)

def init_saver(self):
if hvd.rank() == 0:
Expand Down
3 changes: 2 additions & 1 deletion src/utils/torch/distributed_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def init_optimizer(self):
torch_trainer.init_optimizer(self)

if self.args.framework.distributed_mode == DistributedMode.horovod:
self._opt = hvd.DistributedOptimizer(self._opt, named_parameters=self._net.named_parameters())
self._opt = hvd.DistributedOptimizer(self._opt, named_parameters=self._net.named_parameters(),
num_groups=self.args.framework.num_groups)
# self._opt.param_groups[0]['capturable'] = True
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self._opt, self.lr_calculator, last_epoch=-1)

Expand Down

0 comments on commit 0645a4c

Please sign in to comment.