Skip to content

Commit

Permalink
Fix accumulated_grad_batches typehint (#9071)
Browse files Browse the repository at this point in the history
* Fix `accumulated_grad_batches` typehint
  • Loading branch information
ananthsub authored Aug 24, 2021
1 parent 1a2468f commit 376734a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Union
from typing import Dict, Union

from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.utilities import GradClipAlgorithmType
Expand All @@ -27,7 +27,7 @@ def on_trainer_init(
gradient_clip_val: float,
gradient_clip_algorithm: str,
track_grad_norm: Union[int, float, str],
accumulate_grad_batches: Union[int, Dict[int, int], List[list]],
accumulate_grad_batches: Union[int, Dict[int, int]],
terminate_on_nan: bool,
):

Expand All @@ -48,7 +48,7 @@ def on_trainer_init(
self.trainer.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)

def configure_accumulated_gradients(self, accumulate_grad_batches):
def configure_accumulated_gradients(self, accumulate_grad_batches: Union[int, Dict[int, int]]) -> None:
if isinstance(accumulate_grad_batches, dict):
self.trainer.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: int = 1,
fast_dev_run: Union[int, bool] = False,
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
accumulate_grad_batches: Union[int, Dict[int, int]] = 1,
max_epochs: Optional[int] = None,
min_epochs: Optional[int] = None,
max_steps: Optional[int] = None,
Expand Down

0 comments on commit 376734a

Please sign in to comment.