Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 187 additions & 3 deletions superset/sql/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
# "db2": ???
# "dremio": ???
"drill": Dialects.DRILL,
# "druid": ???
"druid": Dialects.DRUID,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this to sqlglot a while ago.

"duckdb": Dialects.DUCKDB,
# "dynamodb": ???
# "elasticsearch": ???
Expand Down Expand Up @@ -108,6 +108,150 @@ class LimitMethod(enum.Enum):
FETCH_MANY = enum.auto()


class RLSMethod(enum.Enum):
"""
Methods for enforcing RLS.
"""

AS_PREDICATE = enum.auto()
AS_SUBQUERY = enum.auto()


class RLSTransformer:
"""
AST transformer to apply RLS rules.
"""

def __init__(
self,
catalog: str | None,
schema: str | None,
rules: dict[Table, list[exp.Expression]],
) -> None:
self.catalog = catalog
self.schema = schema
self.rules = rules

def get_predicate(self, table_node: exp.Table) -> exp.Expression | None:
"""
Get the combined RLS predicate for a table.
"""
table = Table(
table_node.name,
table_node.db if table_node.db else self.schema,
table_node.catalog if table_node.catalog else self.catalog,
)
if predicates := self.rules.get(table):
return (
exp.And(
this=predicates[0],
expressions=predicates[1:],
)
if len(predicates) > 1
else predicates[0]
)

return None


class RLSAsPredicateTransformer(RLSTransformer):
"""
Apply Row Level Security role as a predicate.

This transformer will apply any RLS predicates to the relevant tables. For example,
given the RLS rule:

table: some_table
clause: id = 42

If a user subject to the rule runs the following query:

SELECT foo FROM some_table WHERE bar = 'baz'

The query will be modified to:

SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42

This approach is probably less secure than using subqueries, so it's only used for
databases without support for subqueries.
"""

def __call__(self, node: exp.Expression) -> exp.Expression:
if not isinstance(node, exp.Table):
return node

predicate = self.get_predicate(node)
if not predicate:
return node

# qualify columns with table name
for column in predicate.find_all(exp.Column):
column.set("table", node.alias or node.this)

if isinstance(node.parent, exp.From):
select = node.parent.parent
if where := select.args.get("where"):
predicate = exp.And(
this=predicate,
expression=exp.Paren(this=where.this),
)
select.set("where", exp.Where(this=predicate))

elif isinstance(node.parent, exp.Join):
join = node.parent
if on := join.args.get("on"):
predicate = exp.And(
this=predicate,
expression=exp.Paren(this=on),
)
join.set("on", predicate)

return node


class RLSAsSubqueryTransformer(RLSTransformer):
"""
Apply Row Level Security role as a subquery.

This transformer will apply any RLS predicates to the relevant tables. For example,
given the RLS rule:

table: some_table
clause: id = 42

If a user subject to the rule runs the following query:

SELECT foo FROM some_table WHERE bar = 'baz'

The query will be modified to:

SELECT foo FROM (SELECT * FROM some_table WHERE id = 42) AS some_table
WHERE bar = 'baz'

This approach is probably more secure than using predicates, but it doesn't work for
all databases.
"""

def __call__(self, node: exp.Expression) -> exp.Expression:
if not isinstance(node, exp.Table):
return node

if predicate := self.get_predicate(node):
# use alias or name
alias = node.alias or node.sql()
node.set("alias", None)
node = exp.Subquery(
this=exp.Select(
expressions=[exp.Star()],
where=exp.Where(this=predicate),
**{"from": exp.From(this=node.copy())},
),
alias=alias,
)

return node


@dataclass(eq=True, frozen=True)
class Table:
"""
Expand Down Expand Up @@ -173,7 +317,7 @@ def __init__(
elif statement:
self._parsed = self._parse_statement(statement, engine)
else:
raise SupersetParseError("Either statement or ast must be provided")
raise ValueError("Either statement or ast must be provided")

self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
Expand Down Expand Up @@ -293,6 +437,22 @@ def as_cte(self, alias: str = "__cte") -> SQLStatement:
"""
raise NotImplementedError()

def apply_rls(
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[InternalRepresentation]],
method: RLSMethod,
) -> None:
"""
Apply relevant RLS rules to the statement inplace.

:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:param method: The method to use for applying the rules.
"""
raise NotImplementedError()

def __str__(self) -> str:
return self.format()

Expand Down Expand Up @@ -573,6 +733,30 @@ def as_cte(self, alias: str = "__cte") -> SQLStatement:
engine=self.engine,
)

def apply_rls(
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[exp.Expression]],
method: RLSMethod,
) -> None:
"""
Apply relevant RLS rules to the statement inplace.

:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:param method: The method to use for applying the rules.
"""
transformers = {
RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer,
RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer,
}
if method not in transformers:
raise ValueError(f"Invalid RLS method: {method}")

transformer = transformers[method](catalog, schema, predicates)
self._parsed = self._parsed.transform(transformer)


class KQLSplitState(enum.Enum):
"""
Expand Down Expand Up @@ -966,7 +1150,7 @@ def extract_tables_from_statement(
"""
Extract all table references in a single statement.

Please not that this is not trivial; consider the following queries:
Please note that this is not trivial; consider the following queries:

DESCRIBE some_table;
SHOW PARTITIONS FROM some_table;
Expand Down
Loading
Loading