Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 48 additions & 10 deletions homeassistant/components/python_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import voluptuous as vol

from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import sanitize_filename

DOMAIN = 'python_script'
REQUIREMENTS = ['restrictedpython==4.0a2']
FOLDER = 'python_scripts'
Expand All @@ -14,6 +17,18 @@
DOMAIN: vol.Schema(dict)
}, extra=vol.ALLOW_EXTRA)

ALLOWED_HASS = set(['bus', 'services', 'states'])
ALLOWED_EVENTBUS = set(['fire'])
ALLOWED_STATEMACHINE = set(['entity_ids', 'all', 'get', 'is_state',
'is_state_attr', 'remove', 'set'])
ALLOWED_SERVICEREGISTRY = set(['services', 'has_service', 'call'])


class ScriptError(HomeAssistantError):
"""When a script error occurs."""

pass


def setup(hass, config):
"""Initialize the python_script component."""
Expand All @@ -23,21 +38,27 @@ def setup(hass, config):
_LOGGER.warning('Folder %s not found in config folder', FOLDER)
return False

def service_handler(call):
def python_script_service_handler(call):
"""Handle python script service calls."""
filename = '{}.py'.format(call.service)
with open(hass.config.path(FOLDER, filename)) as fil:
execute(hass, filename, fil.read(), call.data)
execute_script(hass, call.service, call.data)

for fil in glob.iglob(os.path.join(path, '*.py')):
name = os.path.splitext(os.path.basename(fil))[0]
hass.services.register(DOMAIN, name, service_handler)
hass.services.register(DOMAIN, name, python_script_service_handler)

return True


def execute(hass, filename, source, data):
def execute_script(hass, name, data=None):
"""Execute a script."""
filename = '{}.py'.format(name)
with open(hass.config.path(FOLDER, sanitize_filename(filename))) as fil:
source = fil.read()
execute(hass, filename, source, data)


def execute(hass, filename, source, data=None):
"""Execute Python source."""
from RestrictedPython import compile_restricted_exec
from RestrictedPython.Guards import safe_builtins, full_write_guard

Expand All @@ -52,24 +73,41 @@ def execute(hass, filename, source, data):
_LOGGER.warning('Warning loading script %s: %s', filename,
', '.join(compiled.warnings))

def protected_getattr(obj, name, default=None):
"""Restricted method to get attributes."""
# pylint: disable=too-many-boolean-expressions
if name.startswith('async_'):
raise ScriptError('Not allowed to access async methods')
elif (obj is hass and name not in ALLOWED_HASS or
obj is hass.bus and name not in ALLOWED_EVENTBUS or
obj is hass.states and name not in ALLOWED_STATEMACHINE or
obj is hass.services and name not in ALLOWED_SERVICEREGISTRY):
raise ScriptError('Not allowed to access {}.{}'.format(
obj.__class__.__name__, name))

return getattr(obj, name, default)

restricted_globals = {
'__builtins__': safe_builtins,
'_print_': StubPrinter,
'_getattr_': getattr,
'_getattr_': protected_getattr,
'_write_': full_write_guard,
}
logger = logging.getLogger('{}.{}'.format(__name__, filename))
local = {
'hass': hass,
'data': data,
'logger': logging.getLogger('{}.{}'.format(__name__, filename))
'data': data or {},
'logger': logger
}

try:
_LOGGER.info('Executing %s: %s', filename, data)
# pylint: disable=exec-used
exec(compiled.code, restricted_globals, local)
except ScriptError as err:
logger.error('Error executing script: %s', err)
except Exception as err: # pylint: disable=broad-except
_LOGGER.exception('Error executing script %s: %s', filename, err)
logger.exception('Error executing script: %s', err)


class StubPrinter:
Expand Down
30 changes: 29 additions & 1 deletion tests/components/test_python_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,32 @@ def test_execute_runtime_error(hass, caplog):
hass.async_add_job(execute, hass, 'test.py', source, {})
yield from hass.async_block_till_done()

assert "Error executing script test.py" in caplog.text
assert "Error executing script: boom" in caplog.text


@asyncio.coroutine
def test_accessing_async_methods(hass, caplog):
"""Test compile error logs error."""
caplog.set_level(logging.ERROR)
source = """
hass.async_stop()
"""

hass.async_add_job(execute, hass, 'test.py', source, {})
yield from hass.async_block_till_done()

assert "Not allowed to access async methods" in caplog.text


@asyncio.coroutine
def test_accessing_forbidden_methods(hass, caplog):
"""Test compile error logs error."""
caplog.set_level(logging.ERROR)
source = """
hass.stop()
"""

hass.async_add_job(execute, hass, 'test.py', source, {})
yield from hass.async_block_till_done()

assert "Not allowed to access HomeAssistant.stop" in caplog.text