Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 33 additions & 69 deletions superset/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def make_shell_context():
@app.cli.command()
def init():
"""Inits the Superset application"""
utils.get_or_create_main_db()
utils.get_example_database()
appbuilder.add_permissions(update_perms=True)
security_manager.sync_role_definitions()
Expand Down Expand Up @@ -430,75 +429,40 @@ def load_test_users_run():
Syncs permissions for those users/roles
"""
if config.get("TESTING"):
security_manager.sync_role_definitions()
gamma_sqllab_role = security_manager.add_role("gamma_sqllab")
for perm in security_manager.find_role("Gamma").permissions:
security_manager.add_permission_role(gamma_sqllab_role, perm)
utils.get_or_create_main_db()
db_perm = utils.get_main_database().perm
security_manager.add_permission_view_menu("database_access", db_perm)
db_pvm = security_manager.find_permission_view_menu(
view_menu_name=db_perm, permission_name="database_access"

sm = security_manager

examples_db = utils.get_example_database()

examples_pv = sm.add_permission_view_menu("database_access", examples_db.perm)

sm.sync_role_definitions()
gamma_sqllab_role = sm.add_role("gamma_sqllab")
sm.add_permission_role(gamma_sqllab_role, examples_pv)

for role in ["Gamma", "sql_lab"]:
for perm in sm.find_role(role).permissions:
sm.add_permission_role(gamma_sqllab_role, perm)

users = (
("admin", "Admin"),
("gamma", "Gamma"),
("gamma2", "Gamma"),
("gamma_sqllab", "gamma_sqllab"),
("alpha", "Alpha"),
)
gamma_sqllab_role.permissions.append(db_pvm)
for perm in security_manager.find_role("sql_lab").permissions:
security_manager.add_permission_role(gamma_sqllab_role, perm)

admin = security_manager.find_user("admin")
if not admin:
security_manager.add_user(
"admin",
"admin",
" user",
"[email protected]",
security_manager.find_role("Admin"),
password="general",
)

gamma = security_manager.find_user("gamma")
if not gamma:
security_manager.add_user(
"gamma",
"gamma",
"user",
"[email protected]",
security_manager.find_role("Gamma"),
password="general",
)

gamma2 = security_manager.find_user("gamma2")
if not gamma2:
security_manager.add_user(
"gamma2",
"gamma2",
"user",
"[email protected]",
security_manager.find_role("Gamma"),
password="general",
)

gamma_sqllab_user = security_manager.find_user("gamma_sqllab")
if not gamma_sqllab_user:
security_manager.add_user(
"gamma_sqllab",
"gamma_sqllab",
"user",
"[email protected]",
gamma_sqllab_role,
password="general",
)

alpha = security_manager.find_user("alpha")
if not alpha:
security_manager.add_user(
"alpha",
"alpha",
"user",
"[email protected]",
security_manager.find_role("Alpha"),
password="general",
)
security_manager.get_session.commit()
for username, role in users:
user = sm.find_user(username)
if not user:
sm.add_user(
username,
username,
"user",
username + "@fab.org",
sm.find_role(role),
password="general",
)
sm.get_session.commit()


@app.cli.command()
Expand Down
7 changes: 3 additions & 4 deletions superset/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def database_access(self, database: "Database") -> bool:
:param database: The Superset database
:returns: Whether the user can access the Superset database
"""

return (
self.all_datasource_access()
or self.all_database_access()
Expand Down Expand Up @@ -269,9 +268,9 @@ def get_table_access_error_msg(self, tables: List[str]) -> str:
:param tables: The list of denied SQL table names
:returns: The error message
"""

return f"""You need access to the following tables: {", ".join(tables)}, all
database access or `all_datasource_access` permission"""
quoted_tables = [f"`{t}`" for t in tables]
return f"""You need access to the following tables: {", ".join(quoted_tables)},
`all_database_access` or `all_datasource_access` permission"""

def get_table_access_link(self, tables: List[str]) -> Optional[str]:
"""
Expand Down
10 changes: 0 additions & 10 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,10 +936,6 @@ def user_label(user: User) -> Optional[str]:
return None


def get_or_create_main_db():
get_main_database()


def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
from superset import db
from superset.models import core as models
Expand All @@ -957,12 +953,6 @@ def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
return database


def get_main_database():
from superset import conf

return get_or_create_db("main", conf.get("SQLALCHEMY_DATABASE_URI"))


def get_example_database():
from superset import conf

Expand Down
11 changes: 2 additions & 9 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2446,10 +2446,7 @@ def results(self, key):
)
if rejected_tables:
return json_error_response(
security_manager.get_table_access_error_msg(
"{}".format(rejected_tables)
),
status=403,
security_manager.get_table_access_error_msg(rejected_tables), status=403
)

payload = utils.zlib_decompress(blob, decode=not results_backend_use_msgpack)
Expand Down Expand Up @@ -2691,11 +2688,7 @@ def csv(self, client_id):
query.sql, query.database, query.schema
)
if rejected_tables:
flash(
security_manager.get_table_access_error_msg(
"{}".format(rejected_tables)
)
)
flash(security_manager.get_table_access_error_msg(rejected_tables))
return redirect("/")
blob = None
if results_backend and query.results_key:
Expand Down
41 changes: 36 additions & 5 deletions tests/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models
from superset.models.core import Database
from superset.utils.core import get_main_database
from superset.utils.core import get_example_database

BASE_DIR = app.config.get("BASE_DIR")

Expand Down Expand Up @@ -168,18 +168,25 @@ def revoke_public_access_to_table(self, table):
):
security_manager.del_permission_role(public_role, perm)

def _get_database_by_name(self, database_name="main"):
if database_name == "examples":
return get_example_database()
else:
raise ValueError("Database doesn't exist")

def run_sql(
self,
sql,
client_id=None,
user_name=None,
raise_on_error=False,
query_limit=None,
database_name="examples",
):
if user_name:
self.logout()
self.login(username=(user_name if user_name else "admin"))
dbid = get_main_database().id
self.login(username=(user_name or "admin"))
dbid = self._get_database_by_name(database_name).id
resp = self.get_json_resp(
"/superset/sql_json/",
raise_on_error=False,
Expand All @@ -195,11 +202,35 @@ def run_sql(
raise Exception("run_sql failed")
return resp

def validate_sql(self, sql, client_id=None, user_name=None, raise_on_error=False):
def create_fake_db(self):
self.login(username="admin")
database_name = "fake_db_100"
db_id = 100
extra = """{
"schemas_allowed_for_csv_upload":
["this_schema_is_allowed", "this_schema_is_allowed_too"]
}"""

return self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
session=db.session,
id=db_id,
extra=extra,
)

def validate_sql(
self,
sql,
client_id=None,
user_name=None,
raise_on_error=False,
database_name="examples",
):
if user_name:
self.logout()
self.login(username=(user_name if user_name else "admin"))
dbid = get_main_database().id
dbid = self._get_database_by_name(database_name).id
resp = self.get_json_resp(
"/superset/validate_sql_json/",
raise_on_error=False,
Expand Down
39 changes: 20 additions & 19 deletions tests/celery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from superset.models.helpers import QueryStatus
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery
from superset.utils.core import get_main_database
from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase


Expand Down Expand Up @@ -132,20 +132,20 @@ def run_sql(
return json.loads(resp.data)

def test_run_sync_query_dont_exist(self):
main_db = get_main_database()
main_db = get_example_database()
db_id = main_db.id
sql_dont_exist = "SELECT name FROM table_dont_exist"
result1 = self.run_sql(db_id, sql_dont_exist, "1", cta="true")
self.assertTrue("error" in result1)

def test_run_sync_query_cta(self):
main_db = get_main_database()
main_db = get_example_database()
backend = main_db.backend
db_id = main_db.id
tmp_table_name = "tmp_async_22"
self.drop_table_if_exists(tmp_table_name, main_db)
perm_name = "can_sql_json"
sql_where = "SELECT name FROM ab_permission WHERE name='{}'".format(perm_name)
name = "James"
sql_where = f"SELECT name FROM birth_names WHERE name='{name}' LIMIT 1"
result = self.run_sql(
db_id, sql_where, "2", tmp_table=tmp_table_name, cta="true"
)
Expand All @@ -162,9 +162,9 @@ def test_run_sync_query_cta(self):
self.assertGreater(len(results["data"]), 0)

def test_run_sync_query_cta_no_data(self):
main_db = get_main_database()
main_db = get_example_database()
db_id = main_db.id
sql_empty_result = "SELECT * FROM ab_user WHERE id=666"
sql_empty_result = "SELECT * FROM birth_names WHERE name='random'"
result3 = self.run_sql(db_id, sql_empty_result, "3")
self.assertEqual(QueryStatus.SUCCESS, result3["query"]["state"])
self.assertEqual([], result3["data"])
Expand All @@ -183,12 +183,12 @@ def drop_table_if_exists(self, table_name, database=None):
return self.run_sql(db_id, sql)

def test_run_async_query(self):
main_db = get_main_database()
main_db = get_example_database()
db_id = main_db.id

self.drop_table_if_exists("tmp_async_1", main_db)

sql_where = "SELECT name FROM ab_role WHERE name='Admin'"
sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
result = self.run_sql(
db_id, sql_where, "4", async_="true", tmp_table="tmp_async_1", cta="true"
)
Expand All @@ -202,12 +202,13 @@ def test_run_async_query(self):

query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status)

self.assertTrue("FROM tmp_async_1" in query.select_sql)
self.assertEqual(
"CREATE TABLE tmp_async_1 AS \n"
"SELECT name FROM ab_role "
"WHERE name='Admin'\n"
"LIMIT 666",
"SELECT name FROM birth_names "
"WHERE name='James' "
"LIMIT 10",
query.executed_sql,
)
self.assertEqual(sql_where, query.sql)
Expand All @@ -216,13 +217,14 @@ def test_run_async_query(self):
self.assertEqual(True, query.select_as_cta_used)

def test_run_async_query_with_lower_limit(self):
main_db = get_main_database()
main_db = get_example_database()
db_id = main_db.id
self.drop_table_if_exists("tmp_async_2", main_db)
tmp_table = "tmp_async_2"
self.drop_table_if_exists(tmp_table, main_db)

sql_where = "SELECT name FROM ab_role WHERE name='Alpha' LIMIT 1"
sql_where = "SELECT name FROM birth_names LIMIT 1"
result = self.run_sql(
db_id, sql_where, "5", async_="true", tmp_table="tmp_async_2", cta="true"
db_id, sql_where, "5", async_="true", tmp_table=tmp_table, cta="true"
)
assert result["query"]["state"] in (
QueryStatus.PENDING,
Expand All @@ -234,10 +236,9 @@ def test_run_async_query_with_lower_limit(self):

query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue("FROM tmp_async_2" in query.select_sql)
self.assertTrue(f"FROM {tmp_table}" in query.select_sql)
self.assertEqual(
"CREATE TABLE tmp_async_2 AS \nSELECT name FROM ab_role "
"WHERE name='Alpha' LIMIT 1",
f"CREATE TABLE {tmp_table} AS \n" "SELECT name FROM birth_names LIMIT 1",
query.executed_sql,
)
self.assertEqual(sql_where, query.sql)
Expand Down
Loading