diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 0cada5f0e5c..1ca100975f2 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -55,7 +55,7 @@ # "db2": ??? # "dremio": ??? "drill": Dialects.DRILL, - # "druid": ??? + "druid": Dialects.DRUID, "duckdb": Dialects.DUCKDB, # "dynamodb": ??? # "elasticsearch": ??? @@ -108,6 +108,150 @@ class LimitMethod(enum.Enum): FETCH_MANY = enum.auto() +class RLSMethod(enum.Enum): + """ + Methods for enforcing RLS. + """ + + AS_PREDICATE = enum.auto() + AS_SUBQUERY = enum.auto() + + +class RLSTransformer: + """ + AST transformer to apply RLS rules. + """ + + def __init__( + self, + catalog: str | None, + schema: str | None, + rules: dict[Table, list[exp.Expression]], + ) -> None: + self.catalog = catalog + self.schema = schema + self.rules = rules + + def get_predicate(self, table_node: exp.Table) -> exp.Expression | None: + """ + Get the combined RLS predicate for a table. + """ + table = Table( + table_node.name, + table_node.db if table_node.db else self.schema, + table_node.catalog if table_node.catalog else self.catalog, + ) + if predicates := self.rules.get(table): + return ( + exp.And( + this=predicates[0], + expressions=predicates[1:], + ) + if len(predicates) > 1 + else predicates[0] + ) + + return None + + +class RLSAsPredicateTransformer(RLSTransformer): + """ + Apply Row Level Security role as a predicate. + + This transformer will apply any RLS predicates to the relevant tables. For example, + given the RLS rule: + + table: some_table + clause: id = 42 + + If a user subject to the rule runs the following query: + + SELECT foo FROM some_table WHERE bar = 'baz' + + The query will be modified to: + + SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42 + + This approach is probably less secure than using subqueries, so it's only used for + databases without support for subqueries. + """ + + def __call__(self, node: exp.Expression) -> exp.Expression: + if not isinstance(node, exp.Table): + return node + + predicate = self.get_predicate(node) + if not predicate: + return node + + # qualify columns with table name + for column in predicate.find_all(exp.Column): + column.set("table", node.alias or node.this) + + if isinstance(node.parent, exp.From): + select = node.parent.parent + if where := select.args.get("where"): + predicate = exp.And( + this=predicate, + expression=exp.Paren(this=where.this), + ) + select.set("where", exp.Where(this=predicate)) + + elif isinstance(node.parent, exp.Join): + join = node.parent + if on := join.args.get("on"): + predicate = exp.And( + this=predicate, + expression=exp.Paren(this=on), + ) + join.set("on", predicate) + + return node + + +class RLSAsSubqueryTransformer(RLSTransformer): + """ + Apply Row Level Security role as a subquery. + + This transformer will apply any RLS predicates to the relevant tables. For example, + given the RLS rule: + + table: some_table + clause: id = 42 + + If a user subject to the rule runs the following query: + + SELECT foo FROM some_table WHERE bar = 'baz' + + The query will be modified to: + + SELECT foo FROM (SELECT * FROM some_table WHERE id = 42) AS some_table + WHERE bar = 'baz' + + This approach is probably more secure than using predicates, but it doesn't work for + all databases. + """ + + def __call__(self, node: exp.Expression) -> exp.Expression: + if not isinstance(node, exp.Table): + return node + + if predicate := self.get_predicate(node): + # use alias or name + alias = node.alias or node.sql() + node.set("alias", None) + node = exp.Subquery( + this=exp.Select( + expressions=[exp.Star()], + where=exp.Where(this=predicate), + **{"from": exp.From(this=node.copy())}, + ), + alias=alias, + ) + + return node + + @dataclass(eq=True, frozen=True) class Table: """ @@ -173,7 +317,7 @@ def __init__( elif statement: self._parsed = self._parse_statement(statement, engine) else: - raise SupersetParseError("Either statement or ast must be provided") + raise ValueError("Either statement or ast must be provided") self.engine = engine self.tables = self._extract_tables_from_statement(self._parsed, self.engine) @@ -293,6 +437,22 @@ def as_cte(self, alias: str = "__cte") -> SQLStatement: """ raise NotImplementedError() + def apply_rls( + self, + catalog: str | None, + schema: str | None, + predicates: dict[Table, list[InternalRepresentation]], + method: RLSMethod, + ) -> None: + """ + Apply relevant RLS rules to the statement inplace. + + :param catalog: The default catalog for non-qualified table names + :param schema: The default schema for non-qualified table names + :param method: The method to use for applying the rules. + """ + raise NotImplementedError() + def __str__(self) -> str: return self.format() @@ -573,6 +733,30 @@ def as_cte(self, alias: str = "__cte") -> SQLStatement: engine=self.engine, ) + def apply_rls( + self, + catalog: str | None, + schema: str | None, + predicates: dict[Table, list[exp.Expression]], + method: RLSMethod, + ) -> None: + """ + Apply relevant RLS rules to the statement inplace. + + :param catalog: The default catalog for non-qualified table names + :param schema: The default schema for non-qualified table names + :param method: The method to use for applying the rules. + """ + transformers = { + RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer, + RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer, + } + if method not in transformers: + raise ValueError(f"Invalid RLS method: {method}") + + transformer = transformers[method](catalog, schema, predicates) + self._parsed = self._parsed.transform(transformer) + class KQLSplitState(enum.Enum): """ @@ -966,7 +1150,7 @@ def extract_tables_from_statement( """ Extract all table references in a single statement. - Please not that this is not trivial; consider the following queries: + Please note that this is not trivial; consider the following queries: DESCRIBE some_table; SHOW PARTITIONS FROM some_table; diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index d750870c987..72907f92c7b 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -18,13 +18,14 @@ import pytest -from sqlglot import Dialects +from sqlglot import Dialects, parse_one from superset.exceptions import SupersetParseError from superset.sql.parse import ( extract_tables_from_statement, KustoKQLStatement, LimitMethod, + RLSMethod, split_kql, SQLGLOT_DIALECTS, SQLScript, @@ -303,11 +304,13 @@ def test_format_no_dialect() -> None: """ assert ( SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "dremio").format() - == """SELECT + == """ +SELECT col FROM t WHERE - NOT col IN (1, 2)""" + NOT col IN (1, 2) + """.strip() ) @@ -1118,7 +1121,8 @@ def test_optimize() -> None: WHERE anon_1.a > 1 AND anon_1.b = 2 """ - optimized = """SELECT + optimized = """ +SELECT anon_1.a, anon_1.b FROM ( @@ -1131,9 +1135,11 @@ def test_optimize() -> None: some_table.a > 1 AND some_table.b = 2 ) AS anon_1 WHERE - TRUE AND TRUE""" + TRUE AND TRUE + """.strip() - not_optimized = """SELECT + not_optimized = """ +SELECT anon_1.a, anon_1.b FROM ( @@ -1144,7 +1150,8 @@ def test_optimize() -> None: FROM some_table ) AS anon_1 WHERE - anon_1.a > 1 AND anon_1.b = 2""" + anon_1.a > 1 AND anon_1.b = 2 + """.strip() assert SQLStatement(sql, "sqlite").optimize().format() == optimized assert SQLStatement(sql, "dremio").optimize().format() == not_optimized @@ -1195,9 +1202,11 @@ def test_firebolt_old() -> None: sql = "SELECT * FROM t1 UNNEST(col1 AS foo)" assert ( SQLStatement(sql, "firebolt").format() - == """SELECT + == """ +SELECT * -FROM t1 UNNEST(col1 AS foo)""" +FROM t1 UNNEST(col1 AS foo) + """.strip() ) @@ -1216,9 +1225,11 @@ def test_firebolt_old_escape_string() -> None: # but they normalize to '' assert ( SQLStatement(sql, "firebolt").format() - == """SELECT + == """ +SELECT 'foo''bar', - 'foo''bar'""" + 'foo''bar' + """.strip() ) @@ -1410,7 +1421,8 @@ def test_get_kql_limit_value(kql: str, expected: str) -> None: "mssql", 1000, LimitMethod.FORCE_LIMIT, - """WITH abc AS ( + """ +WITH abc AS ( SELECT * FROM test @@ -1422,7 +1434,8 @@ def test_get_kql_limit_value(kql: str, expected: str) -> None: SELECT TOP 1000 * -FROM currency""", +FROM currency + """.strip(), ), ( "SELECT DISTINCT x from tbl", @@ -1457,10 +1470,12 @@ def test_get_kql_limit_value(kql: str, expected: str) -> None: "postgresql", 1000, LimitMethod.FORCE_LIMIT, - """SELECT + """ +SELECT * FROM birth_names /* SOME COMMENT WITH LIMIT 555 */ -LIMIT 1000""", +LIMIT 1000 + """.strip(), ), ( "SELECT * FROM birth_names LIMIT 555", @@ -1602,7 +1617,8 @@ def test_has_cte(sql: str, engine: str, expected: bool) -> None: SELECT * FROM currency_2 """, "postgresql", - """WITH currency AS ( + """ +WITH currency AS ( SELECT 'INR' AS cur ), currency_2 AS ( @@ -1616,7 +1632,8 @@ def test_has_cte(sql: str, engine: str, expected: bool) -> None: SELECT * FROM currency_2 -)""", +) + """.strip(), ), ], ) @@ -1625,3 +1642,608 @@ def test_as_cte(sql: str, engine: str, expected: str) -> None: Test that we can covert select to CTE. """ assert SQLStatement(sql, engine).as_cte().format() == expected + + +@pytest.mark.parametrize( + "sql, rules, expected", + [ + ( + "SELECT t.foo FROM some_table AS t", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM ( + SELECT + * + FROM some_table + WHERE + id = 42 +) AS t + """.strip(), + ), + ( + "SELECT t.foo FROM some_table AS t WHERE bar = 'baz'", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM ( + SELECT + * + FROM some_table + WHERE + id = 42 +) AS t +WHERE + bar = 'baz' + """.strip(), + ), + ( + "SELECT t.foo FROM schema1.some_table AS t", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM ( + SELECT + * + FROM schema1.some_table + WHERE + id = 42 +) AS t + """.strip(), + ), + ( + "SELECT t.foo FROM schema1.some_table AS t", + {Table("some_table", "schema2"): "id = 42"}, + "SELECT\n t.foo\nFROM schema1.some_table AS t", + ), + ( + "SELECT t.foo FROM catalog1.schema1.some_table AS t", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM ( + SELECT + * + FROM catalog1.schema1.some_table + WHERE + id = 42 +) AS t + """.strip(), + ), + ( + "SELECT t.foo FROM catalog1.schema1.some_table AS t", + {Table("some_table", "schema1", "catalog2"): "id = 42"}, + "SELECT\n t.foo\nFROM catalog1.schema1.some_table AS t", + ), + ( + "SELECT * FROM some_table WHERE 1=1", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM ( + SELECT + * + FROM some_table + WHERE + id = 42 +) AS some_table +WHERE + 1 = 1 + """.strip(), + ), + ( + "SELECT * FROM table WHERE 1=1", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM ( + SELECT + * + FROM table + WHERE + id = 42 +) AS table +WHERE + 1 = 1 + """.strip(), + ), + ( + 'SELECT * FROM "table" WHERE 1=1', + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM ( + SELECT + * + FROM "table" + WHERE + id = 42 +) AS "table" +WHERE + 1 = 1 + """.strip(), + ), + ( + "SELECT * FROM table WHERE 1=1", + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + 1 = 1 + """.strip(), + ), + ( + "SELECT * FROM other_table WHERE 1=1", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM other_table +WHERE + 1 = 1 + """.strip(), + ), + ( + "SELECT * FROM table JOIN other_table ON table.id = other_table.id", + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +JOIN ( + SELECT + * + FROM other_table + WHERE + id = 42 +) AS other_table + ON table.id = other_table.id + """.strip(), + ), + ( + 'SELECT * FROM "table" JOIN other_table ON "table".id = other_table.id', + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM ( + SELECT + * + FROM "table" + WHERE + id = 42 +) AS "table" +JOIN other_table + ON "table".id = other_table.id + """.strip(), + ), + ( + "SELECT * FROM (SELECT * FROM some_table)", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM ( + SELECT + * + FROM ( + SELECT + * + FROM some_table + WHERE + id = 42 + ) AS some_table +) + """.strip(), + ), + ( + "SELECT * FROM table UNION ALL SELECT * FROM other_table", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM ( + SELECT + * + FROM table + WHERE + id = 42 +) AS table +UNION ALL +SELECT + * +FROM other_table + """.strip(), + ), + ( + "SELECT * FROM table UNION ALL SELECT * FROM other_table", + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +UNION ALL +SELECT + * +FROM ( + SELECT + * + FROM other_table + WHERE + id = 42 +) AS other_table + """.strip(), + ), + ( + "SELECT a.*, b.* FROM tbl_a AS a INNER JOIN tbl_b AS b ON a.col = b.col", + {Table("tbl_a", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + a.*, + b.* +FROM ( + SELECT + * + FROM tbl_a + WHERE + id = 42 +) AS a +INNER JOIN tbl_b AS b + ON a.col = b.col + """.strip(), + ), + ( + "SELECT a.*, b.* FROM tbl_a a INNER JOIN tbl_b b ON a.col = b.col", + {Table("tbl_a", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + a.*, + b.* +FROM ( + SELECT + * + FROM tbl_a + WHERE + id = 42 +) AS a +INNER JOIN tbl_b AS b + ON a.col = b.col + """.strip(), + ), + ], +) +def test_rls_subquery_transformer( + sql: str, + rules: dict[Table, str], + expected: str, +) -> None: + """ + Test `RLSAsSubqueryTransformer`. + """ + statement = SQLStatement(sql) + statement.apply_rls( + "catalog1", + "schema1", + {k: [parse_one(v)] for k, v in rules.items()}, + RLSMethod.AS_SUBQUERY, + ) + assert statement.format() == expected + + +@pytest.mark.parametrize( + "sql, rules, expected", + [ + ( + "SELECT t.foo FROM some_table AS t", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM some_table AS t +WHERE + t.id = 42 + """.strip(), + ), + ( + "SELECT t.foo FROM schema2.some_table AS t", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM schema2.some_table AS t + """.strip(), + ), + ( + "SELECT t.foo FROM catalog2.schema1.some_table AS t", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM catalog2.schema1.some_table AS t + """.strip(), + ), + ( + "SELECT t.foo FROM some_table AS t WHERE bar = 'baz'", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM some_table AS t +WHERE + t.id = 42 AND ( + bar = 'baz' + ) + """.strip(), + ), + ( + "SELECT t.foo FROM some_table AS t WHERE bar = 'baz' OR foo = 'qux'", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM some_table AS t +WHERE + t.id = 42 AND ( + bar = 'baz' OR foo = 'qux' + ) + """.strip(), + ), + ( + "SELECT * FROM some_table WHERE 1=1", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM some_table +WHERE + some_table.id = 42 AND ( + 1 = 1 + ) + """.strip(), + ), + ( + "SELECT * FROM some_table WHERE TRUE OR FALSE", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM some_table +WHERE + some_table.id = 42 AND ( + TRUE OR FALSE + ) + """.strip(), + ), + ( + "SELECT * FROM table WHERE 1=1", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + table.id = 42 AND ( + 1 = 1 + ) + """.strip(), + ), + ( + 'SELECT * FROM "table" WHERE 1=1', + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM "table" +WHERE + "table".id = 42 AND ( + 1 = 1 + ) + """.strip(), + ), + ( + "SELECT * FROM table WHERE 1=1", + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + 1 = 1 + """.strip(), + ), + ( + "SELECT * FROM other_table WHERE 1=1", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM other_table +WHERE + 1 = 1 + """.strip(), + ), + ( + "SELECT * FROM table", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + table.id = 42 + """.strip(), + ), + ( + "SELECT * FROM some_table", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM some_table +WHERE + some_table.id = 42 + """.strip(), + ), + ( + "SELECT * FROM table ORDER BY id", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + table.id = 42 +ORDER BY + id + """.strip(), + ), + ( + "SELECT * FROM table WHERE 1=1 AND table.id=42", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + table.id = 42 AND ( + 1 = 1 AND table.id = 42 + ) + """.strip(), + ), + ( + """ +SELECT * FROM table +JOIN other_table +ON table.id = other_table.id +AND other_table.id=42 + """, + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +JOIN other_table + ON other_table.id = 42 AND ( + table.id = other_table.id AND other_table.id = 42 + ) + """.strip(), + ), + ( + "SELECT * FROM table WHERE 1=1 AND id=42", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + table.id = 42 AND ( + 1 = 1 AND id = 42 + ) + """.strip(), + ), + ( + "SELECT * FROM table JOIN other_table ON table.id = other_table.id", + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +JOIN other_table + ON other_table.id = 42 AND ( + table.id = other_table.id + ) + """.strip(), + ), + ( + """ +SELECT * +FROM table +JOIN other_table +ON table.id = other_table.id +WHERE 1=1 + """, + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +JOIN other_table + ON other_table.id = 42 AND ( + table.id = other_table.id + ) +WHERE + 1 = 1 + """.strip(), + ), + ( + "SELECT * FROM (SELECT * FROM other_table)", + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM ( + SELECT + * + FROM other_table + WHERE + other_table.id = 42 +) + """.strip(), + ), + ( + "SELECT * FROM table UNION ALL SELECT * FROM other_table", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + table.id = 42 +UNION ALL +SELECT + * +FROM other_table + """.strip(), + ), + ( + "SELECT * FROM table UNION ALL SELECT * FROM other_table", + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +UNION ALL +SELECT + * +FROM other_table +WHERE + other_table.id = 42 + """.strip(), + ), + ], +) +def test_rls_predicate_transformer( + sql: str, + rules: dict[Table, str], + expected: str, +) -> None: + """ + Test `RLSPredicateTransformer`. + """ + statement = SQLStatement(sql) + statement.apply_rls( + "catalog1", + "schema1", + {k: [parse_one(v)] for k, v in rules.items()}, + RLSMethod.AS_PREDICATE, + ) + assert statement.format() == expected