diff --git a/homeassistant/components/pulseaudio_loopback/switch.py b/homeassistant/components/pulseaudio_loopback/switch.py index ec1adc7641bb21..fe4bda69f8a77f 100644 --- a/homeassistant/components/pulseaudio_loopback/switch.py +++ b/homeassistant/components/pulseaudio_loopback/switch.py @@ -2,7 +2,9 @@ from datetime import timedelta import logging import re +import select import socket +from time import monotonic import voluptuous as vol @@ -84,7 +86,7 @@ def __init__(self, host, port, buff_sz, tcp_timeout): self._buffer_size = int(buff_sz) self._tcp_timeout = int(tcp_timeout) - def _send_command(self, cmd, response_expected): + def _send_command(self, cmd, verifier=None): """Send a command to the pa server using a socket.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(self._tcp_timeout) @@ -92,8 +94,8 @@ def _send_command(self, cmd, response_expected): sock.connect((self._pa_host, self._pa_port)) _LOGGER.info("Calling pulseaudio: %s", cmd) sock.send((cmd + "\n").encode("utf-8")) - if response_expected: - return_data = self._get_full_response(sock) + if verifier: + return_data = self._verify_response(sock, verifier) _LOGGER.debug("Data received from pulseaudio: %s", return_data) else: return_data = "" @@ -101,30 +103,57 @@ def _send_command(self, cmd, response_expected): sock.close() return return_data - def _get_full_response(self, sock): + def _verify_response(self, sock, verifier): """Get the full response back from pulseaudio.""" result = "" - rcv_buffer = sock.recv(self._buffer_size) - result += rcv_buffer.decode("utf-8") - while len(rcv_buffer) == self._buffer_size: - rcv_buffer = sock.recv(self._buffer_size) - result += rcv_buffer.decode("utf-8") + sock.setblocking(False) + + start_time = monotonic() + remaining_timeout = float(self._tcp_timeout) + + while remaining_timeout > 0: + ready = select.select([sock], [], [], remaining_timeout) + if ready[0]: + rcv_buffer = sock.recv(self._buffer_size) + result += rcv_buffer.decode("utf-8") + else: + break + + if verifier(result): + break + + remaining_timeout -= monotonic() - start_time return result @util.Throttle(MIN_TIME_BETWEEN_SCANS, MIN_TIME_BETWEEN_FORCED_SCANS) def update_module_state(self): """Refresh state in case an alternate process modified this data.""" - self._current_module_state = self._send_command("list-modules", True) + + def verify_module_list(response): + """Test if result for 'list-modules' call is complete.""" + try: + expected_modules = int( + re.match(r"^(\d*)\smodule\(s\) loaded.*", response).group(1) + ) + actual_modules = int(len(re.findall(r"\n\s*index:\s\d*", response))) + + return (expected_modules == actual_modules) and response[-1] == "\n" + except AttributeError: + return False + + self._current_module_state = self._send_command( + "list-modules", verify_module_list + ) def turn_on(self, sink_name, source_name): """Send a command to pulseaudio to turn on the loopback.""" - self._send_command(str.format(LOAD_CMD, sink_name, source_name), False) + self._send_command(str.format(LOAD_CMD, sink_name, source_name)) def turn_off(self, module_idx): """Send a command to pulseaudio to turn off the loopback.""" - self._send_command(str.format(UNLOAD_CMD, module_idx), False) + self._send_command(str.format(UNLOAD_CMD, module_idx)) def get_module_idx(self, sink_name, source_name): """For a sink/source, return its module id in our cache, if found."""