Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Set param from config #553

Merged
merged 14 commits into from
Apr 7, 2024
Merged
66 changes: 56 additions & 10 deletions sqllineage/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import os
import threading
from typing import Any, Dict, Set

from sqllineage.exceptions import ConfigException


class _SQLLineageConfigLoader:
Expand All @@ -17,10 +21,18 @@ class _SQLLineageConfigLoader:
# lateral column alias reference supported by some dialect (redshift, spark 3.4+, etc)
"LATERAL_COLUMN_ALIAS_REFERENCE": (bool, False),
}
BOOLEAN_TRUE_STRINGS = ("true", "on", "ok", "y", "yes", "1")

def __getattr__(self, item):
if item in self.config:
def __init__(self) -> None:
self._thread_config: Dict[int, Dict[str, Any]] = {}
self._thread_in_context_manager: Set[int] = set()

def __getattr__(self, item: str):
if item in self.config.keys():
if (
value := self._thread_config.get(self.get_ident(), {}).get(item)
) is not None:
return value

type_, default = self.config[item]
# require SQLLINEAGE_ prefix from environment variable
return self.parse_value(
Expand All @@ -29,23 +41,57 @@ def __getattr__(self, item):
else:
return super().__getattribute__(item)

@classmethod
def parse_value(cls, value, cast):
"""Parse and cast provided value
def __setattr__(self, key, value) -> None:
if key in self.config:
raise ConfigException(
"SQLLineageConfig is read-only. Use context manager to update thread level config."
)
else:
super().__setattr__(key, value)

def __call__(self, *args, **kwargs):
if self.get_ident() not in self._thread_config.keys():
self._thread_config[self.get_ident()] = {}
for key, value in kwargs.items():
if key in self.config.keys():
self._thread_config[self.get_ident()][key] = self.parse_value(
value, self.config[key][0]
)
else:
raise ConfigException(f"Invalid config key: {key}")
return self

def __enter__(self):
if (thread_id := self.get_ident()) not in self._thread_in_context_manager:
self._thread_in_context_manager.add(thread_id)
else:
raise ConfigException("SQLLineageConfig context manager is not reentrant")

def __exit__(self, exc_type, exc_val, exc_tb):
thread_id = self.get_ident()
if thread_id in self._thread_config:
self._thread_config.pop(self.get_ident())
if thread_id in self._thread_in_context_manager:
self._thread_in_context_manager.remove(thread_id)

@staticmethod
def get_ident() -> int:
return threading.get_ident()

@staticmethod
def parse_value(value, cast) -> Any:
"""Parse and cast provided value
:param value: Stringed value.
:param cast: Type to cast return value as.

:returns: Casted value
:returns: cast value
"""
if cast is bool:
try:
value = int(value) != 0
except ValueError:
value = value.lower().strip() in cls.BOOLEAN_TRUE_STRINGS
value = value.lower().strip() in ("true", "on", "ok", "y", "yes", "1")
else:
value = cast(value)

return value


Expand Down
4 changes: 4 additions & 0 deletions sqllineage/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ class InvalidSyntaxException(SQLLineageException):

class MetaDataProviderException(SQLLineageException):
"""Raised for MetaDataProvider errors"""


class ConfigException(SQLLineageException):
"""Raised for configuration errors"""
56 changes: 56 additions & 0 deletions tests/core/test_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import concurrent.futures
import os
import random
import time
from unittest.mock import patch

import pytest

from sqllineage.config import SQLLineageConfig
from sqllineage.exceptions import ConfigException


@patch(
Expand All @@ -21,3 +27,53 @@ def test_config():

assert type(SQLLineageConfig.TSQL_NO_SEMICOLON) is bool
assert SQLLineageConfig.TSQL_NO_SEMICOLON is True


def test_disable_direct_update_config():
with pytest.raises(ConfigException):
SQLLineageConfig.DEFAULT_SCHEMA = "ods"


def test_update_config_using_context_manager():
with SQLLineageConfig(LATERAL_COLUMN_ALIAS_REFERENCE=True):
assert SQLLineageConfig.LATERAL_COLUMN_ALIAS_REFERENCE is True
assert SQLLineageConfig.LATERAL_COLUMN_ALIAS_REFERENCE is False

with SQLLineageConfig(DEFAULT_SCHEMA="ods"):
assert SQLLineageConfig.DEFAULT_SCHEMA == "ods"
assert SQLLineageConfig.DEFAULT_SCHEMA == ""

with SQLLineageConfig(DIRECTORY=""):
assert SQLLineageConfig.DIRECTORY == ""
assert SQLLineageConfig.DIRECTORY != ""


def test_update_config_context_manager_non_reentrant():
with pytest.raises(ConfigException):
with SQLLineageConfig(DEFAULT_SCHEMA="ods"):
with SQLLineageConfig(DEFAULT_SCHEMA="dwd"):
pass


def test_disable_update_unknown_config():
with pytest.raises(ConfigException):
with SQLLineageConfig(UNKNOWN_KEY="value"):
pass


def _check_schema(schema: str):
# used by test_config_parallel, must be a global function so that it can be pickled between processes
with SQLLineageConfig(DEFAULT_SCHEMA=schema):
# randomly sleep [0, 0.1) second to simulate real parsing scenario
time.sleep(random.random() * 0.1)
return SQLLineageConfig.DEFAULT_SCHEMA


@pytest.mark.parametrize("pool", ["ThreadPoolExecutor", "ProcessPoolExecutor"])
def test_config_parallel(pool: str):
executor_class = getattr(concurrent.futures, pool)
schemas = [f"db{i}" for i in range(100)]
with executor_class() as executor:
futures = [executor.submit(_check_schema, schema) for schema in schemas]
for i, future in enumerate(futures):
assert future.result() == schemas[i]
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ commands =
[flake8]
exclude = .tox,.git,__pycache__,build,sqllineagejs,venv,env
max-line-length = 120
# ignore = D100,D101
ignore = A005,W503
show-source = true
enable-extensions=G
application-import-names = sqllineage
Expand Down
Loading