Skip to content

minor code cleaning #982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def prune_packages(req_file: str, *pkgs: str) -> None:
lines = [ln for ln in lines if not ln.startswith(pkg)]
logging.info(lines)

with open(req_file, "w") as fp:
with open(req_file, "w", encoding="utf_8") as fp:
fp.writelines(lines)

@staticmethod
Expand All @@ -71,17 +71,17 @@ def set_min_torch_by_python(fpath: str = "requirements.txt") -> None:
with open(fpath) as fp:
req = fp.read()
req = re.sub(r"torch>=[\d\.]+", f"torch>={LUT_PYTHON_TORCH[py_ver]}", req)
with open(fpath, "w") as fp:
with open(fpath, "w", encoding="utf_8") as fp:
fp.write(req)

@staticmethod
def replace_min_requirements(fpath: str) -> None:
"""Replace all `>=` by `==` in given file."""
logging.info(f"processing: {fpath}")
with open(fpath) as fp:
with open(fpath, encoding="utf_8") as fp:
req = fp.read()
req = req.replace(">=", "==")
with open(fpath, "w") as fp:
with open(fpath, "w", encoding="utf_8") as fp:
fp.write(req)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
elif ln.startswith("### "):
ln = ln.replace("###", f"### {chlog_ver} -")
chlog_lines[i] = ln
with open(path_out, "w") as fp:
with open(path_out, "w", encoding="utf_8") as fp:
fp.writelines(chlog_lines)


Expand Down
10 changes: 5 additions & 5 deletions tests/audio/test_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_sdr(self, preds, target, sk_metric, ddp, dist_sync_on_step):
SignalDistortionRatio,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(),
metric_args={},
)

def test_sdr_functional(self, preds, target, sk_metric):
Expand All @@ -96,7 +96,7 @@ def test_sdr_functional(self, preds, target, sk_metric):
target,
signal_distortion_ratio,
sk_metric,
metric_args=dict(),
metric_args={},
)

def test_sdr_differentiability(self, preds, target, sk_metric):
Expand All @@ -105,7 +105,7 @@ def test_sdr_differentiability(self, preds, target, sk_metric):
target=target,
metric_module=SignalDistortionRatio,
metric_functional=signal_distortion_ratio,
metric_args=dict(),
metric_args={},
)

@pytest.mark.skipif(
Expand All @@ -117,7 +117,7 @@ def test_sdr_half_cpu(self, preds, target, sk_metric):
target=target,
metric_module=SignalDistortionRatio,
metric_functional=signal_distortion_ratio,
metric_args=dict(),
metric_args={},
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
Expand All @@ -127,7 +127,7 @@ def test_sdr_half_gpu(self, preds, target, sk_metric):
target=target,
metric_module=SignalDistortionRatio,
metric_functional=signal_distortion_ratio,
metric_args=dict(),
metric_args={},
)


Expand Down
2 changes: 1 addition & 1 deletion tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def test_error_on_wrong_input():
metric.update([], torch.Tensor()) # type: ignore

with pytest.raises(ValueError, match="Expected argument `preds` and `target` to have the same length"):
metric.update([dict()], [dict(), dict()])
metric.update([{}], [{}, {}])

with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `boxes` key"):
metric.update(
Expand Down
7 changes: 3 additions & 4 deletions torchmetrics/functional/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,8 @@ def _auroc_compute(
# max_fpr parameter is only support for binary
if mode != DataType.BINARY:
raise ValueError(
f"Partial AUC computation not available in"
f" multilabel/multiclass setting, 'max_fpr' must be"
f" set to `None`, received `{max_fpr}`."
"Partial AUC computation not available in multilabel/multiclass setting,"
f" 'max_fpr' must be set to `None`, received `{max_fpr}`."
)

# calculate fpr, tpr
Expand Down Expand Up @@ -172,7 +171,7 @@ def _auroc_compute(

allowed_average = (AverageMethod.NONE.value, AverageMethod.MACRO.value, AverageMethod.WEIGHTED.value)
raise ValueError(
f"Argument `average` expected to be one of the following:" f" {allowed_average} but got {average}"
f"Argument `average` expected to be one of the following: {allowed_average} but got {average}"
)

return _auc_compute_without_check(fpr, tpr, 1.0)
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def rouge_score(
stemmer = nltk.stem.porter.PorterStemmer() if use_stemmer else None

if not isinstance(rouge_keys, tuple):
rouge_keys = tuple([rouge_keys])
rouge_keys = (rouge_keys,)
for key in rouge_keys:
if key not in ALLOWED_ROUGE_KEYS.keys():
raise ValueError(f"Got unknown rouge key {key}. Expected to be one of {list(ALLOWED_ROUGE_KEYS.keys())}")
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/text/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _squad_input_check(
_fn_answer = lambda tgt: dict(
answers=[dict(text=txt) for txt in tgt["answers"]["text"]], id=tgt["id"] # type: ignore
)
targets_dict = [dict(paragraphs=[dict(qas=[_fn_answer(target) for target in targets])])]
targets_dict = [{"paragraphs": [dict(qas=[_fn_answer(target) for target in targets])]}]
return preds_dict, targets_dict


Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

if not (isinstance(kernel_size, Sequence) or isinstance(kernel_size, int)):
if not (isinstance(kernel_size, (Sequence, int))):
raise ValueError(
f"Argument `kernel_size` expected to be an sequence or an int, or a single int. Got {kernel_size}"
)
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
k: v for k, v in kwargs.items() if (k in _sign_params.keys() and _sign_params[k].kind not in _params)
}

exists_var_keyword = any([v.kind == inspect.Parameter.VAR_KEYWORD for v in _sign_params.values()])
exists_var_keyword = any(v.kind == inspect.Parameter.VAR_KEYWORD for v in _sign_params.values())
# if no kwargs filtered, return all kwargs as default
if not filtered_kwargs and not exists_var_keyword:
# no kwargs in update signature -> don't return any kwargs
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
import nltk

if not isinstance(rouge_keys, tuple):
rouge_keys = tuple([rouge_keys])
rouge_keys = (rouge_keys,)
for key in rouge_keys:
if key not in ALLOWED_ROUGE_KEYS:
raise ValueError(f"Got unknown rouge key {key}. Expected to be one of {ALLOWED_ROUGE_KEYS}")
Expand Down