Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions superset-frontend/src/SqlLab/actions/sqlLab.js
Original file line number Diff line number Diff line change
Expand Up @@ -914,11 +914,27 @@ export function queryEditorSetAndSaveSql(targetQueryEditor, sql, queryId) {

export function formatQuery(queryEditor) {
return function (dispatch, getState) {
const { sql } = getUpToDateQuery(getState(), queryEditor);
const { sql, dbId, templateParams } = getUpToDateQuery(
getState(),
queryEditor,
);
const body = { sql };

// Include database_id and template_params if available for Jinja processing
if (dbId) {
body.database_id = dbId;
}
if (templateParams) {
// Send templateParams as a JSON string to match the backend schema
body.template_params =
typeof templateParams === 'string'
? templateParams
: JSON.stringify(templateParams);
}

return SupersetClient.post({
endpoint: `/api/v1/sqllab/format_sql/`,
// TODO (betodealmeida): pass engine as a parameter for better formatting
body: JSON.stringify({ sql }),
body: JSON.stringify(body),
headers: { 'Content-Type': 'application/json' },
}).then(({ json }) => {
dispatch(queryEditorSetSql(queryEditor, json.result));
Expand Down
195 changes: 194 additions & 1 deletion superset-frontend/src/SqlLab/actions/sqlLab.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,16 @@ describe('async actions', () => {
describe('formatQuery', () => {
const formatQueryEndpoint = 'glob:*/api/v1/sqllab/format_sql/';
const expectedSql = 'SELECT 1';
fetchMock.post(formatQueryEndpoint, { result: expectedSql });

beforeEach(() => {
fetchMock.post(
formatQueryEndpoint,
{ result: expectedSql },
{
overwriteRoutes: true,
},
);
});

test('posts to the correct url', async () => {
const store = mockStore(initialState);
Expand All @@ -181,6 +190,190 @@ describe('async actions', () => {
expect(store.getActions()[0].type).toBe(actions.QUERY_EDITOR_SET_SQL);
expect(store.getActions()[0].sql).toBe(expectedSql);
});

test('sends only sql in request body when no dbId or templateParams', async () => {
const queryEditorWithoutExtras = {
...defaultQueryEditor,
sql: 'SELECT * FROM table',
dbId: null,
templateParams: null,
};
const state = {
sqlLab: {
queryEditors: [queryEditorWithoutExtras],
unsavedQueryEditor: {},
},
};
const store = mockStore(state);

store.dispatch(actions.formatQuery(queryEditorWithoutExtras));

await waitFor(() =>
expect(fetchMock.calls(formatQueryEndpoint)).toHaveLength(1),
);

const call = fetchMock.calls(formatQueryEndpoint)[0];
const body = JSON.parse(call[1].body);

expect(body).toEqual({ sql: 'SELECT * FROM table' });
expect(body.database_id).toBeUndefined();
expect(body.template_params).toBeUndefined();
});

test('includes database_id in request when dbId is provided', async () => {
const queryEditorWithDb = {
...defaultQueryEditor,
sql: 'SELECT * FROM table',
dbId: 5,
templateParams: null,
};
const state = {
sqlLab: {
queryEditors: [queryEditorWithDb],
unsavedQueryEditor: {},
},
};
const store = mockStore(state);

store.dispatch(actions.formatQuery(queryEditorWithDb));

await waitFor(() =>
expect(fetchMock.calls(formatQueryEndpoint)).toHaveLength(1),
);

const call = fetchMock.calls(formatQueryEndpoint)[0];
const body = JSON.parse(call[1].body);

expect(body).toEqual({
sql: 'SELECT * FROM table',
database_id: 5,
});
});

test('includes template_params as string when provided as string', async () => {
const queryEditorWithTemplateString = {
...defaultQueryEditor,
sql: 'SELECT * FROM table WHERE id = {{ user_id }}',
dbId: 5,
templateParams: '{"user_id": 123}',
};
const state = {
sqlLab: {
queryEditors: [queryEditorWithTemplateString],
unsavedQueryEditor: {},
},
};
const store = mockStore(state);

store.dispatch(actions.formatQuery(queryEditorWithTemplateString));

await waitFor(() =>
expect(fetchMock.calls(formatQueryEndpoint)).toHaveLength(1),
);

const call = fetchMock.calls(formatQueryEndpoint)[0];
const body = JSON.parse(call[1].body);

expect(body).toEqual({
sql: 'SELECT * FROM table WHERE id = {{ user_id }}',
database_id: 5,
template_params: '{"user_id": 123}',
});
});

test('stringifies template_params when provided as object', async () => {
const queryEditorWithTemplateObject = {
...defaultQueryEditor,
sql: 'SELECT * FROM table WHERE id = {{ user_id }}',
dbId: 5,
templateParams: { user_id: 123, status: 'active' },
};
const state = {
sqlLab: {
queryEditors: [queryEditorWithTemplateObject],
unsavedQueryEditor: {},
},
};
const store = mockStore(state);

store.dispatch(actions.formatQuery(queryEditorWithTemplateObject));

await waitFor(() =>
expect(fetchMock.calls(formatQueryEndpoint)).toHaveLength(1),
);

const call = fetchMock.calls(formatQueryEndpoint)[0];
const body = JSON.parse(call[1].body);

expect(body).toEqual({
sql: 'SELECT * FROM table WHERE id = {{ user_id }}',
database_id: 5,
template_params: '{"user_id":123,"status":"active"}',
});
});

test('dispatches QUERY_EDITOR_SET_SQL with formatted result', async () => {
const formattedSql = 'SELECT\n *\nFROM\n table';
fetchMock.post(
formatQueryEndpoint,
{ result: formattedSql },
{
overwriteRoutes: true,
},
);

const queryEditorToFormat = {
...defaultQueryEditor,
sql: 'SELECT * FROM table',
};
const state = {
sqlLab: {
queryEditors: [queryEditorToFormat],
unsavedQueryEditor: {},
},
};
const store = mockStore(state);

await store.dispatch(actions.formatQuery(queryEditorToFormat));

const dispatchedActions = store.getActions();
expect(dispatchedActions).toHaveLength(1);
expect(dispatchedActions[0].type).toBe(actions.QUERY_EDITOR_SET_SQL);
expect(dispatchedActions[0].sql).toBe(formattedSql);
});

test('uses up-to-date query editor state from store', async () => {
const outdatedQueryEditor = {
...defaultQueryEditor,
sql: 'OLD SQL',
dbId: 1,
};
const upToDateQueryEditor = {
...defaultQueryEditor,
sql: 'SELECT * FROM updated_table',
dbId: 10,
};
const state = {
sqlLab: {
queryEditors: [upToDateQueryEditor],
unsavedQueryEditor: {},
},
};
const store = mockStore(state);

// Pass outdated query editor, but expect the function to use the up-to-date one from store
store.dispatch(actions.formatQuery(outdatedQueryEditor));

await waitFor(() =>
expect(fetchMock.calls(formatQueryEndpoint)).toHaveLength(1),
);

const call = fetchMock.calls(formatQueryEndpoint)[0];
const body = JSON.parse(call[1].body);

expect(body.sql).toBe('SELECT * FROM updated_table');
expect(body.database_id).toBe(10);
});
});

// eslint-disable-next-line no-restricted-globals -- TODO: Migrate from describe blocks
Expand Down
29 changes: 28 additions & 1 deletion superset/sqllab/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,34 @@ def format_sql(self) -> FlaskResponse:
"""
try:
model = self.format_model_schema.load(request.json)
result = SQLScript(model["sql"], model.get("engine")).format()
sql = model["sql"]
template_params = model.get("template_params")
database_id = model.get("database_id")

# Process Jinja templates if template_params and database_id are provided
if template_params and database_id is not None:
database = DatabaseDAO.find_by_id(database_id)
if database:
try:
template_params = (
json.loads(template_params)
if isinstance(template_params, str)
else template_params
)
if template_params:
template_processor = get_template_processor(
database=database
)
sql = template_processor.process_template(
sql, **template_params
)
except json.JSONDecodeError:
logger.warning(
"Invalid template parameter %s. Skipping processing",
str(template_params),
)

result = SQLScript(sql, model.get("engine")).format()
return self.response(200, result=result)
except ValidationError as error:
return self.response_400(message=error.messages)
Expand Down
8 changes: 8 additions & 0 deletions superset/sqllab/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ class EstimateQueryCostSchema(Schema):
class FormatQueryPayloadSchema(Schema):
sql = fields.String(required=True)
engine = fields.String(required=False, allow_none=True)
database_id = fields.Integer(
required=False, allow_none=True, metadata={"description": "The database id"}
)
template_params = fields.String(
required=False,
allow_none=True,
metadata={"description": "The SQL query template params as JSON string"},
)


class ExecutePayloadSchema(Schema):
Expand Down
19 changes: 19 additions & 0 deletions tests/integration_tests/sql_lab/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,25 @@ def test_format_sql_request(self):
self.assertDictEqual(resp_data, success_resp) # noqa: PT009
assert rv.status_code == 200

def test_format_sql_request_with_jinja(self):
self.login(ADMIN_USERNAME)
example_db = get_example_database()

data = {
"sql": "select * from {{tbl}}",
"database_id": example_db.id,
"template_params": json.dumps({"tbl": '"Vehicle Sales"'}),
}
rv = self.client.post(
"/api/v1/sqllab/format_sql/",
json=data,
)
resp_data = json.loads(rv.data.decode("utf-8"))
# Verify that Jinja template was processed before formatting
assert "{{tbl}}" not in resp_data["result"]
assert '"Vehicle Sales"' in resp_data["result"]
assert rv.status_code == 200

@mock.patch("superset.commands.sql_lab.results.results_backend_use_msgpack", False)
def test_execute_required_params(self):
self.login(ADMIN_USERNAME)
Expand Down
Loading