Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
unify docformatter config (#1642)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

additional_dependencies: [tomli]
  • Loading branch information
Borda committed Jul 12, 2023
1 parent bbb65d3 commit 269e852
Show file tree
Hide file tree
Showing 69 changed files with 181 additions and 21 deletions.
14 changes: 2 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,15 @@ repos:
rev: v1.7.3
hooks:
- id: docformatter
args:
- "--in-place"
- "--wrap-summaries=120"
- "--wrap-descriptions=120"
additional_dependencies: [tomli]
args: ["--in-place"]

- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
name: Format code

- repo: https://github.com/asottile/blacken-docs
rev: 1.14.0
hooks:
- id: blacken-docs
args:
- "--line-length=120"
- "--skip-errors"

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.276
hooks:
Expand Down
14 changes: 5 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,11 @@ exclude_lines = [
line-length = 120
exclude = "(.eggs|.git|.hg|.mypy_cache|.venv|_build|buck-out|build|dist)"

[tool.isort]
known_first_party = [
"flash",
"examples",
"tests",
]
skip_glob = []
profile = "black"
line_length = 120
[tool.docformatter]
recursive = true
wrap-summaries = 120
wrap-descriptions = 120
blank = true


[tool.ruff]
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str:
>>> _load_readme_description(_PATH_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
'<div align="center">...'
"""
path_readme = os.path.join(path_dir, "README.md")
text = open(path_readme, encoding="utf-8").read()
Expand Down Expand Up @@ -65,6 +66,7 @@ def _augment_requirement(ln: str, comment_char: str = "#", unfreeze: bool = True
'arrow>=1.2.0, <=1.2.2 # strict'
>>> _augment_requirement("arrow", unfreeze=True)
'arrow'
"""
# filer all comments
if comment_char in ln:
Expand Down Expand Up @@ -95,6 +97,7 @@ def _load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: boo
>>> path_req = os.path.join(_PATH_ROOT, "requirements")
>>> _load_requirements(path_req, "docs.txt") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['sphinx>=4.0', ...]
"""
with open(os.path.join(path_dir, file_name)) as file:
lines = [ln.strip() for ln in file.readlines()]
Expand Down
1 change: 1 addition & 0 deletions src/flash/audio/speech_recognition/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class DataCollatorCTCWithPadding:
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""

processor: AutoProcessor
Expand Down
1 change: 1 addition & 0 deletions src/flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class ClassificationOutput(Output):
Args:
multi_label: If true, treats outputs as multi label logits.
"""

def __init__(self, multi_label: bool = False):
Expand Down
1 change: 1 addition & 0 deletions src/flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def default_uncollate(batch: Any) -> List[Any]:
ValueError: If the input is a ``dict`` whose values are not all list-like.
ValueError: If the input is a ``dict`` whose values are not all the same length.
ValueError: If the input is not a ``dict`` or list-like.
"""
if isinstance(batch, dict):
if any(not _is_list_like_excluding_str(sub_batch) for sub_batch in batch.values()):
Expand Down
1 change: 1 addition & 0 deletions src/flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@ def _split_train_val(
Returns:
A tuple containing the training and validation datasets
"""

if not isinstance(val_split, float) or (isinstance(val_split, float) and val_split > 1 or val_split < 0):
Expand Down
1 change: 1 addition & 0 deletions src/flash/core/data/io/classification_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,6 @@ def format_target(self, target: Any) -> Any:
Returns:
The formatted target.
"""
return getattr(self, "target_formatter", lambda x: x)(target)
1 change: 1 addition & 0 deletions src/flash/core/data/io/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def _has_len(data: Union[Sequence, Iterable]) -> bool:
Args:
data: The object to check for length support.
"""
try:
len(data)
Expand Down
27 changes: 27 additions & 0 deletions src/flash/core/data/io/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def per_sample_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_sample_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
pass

Expand All @@ -97,6 +98,7 @@ def train_per_sample_transform(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
"""
return self.per_sample_transform()

Expand All @@ -121,6 +123,7 @@ def val_per_sample_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_sample_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
return self.per_sample_transform()

Expand All @@ -134,6 +137,7 @@ def test_per_sample_transform(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
"""
return self.per_sample_transform()

Expand All @@ -158,6 +162,7 @@ def predict_per_sample_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_sample_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
return self.per_sample_transform()

Expand All @@ -182,6 +187,7 @@ def serve_per_sample_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_sample_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
return self.per_sample_transform()

Expand Down Expand Up @@ -210,6 +216,7 @@ def per_sample_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def per_sample_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
pass

Expand All @@ -223,6 +230,7 @@ def train_per_sample_transform_on_device(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
"""
return self.per_sample_transform_on_device()

Expand All @@ -247,6 +255,7 @@ def val_per_sample_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def per_sample_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
return self.per_sample_transform_on_device()

Expand All @@ -260,6 +269,7 @@ def test_per_sample_transform_on_device(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
"""
return self.per_sample_transform_on_device()

Expand All @@ -284,6 +294,7 @@ def predict_per_sample_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def per_sample_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
return self.per_sample_transform_on_device()

Expand All @@ -308,6 +319,7 @@ def serve_per_sample_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def serve_per_sample_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
return self.per_sample_transform_on_device()

Expand Down Expand Up @@ -336,6 +348,7 @@ def per_batch_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_batch_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
pass

Expand All @@ -349,6 +362,7 @@ def train_per_batch_transform(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
"""
return self.per_batch_transform()

Expand All @@ -373,6 +387,7 @@ def val_per_batch_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_batch_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
return self.per_batch_transform()

Expand All @@ -386,6 +401,7 @@ def test_per_batch_transform(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
"""
return self.per_batch_transform()

Expand All @@ -410,6 +426,7 @@ def predict_per_batch_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_batch_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
return self.per_batch_transform()

Expand All @@ -434,6 +451,7 @@ def serve_per_batch_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_batch_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
return self.per_batch_transform()

Expand Down Expand Up @@ -462,6 +480,7 @@ def per_batch_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def per_batch_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
pass

Expand All @@ -475,6 +494,7 @@ def train_per_batch_transform_on_device(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
"""
return self.per_batch_transform_on_device()

Expand All @@ -499,6 +519,7 @@ def val_per_batch_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def per_batch_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
return self.per_batch_transform_on_device()

Expand All @@ -512,6 +533,7 @@ def test_per_batch_transform_on_device(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
"""
return self.per_batch_transform_on_device()

Expand All @@ -536,6 +558,7 @@ def predict_per_batch_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def per_batch_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
return self.per_batch_transform_on_device()

Expand All @@ -560,6 +583,7 @@ def serve_per_batch_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def serve_per_batch_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
"""
return self.per_batch_transform_on_device()

Expand Down Expand Up @@ -606,6 +630,7 @@ def _per_batch_transform(self, batch: Any, stage: RunningStage) -> Any:
.. note:: This option is mutually exclusive with :meth:`per_sample_transform_on_device`, since if both are
specified, uncollation has to be applied.
"""
return self.current_transform(stage=stage, current_fn="per_batch_transform")(batch)

Expand All @@ -620,6 +645,7 @@ def _per_sample_transform_on_device(self, sample: Any, stage: RunningStage) -> A
specified, uncollation has to be applied. .. note:: This function won't be called within the dataloader
workers, since to make that happen each of the workers would have to create it's own CUDA-context which
would pollute GPU memory (if on GPU).
"""
fn = self.current_transform(stage=stage, current_fn="per_sample_transform_on_device")
if isinstance(sample, list):
Expand All @@ -631,6 +657,7 @@ def _per_batch_transform_on_device(self, batch: Any, stage: RunningStage) -> Any
.. note:: This function won't be called within the dataloader workers, since to make that happen each of
the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU).
"""
return self.current_transform(stage=stage, current_fn="per_batch_transform_on_device")(batch)

Expand Down
1 change: 1 addition & 0 deletions src/flash/core/data/io/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def transform(sample: Any) -> Any:
Returns:
The converted output.
"""
return sample

Expand Down
3 changes: 3 additions & 0 deletions src/flash/core/data/io/output_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def per_batch_transform(batch: Any) -> Any:
"""Transforms to apply on a whole batch before uncollation to individual samples.
Can involve both CPU and Device transforms as this is not applied in separate workers.
"""
return batch

Expand All @@ -33,6 +34,7 @@ def per_sample_transform(sample: Any) -> Any:
"""Transforms to apply to a single sample after splitting up the batch.
Can involve both CPU and Device transforms as this is not applied in separate workers.
"""
return sample

Expand All @@ -41,6 +43,7 @@ def uncollate(batch: Any) -> Any:
"""Uncollates a batch into single samples.
Tries to preserve the type wherever possible.
"""
return default_uncollate(batch)

Expand Down
1 change: 1 addition & 0 deletions src/flash/core/data/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class SplitDataset(Properties, Dataset):
split_ds = SplitDataset(dataset, indices=[10, 14, 25])
split_ds = SplitDataset(dataset, indices=[10, 10, 10, 14, 25], use_duplicated_indices=True)
"""

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions src/flash/core/data/utilities/data_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def resolve_targets(data_frame: pd.DataFrame, target_keys: Union[str, List[str]]
Args:
data_frame: The ``pd.DataFrame`` containing the target column / columns.
target_keys: The column in the data frame (or a list of columns) from which to resolve the target.
"""
if not isinstance(target_keys, List):
return data_frame[target_keys].tolist()
Expand Down Expand Up @@ -63,6 +64,7 @@ def resolve_files(
root: The root path to use when resolving files.
resolver: The resolver function to use. This function should receive the root and a file ID as input and return
the path to an existing file.
"""
if resolver is None:
resolver = default_resolver
Expand Down
Loading

0 comments on commit 269e852

Please sign in to comment.