Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@

<!-- Changes to blackd -->

- Implemented BlackDClient. This simple python client allows to easily send formatting
requests to blackd (#4774)

### Integrations

<!-- For example, Docker, GitHub Actions, pre-commit, editors -->
Expand Down
20 changes: 18 additions & 2 deletions docs/usage_and_configuration/black_as_a_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,30 @@ formatting requests.

```

There is no official `blackd` client tool (yet!). You can test that blackd is working
using `curl`:
You can test that blackd is working using `curl`:

```sh
blackd --bind-port 9090 & # or let blackd choose a port
curl -s -XPOST "localhost:9090" -d "print('valid')"
```

Or using the python client:

```python
import asyncio

from blackd.client import BlackDClient

async def main():
client = BlackDClient(url="http://127.0.0.1:9090")
unformatted_code = "def hello(): print('Hello, World!')"
formatted_code = await client.format_code(unformatted_code)
print(formatted_code)

if __name__ == "__main__":
asyncio.run(main())
```

## Protocol

`blackd` only accepts `POST` requests at the `/` path. The body of the request should
Expand Down
94 changes: 94 additions & 0 deletions src/blackd/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import Optional

import aiohttp
from aiohttp.typedefs import StrOrURL

import black

_DEFAULT_HEADERS = {"Content-Type": "text/plain; charset=utf-8"}


class BlackDClient:
def __init__(
self,
url: StrOrURL = "http://localhost:9090",
line_length: Optional[int] = None,
skip_source_first_line: bool = False,
skip_string_normalization: bool = False,
skip_magic_trailing_comma: bool = False,
preview: bool = False,
fast: bool = False,
python_variant: Optional[str] = None,
diff: bool = False,
headers: Optional[dict[str, str]] = None,
):
"""
Initialize a BlackDClient object.
:param url: The URL of the BlackD server.
:param line_length: The maximum line length.
Corresponds to the ``--line-length`` CLI option.
:param skip_source_first_line: True to skip the first line of the source.
Corresponds to the ``--skip-source-first-line`` CLI option.
:param skip_string_normalization: True to skip string normalization.
Corresponds to the ``--skip-string-normalization`` CLI option.
:param skip_magic_trailing_comma: True to skip magic trailing comma.
Corresponds to the ``--skip-magic-trailing-comma`` CLI option.
:param preview: True to enable experimental preview mode.
Corresponds to the ``--preview`` CLI option.
:param fast: True to enable fast mode.
Corresponds to the ``--fast`` CLI option.
:param python_variant: The Python variant to use.
Corresponds to the ``--pyi`` CLI option if this is "pyi".
Otherwise, corresponds to the ``--target-version`` CLI option.
:param diff: True to enable diff mode.
Corresponds to the ``--diff`` CLI option.
:param headers: A dictionary of additional custom headers to send with
the request.
"""
self.url = url
self.headers = _DEFAULT_HEADERS.copy()

if line_length is not None:
self.headers["X-Line-Length"] = str(line_length)
if skip_source_first_line:
self.headers["X-Skip-Source-First-Line"] = "yes"
if skip_string_normalization:
self.headers["X-Skip-String-Normalization"] = "yes"
if skip_magic_trailing_comma:
self.headers["X-Skip-Magic-Trailing-Comma"] = "yes"
if preview:
self.headers["X-Preview"] = "yes"
if fast:
self.headers["X-Fast-Or-Safe"] = "fast"
if python_variant is not None:
self.headers["X-Python-Variant"] = python_variant
if diff:
self.headers["X-Diff"] = "yes"

if headers is not None:
self.headers.update(headers)

async def format_code(self, unformatted_code: str) -> str:
async with aiohttp.ClientSession() as session:
async with session.post(
self.url, headers=self.headers, data=unformatted_code.encode("utf-8")
) as response:
if response.status == 204:
# Input is already well-formatted
return unformatted_code
elif response.status == 200:
# Formatting was needed
return await response.text()
elif response.status == 400:
# Input contains a syntax error
error_message = await response.text()
raise black.InvalidInput(error_message)
elif response.status == 500:
# Other kind of error while formatting
error_message = await response.text()
raise RuntimeError(f"Error while formatting: {error_message}")
else:
# Unexpected response status code
raise RuntimeError(
f"Unexpected response status code: {response.status}"
)
114 changes: 114 additions & 0 deletions tests/test_blackd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
from aiohttp.test_utils import AioHTTPTestCase

import blackd
import blackd.client
except ImportError as e:
raise RuntimeError("Please install Black with the 'd' extra") from e

import black


@pytest.mark.blackd
class BlackDTestCase(AioHTTPTestCase):
Expand Down Expand Up @@ -218,3 +221,114 @@ async def test_single_character(self) -> None:
response = await self.client.post("/", data="1")
self.assertEqual(await response.text(), "1\n")
self.assertEqual(response.status, 200)


@pytest.mark.blackd
class BlackDClientTestCase(AioHTTPTestCase):
def tearDown(self) -> None:
# Work around https://github.com/python/cpython/issues/124706
gc.collect()
super().tearDown()

async def get_application(self) -> web.Application:
return blackd.make_app()

async def test_unformatted_code(self) -> None:
client = blackd.client.BlackDClient(self.client.make_url("/"))
unformatted_code = "def hello(): print('Hello, World!')"
expected = 'def hello():\n print("Hello, World!")\n'
formatted_code = await client.format_code(unformatted_code)

self.assertEqual(formatted_code, expected)

async def test_formatted_code(self) -> None:
client = blackd.client.BlackDClient(self.client.make_url("/"))
initial_code = 'def hello():\n print("Hello, World!")\n'
expected = 'def hello():\n print("Hello, World!")\n'
formatted_code = await client.format_code(initial_code)

self.assertEqual(formatted_code, expected)

async def test_line_length(self) -> None:
client = blackd.client.BlackDClient(self.client.make_url("/"), line_length=10)
unformatted_code = "def hello(): print('Hello, World!')"
expected = 'def hello():\n print(\n "Hello, World!"\n )\n'
formatted_code = await client.format_code(unformatted_code)

self.assertEqual(formatted_code, expected)

async def test_skip_source_first_line(self) -> None:
client = blackd.client.BlackDClient(
self.client.make_url("/"), skip_source_first_line=True
)
invalid_first_line = "Header will be skipped\r\ni = [1,2,3]\nj = [1,2,3]\n"
expected_result = "Header will be skipped\r\ni = [1, 2, 3]\nj = [1, 2, 3]\n"
formatted_code = await client.format_code(invalid_first_line)

self.assertEqual(formatted_code, expected_result)

async def test_skip_string_normalization(self) -> None:
client = blackd.client.BlackDClient(
self.client.make_url("/"), skip_string_normalization=True
)
unformatted_code = "def hello(): print('Hello, World!')"
expected = "def hello():\n print('Hello, World!')\n"
formatted_code = await client.format_code(unformatted_code)

self.assertEqual(formatted_code, expected)

async def test_skip_magic_trailing_comma(self) -> None:
client = blackd.client.BlackDClient(
self.client.make_url("/"), skip_magic_trailing_comma=True
)
unformatted_code = "def hello(): print('Hello, World!')"
expected = 'def hello():\n print("Hello, World!")\n'
formatted_code = await client.format_code(unformatted_code)

self.assertEqual(formatted_code, expected)

async def test_preview(self) -> None:
client = blackd.client.BlackDClient(self.client.make_url("/"), preview=True)
unformatted_code = "def hello(): print('Hello, World!')"
expected = 'def hello():\n print("Hello, World!")\n'
formatted_code = await client.format_code(unformatted_code)

self.assertEqual(formatted_code, expected)

async def test_fast(self) -> None:
client = blackd.client.BlackDClient(self.client.make_url("/"), fast=True)
unformatted_code = "def hello(): print('Hello, World!')"
expected = 'def hello():\n print("Hello, World!")\n'
formatted_code = await client.format_code(unformatted_code)

self.assertEqual(formatted_code, expected)

async def test_python_variant(self) -> None:
client = blackd.client.BlackDClient(
self.client.make_url("/"), python_variant="3.6"
)
unformatted_code = "def hello(): print('Hello, World!')"
expected = 'def hello():\n print("Hello, World!")\n'
formatted_code = await client.format_code(unformatted_code)

self.assertEqual(formatted_code, expected)

async def test_diff(self) -> None:
diff_header = re.compile(
r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d\+\d\d:\d\d"
)

client = blackd.client.BlackDClient(self.client.make_url("/"), diff=True)
source, _ = read_data("miscellaneous", "blackd_diff")
expected, _ = read_data("miscellaneous", "blackd_diff.diff")

diff = await client.format_code(source)
diff = diff_header.sub(DETERMINISTIC_HEADER, diff)

self.assertEqual(diff, expected)

async def test_syntax_error(self) -> None:
client = blackd.client.BlackDClient(self.client.make_url("/"))
with_syntax_error = "def hello(): a 'Hello, World!'"
with self.assertRaises(black.InvalidInput):
_ = await client.format_code(with_syntax_error)