diff --git a/tests/security_tests.py b/tests/security_tests.py index 64c6f32cc6ad..50548fdeb407 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -1020,7 +1020,7 @@ def setUp(self): .filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"])) .all() ) - self.rls_entry.clause = "value > 1" + self.rls_entry.clause = "value > {{ cache_key_wrapper(1) }}" self.rls_entry.roles.append( security_manager.find_role("Gamma") ) # db.session.query(Role).filter_by(name="Gamma").first()) @@ -1052,7 +1052,8 @@ def test_rls_filter_alters_query(self): extras={}, ) sql = tbl.get_query_str(query_obj) - self.assertIn("value > 1", sql) + assert tbl.get_extra_cache_keys(query_obj) == [1] + assert "value > 1" in sql def test_rls_filter_doesnt_alter_query(self): g.user = self.get_user( @@ -1071,7 +1072,8 @@ def test_rls_filter_doesnt_alter_query(self): extras={}, ) sql = tbl.get_query_str(query_obj) - self.assertNotIn("value > 1", sql) + assert tbl.get_extra_cache_keys(query_obj) == [] + assert "value > 1" not in sql def test_multiple_table_filter_alters_another_tables_query(self): g.user = self.get_user( @@ -1090,4 +1092,5 @@ def test_multiple_table_filter_alters_another_tables_query(self): extras={}, ) sql = tbl.get_query_str(query_obj) - self.assertIn("value > 1", sql) + assert tbl.get_extra_cache_keys(query_obj) == [1] + assert "value > 1" in sql