Skip to content
Merged
1 change: 0 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ omit =
homeassistant/components/canary/alarm_control_panel.py
homeassistant/components/canary/camera.py
homeassistant/components/cast/*
homeassistant/components/cert_expiry/sensor.py
homeassistant/components/cert_expiry/helper.py
homeassistant/components/channels/*
homeassistant/components/cisco_ios/device_tracker.py
Expand Down
95 changes: 36 additions & 59 deletions homeassistant/components/cert_expiry/config_flow.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
"""Config flow for the Cert Expiry platform."""
import logging
import socket
import ssl

import voluptuous as vol

from homeassistant import config_entries
from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PORT
from homeassistant.core import HomeAssistant, callback
from homeassistant.const import CONF_HOST, CONF_PORT

from .const import DEFAULT_NAME, DEFAULT_PORT, DOMAIN
from .helper import get_cert
from .const import DEFAULT_PORT, DOMAIN # pylint: disable=unused-import
from .errors import (
ConnectionRefused,
ConnectionTimeout,
ResolveFailed,
ValidationFailure,
)
from .helper import get_cert_time_to_expiry

_LOGGER = logging.getLogger(__name__)


@callback
def certexpiry_entries(hass: HomeAssistant):
"""Return the host,port tuples for the domain."""
return set(
(entry.data[CONF_HOST], entry.data[CONF_PORT])
for entry in hass.config_entries.async_entries(DOMAIN)
)


class CertexpiryConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow."""

Expand All @@ -34,69 +28,54 @@ def __init__(self) -> None:
"""Initialize the config flow."""
self._errors = {}

def _prt_in_configuration_exists(self, user_input) -> bool:
"""Return True if host, port combination exists in configuration."""
host = user_input[CONF_HOST]
port = user_input.get(CONF_PORT, DEFAULT_PORT)
if (host, port) in certexpiry_entries(self.hass):
return True
return False

async def _test_connection(self, user_input=None):
"""Test connection to the server and try to get the certtificate."""
host = user_input[CONF_HOST]
"""Test connection to the server and try to get the certificate."""
try:
await self.hass.async_add_executor_job(
get_cert, host, user_input.get(CONF_PORT, DEFAULT_PORT)
await get_cert_time_to_expiry(
self.hass,
user_input[CONF_HOST],
user_input.get(CONF_PORT, DEFAULT_PORT),
)
return True
except socket.gaierror:
_LOGGER.error("Host cannot be resolved: %s", host)
except ResolveFailed:
self._errors[CONF_HOST] = "resolve_failed"
except socket.timeout:
_LOGGER.error("Timed out connecting to %s", host)
except ConnectionTimeout:
self._errors[CONF_HOST] = "connection_timeout"
except ssl.CertificateError as err:
if "doesn't match" in err.args[0]:
_LOGGER.error("Certificate does not match host: %s", host)
self._errors[CONF_HOST] = "wrong_host"
else:
_LOGGER.error("Certificate could not be validated: %s", host)
self._errors[CONF_HOST] = "certificate_error"
except ssl.SSLError:
_LOGGER.error("Certificate could not be validated: %s", host)
self._errors[CONF_HOST] = "certificate_error"
except ConnectionRefused:
self._errors[CONF_HOST] = "connection_refused"
except ValidationFailure:
return True
return False

async def async_step_user(self, user_input=None):
"""Step when user initializes a integration."""
self._errors = {}
if user_input is not None:
# set some defaults in case we need to return to the form
if self._prt_in_configuration_exists(user_input):
self._errors[CONF_HOST] = "host_port_exists"
else:
if await self._test_connection(user_input):
return self.async_create_entry(
title=user_input.get(CONF_NAME, DEFAULT_NAME),
Comment thread
jjlawren marked this conversation as resolved.
data={
CONF_HOST: user_input[CONF_HOST],
CONF_PORT: user_input.get(CONF_PORT, DEFAULT_PORT),
},
)
host = user_input[CONF_HOST]
port = user_input.get(CONF_PORT, DEFAULT_PORT)
await self.async_set_unique_id(f"{host}:{port}")
self._abort_if_unique_id_configured()

if await self._test_connection(user_input):
title_port = f":{port}" if port != DEFAULT_PORT else ""
title = f"{host}{title_port}"
return self.async_create_entry(
title=title, data={CONF_HOST: host, CONF_PORT: port},
)
if ( # pylint: disable=no-member
self.context["source"] == config_entries.SOURCE_IMPORT
):
_LOGGER.error("Config import failed for %s", user_input[CONF_HOST])
return self.async_abort(reason="import_failed")
else:
user_input = {}
user_input[CONF_NAME] = DEFAULT_NAME
user_input[CONF_HOST] = ""
user_input[CONF_PORT] = DEFAULT_PORT

return self.async_show_form(
step_id="user",
data_schema=vol.Schema(
{
vol.Required(
CONF_NAME, default=user_input.get(CONF_NAME, DEFAULT_NAME)
): str,
vol.Required(CONF_HOST, default=user_input[CONF_HOST]): str,
vol.Required(
CONF_PORT, default=user_input.get(CONF_PORT, DEFAULT_PORT)
Expand All @@ -111,6 +90,4 @@ async def async_step_import(self, user_input=None):

Only host was required in the yaml file all other fields are optional
"""
if self._prt_in_configuration_exists(user_input):
return self.async_abort(reason="host_port_exists")
return await self.async_step_user(user_input)
1 change: 0 additions & 1 deletion homeassistant/components/cert_expiry/const.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Const for Cert Expiry."""

DOMAIN = "cert_expiry"
DEFAULT_NAME = "SSL Certificate Expiry"
DEFAULT_PORT = 443
TIMEOUT = 10.0
26 changes: 26 additions & 0 deletions homeassistant/components/cert_expiry/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Errors for the cert_expiry integration."""
from homeassistant.exceptions import HomeAssistantError


class CertExpiryException(HomeAssistantError):
"""Base class for cert_expiry exceptions."""


class TemporaryFailure(CertExpiryException):
"""Temporary failure has occurred."""


class ValidationFailure(CertExpiryException):
"""Certificate validation failure has occurred."""


class ResolveFailed(TemporaryFailure):
"""Name resolution failed."""


class ConnectionTimeout(TemporaryFailure):
"""Network connection timed out."""


class ConnectionRefused(TemporaryFailure):
"""Network connection refused."""
30 changes: 29 additions & 1 deletion homeassistant/components/cert_expiry/helper.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,44 @@
"""Helper functions for the Cert Expiry platform."""
from datetime import datetime
import socket
import ssl

from .const import TIMEOUT
from .errors import (
ConnectionRefused,
ConnectionTimeout,
ResolveFailed,
ValidationFailure,
)


def get_cert(host, port):
"""Get the ssl certificate for the host and port combination."""
"""Get the certificate for the host and port combination."""
ctx = ssl.create_default_context()
address = (host, port)
with socket.create_connection(address, timeout=TIMEOUT) as sock:
with ctx.wrap_socket(sock, server_hostname=address[0]) as ssock:
# pylint disable: https://github.com/PyCQA/pylint/issues/3166
cert = ssock.getpeercert() # pylint: disable=no-member
return cert


async def get_cert_time_to_expiry(hass, hostname, port):
"""Return the certificate's time to expiry in days."""
try:
cert = await hass.async_add_executor_job(get_cert, hostname, port)
except socket.gaierror:
raise ResolveFailed(f"Cannot resolve hostname: {hostname}")
except socket.timeout:
raise ConnectionTimeout(f"Connection timeout with server: {hostname}:{port}")
except ConnectionRefusedError:
raise ConnectionRefused(f"Connection refused by server: {hostname}:{port}")
except ssl.CertificateError as err:
raise ValidationFailure(err.verify_message)
except ssl.SSLError as err:
raise ValidationFailure(err.args[0])

ts_seconds = ssl.cert_time_to_seconds(cert["notAfter"])
timestamp = datetime.fromtimestamp(ts_seconds)
expiry = timestamp - datetime.today()
return expiry.days
Loading