Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SSL by default #67

Merged
merged 3 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

## Upgrading

<!-- Here goes notes on how to upgrade from previous versions, including deprecations and what they should be replaced with -->
- The `parse_grpc_uri` function (and `BaseApiClient` constructor) now enables SSL by default (`ssl=false` should be passed to disable it).
- The `parse_grpc_uri` function now accepts an optional `default_ssl` parameter to set the default value for the `ssl` parameter when not present in the URI.

## New Features

Expand Down
15 changes: 11 additions & 4 deletions src/frequenz/client/base/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@ def _to_bool(value: str) -> bool:


def parse_grpc_uri(
uri: str, channel_type: type[ChannelT], /, *, default_port: int = 9090
uri: str,
channel_type: type[ChannelT],
/,
*,
default_port: int = 9090,
default_ssl: bool = True,
) -> ChannelT:
"""Create a grpclib client channel from a URI.

The URI must have the following format:

```
grpc://hostname[:port][?ssl=false]
grpc://hostname[:port][?ssl=<bool>]
```

A few things to consider about URI components:
Expand All @@ -39,14 +44,15 @@ def parse_grpc_uri(
- If the port is omitted, the `default_port` is used.
- If a query parameter is passed many times, the last value is used.
- The only supported query parameter is `ssl`, which must be a boolean value and
defaults to `false`.
defaults to the `default_ssl` argument if not present.
- Boolean query parameters can be specified with the following values
(case-insensitive): `true`, `1`, `on`, `false`, `0`, `off`.

Args:
uri: The gRPC URI specifying the connection parameters.
channel_type: The type of channel to create.
default_port: The default port number to use if the URI does not specify one.
default_ssl: The default SSL setting to use if the URI does not specify one.

Returns:
A grpclib client channel object.
Expand All @@ -69,7 +75,8 @@ def parse_grpc_uri(
)

options = {k: v[-1] for k, v in parse_qs(parsed_uri.query).items()}
ssl = _to_bool(options.pop("ssl", "false"))
ssl_option = options.pop("ssl", None)
ssl = _to_bool(ssl_option) if ssl_option is not None else default_ssl
if options:
raise ValueError(
f"Unexpected query parameters {options!r} in the URI '{uri}'",
Expand Down
61 changes: 49 additions & 12 deletions tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Test cases for the channel module."""

from dataclasses import dataclass
from typing import NotRequired, TypedDict
from unittest import mock

import pytest
Expand All @@ -12,8 +13,8 @@
from frequenz.client.base.channel import parse_grpc_uri

VALID_URLS = [
("grpc://localhost", "localhost", 9090, False),
("grpc://localhost:1234", "localhost", 1234, False),
("grpc://localhost", "localhost", 9090, True),
("grpc://localhost:1234", "localhost", 1234, True),
("grpc://localhost:1234?ssl=true", "localhost", 1234, True),
("grpc://localhost:1234?ssl=false", "localhost", 1234, False),
("grpc://localhost:1234?ssl=1", "localhost", 1234, True),
Expand All @@ -29,12 +30,25 @@
]


class _CreateChannelKwargs(TypedDict):
default_port: NotRequired[int]
default_ssl: NotRequired[bool]


@pytest.mark.parametrize("uri, host, port, ssl", VALID_URLS)
def test_grpclib_parse_uri_ok(
@pytest.mark.parametrize(
"default_port", [None, 9090, 1234], ids=lambda x: f"default_port={x}"
)
@pytest.mark.parametrize(
"default_ssl", [None, True, False], ids=lambda x: f"default_ssl={x}"
)
def test_grpclib_parse_uri_ok( # pylint: disable=too-many-arguments
uri: str,
host: str,
port: int,
ssl: bool,
default_port: int | None,
default_ssl: bool | None,
) -> None:
"""Test successful parsing of gRPC URIs using grpclib."""

Expand All @@ -44,24 +58,39 @@ class _FakeChannel:
port: int
ssl: bool

kwargs = _CreateChannelKwargs()
if default_port is not None:
kwargs["default_port"] = default_port
if default_ssl is not None:
kwargs["default_ssl"] = default_ssl

expected_port = port if f":{port}" in uri or default_port is None else default_port
expected_ssl = ssl if "ssl" in uri or default_ssl is None else default_ssl

with mock.patch(
"frequenz.client.base.channel._grpchacks.grpclib_create_channel",
return_value=_FakeChannel(host, port, ssl),
):
channel = parse_grpc_uri(uri, _grpchacks.GrpclibChannel)
) as create_channel_mock:
channel = parse_grpc_uri(uri, _grpchacks.GrpclibChannel, **kwargs)

assert isinstance(channel, _FakeChannel)
assert channel.host == host
assert channel.port == port
assert channel.ssl == ssl
create_channel_mock.assert_called_once_with(host, expected_port, expected_ssl)


@pytest.mark.parametrize("uri, host, port, ssl", VALID_URLS)
def test_grpcio_parse_uri_ok(
@pytest.mark.parametrize(
"default_port", [None, 9090, 1234], ids=lambda x: f"default_port={x}"
)
@pytest.mark.parametrize(
"default_ssl", [None, True, False], ids=lambda x: f"default_ssl={x}"
)
def test_grpcio_parse_uri_ok( # pylint: disable=too-many-arguments,too-many-locals
uri: str,
host: str,
port: int,
ssl: bool,
default_port: int | None,
default_ssl: bool | None,
) -> None:
"""Test successful parsing of gRPC URIs using grpcio."""
expected_channel = mock.MagicMock(
Expand All @@ -70,6 +99,14 @@ def test_grpcio_parse_uri_ok(
expected_credentials = mock.MagicMock(
name="mock_credentials", spec=_grpchacks.GrpcioChannel
)
expected_port = port if f":{port}" in uri or default_port is None else default_port
expected_ssl = ssl if "ssl" in uri or default_ssl is None else default_ssl

kwargs = _CreateChannelKwargs()
if default_port is not None:
kwargs["default_port"] = default_port
if default_ssl is not None:
kwargs["default_ssl"] = default_ssl

with (
mock.patch(
Expand All @@ -85,11 +122,11 @@ def test_grpcio_parse_uri_ok(
return_value=expected_credentials,
) as ssl_channel_credentials_mock,
):
channel = parse_grpc_uri(uri, _grpchacks.GrpcioChannel)
channel = parse_grpc_uri(uri, _grpchacks.GrpcioChannel, **kwargs)

assert channel == expected_channel
expected_target = f"{host}:{port}"
if ssl:
expected_target = f"{host}:{expected_port}"
if expected_ssl:
ssl_channel_credentials_mock.assert_called_once_with()
secure_channel_mock.assert_called_once_with(
expected_target, expected_credentials
Expand Down