Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not require lr_scheduler for all other recipes #2051

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def setup(self, cfg: DictConfig) -> None:
# Learning rate scheduler can only be set up after number of steps
# has been computed
self._lr_scheduler = self._setup_lr_scheduler(
cfg_lr_scheduler=cfg.lr_scheduler,
cfg_lr_scheduler=cfg.get("lr_scheduler", None),
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)
Expand Down Expand Up @@ -626,10 +626,15 @@ def _setup_optimizer(

def _setup_lr_scheduler(
self,
cfg_lr_scheduler: DictConfig,
cfg_lr_scheduler: Optional[DictConfig],
num_training_steps: int,
last_epoch: int,
) -> Optimizer:
) -> Optional[Optimizer]:
if cfg_lr_scheduler is None:
log.info(
"No learning rate scheduler configured. Using constant learning rate."
)
return None
lr_scheduler = config.instantiate(
cfg_lr_scheduler,
self._optimizer,
Expand Down Expand Up @@ -886,7 +891,8 @@ def train(self) -> None:
kd_loss_to_log = running_kd_loss.item() / num_tokens
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
if self._lr_scheduler is not None:
self._lr_scheduler.step()
# Update the number of steps when the weights are updated
self.global_step += 1

Expand Down
15 changes: 11 additions & 4 deletions recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def setup(self, cfg: DictConfig) -> None:
# Learning rate scheduler can only be set up after number of steps
# has been computed
self._lr_scheduler = self._setup_lr_scheduler(
cfg_lr_scheduler=cfg.lr_scheduler,
cfg_lr_scheduler=cfg.get("lr_scheduler", None),
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)
Expand Down Expand Up @@ -495,10 +495,16 @@ def _setup_optimizer(

def _setup_lr_scheduler(
self,
cfg_lr_scheduler: DictConfig,
cfg_lr_scheduler: Optional[DictConfig],
num_training_steps: int,
last_epoch: int,
) -> Optimizer:
) -> Optional[Optimizer]:
if cfg_lr_scheduler is None:
log.info(
"No learning rate scheduler configured. Using constant learning rate."
)
return None

lr_scheduler = config.instantiate(
cfg_lr_scheduler,
self._optimizer,
Expand Down Expand Up @@ -727,7 +733,8 @@ def train(self) -> None:
)
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
if self._lr_scheduler is not None:
self._lr_scheduler.step()
# Update the number of steps when the weights are updated
self.global_step += 1

Expand Down
15 changes: 11 additions & 4 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def setup(self, cfg: DictConfig) -> None:
# Learning rate scheduler can only be set up after number of steps
# has been computed
self._lr_scheduler = self._setup_lr_scheduler(
cfg_lr_scheduler=cfg.lr_scheduler,
cfg_lr_scheduler=cfg.get("lr_scheduler", None),
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)
Expand Down Expand Up @@ -426,10 +426,16 @@ def _setup_optimizer(

def _setup_lr_scheduler(
self,
cfg_lr_scheduler: DictConfig,
cfg_lr_scheduler: Optional[DictConfig],
num_training_steps: int,
last_epoch: int,
) -> Optimizer:
) -> Optional[Optimizer]:
if cfg_lr_scheduler is None:
log.info(
"No learning rate scheduler configured. Using constant learning rate."
)
return None

lr_scheduler = config.instantiate(
cfg_lr_scheduler,
self._optimizer,
Expand Down Expand Up @@ -679,7 +685,8 @@ def train(self) -> None:
if (idx + 1) % self._gradient_accumulation_steps == 0:
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
if self._lr_scheduler is not None:
self._lr_scheduler.step()

# Update the number of steps when the weights are updated
self.global_step += 1
Expand Down
15 changes: 11 additions & 4 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def setup(self, cfg: DictConfig) -> None:
# Learning rate scheduler can only be set up after number of steps
# has been computed
self._lr_scheduler = self._setup_lr_scheduler(
cfg_lr_scheduler=cfg.lr_scheduler,
cfg_lr_scheduler=cfg.get("lr_scheduler", None),
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)
Expand Down Expand Up @@ -563,10 +563,16 @@ def _setup_optimizer(

def _setup_lr_scheduler(
self,
cfg_lr_scheduler: DictConfig,
cfg_lr_scheduler: Optional[DictConfig],
num_training_steps: int,
last_epoch: int,
) -> Optimizer:
) -> Optional[Optimizer]:
if cfg_lr_scheduler is None:
log.info(
"No learning rate scheduler configured. Using constant learning rate."
)
return None

lr_scheduler = config.instantiate(
cfg_lr_scheduler,
self._optimizer,
Expand Down Expand Up @@ -837,7 +843,8 @@ def train(self) -> None:
)
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
if self._lr_scheduler is not None:
self._lr_scheduler.step()

# Update the number of steps when the weights are updated
self.global_step += 1
Expand Down
15 changes: 11 additions & 4 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def setup(self, cfg: DictConfig) -> None:
# Learning rate scheduler can only be set up after number of steps
# has been computed
self._lr_scheduler = self._setup_lr_scheduler(
cfg_lr_scheduler=cfg.lr_scheduler,
cfg_lr_scheduler=cfg.get("lr_scheduler", None),
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)
Expand Down Expand Up @@ -497,10 +497,16 @@ def _setup_optimizer(

def _setup_lr_scheduler(
self,
cfg_lr_scheduler: DictConfig,
cfg_lr_scheduler: Optional[DictConfig],
num_training_steps: int,
last_epoch: int,
) -> Optimizer:
) -> Optional[Optimizer]:
if cfg_lr_scheduler is None:
log.info(
"No learning rate scheduler configured. Using constant learning rate."
)
return None

lr_scheduler = config.instantiate(
cfg_lr_scheduler,
self._optimizer,
Expand Down Expand Up @@ -717,7 +723,8 @@ def train(self) -> None:
)
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
if self._lr_scheduler is not None:
self._lr_scheduler.step()
# Update the number of steps when the weights are updated
self.global_step += 1

Expand Down
Loading