Skip to content

Commit

Permalink
minor code cleaning (#982)
Browse files Browse the repository at this point in the history
* minor code cleaning
* Apply suggestions from code review

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justus Schock <[email protected]>
  • Loading branch information
3 people authored Apr 25, 2022
1 parent ee60f67 commit e236821
Show file tree
Hide file tree
Showing 10 changed files with 19 additions and 20 deletions.
8 changes: 4 additions & 4 deletions .github/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def prune_packages(req_file: str, *pkgs: str) -> None:
lines = [ln for ln in lines if not ln.startswith(pkg)]
logging.info(lines)

with open(req_file, "w") as fp:
with open(req_file, "w", encoding="utf-8") as fp:
fp.writelines(lines)

@staticmethod
Expand All @@ -71,17 +71,17 @@ def set_min_torch_by_python(fpath: str = "requirements.txt") -> None:
with open(fpath) as fp:
req = fp.read()
req = re.sub(r"torch>=[\d\.]+", f"torch>={LUT_PYTHON_TORCH[py_ver]}", req)
with open(fpath, "w") as fp:
with open(fpath, "w", encoding="utf-8") as fp:
fp.write(req)

@staticmethod
def replace_min_requirements(fpath: str) -> None:
"""Replace all `>=` by `==` in given file."""
logging.info(f"processing: {fpath}")
with open(fpath) as fp:
with open(fpath, encoding="utf-8") as fp:
req = fp.read()
req = req.replace(">=", "==")
with open(fpath, "w") as fp:
with open(fpath, "w", encoding="utf-8") as fp:
fp.write(req)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
elif ln.startswith("### "):
ln = ln.replace("###", f"### {chlog_ver} -")
chlog_lines[i] = ln
with open(path_out, "w") as fp:
with open(path_out, "w", encoding="utf-8") as fp:
fp.writelines(chlog_lines)


Expand Down
10 changes: 5 additions & 5 deletions tests/audio/test_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_sdr(self, preds, target, sk_metric, ddp, dist_sync_on_step):
SignalDistortionRatio,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(),
metric_args={},
)

def test_sdr_functional(self, preds, target, sk_metric):
Expand All @@ -96,7 +96,7 @@ def test_sdr_functional(self, preds, target, sk_metric):
target,
signal_distortion_ratio,
sk_metric,
metric_args=dict(),
metric_args={},
)

def test_sdr_differentiability(self, preds, target, sk_metric):
Expand All @@ -105,7 +105,7 @@ def test_sdr_differentiability(self, preds, target, sk_metric):
target=target,
metric_module=SignalDistortionRatio,
metric_functional=signal_distortion_ratio,
metric_args=dict(),
metric_args={},
)

@pytest.mark.skipif(
Expand All @@ -117,7 +117,7 @@ def test_sdr_half_cpu(self, preds, target, sk_metric):
target=target,
metric_module=SignalDistortionRatio,
metric_functional=signal_distortion_ratio,
metric_args=dict(),
metric_args={},
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
Expand All @@ -127,7 +127,7 @@ def test_sdr_half_gpu(self, preds, target, sk_metric):
target=target,
metric_module=SignalDistortionRatio,
metric_functional=signal_distortion_ratio,
metric_args=dict(),
metric_args={},
)


Expand Down
2 changes: 1 addition & 1 deletion tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def test_error_on_wrong_input():
metric.update([], torch.Tensor()) # type: ignore

with pytest.raises(ValueError, match="Expected argument `preds` and `target` to have the same length"):
metric.update([dict()], [dict(), dict()])
metric.update([{}], [{}, {}])

with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `boxes` key"):
metric.update(
Expand Down
7 changes: 3 additions & 4 deletions torchmetrics/functional/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,8 @@ def _auroc_compute(
# max_fpr parameter is only support for binary
if mode != DataType.BINARY:
raise ValueError(
f"Partial AUC computation not available in"
f" multilabel/multiclass setting, 'max_fpr' must be"
f" set to `None`, received `{max_fpr}`."
"Partial AUC computation not available in multilabel/multiclass setting,"
f" 'max_fpr' must be set to `None`, received `{max_fpr}`."
)

# calculate fpr, tpr
Expand Down Expand Up @@ -172,7 +171,7 @@ def _auroc_compute(

allowed_average = (AverageMethod.NONE.value, AverageMethod.MACRO.value, AverageMethod.WEIGHTED.value)
raise ValueError(
f"Argument `average` expected to be one of the following:" f" {allowed_average} but got {average}"
f"Argument `average` expected to be one of the following: {allowed_average} but got {average}"
)

return _auc_compute_without_check(fpr, tpr, 1.0)
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def rouge_score(
stemmer = nltk.stem.porter.PorterStemmer() if use_stemmer else None

if not isinstance(rouge_keys, tuple):
rouge_keys = tuple([rouge_keys])
rouge_keys = (rouge_keys,)
for key in rouge_keys:
if key not in ALLOWED_ROUGE_KEYS.keys():
raise ValueError(f"Got unknown rouge key {key}. Expected to be one of {list(ALLOWED_ROUGE_KEYS.keys())}")
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/text/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _squad_input_check(
_fn_answer = lambda tgt: dict(
answers=[dict(text=txt) for txt in tgt["answers"]["text"]], id=tgt["id"] # type: ignore
)
targets_dict = [dict(paragraphs=[dict(qas=[_fn_answer(target) for target in targets])])]
targets_dict = [{"paragraphs": [{"qas": [_fn_answer(target) for target in targets]}]}]
return preds_dict, targets_dict


Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __init__(
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

if not (isinstance(kernel_size, Sequence) or isinstance(kernel_size, int)):
if not (isinstance(kernel_size, (Sequence, int))):
raise ValueError(
f"Argument `kernel_size` expected to be an sequence or an int, or a single int. Got {kernel_size}"
)
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
k: v for k, v in kwargs.items() if (k in _sign_params.keys() and _sign_params[k].kind not in _params)
}

exists_var_keyword = any([v.kind == inspect.Parameter.VAR_KEYWORD for v in _sign_params.values()])
exists_var_keyword = any(v.kind == inspect.Parameter.VAR_KEYWORD for v in _sign_params.values())
# if no kwargs filtered, return all kwargs as default
if not filtered_kwargs and not exists_var_keyword:
# no kwargs in update signature -> don't return any kwargs
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
import nltk

if not isinstance(rouge_keys, tuple):
rouge_keys = tuple([rouge_keys])
rouge_keys = (rouge_keys,)
for key in rouge_keys:
if key not in ALLOWED_ROUGE_KEYS:
raise ValueError(f"Got unknown rouge key {key}. Expected to be one of {ALLOWED_ROUGE_KEYS}")
Expand Down

0 comments on commit e236821

Please sign in to comment.