Skip to content

Commit

Permalink
Validate docstring examples (#384)
Browse files Browse the repository at this point in the history
- Add module to validate ``` python wrapped code examples in docstrings
- Make all examples validate correctly

fixes #81
  • Loading branch information
Marenz authored May 16, 2023
2 parents 3edf6a5 + b25272d commit fe4727c
Show file tree
Hide file tree
Showing 10 changed files with 358 additions and 87 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ pytest = [
"pytest-asyncio == 0.21.0",
"time-machine == 2.9.0",
"async-solipsism == 0.5",
# For checking docstring code examples
"sybil == 5.0.1",
"pylint == 2.17.4",
]
mypy = [
"mypy == 1.3.0",
Expand Down
209 changes: 209 additions & 0 deletions src/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# License: MIT
# Copyright © 2023 Frequenz Energy-as-a-Service GmbH

"""Pytest plugin to validate docstring code examples.
Code examples are often wrapped in triple backticks (```) within our docstrings.
This plugin extracts these code examples and validates them using pylint.
"""

from __future__ import annotations

import ast
import os
import subprocess
from pathlib import Path

from sybil import Example, Sybil
from sybil.evaluators.python import pad
from sybil.parsers.abstract.lexers import textwrap
from sybil.parsers.myst import CodeBlockParser

PYLINT_DISABLE_COMMENT = (
"# pylint: {}=unused-import,wildcard-import,unused-wildcard-import"
)

FORMAT_STRING = """
# Generated auto-imports for code example
{disable_pylint}
{imports}
{enable_pylint}
{code}"""


def get_import_statements(code: str) -> list[str]:
"""Get all import statements from a given code string.
Args:
code: The code to extract import statements from.
Returns:
A list of import statements.
"""
tree = ast.parse(code)
import_statements = []

for node in ast.walk(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
import_statement = ast.get_source_segment(code, node)
import_statements.append(import_statement)

return import_statements


def path_to_import_statement(path: Path) -> str:
"""Convert a path to a Python file to an import statement.
Args:
path: The path to convert.
Returns:
The import statement.
Raises:
ValueError: If the path does not point to a Python file.
"""
# Make the path relative to the present working directory
if path.is_absolute():
path = path.relative_to(Path.cwd())

# Check if the path is a Python file
if path.suffix != ".py":
raise ValueError("Path must point to a Python file (.py)")

# Remove 'src' prefix if present
parts = path.parts
if parts[0] == "src":
parts = parts[1:]

# Remove the '.py' extension and join parts with '.'
module_path = ".".join(parts)[:-3]

# Create the import statement
import_statement = f"from {module_path} import *"
return import_statement


class CustomPythonCodeBlockParser(CodeBlockParser):
"""Code block parser that validates extracted code examples using pylint.
This parser is a modified version of the default Python code block parser
from the Sybil library.
It uses pylint to validate the extracted code examples.
All code examples are preceded by the original file's import statements as
well as an wildcard import of the file itself.
This allows us to use the code examples as if they were part of the original
file.
Additionally, the code example is padded with empty lines to make sure the
line numbers are correct.
Pylint warnings which are unimportant for code examples are disabled.
"""

def __init__(self):
"""Initialize the parser."""
super().__init__("python")

def evaluate(self, example: Example) -> None | str:
"""Validate the extracted code example using pylint.
Args:
example: The extracted code example.
Returns:
None if the code example is valid, otherwise the pylint output.
"""
# Get the import statements for the original file
import_header = get_import_statements(example.document.text)
# Add a wildcard import of the original file
import_header.append(
path_to_import_statement(Path(os.path.relpath(example.path)))
)
imports_code = "\n".join(import_header)

# Dedent the code example
# There is also example.parsed that is already prepared, but it has
# empty lines stripped and thus fucks up the line numbers.
example_code = textwrap.dedent(
example.document.text[example.start : example.end]
)
# Remove first line (the line with the triple backticks)
example_code = example_code[example_code.find("\n") + 1 :]

example_with_imports = FORMAT_STRING.format(
disable_pylint=PYLINT_DISABLE_COMMENT.format("disable"),
imports=imports_code,
enable_pylint=PYLINT_DISABLE_COMMENT.format("enable"),
code=example_code,
)

# Make sure the line numbers are correct
source = pad(
example_with_imports,
example.line - imports_code.count("\n") - FORMAT_STRING.count("\n"),
)

# pylint disable parameters
pylint_disable_params = [
"missing-module-docstring",
"missing-class-docstring",
"missing-function-docstring",
"reimported",
"unused-variable",
"no-name-in-module",
"await-outside-async",
]

response = validate_with_pylint(source, example.path, pylint_disable_params)

if len(response) > 0:
return (
f"Pylint validation failed for code example:\n"
f"{example_with_imports}\nOutput: {response}"
)

return None


def validate_with_pylint(
code_example: str, path: str, disable_params: list[str]
) -> list[str]:
"""Validate a code example using pylint.
Args:
code_example: The code example to validate.
path: The path to the original file.
disable_params: The pylint disable parameters.
Returns:
A list of pylint messages.
"""
try:
pylint_command = [
"pylint",
"--disable",
",".join(disable_params),
"--from-stdin",
path,
]

subprocess.run(
pylint_command,
input=code_example,
text=True,
capture_output=True,
check=True,
)
except subprocess.CalledProcessError as exception:
return exception.output.splitlines()

return []


pytest_collect_file = Sybil(
parsers=[CustomPythonCodeBlockParser()],
patterns=["*.py"],
).pytest()
20 changes: 11 additions & 9 deletions src/frequenz/sdk/actor/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def actor(cls: Type[Any]) -> Type[Any]:
TypeError: when the class doesn't have a `run` method as per spec.
Example (one actor receiving from two receivers):
``` python
```python
from frequenz.channels import Broadcast, Receiver, Sender
from frequenz.channels.util import Select
@actor
class EchoActor:
def __init__(
Expand Down Expand Up @@ -115,11 +117,12 @@ async def run(self) -> None:
echo_rx = echo_chan.new_receiver()
await input_chan_2.new_sender().send(True)
msg = await echo_rx.receive()
received_msg = await echo_rx.receive()
```
Example (two Actors composed):
``` python
```python
from frequenz.channels import Broadcast, Receiver, Sender
@actor
class Actor1:
def __init__(
Expand Down Expand Up @@ -153,16 +156,15 @@ async def run(self) -> None:
async for msg in self._recv:
await self._output.send(msg)
input_chan: Broadcast[bool] = Broadcast("Input to A1")
a1_chan: Broadcast[bool] = Broadcast["A1 stream"]
a2_chan: Broadcast[bool] = Broadcast["A2 stream"]
a1 = Actor1(
a1_chan: Broadcast[bool] = Broadcast("A1 stream")
a2_chan: Broadcast[bool] = Broadcast("A2 stream")
a_1 = Actor1(
name="ActorOne",
recv=input_chan.new_receiver(),
output=a1_chan.new_sender(),
)
a2 = Actor2(
a_2 = Actor2(
name="ActorTwo",
recv=a1_chan.new_receiver(),
output=a2_chan.new_sender(),
Expand All @@ -171,7 +173,7 @@ async def run(self) -> None:
a2_rx = a2_chan.new_receiver()
await input_chan.new_sender().send(True)
msg = await a2_rx.receive()
received_msg = await a2_rx.receive()
```
"""
Expand Down
40 changes: 28 additions & 12 deletions src/frequenz/sdk/actor/power_distributing/power_distributing.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,10 @@ class PowerDistributingActor:
printed.
Example:
``` python
import grpc.aio as grpcaio
from frequenz.sdk.microgrid.graph import _MicrogridComponentGraph
```python
from frequenz.sdk import microgrid
from frequenz.sdk.microgrid.component import ComponentCategory
from frequenz.sdk.actor import ResamplerConfig
from frequenz.sdk.actor.power_distributing import (
PowerDistributingActor,
Request,
Expand All @@ -103,29 +102,45 @@ class PowerDistributingActor:
PartialFailure,
Ignored,
)
from frequenz.channels import Bidirectional, Broadcast, Receiver, Sender
from datetime import timedelta
from frequenz.sdk import actor
HOST = "localhost"
PORT = 50051
target = f"{host}:{port}"
grpc_channel = grpcaio.insecure_channel(target)
api = MicrogridGrpcClient(grpc_channel, target)
await microgrid.initialize(
HOST,
PORT,
ResamplerConfig(resampling_period=timedelta(seconds=1))
)
graph = _MicrogridComponentGraph()
await graph.refresh_from_api(api)
graph = microgrid.connection_manager.get().component_graph
batteries = graph.components(component_category={ComponentCategory.BATTERY})
batteries_ids = {c.component_id for c in batteries}
battery_status_channel = Broadcast[BatteryStatus]("battery-status")
channel = Bidirectional[Request, Result]("user1", "power_distributor")
power_distributor = PowerDistributingActor(
mock_api, component_graph, {"user1": channel.service_handle}
users_channels={"user1": channel.service_handle},
battery_status_sender=battery_status_channel.new_sender(),
)
# Start the actor
await actor.run(power_distributor)
client_handle = channel.client_handle
# Set power 1200W to given batteries.
request = Request(power=1200.0, batteries=batteries_ids, request_timeout_sec=10.0)
await client_handle.send(request)
# Set power 1200W to given batteries.
request = Request(power=1200, batteries=batteries_ids, request_timeout_sec=10.0)
await client_handle.send(request)
# It is recommended to use timeout when waiting for the response!
result: Result = await asyncio.wait_for(client_handle.receive(), timeout=10)
Expand All @@ -134,9 +149,10 @@ class PowerDistributingActor:
elif isinstance(result, PartialFailure):
print(
f"Batteries {result.failed_batteries} failed, total failed power" \
f"{result.failed_power}")
f"{result.failed_power}"
)
elif isinstance(result, Ignored):
print(f"Request was ignored, because of newer request")
print("Request was ignored, because of newer request")
elif isinstance(result, Error):
print(f"Request failed with error: {result.msg}")
```
Expand Down
Loading

0 comments on commit fe4727c

Please sign in to comment.