Skip to content

Commit

Permalink
Add missing annotations in `torchtnt/tests/framework/callbacks/test_l…
Browse files Browse the repository at this point in the history
…ambda.py`

Summary:
Adding annotations to and removing error suppressions from this module helps us move towards a well-typed codebase.

Added annotation to _get_members_in_different_name() and four call() functions.

Reviewed By: daniellepintz

Differential Revision: D46009990

fbshipit-source-id: 284ca95513f3cbe465bbc14c29d67f55da51523e
  • Loading branch information
Feng Hu authored and facebook-github-bot committed May 21, 2023
1 parent 49f84ca commit 60331d0
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tests/framework/callbacks/test_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import unittest
from functools import partial
from inspect import getmembers, isfunction
from typing import Tuple
from typing import Set, Tuple

import torch
from torchtnt.framework._test_utils import (
Expand Down Expand Up @@ -36,7 +36,7 @@ def train_step(self, state: State, data: Batch) -> None:
raise RuntimeError("testing")


def _get_members_in_different_name(cls: Callback, phase: str):
def _get_members_in_different_name(cls: Callback, phase: str) -> Set[str]:
# retrieve Callback in different phases, including: train, predict, fit, eval
return {
h
Expand All @@ -57,7 +57,7 @@ def test_lambda_callback_train(self) -> None:
)
checker = set()

def call(hook, *_, **__):
def call(hook: str, *_, **__) -> None:
checker.add(hook)

hooks = _get_members_in_different_name(Callback, "train")
Expand All @@ -81,7 +81,7 @@ def test_lambda_callback_eval(self) -> None:
)
checker = set()

def call(hook, *_, **__):
def call(hook: str, *_, **__) -> None:
checker.add(hook)

hooks = _get_members_in_different_name(Callback, "eval")
Expand All @@ -100,7 +100,7 @@ def test_lambda_callback_predict(self) -> None:
max_steps_per_epoch = 6
checker = set()

def call(hook, *_, **__):
def call(hook: str, *_, **__) -> None:
checker.add(hook)

hooks = _get_members_in_different_name(Callback, "predict")
Expand All @@ -126,7 +126,7 @@ def test_lambda_callback_train_with_except(self) -> None:
)
checker = set()

def call(hook, *_, **__):
def call(hook: str, *_, **__) -> None:
checker.add(hook)

# with on_exception, training will not be ended
Expand Down

0 comments on commit 60331d0

Please sign in to comment.