diff --git a/adafruit_midi/__init__.py b/adafruit_midi/__init__.py index 48fe745..fa6b4cd 100755 --- a/adafruit_midi/__init__.py +++ b/adafruit_midi/__init__.py @@ -25,6 +25,10 @@ https://github.com/adafruit/circuitpython/releases """ +try: + from typing import Union, Tuple, Any, List, Optional, Dict, BinaryIO +except ImportError: + pass from .midi_message import MIDIMessage @@ -54,13 +58,13 @@ class MIDI: def __init__( self, - midi_in=None, - midi_out=None, + midi_in: Optional[BinaryIO] = None, + midi_out: Optional[BinaryIO] = None, *, - in_channel=None, - out_channel=0, - in_buf_size=30, - debug=False + in_channel: Optional[Union[int, Tuple[int, ...]]] = None, + out_channel: int = 0, + in_buf_size: int = 30, + debug: bool = False ): if midi_in is None and midi_out is None: raise ValueError("No midi_in or midi_out provided") @@ -78,7 +82,7 @@ def __init__( self._skipped_bytes = 0 @property - def in_channel(self): + def in_channel(self) -> Optional[Union[int, Tuple[int, ...]]]: """The incoming MIDI channel. Must be 0-15. Correlates to MIDI channels 1-16, e.g. ``in_channel = 3`` will listen on MIDI channel 4. Can also listen on multiple channels, e.g. ``in_channel = (0,1,2)`` @@ -87,7 +91,7 @@ def in_channel(self): return self._in_channel @in_channel.setter - def in_channel(self, channel): + def in_channel(self, channel: Optional[Union[str, int, Tuple[int, ...]]]) -> None: if channel is None or channel == "ALL": self._in_channel = tuple(range(16)) elif isinstance(channel, int) and 0 <= channel <= 15: @@ -98,19 +102,19 @@ def in_channel(self, channel): raise RuntimeError("Invalid input channel") @property - def out_channel(self): + def out_channel(self) -> int: """The outgoing MIDI channel. Must be 0-15. Correlates to MIDI channels 1-16, e.g. ``out_channel = 3`` will send to MIDI channel 4. Default is 0 (MIDI channel 1). """ return self._out_channel @out_channel.setter - def out_channel(self, channel): + def out_channel(self, channel: int) -> None: if not 0 <= channel <= 15: raise RuntimeError("Invalid output channel") self._out_channel = channel - def receive(self): + def receive(self) -> Optional[MIDIMessage]: """Read messages from MIDI port, store them in internal read buffer, then parse that data and return the first MIDI message (event). This maintains the blocking characteristics of the midi_in port. @@ -141,7 +145,7 @@ def receive(self): # msg could still be None at this point, e.g. in middle of monster SysEx return msg - def send(self, msg, channel=None): + def send(self, msg: MIDIMessage, channel: Optional[int] = None) -> None: """Sends a MIDI message. :param msg: Either a MIDIMessage object or a sequence (list) of MIDIMessage objects. @@ -165,7 +169,7 @@ def send(self, msg, channel=None): self._send(data, len(data)) - def _send(self, packet, num): + def _send(self, packet: bytes, num: int) -> None: if self._debug: print("Sending: ", [hex(i) for i in packet[:num]]) self._midi_out.write(packet, num) diff --git a/adafruit_midi/midi_message.py b/adafruit_midi/midi_message.py index d9ee4de..315efb0 100755 --- a/adafruit_midi/midi_message.py +++ b/adafruit_midi/midi_message.py @@ -25,12 +25,19 @@ __version__ = "0.0.0+auto.0" __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_MIDI.git" +try: + from typing import Union, Tuple, Any, List, Optional +except ImportError: + pass + # From C3 - A and B are above G # Semitones A B C D E F G NOTE_OFFSET = [21, 23, 12, 14, 16, 17, 19] -def channel_filter(channel, channel_spec): +def channel_filter( + channel: int, channel_spec: Optional[Union[int, Tuple[int, ...]]] +) -> bool: """ Utility function to return True iff the given channel matches channel_spec. """ @@ -41,13 +48,12 @@ def channel_filter(channel, channel_spec): raise ValueError("Incorrect type for channel_spec" + str(type(channel_spec))) -def note_parser(note): +def note_parser(note: Union[int, str]) -> int: """If note is a string then it will be parsed and converted to a MIDI note (key) number, e.g. "C4" will return 60, "C#4" will return 61. If note is not a string it will simply be returned. :param note: Either 0-127 int or a str representing the note, e.g. "C#4" """ - midi_note = note if isinstance(note, str): if len(note) < 2: raise ValueError("Bad note format") @@ -61,7 +67,8 @@ def note_parser(note): sharpen = -1 # int may throw exception here midi_note = int(note[1 + abs(sharpen) :]) * 12 + NOTE_OFFSET[noteidx] + sharpen - + elif isinstance(note, int): + midi_note = note return midi_note @@ -82,40 +89,43 @@ class MIDIMessage: This is an *abstract* class. """ - _STATUS = None + _STATUS: Optional[int] = None _STATUSMASK = None - LENGTH = None + LENGTH: Optional[int] = None CHANNELMASK = 0x0F ENDSTATUS = None # Commonly used exceptions to save memory @staticmethod - def _raise_valueerror_oor(): + def _raise_valueerror_oor() -> None: raise ValueError("Out of range") # Each element is ((status, mask), class) # order is more specific masks first - _statusandmask_to_class = [] + # Add better type hints for status, mask, class referenced above + _statusandmask_to_class: List[ + Tuple[Tuple[Optional[bytes], Optional[int]], "MIDIMessage"] + ] = [] - def __init__(self, *, channel=None): + def __init__(self, *, channel: Optional[int] = None) -> None: self._channel = channel # dealing with pylint inadequacy self.channel = channel @property - def channel(self): + def channel(self) -> Optional[int]: """The channel number of the MIDI message where appropriate. This is *updated* by MIDI.send() method. """ return self._channel @channel.setter - def channel(self, channel): + def channel(self, channel: int) -> None: if channel is not None and not 0 <= channel <= 15: raise ValueError("Channel must be 0-15 or None") self._channel = channel @classmethod - def register_message_type(cls): + def register_message_type(cls) -> None: """Register a new message by its status value and mask. This is called automagically at ``import`` time for each message. """ @@ -132,7 +142,14 @@ def register_message_type(cls): # pylint: disable=too-many-arguments @classmethod - def _search_eom_status(cls, buf, eom_status, msgstartidx, msgendidxplusone, endidx): + def _search_eom_status( + cls, + buf: bytearray, + eom_status: Optional[int], + msgstartidx: int, + msgendidxplusone: int, + endidx: int, + ) -> Tuple[int, bool, bool]: good_termination = False bad_termination = False @@ -155,7 +172,9 @@ def _search_eom_status(cls, buf, eom_status, msgstartidx, msgendidxplusone, endi return (msgendidxplusone, good_termination, bad_termination) @classmethod - def _match_message_status(cls, buf, msgstartidx, msgendidxplusone, endidx): + def _match_message_status( + cls, buf: bytearray, msgstartidx: int, msgendidxplusone: int, endidx: int + ) -> Tuple[Optional[Any], int, bool, bool, bool, int]: msgclass = None status = buf[msgstartidx] known_msg = False @@ -198,7 +217,9 @@ def _match_message_status(cls, buf, msgstartidx, msgendidxplusone, endidx): # pylint: disable=too-many-locals,too-many-branches @classmethod - def from_message_bytes(cls, midibytes, channel_in): + def from_message_bytes( + cls, midibytes: bytearray, channel_in: Optional[Union[int, Tuple[int, ...]]] + ) -> Tuple[Optional["MIDIMessage"], int, int]: """Create an appropriate object of the correct class for the first message found in some MIDI bytes filtered by channel_in. @@ -270,7 +291,7 @@ def from_message_bytes(cls, midibytes, channel_in): # A default method for constructing wire messages with no data. # Returns an (immutable) bytes with just the status code in. - def __bytes__(self): + def __bytes__(self) -> bytes: """Return the ``bytes`` wire protocol representation of the object with channel number applied where appropriate.""" return bytes([self._STATUS]) @@ -280,12 +301,12 @@ def __bytes__(self): # Returns the new object. # pylint: disable=unused-argument @classmethod - def from_bytes(cls, msg_bytes): + def from_bytes(cls, msg_bytes: bytes) -> "MIDIMessage": """Creates an object from the byte stream of the wire protocol representation of the MIDI message.""" return cls() - def __str__(self): + def __str__(self) -> str: """Print an instance""" cls = self.__class__ if slots := getattr(cls, "_message_slots", None): @@ -313,7 +334,7 @@ class MIDIUnknownEvent(MIDIMessage): _message_slots = ["status"] LENGTH = -1 - def __init__(self, status): + def __init__(self, status: int): self.status = status super().__init__() @@ -333,7 +354,7 @@ class MIDIBadEvent(MIDIMessage): _message_slots = ["msg_bytes", "exception"] - def __init__(self, msg_bytes, exception): + def __init__(self, msg_bytes: bytearray, exception: Exception): self.data = bytes(msg_bytes) self.exception_text = repr(exception) super().__init__()