diff --git a/MANIFEST.in b/MANIFEST.in index f06d1310742f..363faac03b3d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -18,7 +18,7 @@ include NOTICE include LICENSE.txt graft licenses/ include README.md -recursive-include superset/data * +recursive-include superset/examples * recursive-include superset/migrations * recursive-include superset/static * recursive-exclude superset/static/assets/docs * diff --git a/superset/cli.py b/superset/cli.py index cb363c2b5955..c4f83097fd85 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -26,7 +26,7 @@ from pathlib2 import Path import yaml -from superset import app, appbuilder, data, db, security_manager +from superset import app, appbuilder, db, examples, security_manager from superset.utils import core as utils, dashboard_import_export, dict_import_export config = app.config @@ -46,6 +46,7 @@ def make_shell_context(): def init(): """Inits the Superset application""" utils.get_or_create_main_db() + utils.get_example_database() appbuilder.add_permissions(update_perms=True) security_manager.sync_role_definitions() @@ -67,66 +68,76 @@ def version(verbose): print(Style.RESET_ALL) -def load_examples_run(load_test_data): - print("Loading examples into {}".format(db)) +def load_examples_run(load_test_data, only_metadata=False, force=False): + if only_metadata: + print("Loading examples metadata") + else: + examples_db = utils.get_example_database() + print(f"Loading examples metadata and related data into {examples_db}") - data.load_css_templates() + examples.load_css_templates() print("Loading energy related dataset") - data.load_energy() + examples.load_energy(only_metadata, force) print("Loading [World Bank's Health Nutrition and Population Stats]") - data.load_world_bank_health_n_pop() + examples.load_world_bank_health_n_pop(only_metadata, force) print("Loading [Birth names]") - data.load_birth_names() + examples.load_birth_names(only_metadata, force) print("Loading [Unicode test data]") - data.load_unicode_test_data() + examples.load_unicode_test_data(only_metadata, force) if not load_test_data: print("Loading [Random time series data]") - data.load_random_time_series_data() + examples.load_random_time_series_data(only_metadata, force) print("Loading [Random long/lat data]") - data.load_long_lat_data() + examples.load_long_lat_data(only_metadata, force) print("Loading [Country Map data]") - data.load_country_map_data() + examples.load_country_map_data(only_metadata, force) print("Loading [Multiformat time series]") - data.load_multiformat_time_series() + examples.load_multiformat_time_series(only_metadata, force) print("Loading [Paris GeoJson]") - data.load_paris_iris_geojson() + examples.load_paris_iris_geojson(only_metadata, force) print("Loading [San Francisco population polygons]") - data.load_sf_population_polygons() + examples.load_sf_population_polygons(only_metadata, force) print("Loading [Flights data]") - data.load_flights() + examples.load_flights(only_metadata, force) print("Loading [BART lines]") - data.load_bart_lines() + examples.load_bart_lines(only_metadata, force) print("Loading [Multi Line]") - data.load_multi_line() + examples.load_multi_line(only_metadata) print("Loading [Misc Charts] dashboard") - data.load_misc_dashboard() + examples.load_misc_dashboard() print("Loading DECK.gl demo") - data.load_deck_dash() + examples.load_deck_dash() print("Loading [Tabbed dashboard]") - data.load_tabbed_dashboard() + examples.load_tabbed_dashboard(only_metadata) @app.cli.command() @click.option("--load-test-data", "-t", is_flag=True, help="Load additional test data") -def load_examples(load_test_data): +@click.option( + "--only-metadata", "-m", is_flag=True, help="Only load metadata, skip actual data" +) +@click.option( + "--force", "-f", is_flag=True, help="Force load data even if table already exists" +) +def load_examples(load_test_data, only_metadata=False, force=False): """Loads a set of Slices and Dashboards and a supporting dataset """ - load_examples_run(load_test_data) + load_examples_run(load_test_data, only_metadata, force) @app.cli.command() @@ -405,7 +416,7 @@ def load_test_users_run(): for perm in security_manager.find_role("Gamma").permissions: security_manager.add_permission_role(gamma_sqllab_role, perm) utils.get_or_create_main_db() - db_perm = utils.get_main_database(security_manager.get_session).perm + db_perm = utils.get_main_database().perm security_manager.add_permission_view_menu("database_access", db_perm) db_pvm = security_manager.find_permission_view_menu( view_menu_name=db_perm, permission_name="database_access" diff --git a/superset/config.py b/superset/config.py index 7679d204daf7..b676e5179541 100644 --- a/superset/config.py +++ b/superset/config.py @@ -617,6 +617,10 @@ class CeleryConfig(object): "force_https_permanent": False, } +# URI to database storing the example data, points to +# SQLALCHEMY_DATABASE_URI by default if set to `None` +SQLALCHEMY_EXAMPLES_URI = None + try: if CONFIG_PATH_ENV_VAR in os.environ: # Explicitly import config module that is not in pythonpath; useful diff --git a/superset/connectors/connector_registry.py b/superset/connectors/connector_registry.py index be31a37ee070..d5e951aabab7 100644 --- a/superset/connectors/connector_registry.py +++ b/superset/connectors/connector_registry.py @@ -55,18 +55,9 @@ def get_datasource_by_name( cls, session, datasource_type, datasource_name, schema, database_name ): datasource_class = ConnectorRegistry.sources[datasource_type] - datasources = session.query(datasource_class).all() - - # Filter datasoures that don't have database. - db_ds = [ - d - for d in datasources - if d.database - and d.database.name == database_name - and d.name == datasource_name - and schema == schema - ] - return db_ds[0] + return datasource_class.get_datasource_by_name( + session, datasource_name, schema, database_name + ) @classmethod def query_datasources_by_permissions(cls, session, database, permissions): diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 3f81ca7e57c0..6a0873cf77d0 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -732,6 +732,16 @@ def time_offset(granularity): return 6 * 24 * 3600 * 1000 # 6 days return 0 + @classmethod + def get_datasource_by_name(cls, session, datasource_name, schema, database_name): + query = ( + session.query(cls) + .join(DruidCluster) + .filter(cls.datasource_name == datasource_name) + .filter(DruidCluster.cluster_name == database_name) + ) + return query.first() + # uses https://en.wikipedia.org/wiki/ISO_8601 # http://druid.io/docs/0.8.0/querying/granularities.html # TODO: pass origin from the UI diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 6805777995ca..a3f56906eb9f 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -374,6 +374,21 @@ def datasource_name(self): def database_name(self): return self.database.name + @classmethod + def get_datasource_by_name(cls, session, datasource_name, schema, database_name): + schema = schema or None + query = ( + session.query(cls) + .join(Database) + .filter(cls.table_name == datasource_name) + .filter(Database.database_name == database_name) + ) + # Handling schema being '' or None, which is easier to handle + # in python than in the SQLA query in a multi-dialect way + for tbl in query.all(): + if schema == (tbl.schema or None): + return tbl + @property def link(self): name = escape(self.name) diff --git a/superset/data/__init__.py b/superset/examples/__init__.py similarity index 100% rename from superset/data/__init__.py rename to superset/examples/__init__.py diff --git a/superset/data/bart_lines.py b/superset/examples/bart_lines.py similarity index 56% rename from superset/data/bart_lines.py rename to superset/examples/bart_lines.py index 8e615fcba094..203a60e5c086 100644 --- a/superset/data/bart_lines.py +++ b/superset/examples/bart_lines.py @@ -21,37 +21,42 @@ from sqlalchemy import String, Text from superset import db -from superset.utils.core import get_or_create_main_db -from .helpers import TBL, get_example_data +from superset.utils.core import get_example_database +from .helpers import get_example_data, TBL -def load_bart_lines(): +def load_bart_lines(only_metadata=False, force=False): tbl_name = "bart_lines" - content = get_example_data("bart-lines.json.gz") - df = pd.read_json(content, encoding="latin-1") - df["path_json"] = df.path.map(json.dumps) - df["polyline"] = df.path.map(polyline.encode) - del df["path"] + database = get_example_database() + table_exists = database.has_table_by_name(tbl_name) + + if not only_metadata and (not table_exists or force): + content = get_example_data("bart-lines.json.gz") + df = pd.read_json(content, encoding="latin-1") + df["path_json"] = df.path.map(json.dumps) + df["polyline"] = df.path.map(polyline.encode) + del df["path"] + + df.to_sql( + tbl_name, + database.get_sqla_engine(), + if_exists="replace", + chunksize=500, + dtype={ + "color": String(255), + "name": String(255), + "polyline": Text, + "path_json": Text, + }, + index=False, + ) - df.to_sql( - tbl_name, - db.engine, - if_exists="replace", - chunksize=500, - dtype={ - "color": String(255), - "name": String(255), - "polyline": Text, - "path_json": Text, - }, - index=False, - ) print("Creating table {} reference".format(tbl_name)) tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = "BART lines" - tbl.database = get_or_create_main_db() + tbl.database = database db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() diff --git a/superset/data/birth_names.py b/superset/examples/birth_names.py similarity index 93% rename from superset/data/birth_names.py rename to superset/examples/birth_names.py index 9040847ab07f..1fcdc55f6ad6 100644 --- a/superset/data/birth_names.py +++ b/superset/examples/birth_names.py @@ -23,7 +23,7 @@ from superset import db, security_manager from superset.connectors.sqla.models import SqlMetric, TableColumn -from superset.utils.core import get_or_create_main_db +from superset.utils.core import get_example_database from .helpers import ( config, Dash, @@ -36,33 +36,39 @@ ) -def load_birth_names(): +def load_birth_names(only_metadata=False, force=False): """Loading birth name dataset from a zip file in the repo""" - data = get_example_data("birth_names.json.gz") - pdf = pd.read_json(data) - pdf.ds = pd.to_datetime(pdf.ds, unit="ms") - pdf.to_sql( - "birth_names", - db.engine, - if_exists="replace", - chunksize=500, - dtype={ - "ds": DateTime, - "gender": String(16), - "state": String(10), - "name": String(255), - }, - index=False, - ) - print("Done loading table!") - print("-" * 80) + # pylint: disable=too-many-locals + tbl_name = "birth_names" + database = get_example_database() + table_exists = database.has_table_by_name(tbl_name) + + if not only_metadata and (not table_exists or force): + pdf = pd.read_json(get_example_data("birth_names.json.gz")) + pdf.ds = pd.to_datetime(pdf.ds, unit="ms") + pdf.to_sql( + tbl_name, + database.get_sqla_engine(), + if_exists="replace", + chunksize=500, + dtype={ + "ds": DateTime, + "gender": String(16), + "state": String(10), + "name": String(255), + }, + index=False, + ) + print("Done loading table!") + print("-" * 80) - print("Creating table [birth_names] reference") - obj = db.session.query(TBL).filter_by(table_name="birth_names").first() + obj = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not obj: - obj = TBL(table_name="birth_names") + print(f"Creating table [{tbl_name}] reference") + obj = TBL(table_name=tbl_name) + db.session.add(obj) obj.main_dttm_col = "ds" - obj.database = get_or_create_main_db() + obj.database = database obj.filter_select_enabled = True if not any(col.column_name == "num_california" for col in obj.columns): @@ -79,7 +85,6 @@ def load_birth_names(): col = str(column("num").compile(db.engine)) obj.metrics.append(SqlMetric(metric_name="sum__num", expression=f"SUM({col})")) - db.session.merge(obj) db.session.commit() obj.fetch_metadata() tbl = obj @@ -384,10 +389,12 @@ def load_birth_names(): merge_slice(slc) print("Creating a dashboard") - dash = db.session.query(Dash).filter_by(dashboard_title="Births").first() + dash = db.session.query(Dash).filter_by(slug="births").first() if not dash: dash = Dash() + db.session.add(dash) + dash.published = True js = textwrap.dedent( # pylint: disable=line-too-long """\ @@ -649,5 +656,4 @@ def load_birth_names(): dash.dashboard_title = "Births" dash.position_json = json.dumps(pos, indent=4) dash.slug = "births" - db.session.merge(dash) db.session.commit() diff --git a/superset/data/countries.md b/superset/examples/countries.md similarity index 100% rename from superset/data/countries.md rename to superset/examples/countries.md diff --git a/superset/data/countries.py b/superset/examples/countries.py similarity index 100% rename from superset/data/countries.py rename to superset/examples/countries.py diff --git a/superset/data/country_map.py b/superset/examples/country_map.py similarity index 61% rename from superset/data/country_map.py rename to superset/examples/country_map.py index d2b12cf117ce..d664e4b161ef 100644 --- a/superset/data/country_map.py +++ b/superset/examples/country_map.py @@ -33,44 +33,50 @@ ) -def load_country_map_data(): +def load_country_map_data(only_metadata=False, force=False): """Loading data for map with country map""" - csv_bytes = get_example_data( - "birth_france_data_for_country_map.csv", is_gzip=False, make_bytes=True - ) - data = pd.read_csv(csv_bytes, encoding="utf-8") - data["dttm"] = datetime.datetime.now().date() - data.to_sql( # pylint: disable=no-member - "birth_france_by_region", - db.engine, - if_exists="replace", - chunksize=500, - dtype={ - "DEPT_ID": String(10), - "2003": BigInteger, - "2004": BigInteger, - "2005": BigInteger, - "2006": BigInteger, - "2007": BigInteger, - "2008": BigInteger, - "2009": BigInteger, - "2010": BigInteger, - "2011": BigInteger, - "2012": BigInteger, - "2013": BigInteger, - "2014": BigInteger, - "dttm": Date(), - }, - index=False, - ) - print("Done loading table!") - print("-" * 80) + tbl_name = "birth_france_by_region" + database = utils.get_example_database() + table_exists = database.has_table_by_name(tbl_name) + + if not only_metadata and (not table_exists or force): + csv_bytes = get_example_data( + "birth_france_data_for_country_map.csv", is_gzip=False, make_bytes=True + ) + data = pd.read_csv(csv_bytes, encoding="utf-8") + data["dttm"] = datetime.datetime.now().date() + data.to_sql( # pylint: disable=no-member + tbl_name, + database.get_sqla_engine(), + if_exists="replace", + chunksize=500, + dtype={ + "DEPT_ID": String(10), + "2003": BigInteger, + "2004": BigInteger, + "2005": BigInteger, + "2006": BigInteger, + "2007": BigInteger, + "2008": BigInteger, + "2009": BigInteger, + "2010": BigInteger, + "2011": BigInteger, + "2012": BigInteger, + "2013": BigInteger, + "2014": BigInteger, + "dttm": Date(), + }, + index=False, + ) + print("Done loading table!") + print("-" * 80) + print("Creating table reference") - obj = db.session.query(TBL).filter_by(table_name="birth_france_by_region").first() + obj = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not obj: - obj = TBL(table_name="birth_france_by_region") + obj = TBL(table_name=tbl_name) obj.main_dttm_col = "dttm" - obj.database = utils.get_or_create_main_db() + obj.database = database if not any(col.metric_name == "avg__2004" for col in obj.metrics): col = str(column("2004").compile(db.engine)) obj.metrics.append(SqlMetric(metric_name="avg__2004", expression=f"AVG({col})")) diff --git a/superset/data/css_templates.py b/superset/examples/css_templates.py similarity index 100% rename from superset/data/css_templates.py rename to superset/examples/css_templates.py diff --git a/superset/data/deck.py b/superset/examples/deck.py similarity index 99% rename from superset/data/deck.py rename to superset/examples/deck.py index 974be9033c42..6cd044199005 100644 --- a/superset/data/deck.py +++ b/superset/examples/deck.py @@ -501,6 +501,7 @@ def load_deck_dash(): if not dash: dash = Dash() + dash.published = True js = POSITION_JSON pos = json.loads(js) update_slice_ids(pos, slices) diff --git a/superset/data/energy.py b/superset/examples/energy.py similarity index 85% rename from superset/data/energy.py rename to superset/examples/energy.py index a3edb2bca668..00dae11a091f 100644 --- a/superset/data/energy.py +++ b/superset/examples/energy.py @@ -25,36 +25,33 @@ from superset import db from superset.connectors.sqla.models import SqlMetric from superset.utils import core as utils -from .helpers import ( - DATA_FOLDER, - get_example_data, - merge_slice, - misc_dash_slices, - Slice, - TBL, -) +from .helpers import get_example_data, merge_slice, misc_dash_slices, Slice, TBL -def load_energy(): +def load_energy(only_metadata=False, force=False): """Loads an energy related dataset to use with sankey and graphs""" tbl_name = "energy_usage" - data = get_example_data("energy.json.gz") - pdf = pd.read_json(data) - pdf.to_sql( - tbl_name, - db.engine, - if_exists="replace", - chunksize=500, - dtype={"source": String(255), "target": String(255), "value": Float()}, - index=False, - ) + database = utils.get_example_database() + table_exists = database.has_table_by_name(tbl_name) + + if not only_metadata and (not table_exists or force): + data = get_example_data("energy.json.gz") + pdf = pd.read_json(data) + pdf.to_sql( + tbl_name, + database.get_sqla_engine(), + if_exists="replace", + chunksize=500, + dtype={"source": String(255), "target": String(255), "value": Float()}, + index=False, + ) print("Creating table [wb_health_population] reference") tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = "Energy consumption" - tbl.database = utils.get_or_create_main_db() + tbl.database = database if not any(col.metric_name == "sum__value" for col in tbl.metrics): col = str(column("value").compile(db.engine)) diff --git a/superset/data/flights.py b/superset/examples/flights.py similarity index 52% rename from superset/data/flights.py rename to superset/examples/flights.py index 8876d5430c10..d386ae2d850a 100644 --- a/superset/data/flights.py +++ b/superset/examples/flights.py @@ -22,38 +22,45 @@ from .helpers import get_example_data, TBL -def load_flights(): +def load_flights(only_metadata=False, force=False): """Loading random time series data from a zip file in the repo""" tbl_name = "flights" - data = get_example_data("flight_data.csv.gz", make_bytes=True) - pdf = pd.read_csv(data, encoding="latin-1") - - # Loading airports info to join and get lat/long - airports_bytes = get_example_data("airports.csv.gz", make_bytes=True) - airports = pd.read_csv(airports_bytes, encoding="latin-1") - airports = airports.set_index("IATA_CODE") - - pdf["ds"] = pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str) - pdf.ds = pd.to_datetime(pdf.ds) - del pdf["YEAR"] - del pdf["MONTH"] - del pdf["DAY"] - - pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG") - pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST") - pdf.to_sql( - tbl_name, - db.engine, - if_exists="replace", - chunksize=500, - dtype={"ds": DateTime}, - index=False, - ) + database = utils.get_example_database() + table_exists = database.has_table_by_name(tbl_name) + + if not only_metadata and (not table_exists or force): + data = get_example_data("flight_data.csv.gz", make_bytes=True) + pdf = pd.read_csv(data, encoding="latin-1") + + # Loading airports info to join and get lat/long + airports_bytes = get_example_data("airports.csv.gz", make_bytes=True) + airports = pd.read_csv(airports_bytes, encoding="latin-1") + airports = airports.set_index("IATA_CODE") + + pdf["ds"] = ( + pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str) + ) + pdf.ds = pd.to_datetime(pdf.ds) + del pdf["YEAR"] + del pdf["MONTH"] + del pdf["DAY"] + + pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG") + pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST") + pdf.to_sql( + tbl_name, + database.get_sqla_engine(), + if_exists="replace", + chunksize=500, + dtype={"ds": DateTime}, + index=False, + ) + tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = "Random set of flights in the US" - tbl.database = utils.get_or_create_main_db() + tbl.database = database db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() diff --git a/superset/data/helpers.py b/superset/examples/helpers.py similarity index 97% rename from superset/data/helpers.py rename to superset/examples/helpers.py index a6dd734a4402..cff7da6ad86f 100644 --- a/superset/data/helpers.py +++ b/superset/examples/helpers.py @@ -38,7 +38,7 @@ config = app.config -DATA_FOLDER = os.path.join(config.get("BASE_DIR"), "data") +EXAMPLES_FOLDER = os.path.join(config.get("BASE_DIR"), "examples") misc_dash_slices = set() # slices assembled in a 'Misc Chart' dashboard diff --git a/superset/data/long_lat.py b/superset/examples/long_lat.py similarity index 50% rename from superset/data/long_lat.py rename to superset/examples/long_lat.py index 4ecb618bf95b..28b8dbbf01ef 100644 --- a/superset/data/long_lat.py +++ b/superset/examples/long_lat.py @@ -33,52 +33,59 @@ ) -def load_long_lat_data(): +def load_long_lat_data(only_metadata=False, force=False): """Loading lat/long data from a csv file in the repo""" - data = get_example_data("san_francisco.csv.gz", make_bytes=True) - pdf = pd.read_csv(data, encoding="utf-8") - start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) - pdf["datetime"] = [ - start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1)) - for i in range(len(pdf)) - ] - pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))] - pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))] - pdf["geohash"] = pdf[["LAT", "LON"]].apply(lambda x: geohash.encode(*x), axis=1) - pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",") - pdf.to_sql( # pylint: disable=no-member - "long_lat", - db.engine, - if_exists="replace", - chunksize=500, - dtype={ - "longitude": Float(), - "latitude": Float(), - "number": Float(), - "street": String(100), - "unit": String(10), - "city": String(50), - "district": String(50), - "region": String(50), - "postcode": Float(), - "id": String(100), - "datetime": DateTime(), - "occupancy": Float(), - "radius_miles": Float(), - "geohash": String(12), - "delimited": String(60), - }, - index=False, - ) - print("Done loading table!") - print("-" * 80) + tbl_name = "long_lat" + database = utils.get_example_database() + table_exists = database.has_table_by_name(tbl_name) + + if not only_metadata and (not table_exists or force): + data = get_example_data("san_francisco.csv.gz", make_bytes=True) + pdf = pd.read_csv(data, encoding="utf-8") + start = datetime.datetime.now().replace( + hour=0, minute=0, second=0, microsecond=0 + ) + pdf["datetime"] = [ + start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1)) + for i in range(len(pdf)) + ] + pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))] + pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))] + pdf["geohash"] = pdf[["LAT", "LON"]].apply(lambda x: geohash.encode(*x), axis=1) + pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",") + pdf.to_sql( # pylint: disable=no-member + tbl_name, + database.get_sqla_engine(), + if_exists="replace", + chunksize=500, + dtype={ + "longitude": Float(), + "latitude": Float(), + "number": Float(), + "street": String(100), + "unit": String(10), + "city": String(50), + "district": String(50), + "region": String(50), + "postcode": Float(), + "id": String(100), + "datetime": DateTime(), + "occupancy": Float(), + "radius_miles": Float(), + "geohash": String(12), + "delimited": String(60), + }, + index=False, + ) + print("Done loading table!") + print("-" * 80) print("Creating table reference") - obj = db.session.query(TBL).filter_by(table_name="long_lat").first() + obj = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not obj: - obj = TBL(table_name="long_lat") + obj = TBL(table_name=tbl_name) obj.main_dttm_col = "datetime" - obj.database = utils.get_or_create_main_db() + obj.database = database db.session.merge(obj) db.session.commit() obj.fetch_metadata() diff --git a/superset/data/misc_dashboard.py b/superset/examples/misc_dashboard.py similarity index 100% rename from superset/data/misc_dashboard.py rename to superset/examples/misc_dashboard.py diff --git a/superset/data/multi_line.py b/superset/examples/multi_line.py similarity index 93% rename from superset/data/multi_line.py rename to superset/examples/multi_line.py index d2b4d18c5ac4..390e5f84e85a 100644 --- a/superset/data/multi_line.py +++ b/superset/examples/multi_line.py @@ -22,9 +22,9 @@ from .world_bank import load_world_bank_health_n_pop -def load_multi_line(): - load_world_bank_health_n_pop() - load_birth_names() +def load_multi_line(only_metadata=False): + load_world_bank_health_n_pop(only_metadata) + load_birth_names(only_metadata) ids = [ row.id for row in db.session.query(Slice).filter( diff --git a/superset/data/multiformat_time_series.py b/superset/examples/multiformat_time_series.py similarity index 66% rename from superset/data/multiformat_time_series.py rename to superset/examples/multiformat_time_series.py index b33391d4aba2..2875d375c770 100644 --- a/superset/data/multiformat_time_series.py +++ b/superset/examples/multiformat_time_series.py @@ -19,7 +19,7 @@ from sqlalchemy import BigInteger, Date, DateTime, String from superset import db -from superset.utils import core as utils +from superset.utils.core import get_example_database from .helpers import ( config, get_example_data, @@ -31,38 +31,44 @@ ) -def load_multiformat_time_series(): +def load_multiformat_time_series(only_metadata=False, force=False): """Loading time series data from a zip file in the repo""" - data = get_example_data("multiformat_time_series.json.gz") - pdf = pd.read_json(data) + tbl_name = "multiformat_time_series" + database = get_example_database() + table_exists = database.has_table_by_name(tbl_name) - pdf.ds = pd.to_datetime(pdf.ds, unit="s") - pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s") - pdf.to_sql( - "multiformat_time_series", - db.engine, - if_exists="replace", - chunksize=500, - dtype={ - "ds": Date, - "ds2": DateTime, - "epoch_s": BigInteger, - "epoch_ms": BigInteger, - "string0": String(100), - "string1": String(100), - "string2": String(100), - "string3": String(100), - }, - index=False, - ) - print("Done loading table!") - print("-" * 80) - print("Creating table [multiformat_time_series] reference") - obj = db.session.query(TBL).filter_by(table_name="multiformat_time_series").first() + if not only_metadata and (not table_exists or force): + data = get_example_data("multiformat_time_series.json.gz") + pdf = pd.read_json(data) + + pdf.ds = pd.to_datetime(pdf.ds, unit="s") + pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s") + pdf.to_sql( + tbl_name, + database.get_sqla_engine(), + if_exists="replace", + chunksize=500, + dtype={ + "ds": Date, + "ds2": DateTime, + "epoch_s": BigInteger, + "epoch_ms": BigInteger, + "string0": String(100), + "string1": String(100), + "string2": String(100), + "string3": String(100), + }, + index=False, + ) + print("Done loading table!") + print("-" * 80) + + print(f"Creating table [{tbl_name}] reference") + obj = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not obj: - obj = TBL(table_name="multiformat_time_series") + obj = TBL(table_name=tbl_name) obj.main_dttm_col = "ds" - obj.database = utils.get_or_create_main_db() + obj.database = database dttm_and_expr_dict = { "ds": [None, None], "ds2": [None, None], diff --git a/superset/data/paris.py b/superset/examples/paris.py similarity index 61% rename from superset/data/paris.py rename to superset/examples/paris.py index 6387272e6874..a4f252b0343e 100644 --- a/superset/data/paris.py +++ b/superset/examples/paris.py @@ -21,35 +21,39 @@ from superset import db from superset.utils import core as utils -from .helpers import TBL, get_example_data +from .helpers import get_example_data, TBL -def load_paris_iris_geojson(): +def load_paris_iris_geojson(only_metadata=False, force=False): tbl_name = "paris_iris_mapping" + database = utils.get_example_database() + table_exists = database.has_table_by_name(tbl_name) - data = get_example_data("paris_iris.json.gz") - df = pd.read_json(data) - df["features"] = df.features.map(json.dumps) + if not only_metadata and (not table_exists or force): + data = get_example_data("paris_iris.json.gz") + df = pd.read_json(data) + df["features"] = df.features.map(json.dumps) + + df.to_sql( + tbl_name, + database.get_sqla_engine(), + if_exists="replace", + chunksize=500, + dtype={ + "color": String(255), + "name": String(255), + "features": Text, + "type": Text, + }, + index=False, + ) - df.to_sql( - tbl_name, - db.engine, - if_exists="replace", - chunksize=500, - dtype={ - "color": String(255), - "name": String(255), - "features": Text, - "type": Text, - }, - index=False, - ) print("Creating table {} reference".format(tbl_name)) tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = "Map of Paris" - tbl.database = utils.get_or_create_main_db() + tbl.database = database db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() diff --git a/superset/data/random_time_series.py b/superset/examples/random_time_series.py similarity index 67% rename from superset/data/random_time_series.py rename to superset/examples/random_time_series.py index 477cb142ab0d..26dd00398453 100644 --- a/superset/data/random_time_series.py +++ b/superset/examples/random_time_series.py @@ -23,28 +23,33 @@ from .helpers import config, get_example_data, get_slice_json, merge_slice, Slice, TBL -def load_random_time_series_data(): +def load_random_time_series_data(only_metadata=False, force=False): """Loading random time series data from a zip file in the repo""" - data = get_example_data("random_time_series.json.gz") - pdf = pd.read_json(data) - pdf.ds = pd.to_datetime(pdf.ds, unit="s") - pdf.to_sql( - "random_time_series", - db.engine, - if_exists="replace", - chunksize=500, - dtype={"ds": DateTime}, - index=False, - ) - print("Done loading table!") - print("-" * 80) + tbl_name = "random_time_series" + database = utils.get_example_database() + table_exists = database.has_table_by_name(tbl_name) + + if not only_metadata and (not table_exists or force): + data = get_example_data("random_time_series.json.gz") + pdf = pd.read_json(data) + pdf.ds = pd.to_datetime(pdf.ds, unit="s") + pdf.to_sql( + tbl_name, + database.get_sqla_engine(), + if_exists="replace", + chunksize=500, + dtype={"ds": DateTime}, + index=False, + ) + print("Done loading table!") + print("-" * 80) - print("Creating table [random_time_series] reference") - obj = db.session.query(TBL).filter_by(table_name="random_time_series").first() + print(f"Creating table [{tbl_name}] reference") + obj = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not obj: - obj = TBL(table_name="random_time_series") + obj = TBL(table_name=tbl_name) obj.main_dttm_col = "ds" - obj.database = utils.get_or_create_main_db() + obj.database = database db.session.merge(obj) db.session.commit() obj.fetch_metadata() diff --git a/superset/data/sf_population_polygons.py b/superset/examples/sf_population_polygons.py similarity index 61% rename from superset/data/sf_population_polygons.py rename to superset/examples/sf_population_polygons.py index 49550b3d8498..738ac41f2cc3 100644 --- a/superset/data/sf_population_polygons.py +++ b/superset/examples/sf_population_polygons.py @@ -21,35 +21,39 @@ from superset import db from superset.utils import core as utils -from .helpers import TBL, get_example_data +from .helpers import get_example_data, TBL -def load_sf_population_polygons(): +def load_sf_population_polygons(only_metadata=False, force=False): tbl_name = "sf_population_polygons" + database = utils.get_example_database() + table_exists = database.has_table_by_name(tbl_name) - data = get_example_data("sf_population.json.gz") - df = pd.read_json(data) - df["contour"] = df.contour.map(json.dumps) + if not only_metadata and (not table_exists or force): + data = get_example_data("sf_population.json.gz") + df = pd.read_json(data) + df["contour"] = df.contour.map(json.dumps) + + df.to_sql( + tbl_name, + database.get_sqla_engine(), + if_exists="replace", + chunksize=500, + dtype={ + "zipcode": BigInteger, + "population": BigInteger, + "contour": Text, + "area": BigInteger, + }, + index=False, + ) - df.to_sql( - tbl_name, - db.engine, - if_exists="replace", - chunksize=500, - dtype={ - "zipcode": BigInteger, - "population": BigInteger, - "contour": Text, - "area": BigInteger, - }, - index=False, - ) print("Creating table {} reference".format(tbl_name)) tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = "Population density of San Francisco" - tbl.database = utils.get_or_create_main_db() + tbl.database = database db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() diff --git a/superset/data/tabbed_dashboard.py b/superset/examples/tabbed_dashboard.py similarity index 95% rename from superset/data/tabbed_dashboard.py rename to superset/examples/tabbed_dashboard.py index 09f18ad8907b..ab4552cf36f5 100644 --- a/superset/data/tabbed_dashboard.py +++ b/superset/examples/tabbed_dashboard.py @@ -17,30 +17,13 @@ """Loads datasets, dashboards and slices in a new superset instance""" # pylint: disable=C,R,W import json -import os import textwrap -import pandas as pd -from sqlalchemy import DateTime, String - from superset import db -from superset.connectors.sqla.models import SqlMetric -from superset.utils import core as utils -from .helpers import ( - config, - Dash, - DATA_FOLDER, - get_example_data, - get_slice_json, - merge_slice, - misc_dash_slices, - Slice, - TBL, - update_slice_ids, -) +from .helpers import Dash, Slice, update_slice_ids -def load_tabbed_dashboard(): +def load_tabbed_dashboard(only_metadata=False): """Creating a tabbed dashboard""" print("Creating a dashboard with nested tabs") diff --git a/superset/data/unicode_test_data.py b/superset/examples/unicode_test_data.py similarity index 72% rename from superset/data/unicode_test_data.py rename to superset/examples/unicode_test_data.py index 3f3ed55d916d..3f91f3f3df18 100644 --- a/superset/data/unicode_test_data.py +++ b/superset/examples/unicode_test_data.py @@ -35,38 +35,43 @@ ) -def load_unicode_test_data(): +def load_unicode_test_data(only_metadata=False, force=False): """Loading unicode test dataset from a csv file in the repo""" - data = get_example_data( - "unicode_utf8_unixnl_test.csv", is_gzip=False, make_bytes=True - ) - df = pd.read_csv(data, encoding="utf-8") - # generate date/numeric data - df["dttm"] = datetime.datetime.now().date() - df["value"] = [random.randint(1, 100) for _ in range(len(df))] - df.to_sql( # pylint: disable=no-member - "unicode_test", - db.engine, - if_exists="replace", - chunksize=500, - dtype={ - "phrase": String(500), - "short_phrase": String(10), - "with_missing": String(100), - "dttm": Date(), - "value": Float(), - }, - index=False, - ) - print("Done loading table!") - print("-" * 80) + tbl_name = "unicode_test" + database = utils.get_example_database() + table_exists = database.has_table_by_name(tbl_name) + + if not only_metadata and (not table_exists or force): + data = get_example_data( + "unicode_utf8_unixnl_test.csv", is_gzip=False, make_bytes=True + ) + df = pd.read_csv(data, encoding="utf-8") + # generate date/numeric data + df["dttm"] = datetime.datetime.now().date() + df["value"] = [random.randint(1, 100) for _ in range(len(df))] + df.to_sql( # pylint: disable=no-member + tbl_name, + database.get_sqla_engine(), + if_exists="replace", + chunksize=500, + dtype={ + "phrase": String(500), + "short_phrase": String(10), + "with_missing": String(100), + "dttm": Date(), + "value": Float(), + }, + index=False, + ) + print("Done loading table!") + print("-" * 80) print("Creating table [unicode_test] reference") - obj = db.session.query(TBL).filter_by(table_name="unicode_test").first() + obj = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not obj: - obj = TBL(table_name="unicode_test") + obj = TBL(table_name=tbl_name) obj.main_dttm_col = "dttm" - obj.database = utils.get_or_create_main_db() + obj.database = database db.session.merge(obj) db.session.commit() obj.fetch_metadata() @@ -104,7 +109,7 @@ def load_unicode_test_data(): merge_slice(slc) print("Creating a dashboard") - dash = db.session.query(Dash).filter_by(dashboard_title="Unicode Test").first() + dash = db.session.query(Dash).filter_by(slug="unicode-test").first() if not dash: dash = Dash() diff --git a/superset/data/world_bank.py b/superset/examples/world_bank.py similarity index 93% rename from superset/data/world_bank.py rename to superset/examples/world_bank.py index a64bd2ba1e9a..fbe45127a306 100644 --- a/superset/data/world_bank.py +++ b/superset/examples/world_bank.py @@ -30,7 +30,7 @@ from .helpers import ( config, Dash, - DATA_FOLDER, + EXAMPLES_FOLDER, get_example_data, get_slice_json, merge_slice, @@ -41,34 +41,38 @@ ) -def load_world_bank_health_n_pop(): +def load_world_bank_health_n_pop(only_metadata=False, force=False): """Loads the world bank health dataset, slices and a dashboard""" tbl_name = "wb_health_population" - data = get_example_data("countries.json.gz") - pdf = pd.read_json(data) - pdf.columns = [col.replace(".", "_") for col in pdf.columns] - pdf.year = pd.to_datetime(pdf.year) - pdf.to_sql( - tbl_name, - db.engine, - if_exists="replace", - chunksize=50, - dtype={ - "year": DateTime(), - "country_code": String(3), - "country_name": String(255), - "region": String(255), - }, - index=False, - ) + database = utils.get_example_database() + table_exists = database.has_table_by_name(tbl_name) + + if not only_metadata and (not table_exists or force): + data = get_example_data("countries.json.gz") + pdf = pd.read_json(data) + pdf.columns = [col.replace(".", "_") for col in pdf.columns] + pdf.year = pd.to_datetime(pdf.year) + pdf.to_sql( + tbl_name, + database.get_sqla_engine(), + if_exists="replace", + chunksize=50, + dtype={ + "year": DateTime(), + "country_code": String(3), + "country_name": String(255), + "region": String(255), + }, + index=False, + ) print("Creating table [wb_health_population] reference") tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() if not tbl: tbl = TBL(table_name=tbl_name) - tbl.description = utils.readfile(os.path.join(DATA_FOLDER, "countries.md")) + tbl.description = utils.readfile(os.path.join(EXAMPLES_FOLDER, "countries.md")) tbl.main_dttm_col = "year" - tbl.database = utils.get_or_create_main_db() + tbl.database = database tbl.filter_select_enabled = True metrics = [ @@ -328,6 +332,7 @@ def load_world_bank_health_n_pop(): if not dash: dash = Dash() + dash.published = True js = textwrap.dedent( """\ { diff --git a/superset/models/core.py b/superset/models/core.py index 0e5d1dbbba9b..66f2d0e76bbe 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -666,12 +666,13 @@ def export_dashboards(cls, dashboard_ids): ) make_transient(copied_dashboard) for slc in copied_dashboard.slices: + make_transient(slc) datasource_ids.add((slc.datasource_id, slc.datasource_type)) # add extra params for the import slc.alter_params( remote_id=slc.id, datasource_name=slc.datasource.name, - schema=slc.datasource.name, + schema=slc.datasource.schema, database_name=slc.datasource.database.name, ) copied_dashboard.alter_params(remote_id=dashboard_id) @@ -1169,6 +1170,10 @@ def has_table(self, table): engine = self.get_sqla_engine() return engine.has_table(table.table_name, table.schema or None) + def has_table_by_name(self, table_name, schema=None): + engine = self.get_sqla_engine() + return engine.has_table(table_name, schema) + @utils.memoized def get_dialect(self): sqla_url = url.make_url(self.sqlalchemy_uri_decrypted) diff --git a/superset/models/tags.py b/superset/models/tags.py index d3181154f5a8..0bd6c2b70d3b 100644 --- a/superset/models/tags.py +++ b/superset/models/tags.py @@ -81,7 +81,7 @@ class TaggedObject(Model, AuditMixinNullable): object_id = Column(Integer) object_type = Column(Enum(ObjectTypes)) - tag = relationship("Tag") + tag = relationship("Tag", backref="objects") def get_tag(name, session, type_): diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py index d99719c3f269..73dc75687f72 100644 --- a/superset/tasks/cache.py +++ b/superset/tasks/cache.py @@ -18,10 +18,8 @@ import json import logging -import urllib.parse from celery.utils.log import get_task_logger -from flask import url_for import requests from requests.exceptions import RequestException from sqlalchemy import and_, func @@ -75,13 +73,13 @@ def get_form_data(chart_id, dashboard=None): return form_data -def get_url(params): +def get_url(chart): """Return external URL for warming up a given chart/table cache.""" - baseurl = "http://{SUPERSET_WEBSERVER_ADDRESS}:{SUPERSET_WEBSERVER_PORT}/".format( - **app.config - ) with app.test_request_context(): - return urllib.parse.urljoin(baseurl, url_for("Superset.explore_json", **params)) + baseurl = "{SUPERSET_WEBSERVER_ADDRESS}:{SUPERSET_WEBSERVER_PORT}".format( + **app.config + ) + return f"{baseurl}{chart.url}" class Strategy: @@ -136,7 +134,7 @@ def get_urls(self): session = db.create_scoped_session() charts = session.query(Slice).all() - return [get_url({"form_data": get_form_data(chart.id)}) for chart in charts] + return [get_url(chart) for chart in charts] class TopNDashboardsStrategy(Strategy): @@ -180,7 +178,7 @@ def get_urls(self): dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all() for dashboard in dashboards: for chart in dashboard.slices: - urls.append(get_url({"form_data": get_form_data(chart.id, dashboard)})) + urls.append(get_url(chart)) return urls @@ -229,7 +227,7 @@ def get_urls(self): tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)) for dashboard in tagged_dashboards: for chart in dashboard.slices: - urls.append(get_url({"form_data": get_form_data(chart.id, dashboard)})) + urls.append(get_url(chart)) # add charts that are tagged tagged_objects = ( @@ -245,7 +243,7 @@ def get_urls(self): chart_ids = [tagged_object.object_id for tagged_object in tagged_objects] tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids)) for chart in tagged_charts: - urls.append(get_url({"form_data": get_form_data(chart.id)})) + urls.append(get_url(chart)) return urls diff --git a/superset/utils/core.py b/superset/utils/core.py index ebbb08278483..b0ca2dc20499 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -942,25 +942,37 @@ def user_label(user: User) -> Optional[str]: def get_or_create_main_db(): - from superset import conf, db + get_main_database() + + +def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs): + from superset import db from superset.models import core as models - logging.info("Creating database reference") - dbobj = get_main_database(db.session) - if not dbobj: - dbobj = models.Database( - database_name="main", allow_csv_upload=True, expose_in_sqllab=True - ) - dbobj.set_sqlalchemy_uri(conf.get("SQLALCHEMY_DATABASE_URI")) - db.session.add(dbobj) + database = ( + db.session.query(models.Database).filter_by(database_name=database_name).first() + ) + if not database: + logging.info(f"Creating database reference for {database_name}") + database = models.Database(database_name=database_name, *args, **kwargs) + db.session.add(database) + + database.set_sqlalchemy_uri(sqlalchemy_uri) db.session.commit() - return dbobj + return database -def get_main_database(session): - from superset.models import core as models +def get_main_database(): + from superset import conf + + return get_or_create_db("main", conf.get("SQLALCHEMY_DATABASE_URI")) + + +def get_example_database(): + from superset import conf - return session.query(models.Database).filter_by(database_name="main").first() + db_uri = conf.get("SQLALCHEMY_EXAMPLES_URI") or conf.get("SQLALCHEMY_DATABASE_URI") + return get_or_create_db("examples", db_uri) def is_adhoc_metric(metric) -> bool: diff --git a/superset/viz.py b/superset/viz.py index 5498becec671..d81e16e0664d 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -1746,7 +1746,7 @@ def query_obj(self): return qry def get_data(self, df): - from superset.data import countries + from superset.examples import countries fd = self.form_data cols = [fd.get("entity")] diff --git a/tests/access_tests.py b/tests/access_tests.py index 835ac226c21c..999db1f87b16 100644 --- a/tests/access_tests.py +++ b/tests/access_tests.py @@ -31,7 +31,7 @@ "database": [ { "datasource_type": "table", - "name": "main", + "name": "examples", "schema": [{"name": "", "datasources": ["birth_names"]}], } ], @@ -42,7 +42,7 @@ "database": [ { "datasource_type": "table", - "name": "main", + "name": "examples", "schema": [{"name": "", "datasources": ["birth_names"]}], }, { diff --git a/tests/base_tests.py b/tests/base_tests.py index cca317ea5163..8ac5bcdeeebb 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -168,9 +168,6 @@ def revoke_public_access_to_table(self, table): ): security_manager.del_permission_role(public_role, perm) - def get_main_database(self): - return get_main_database(db.session) - def run_sql( self, sql, @@ -182,7 +179,7 @@ def run_sql( if user_name: self.logout() self.login(username=(user_name if user_name else "admin")) - dbid = self.get_main_database().id + dbid = get_main_database().id resp = self.get_json_resp( "/superset/sql_json/", raise_on_error=False, @@ -202,7 +199,7 @@ def validate_sql(self, sql, client_id=None, user_name=None, raise_on_error=False if user_name: self.logout() self.login(username=(user_name if user_name else "admin")) - dbid = self.get_main_database().id + dbid = get_main_database().id resp = self.get_json_resp( "/superset/validate_sql_json/", raise_on_error=False, @@ -223,3 +220,7 @@ def test_nonexistent_feature_flags(self): def test_feature_flags(self): self.assertEquals(is_feature_enabled("foo"), "bar") self.assertEquals(is_feature_enabled("super"), "set") + + def get_dash_by_slug(self, dash_slug): + sesh = db.session() + return sesh.query(models.Dashboard).filter_by(slug=dash_slug).first() diff --git a/tests/celery_tests.py b/tests/celery_tests.py index f204fa6833ba..dba087a66942 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -128,14 +128,14 @@ def run_sql( return json.loads(resp.data) def test_run_sync_query_dont_exist(self): - main_db = get_main_database(db.session) + main_db = get_main_database() db_id = main_db.id sql_dont_exist = "SELECT name FROM table_dont_exist" result1 = self.run_sql(db_id, sql_dont_exist, "1", cta="true") self.assertTrue("error" in result1) def test_run_sync_query_cta(self): - main_db = get_main_database(db.session) + main_db = get_main_database() backend = main_db.backend db_id = main_db.id tmp_table_name = "tmp_async_22" @@ -158,7 +158,7 @@ def test_run_sync_query_cta(self): self.assertGreater(len(results["data"]), 0) def test_run_sync_query_cta_no_data(self): - main_db = get_main_database(db.session) + main_db = get_main_database() db_id = main_db.id sql_empty_result = "SELECT * FROM ab_user WHERE id=666" result3 = self.run_sql(db_id, sql_empty_result, "3") @@ -179,7 +179,7 @@ def drop_table_if_exists(self, table_name, database=None): return self.run_sql(db_id, sql) def test_run_async_query(self): - main_db = get_main_database(db.session) + main_db = get_main_database() db_id = main_db.id self.drop_table_if_exists("tmp_async_1", main_db) @@ -212,7 +212,7 @@ def test_run_async_query(self): self.assertEqual(True, query.select_as_cta_used) def test_run_async_query_with_lower_limit(self): - main_db = get_main_database(db.session) + main_db = get_main_database() db_id = main_db.id self.drop_table_if_exists("tmp_async_2", main_db) diff --git a/tests/core_tests.py b/tests/core_tests.py index ea3b88913b3d..b401c8c0ff62 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -39,7 +39,6 @@ from superset.models import core as models from superset.models.sql_lab import Query from superset.utils import core as utils -from superset.utils.core import get_main_database from superset.views.core import DatabaseView from .base_tests import SupersetTestCase from .fixtures.pyodbcRow import Row @@ -345,7 +344,7 @@ def test_misc(self): def test_testconn(self, username="admin"): self.login(username=username) - database = get_main_database(db.session) + database = utils.get_main_database() # validate that the endpoint works with the password-masked sqlalchemy uri data = json.dumps( @@ -376,7 +375,7 @@ def test_testconn(self, username="admin"): assert response.headers["Content-Type"] == "application/json" def test_custom_password_store(self): - database = get_main_database(db.session) + database = utils.get_main_database() conn_pre = sqla.engine.url.make_url(database.sqlalchemy_uri_decrypted) def custom_password_store(uri): @@ -394,13 +393,13 @@ def test_databaseview_edit(self, username="admin"): # validate that sending a password-masked uri does not over-write the decrypted # uri self.login(username=username) - database = get_main_database(db.session) + database = utils.get_main_database() sqlalchemy_uri_decrypted = database.sqlalchemy_uri_decrypted url = "databaseview/edit/{}".format(database.id) data = {k: database.__getattribute__(k) for k in DatabaseView.add_columns} data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri() self.client.post(url, data=data) - database = get_main_database(db.session) + database = utils.get_main_database() self.assertEqual(sqlalchemy_uri_decrypted, database.sqlalchemy_uri_decrypted) def test_warm_up_cache(self): @@ -483,27 +482,27 @@ def test_csv_endpoint(self): def test_extra_table_metadata(self): self.login("admin") - dbid = get_main_database(db.session).id + dbid = utils.get_main_database().id self.get_json_resp( f"/superset/extra_table_metadata/{dbid}/" "ab_permission_view/panoramix/" ) def test_process_template(self): - maindb = get_main_database(db.session) + maindb = utils.get_main_database() sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'" tp = jinja_context.get_template_processor(database=maindb) rendered = tp.process_template(sql) self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered) def test_get_template_kwarg(self): - maindb = get_main_database(db.session) + maindb = utils.get_main_database() s = "{{ foo }}" tp = jinja_context.get_template_processor(database=maindb, foo="bar") rendered = tp.process_template(s) self.assertEqual("bar", rendered) def test_template_kwarg(self): - maindb = get_main_database(db.session) + maindb = utils.get_main_database() s = "{{ foo }}" tp = jinja_context.get_template_processor(database=maindb) rendered = tp.process_template(s, foo="bar") @@ -516,7 +515,7 @@ def test_templated_sql_json(self): self.assertEqual(data["data"][0]["test"], "2017-01-01T00:00:00") def test_table_metadata(self): - maindb = get_main_database(db.session) + maindb = utils.get_main_database() backend = maindb.backend data = self.get_json_resp("/superset/table/{}/ab_user/null/".format(maindb.id)) self.assertEqual(data["name"], "ab_user") @@ -615,15 +614,16 @@ def test_import_csv(self): test_file.write("john,1\n") test_file.write("paul,2\n") test_file.close() - main_db_uri = ( - db.session.query(models.Database).filter_by(database_name="main").one() - ) + example_db = utils.get_example_database() + example_db.allow_csv_upload = True + db_id = example_db.id + db.session.commit() test_file = open(filename, "rb") form_data = { "csv_file": test_file, "sep": ",", "name": table_name, - "con": main_db_uri.id, + "con": db_id, "if_exists": "append", "index_label": "test_label", "mangle_dupe_cols": False, @@ -638,8 +638,8 @@ def test_import_csv(self): try: # ensure uploaded successfully - form_post = self.get_resp(url, data=form_data) - assert 'CSV file "testCSV.csv" uploaded to table' in form_post + resp = self.get_resp(url, data=form_data) + assert 'CSV file "testCSV.csv" uploaded to table' in resp finally: os.remove(filename) @@ -769,7 +769,8 @@ def test_schemas_access_for_csv_upload_endpoint( def test_select_star(self): self.login(username="admin") - resp = self.get_resp("/superset/select_star/1/birth_names") + examples_db = utils.get_example_database() + resp = self.get_resp(f"/superset/select_star/{examples_db.id}/birth_names") self.assertIn("gender", resp) diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index 7ab389ee726e..733b9bea1272 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -39,6 +39,7 @@ from superset.db_engine_specs.postgres import PostgresEngineSpec from superset.db_engine_specs.presto import PrestoEngineSpec from superset.models.core import Database +from superset.utils.core import get_example_database from .base_tests import SupersetTestCase @@ -925,14 +926,14 @@ def test_pinot_time_expression_sec_1m_grain(self): ) # noqa def test_column_datatype_to_string(self): - main_db = self.get_main_database() - sqla_table = main_db.get_table("energy_usage") - dialect = main_db.get_dialect() + example_db = get_example_database() + sqla_table = example_db.get_table("energy_usage") + dialect = example_db.get_dialect() col_names = [ - main_db.db_engine_spec.column_datatype_to_string(c.type, dialect) + example_db.db_engine_spec.column_datatype_to_string(c.type, dialect) for c in sqla_table.columns ] - if main_db.backend == "postgresql": + if example_db.backend == "postgresql": expected = ["VARCHAR(255)", "VARCHAR(255)", "DOUBLE PRECISION"] else: expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"] diff --git a/tests/dict_import_export_tests.py b/tests/dict_import_export_tests.py index ba6766bbb9cf..fafa024e9afe 100644 --- a/tests/dict_import_export_tests.py +++ b/tests/dict_import_export_tests.py @@ -63,7 +63,7 @@ def create_table(self, name, schema="", id=0, cols_names=[], metric_names=[]): params = {DBREF: id, "database_name": database_name} dict_rep = { - "database_id": get_main_database(db.session).id, + "database_id": get_main_database().id, "table_name": name, "schema": schema, "id": id, diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index f48fa2d06011..d5c36091eac8 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -63,7 +63,7 @@ def create_slice( name, ds_id=None, id=None, - db_name="main", + db_name="examples", table_name="wb_health_population", ): params = { @@ -102,7 +102,7 @@ def create_dashboard(self, title, id=0, slcs=[]): ) def create_table(self, name, schema="", id=0, cols_names=[], metric_names=[]): - params = {"remote_id": id, "database_name": "main"} + params = {"remote_id": id, "database_name": "examples"} table = SqlaTable( id=id, schema=schema, table_name=name, params=json.dumps(params) ) @@ -135,10 +135,6 @@ def get_slice_by_name(self, name): def get_dash(self, dash_id): return db.session.query(models.Dashboard).filter_by(id=dash_id).first() - def get_dash_by_slug(self, dash_slug): - sesh = db.session() - return sesh.query(models.Dashboard).filter_by(slug=dash_slug).first() - def get_datasource(self, datasource_id): return db.session.query(DruidDatasource).filter_by(id=datasource_id).first() @@ -192,9 +188,21 @@ def assert_slice_equals(self, expected_slc, actual_slc): self.assertEquals(expected_slc_name, actual_slc_name) self.assertEquals(expected_slc.datasource_type, actual_slc.datasource_type) self.assertEquals(expected_slc.viz_type, actual_slc.viz_type) - self.assertEquals( - json.loads(expected_slc.params), json.loads(actual_slc.params) - ) + exp_params = json.loads(expected_slc.params) + actual_params = json.loads(actual_slc.params) + diff_params_keys = ( + "schema", + "database_name", + "datasource_name", + "remote_id", + "import_time", + ) + for k in diff_params_keys: + if k in actual_params: + actual_params.pop(k) + if k in exp_params: + exp_params.pop(k) + self.assertEquals(exp_params, actual_params) def test_export_1_dashboard(self): self.login("admin") @@ -233,11 +241,11 @@ def test_export_2_dashboards(self): birth_dash.id, world_health_dash.id ) resp = self.client.get(export_dash_url) + resp_data = json.loads( + resp.data.decode("utf-8"), object_hook=utils.decode_dashboards + ) exported_dashboards = sorted( - json.loads(resp.data.decode("utf-8"), object_hook=utils.decode_dashboards)[ - "dashboards" - ], - key=lambda d: d.dashboard_title, + resp_data.get("dashboards"), key=lambda d: d.dashboard_title ) self.assertEquals(2, len(exported_dashboards)) @@ -255,10 +263,7 @@ def test_export_2_dashboards(self): ) exported_tables = sorted( - json.loads(resp.data.decode("utf-8"), object_hook=utils.decode_dashboards)[ - "datasources" - ], - key=lambda t: t.table_name, + resp_data.get("datasources"), key=lambda t: t.table_name ) self.assertEquals(2, len(exported_tables)) self.assert_table_equals( @@ -297,7 +302,7 @@ def test_import_2_slices_for_same_table(self): self.assertEquals(imported_slc_2.datasource.perm, imported_slc_2.perm) def test_import_slices_for_non_existent_table(self): - with self.assertRaises(IndexError): + with self.assertRaises(AttributeError): models.Slice.import_obj( self.create_slice("Import Me 3", id=10004, table_name="non_existent"), None, @@ -447,7 +452,7 @@ def test_import_table_1_col_1_met(self): imported = self.get_table(imported_id) self.assert_table_equals(table, imported) self.assertEquals( - {"remote_id": 10002, "import_time": 1990, "database_name": "main"}, + {"remote_id": 10002, "import_time": 1990, "database_name": "examples"}, json.loads(imported.params), ) diff --git a/tests/load_examples_test.py b/tests/load_examples_test.py index 6ca3b2d38f7e..0d1db798b1f8 100644 --- a/tests/load_examples_test.py +++ b/tests/load_examples_test.py @@ -14,23 +14,26 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from superset import data +from superset import examples from superset.cli import load_test_users_run from .base_tests import SupersetTestCase class SupersetDataFrameTestCase(SupersetTestCase): def test_load_css_templates(self): - data.load_css_templates() + examples.load_css_templates() def test_load_energy(self): - data.load_energy() + examples.load_energy() def test_load_world_bank_health_n_pop(self): - data.load_world_bank_health_n_pop() + examples.load_world_bank_health_n_pop() def test_load_birth_names(self): - data.load_birth_names() + examples.load_birth_names() def test_load_test_users_run(self): load_test_users_run() + + def test_load_unicode_test_data(self): + examples.load_unicode_test_data() diff --git a/tests/model_tests.py b/tests/model_tests.py index 445ec52db019..fda941bb1848 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -20,9 +20,9 @@ import pandas from sqlalchemy.engine.url import make_url -from superset import app, db +from superset import app from superset.models.core import Database -from superset.utils.core import get_main_database, QueryStatus +from superset.utils.core import get_example_database, get_main_database, QueryStatus from .base_tests import SupersetTestCase @@ -101,7 +101,7 @@ def test_database_impersonate_user(self): self.assertNotEquals(example_user, user_name) def test_select_star(self): - main_db = get_main_database(db.session) + main_db = get_example_database() table_name = "energy_usage" sql = main_db.select_star(table_name, show_cols=False, latest_partition=False) expected = textwrap.dedent( @@ -124,7 +124,7 @@ def test_select_star(self): assert sql.startswith(expected) def test_single_statement(self): - main_db = get_main_database(db.session) + main_db = get_main_database() if main_db.backend == "mysql": df = main_db.get_df("SELECT 1", None) @@ -134,7 +134,7 @@ def test_single_statement(self): self.assertEquals(df.iat[0, 0], 1) def test_multi_statement(self): - main_db = get_main_database(db.session) + main_db = get_main_database() if main_db.backend == "mysql": df = main_db.get_df("USE superset; SELECT 1", None) diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index 5dc56bfa3a9a..b16b796abdaa 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -83,7 +83,7 @@ def test_explain(self): self.assertLess(0, len(data["data"])) def test_sql_json_has_access(self): - main_db = get_main_database(db.session) + main_db = get_main_database() security_manager.add_permission_view_menu("database_access", main_db.perm) db.session.commit() main_db_permission_view = ( diff --git a/tests/strategy_tests.py b/tests/strategy_tests.py index 0f5a20e11b10..0786ed36f74d 100644 --- a/tests/strategy_tests.py +++ b/tests/strategy_tests.py @@ -23,14 +23,13 @@ from superset.models.tags import get_tag, ObjectTypes, TaggedObject, TagTypes from superset.tasks.cache import ( DashboardTagsStrategy, - DummyStrategy, get_form_data, TopNDashboardsStrategy, ) from .base_tests import SupersetTestCase -TEST_URL = "http://0.0.0.0:8081/superset/explore_json" +URL_PREFIX = "0.0.0.0:8081" class CacheWarmUpTests(SupersetTestCase): @@ -141,61 +140,61 @@ def test_get_form_data(self): } self.assertEqual(result, expected) - def test_dummy_strategy(self): - strategy = DummyStrategy() - result = sorted(strategy.get_urls()) - expected = [ - f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+1%7D", - f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+17%7D", - f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+18%7D", - f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+19%7D", - f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+30%7D", - f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+31%7D", - f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+8%7D", - ] - self.assertEqual(result, expected) - def test_top_n_dashboards_strategy(self): # create a top visited dashboard db.session.query(Log).delete() self.login(username="admin") + dash = self.get_dash_by_slug("births") for _ in range(10): - self.client.get("/superset/dashboard/3/") + self.client.get(f"/superset/dashboard/{dash.id}/") strategy = TopNDashboardsStrategy(1) result = sorted(strategy.get_urls()) - expected = [f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+31%7D"] + expected = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices]) self.assertEqual(result, expected) + def reset_tag(self, tag): + """Remove associated object from tag, used to reset tests""" + if tag.objects: + for o in tag.objects: + db.session.delete(o) + db.session.commit() + def test_dashboard_tags(self): - strategy = DashboardTagsStrategy(["tag1"]) + tag1 = get_tag("tag1", db.session, TagTypes.custom) + # delete first to make test idempotent + self.reset_tag(tag1) + strategy = DashboardTagsStrategy(["tag1"]) result = sorted(strategy.get_urls()) expected = [] self.assertEqual(result, expected) - # tag dashboard 3 with `tag1` + # tag dashboard 'births' with `tag1` tag1 = get_tag("tag1", db.session, TagTypes.custom) - object_id = 3 + dash = self.get_dash_by_slug("births") + tag1_urls = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices]) tagged_object = TaggedObject( - tag_id=tag1.id, object_id=object_id, object_type=ObjectTypes.dashboard + tag_id=tag1.id, object_id=dash.id, object_type=ObjectTypes.dashboard ) db.session.add(tagged_object) db.session.commit() - result = sorted(strategy.get_urls()) - expected = [f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+31%7D"] - self.assertEqual(result, expected) + self.assertEqual(sorted(strategy.get_urls()), tag1_urls) strategy = DashboardTagsStrategy(["tag2"]) + tag2 = get_tag("tag2", db.session, TagTypes.custom) + self.reset_tag(tag2) result = sorted(strategy.get_urls()) expected = [] self.assertEqual(result, expected) - # tag chart 30 with `tag2` - tag2 = get_tag("tag2", db.session, TagTypes.custom) - object_id = 30 + # tag first slice + dash = self.get_dash_by_slug("unicode-test") + slc = dash.slices[0] + tag2_urls = [f"{URL_PREFIX}{slc.url}"] + object_id = slc.id tagged_object = TaggedObject( tag_id=tag2.id, object_id=object_id, object_type=ObjectTypes.chart ) @@ -203,14 +202,10 @@ def test_dashboard_tags(self): db.session.commit() result = sorted(strategy.get_urls()) - expected = [f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+30%7D"] - self.assertEqual(result, expected) + self.assertEqual(result, tag2_urls) strategy = DashboardTagsStrategy(["tag1", "tag2"]) result = sorted(strategy.get_urls()) - expected = [ - f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+30%7D", - f"{TEST_URL}/?form_data=%7B%27slice_id%27%3A+31%7D", - ] + expected = sorted(tag1_urls + tag2_urls) self.assertEqual(result, expected) diff --git a/tests/viz_tests.py b/tests/viz_tests.py index 136fdf8423ec..b5638fa581be 100644 --- a/tests/viz_tests.py +++ b/tests/viz_tests.py @@ -109,7 +109,6 @@ def test_get_df_handles_dttm_col(self): datasource.get_col = Mock(return_value=mock_dttm_col) mock_dttm_col.python_date_format = "epoch_ms" result = test_viz.get_df(query_obj) - print(result) import logging logging.info(result) diff --git a/tox.ini b/tox.ini index 189862e8bba6..3625ed64a001 100644 --- a/tox.ini +++ b/tox.ini @@ -46,7 +46,7 @@ setenv = PYTHONPATH = {toxinidir} SUPERSET_CONFIG = tests.superset_test_config SUPERSET_HOME = {envtmpdir} - py36-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset + py36-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset?charset=utf8 py36-postgres: SUPERSET__SQLALCHEMY_DATABASE_URI = postgresql+psycopg2://postgresuser:pguserpassword@localhost/superset py36-sqlite: SUPERSET__SQLALCHEMY_DATABASE_URI = sqlite:////{envtmpdir}/superset.db whitelist_externals =