Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions homeassistant/components/sql/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from homeassistant import config_entries
from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL
from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import selector

Expand Down Expand Up @@ -44,24 +44,23 @@ def validate_sql_select(value: str) -> str | None:

def validate_query(db_url: str, query: str, column: str) -> bool:
"""Validate SQL query."""
try:
engine = sqlalchemy.create_engine(db_url, future=True)
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
except SQLAlchemyError as error:
raise error

engine = sqlalchemy.create_engine(db_url, future=True)
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
sess: scoped_session = sessmaker()

try:
result: Result = sess.execute(sqlalchemy.text(query))
for res in result.mappings():
data = res[column]
_LOGGER.debug("Return value from query: %s", data)
except SQLAlchemyError as error:
_LOGGER.debug("Execution error %s", error)
if sess:
sess.close()
raise ValueError(error) from error

for res in result.mappings():
data = res[column]
_LOGGER.debug("Return value from query: %s", data)

if sess:
sess.close()

Expand All @@ -73,9 +72,6 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):

VERSION = 1

entry: config_entries.ConfigEntry
hass: HomeAssistant

@staticmethod
@callback
def async_get_options_flow(
Expand Down
22 changes: 10 additions & 12 deletions homeassistant/components/sql/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging

import sqlalchemy
from sqlalchemy.engine import Result
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import scoped_session, sessionmaker
import voluptuous as vol
Expand Down Expand Up @@ -73,11 +74,11 @@ async def async_setup_platform(
for query in config[CONF_QUERIES]:
new_config = {
CONF_DB_URL: config.get(CONF_DB_URL, default_db_url),
CONF_NAME: query.get(CONF_NAME),
CONF_QUERY: query.get(CONF_QUERY),
CONF_NAME: query[CONF_NAME],
CONF_QUERY: query[CONF_QUERY],
CONF_UNIT_OF_MEASUREMENT: query.get(CONF_UNIT_OF_MEASUREMENT),
CONF_VALUE_TEMPLATE: query.get(CONF_VALUE_TEMPLATE),
CONF_COLUMN_NAME: query.get(CONF_COLUMN_NAME),
CONF_COLUMN_NAME: query[CONF_COLUMN_NAME],
}
hass.async_create_task(
hass.config_entries.flow.async_init(
Expand Down Expand Up @@ -119,11 +120,10 @@ async def async_setup_entry(

# MSSQL uses TOP and not LIMIT
if not ("LIMIT" in query_str.upper() or "SELECT TOP" in query_str.upper()):
query_str = (
query_str.replace("SELECT", "SELECT TOP 1")
if "mssql" in db_url
else query_str.replace(";", " LIMIT 1;")
)
if "mssql" in db_url:
query_str = query_str.upper().replace("SELECT", "SELECT TOP 1")
else:
query_str = query_str.replace(";", "") + " LIMIT 1;"

async_add_entities(
[
Expand Down Expand Up @@ -179,7 +179,7 @@ def update(self) -> None:
self._attr_extra_state_attributes = {}
sess: scoped_session = self.sessionmaker()
try:
result = sess.execute(sqlalchemy.text(self._query))
result: Result = sess.execute(sqlalchemy.text(self._query))
except SQLAlchemyError as err:
_LOGGER.error(
"Error executing query %s: %s",
Expand All @@ -188,10 +188,8 @@ def update(self) -> None:
)
return

_LOGGER.debug("Result %s, ResultMapping %s", result, result.mappings())

for res in result.mappings():
_LOGGER.debug("result = %s", res.items())
_LOGGER.debug("Query %s result in %s", self._query, res.items())
data = res[self._column_name]
for key, value in res.items():
if isinstance(value, decimal.Decimal):
Expand Down
6 changes: 2 additions & 4 deletions homeassistant/components/sql/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
},
"error": {
"db_url_invalid": "Database URL invalid",
"query_invalid": "SQL Query invalid",
"value_template_invalid": "Value Template invalid"
"query_invalid": "SQL Query invalid"
},
"step": {
"user": {
Expand Down Expand Up @@ -52,8 +51,7 @@
},
"error": {
"db_url_invalid": "[%key:component::sql::config::error::db_url_invalid%]",
"query_invalid": "[%key:component::sql::config::error::query_invalid%]",
"value_template_invalid": "[%key:component::sql::config::error::value_template_invalid%]"
"query_invalid": "[%key:component::sql::config::error::query_invalid%]"
}
}
}
6 changes: 2 additions & 4 deletions homeassistant/components/sql/translations/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
},
"error": {
"db_url_invalid": "Database URL invalid",
"query_invalid": "SQL Query invalid",
"value_template_invalid": "Value Template invalid"
"query_invalid": "SQL Query invalid"
},
"step": {
"user": {
Expand All @@ -32,8 +31,7 @@
"options": {
"error": {
"db_url_invalid": "Database URL invalid",
"query_invalid": "SQL Query invalid",
"value_template_invalid": "Value Template invalid"
"query_invalid": "SQL Query invalid"
},
"step": {
"init": {
Expand Down
1 change: 0 additions & 1 deletion tests/components/sql/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ async def test_form(hass: HomeAssistant) -> None:
ENTRY_CONFIG,
)
await hass.async_block_till_done()
print(ENTRY_CONFIG)

assert result2["type"] == RESULT_TYPE_CREATE_ENTRY
assert result2["title"] == "Get Value"
Expand Down
25 changes: 24 additions & 1 deletion tests/components/sql/test_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,30 @@ async def test_query_no_value(
state = hass.states.get("sensor.count_tables")
assert state.state == STATE_UNKNOWN

text = "SELECT 5 as value where 1=2 returned no results"
text = "SELECT 5 as value where 1=2 LIMIT 1; returned no results"
assert text in caplog.text


async def test_query_mssql_no_result(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test the SQL sensor with a query that returns no value."""
config = {
"db_url": "mssql://",
"query": "SELECT 5 as value where 1=2",
"column": "value",
"name": "count_tables",
}
with patch("homeassistant.components.sql.sensor.sqlalchemy"), patch(
"homeassistant.components.sql.sensor.sqlalchemy.text",
return_value="SELECT TOP 1 5 as value where 1=2",
):
await init_integration(hass, config)

state = hass.states.get("sensor.count_tables")
assert state.state == STATE_UNKNOWN

text = "SELECT TOP 1 5 AS VALUE WHERE 1=2 returned no results"
assert text in caplog.text


Expand Down