From 4990ce190f89276d00e2af2926d41a671a916616 Mon Sep 17 00:00:00 2001 From: Serge Smertin <259697+nfx@users.noreply.github.com> Date: Tue, 2 Jul 2024 14:59:24 +0200 Subject: [PATCH] Ensure propagation of `lsql` version into `User-Agent` header when it is used as library (#206) This PR ensures correct library attribution. --- pyproject.toml | 2 +- src/databricks/labs/lsql/__init__.py | 8 +++- src/databricks/labs/lsql/backends.py | 8 ++-- tests/unit/test_useragent.py | 59 ++++++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 6 deletions(-) create mode 100644 tests/unit/test_useragent.py diff --git a/pyproject.toml b/pyproject.toml index e9076d92..5a87c27a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ ] dependencies = [ "databricks-labs-blueprint[yaml]>=0.4.2", - "databricks-sdk>=0.22.0", + "databricks-sdk>=0.29.0", "sqlglot>=22.3.1" ] diff --git a/src/databricks/labs/lsql/__init__.py b/src/databricks/labs/lsql/__init__.py index a598eaf1..36d0eed7 100644 --- a/src/databricks/labs/lsql/__init__.py +++ b/src/databricks/labs/lsql/__init__.py @@ -1,3 +1,9 @@ -from databricks.labs.lsql.core import Row +from databricks.sdk.core import with_user_agent_extra + +from .__about__ import __version__ +from .core import Row __all__ = ["Row"] + + +with_user_agent_extra("lsql", __version__) diff --git a/src/databricks/labs/lsql/backends.py b/src/databricks/labs/lsql/backends.py index eb2897a1..3c230e80 100644 --- a/src/databricks/labs/lsql/backends.py +++ b/src/databricks/labs/lsql/backends.py @@ -171,12 +171,12 @@ def _row_to_sql(row: DataclassInstance, fields: tuple[dataclasses.Field[Any], .. field_type = field_type.__args__[0] if value is None: data.append("NULL") - elif field_type == bool: + elif field_type is bool: data.append("TRUE" if value else "FALSE") - elif field_type == str: + elif field_type is str: value = str(value).replace("'", "''") data.append(f"'{value}'") - elif field_type == int: + elif field_type is int: data.append(f"{value}") else: msg = f"unknown type: {field_type}" @@ -336,7 +336,7 @@ def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: D rows = self._filter_none_rows(rows, klass) if mode == "overwrite": self._save_table = [] - if klass.__class__ == type: + if klass.__class__ == type: # noqa: E721 row_factory = self._row_factory(klass) rows = [row_factory(*dataclasses.astuple(r)) for r in rows] self._save_table.append((full_name, rows, mode)) diff --git a/tests/unit/test_useragent.py b/tests/unit/test_useragent.py new file mode 100644 index 00000000..2ec33da3 --- /dev/null +++ b/tests/unit/test_useragent.py @@ -0,0 +1,59 @@ +import contextlib +import functools +import typing +from http.server import BaseHTTPRequestHandler + +from databricks.sdk import WorkspaceClient + +from databricks.labs.lsql.__about__ import __version__ +from databricks.labs.lsql.dashboards import Dashboards + + +@contextlib.contextmanager +def http_fixture_server(handler: typing.Callable[[BaseHTTPRequestHandler], None]): + from http.server import HTTPServer + from threading import Thread + + class _handler(BaseHTTPRequestHandler): + def __init__(self, handler: typing.Callable[[BaseHTTPRequestHandler], None], *args): + self._handler = handler + super().__init__(*args) + + def __getattr__(self, item): + if "do_" != item[0:3]: + raise AttributeError(f"method {item} not found") + return functools.partial(self._handler, self) + + handler_factory = functools.partial(_handler, handler) + srv = HTTPServer(("localhost", 0), handler_factory) + t = Thread(target=srv.serve_forever) + try: + t.daemon = True + t.start() + yield "http://{0}:{1}".format(*srv.server_address) + finally: + srv.shutdown() + + +def test_user_agent_is_propagated(): + user_agent = {} + + def inner(h: BaseHTTPRequestHandler): + for pair in h.headers["User-Agent"].split(" "): + if "/" not in pair: + continue + k, v = pair.split("/") + user_agent[k] = v + h.send_response(200) + h.send_header("Content-Type", "application/json") + h.end_headers() + h.wfile.write(b"{}") + h.wfile.flush() + + with http_fixture_server(inner) as host: + ws = WorkspaceClient(host=host, token="_") + d = Dashboards(ws) + d.get_dashboard("...") + + assert "lsql" in user_agent + assert user_agent["lsql"] == __version__