diff --git a/bumble/transport/pyusb.py b/bumble/transport/pyusb.py index 61ce17e2..68a1dfd9 100644 --- a/bumble/transport/pyusb.py +++ b/bumble/transport/pyusb.py @@ -23,11 +23,24 @@ import usb.core import usb.util +from typing import Optional +from usb.core import Device as UsbDevice +from usb.core import USBError +from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER +from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB + from .common import Transport, ParserSource from .. import hci from ..colors import color +# ----------------------------------------------------------------------------- +# Constant +# ----------------------------------------------------------------------------- +USB_PORT_FEATURE_POWER = 8 +POWER_CYCLE_DELAY = 1 +RESET_DELAY = 3 + # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- @@ -214,6 +227,10 @@ async def close(self): usb_find = libusb_package.find # Find the device according to the spec moniker + power_cycle = False + if spec.startswith('!'): + power_cycle = True + spec = spec[1:] if ':' in spec: vendor_id, product_id = spec.split(':') device = usb_find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16)) @@ -245,6 +262,14 @@ def device_path(device): raise ValueError('device not found') logger.debug(f'USB Device: {device}') + # Power Cycle the device + if power_cycle: + try: + device = await _power_cycle(device) # type: ignore + except Exception as e: + logging.debug(e) + logging.info(f"Unable to power cycle {hex(device.idVendor)} {hex(device.idProduct)}") # type: ignore + # Collect the metadata device_metadata = {'vendor_id': device.idVendor, 'product_id': device.idProduct} @@ -308,3 +333,73 @@ def device_path(device): packet_sink.start() return UsbTransport(device, packet_source, packet_sink) + + +async def _power_cycle(device: UsbDevice) -> UsbDevice: + """ + For devices connected to compatible USB hubs: Performs a power cycle on a given USB device. + This involves temporarily disabling its port on the hub and then re-enabling it. + """ + device_path = f'{device.bus}-{".".join(map(str, device.port_numbers))}' # type: ignore + hub = _find_hub_by_device_path(device_path) + + if hub: + try: + device_port = device.port_numbers[-1] # type: ignore + _set_port_status(hub, device_port, False) + await asyncio.sleep(POWER_CYCLE_DELAY) + _set_port_status(hub, device_port, True) + await asyncio.sleep(RESET_DELAY) + + # Device needs to be find again otherwise it will appear as disconnected + return usb.core.find(idVendor=device.idVendor, idProduct=device.idProduct) # type: ignore + except USBError as e: + logger.error(f"Adjustment needed: Please revise the udev rule for device {hex(device.idVendor)}:{hex(device.idProduct)} for proper recognition.") # type: ignore + logger.error(e) + + return device + + +def _set_port_status(device: UsbDevice, port: int, on: bool): + """Sets the power status of a specific port on a USB hub.""" + device.ctrl_transfer( + bmRequestType=CTRL_TYPE_CLASS | CTRL_RECIPIENT_OTHER, + bRequest=REQ_SET_FEATURE if on else REQ_CLEAR_FEATURE, + wIndex=port, + wValue=USB_PORT_FEATURE_POWER, + ) + + +def _find_device_by_path(sys_path: str) -> Optional[UsbDevice]: + """Finds a USB device based on its system path.""" + bus_num, *port_parts = sys_path.split('-') + ports = [int(port) for port in port_parts[0].split('.')] + devices = usb.core.find(find_all=True, bus=int(bus_num)) + if devices: + for device in devices: + if device.bus == int(bus_num) and list(device.port_numbers) == ports: # type: ignore + return device + + return None + + +def _find_hub_by_device_path(sys_path: str) -> Optional[UsbDevice]: + """Finds the USB hub associated with a specific device path.""" + hub_sys_path = sys_path.rsplit('.', 1)[0] + hub_device = _find_device_by_path(hub_sys_path) + + if hub_device is None: + return None + else: + return hub_device if _is_hub(hub_device) else None + + +def _is_hub(device: UsbDevice) -> bool: + """Checks if a USB device is a hub""" + if device.bDeviceClass == CLASS_HUB: # type: ignore + return True + for config in device: + for interface in config: + if interface.bInterfaceClass == CLASS_HUB: # type: ignore + return True + return False