Skip to content

Commit

Permalink
Merge pull request #52 from jordanhemingway-revvity/type_annotations
Browse files Browse the repository at this point in the history
added typehints to midimessage and init
  • Loading branch information
FoamyGuy authored Aug 5, 2024
2 parents 36f8d5a + 9817d84 commit 7d4c2c2
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 33 deletions.
30 changes: 17 additions & 13 deletions adafruit_midi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
https://github.com/adafruit/circuitpython/releases
"""
try:
from typing import Union, Tuple, Any, List, Optional, Dict, BinaryIO
except ImportError:
pass

from .midi_message import MIDIMessage

Expand Down Expand Up @@ -54,13 +58,13 @@ class MIDI:

def __init__(
self,
midi_in=None,
midi_out=None,
midi_in: Optional[BinaryIO] = None,
midi_out: Optional[BinaryIO] = None,
*,
in_channel=None,
out_channel=0,
in_buf_size=30,
debug=False
in_channel: Optional[Union[int, Tuple[int, ...]]] = None,
out_channel: int = 0,
in_buf_size: int = 30,
debug: bool = False
):
if midi_in is None and midi_out is None:
raise ValueError("No midi_in or midi_out provided")
Expand All @@ -78,7 +82,7 @@ def __init__(
self._skipped_bytes = 0

@property
def in_channel(self):
def in_channel(self) -> Optional[Union[int, Tuple[int, ...]]]:
"""The incoming MIDI channel. Must be 0-15. Correlates to MIDI channels 1-16, e.g.
``in_channel = 3`` will listen on MIDI channel 4.
Can also listen on multiple channels, e.g. ``in_channel = (0,1,2)``
Expand All @@ -87,7 +91,7 @@ def in_channel(self):
return self._in_channel

@in_channel.setter
def in_channel(self, channel):
def in_channel(self, channel: Optional[Union[str, int, Tuple[int, ...]]]) -> None:
if channel is None or channel == "ALL":
self._in_channel = tuple(range(16))
elif isinstance(channel, int) and 0 <= channel <= 15:
Expand All @@ -98,19 +102,19 @@ def in_channel(self, channel):
raise RuntimeError("Invalid input channel")

@property
def out_channel(self):
def out_channel(self) -> int:
"""The outgoing MIDI channel. Must be 0-15. Correlates to MIDI channels 1-16, e.g.
``out_channel = 3`` will send to MIDI channel 4. Default is 0 (MIDI channel 1).
"""
return self._out_channel

@out_channel.setter
def out_channel(self, channel):
def out_channel(self, channel: int) -> None:
if not 0 <= channel <= 15:
raise RuntimeError("Invalid output channel")
self._out_channel = channel

def receive(self):
def receive(self) -> Optional[MIDIMessage]:
"""Read messages from MIDI port, store them in internal read buffer, then parse that data
and return the first MIDI message (event).
This maintains the blocking characteristics of the midi_in port.
Expand Down Expand Up @@ -141,7 +145,7 @@ def receive(self):
# msg could still be None at this point, e.g. in middle of monster SysEx
return msg

def send(self, msg, channel=None):
def send(self, msg: MIDIMessage, channel: Optional[int] = None) -> None:
"""Sends a MIDI message.
:param msg: Either a MIDIMessage object or a sequence (list) of MIDIMessage objects.
Expand All @@ -165,7 +169,7 @@ def send(self, msg, channel=None):

self._send(data, len(data))

def _send(self, packet, num):
def _send(self, packet: bytes, num: int) -> None:
if self._debug:
print("Sending: ", [hex(i) for i in packet[:num]])
self._midi_out.write(packet, num)
61 changes: 41 additions & 20 deletions adafruit_midi/midi_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,19 @@
__version__ = "0.0.0+auto.0"
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_MIDI.git"

try:
from typing import Union, Tuple, Any, List, Optional
except ImportError:
pass

# From C3 - A and B are above G
# Semitones A B C D E F G
NOTE_OFFSET = [21, 23, 12, 14, 16, 17, 19]


def channel_filter(channel, channel_spec):
def channel_filter(
channel: int, channel_spec: Optional[Union[int, Tuple[int, ...]]]
) -> bool:
"""
Utility function to return True iff the given channel matches channel_spec.
"""
Expand All @@ -41,13 +48,12 @@ def channel_filter(channel, channel_spec):
raise ValueError("Incorrect type for channel_spec" + str(type(channel_spec)))


def note_parser(note):
def note_parser(note: Union[int, str]) -> int:
"""If note is a string then it will be parsed and converted to a MIDI note (key) number, e.g.
"C4" will return 60, "C#4" will return 61. If note is not a string it will simply be returned.
:param note: Either 0-127 int or a str representing the note, e.g. "C#4"
"""
midi_note = note
if isinstance(note, str):
if len(note) < 2:
raise ValueError("Bad note format")
Expand All @@ -61,7 +67,8 @@ def note_parser(note):
sharpen = -1
# int may throw exception here
midi_note = int(note[1 + abs(sharpen) :]) * 12 + NOTE_OFFSET[noteidx] + sharpen

elif isinstance(note, int):
midi_note = note
return midi_note


Expand All @@ -82,40 +89,43 @@ class MIDIMessage:
This is an *abstract* class.
"""

_STATUS = None
_STATUS: Optional[int] = None
_STATUSMASK = None
LENGTH = None
LENGTH: Optional[int] = None
CHANNELMASK = 0x0F
ENDSTATUS = None

# Commonly used exceptions to save memory
@staticmethod
def _raise_valueerror_oor():
def _raise_valueerror_oor() -> None:
raise ValueError("Out of range")

# Each element is ((status, mask), class)
# order is more specific masks first
_statusandmask_to_class = []
# Add better type hints for status, mask, class referenced above
_statusandmask_to_class: List[
Tuple[Tuple[Optional[bytes], Optional[int]], "MIDIMessage"]
] = []

def __init__(self, *, channel=None):
def __init__(self, *, channel: Optional[int] = None) -> None:
self._channel = channel # dealing with pylint inadequacy
self.channel = channel

@property
def channel(self):
def channel(self) -> Optional[int]:
"""The channel number of the MIDI message where appropriate.
This is *updated* by MIDI.send() method.
"""
return self._channel

@channel.setter
def channel(self, channel):
def channel(self, channel: int) -> None:
if channel is not None and not 0 <= channel <= 15:
raise ValueError("Channel must be 0-15 or None")
self._channel = channel

@classmethod
def register_message_type(cls):
def register_message_type(cls) -> None:
"""Register a new message by its status value and mask.
This is called automagically at ``import`` time for each message.
"""
Expand All @@ -132,7 +142,14 @@ def register_message_type(cls):

# pylint: disable=too-many-arguments
@classmethod
def _search_eom_status(cls, buf, eom_status, msgstartidx, msgendidxplusone, endidx):
def _search_eom_status(
cls,
buf: bytearray,
eom_status: Optional[int],
msgstartidx: int,
msgendidxplusone: int,
endidx: int,
) -> Tuple[int, bool, bool]:
good_termination = False
bad_termination = False

Expand All @@ -155,7 +172,9 @@ def _search_eom_status(cls, buf, eom_status, msgstartidx, msgendidxplusone, endi
return (msgendidxplusone, good_termination, bad_termination)

@classmethod
def _match_message_status(cls, buf, msgstartidx, msgendidxplusone, endidx):
def _match_message_status(
cls, buf: bytearray, msgstartidx: int, msgendidxplusone: int, endidx: int
) -> Tuple[Optional[Any], int, bool, bool, bool, int]:
msgclass = None
status = buf[msgstartidx]
known_msg = False
Expand Down Expand Up @@ -198,7 +217,9 @@ def _match_message_status(cls, buf, msgstartidx, msgendidxplusone, endidx):

# pylint: disable=too-many-locals,too-many-branches
@classmethod
def from_message_bytes(cls, midibytes, channel_in):
def from_message_bytes(
cls, midibytes: bytearray, channel_in: Optional[Union[int, Tuple[int, ...]]]
) -> Tuple[Optional["MIDIMessage"], int, int]:
"""Create an appropriate object of the correct class for the
first message found in some MIDI bytes filtered by channel_in.
Expand Down Expand Up @@ -270,7 +291,7 @@ def from_message_bytes(cls, midibytes, channel_in):

# A default method for constructing wire messages with no data.
# Returns an (immutable) bytes with just the status code in.
def __bytes__(self):
def __bytes__(self) -> bytes:
"""Return the ``bytes`` wire protocol representation of the object
with channel number applied where appropriate."""
return bytes([self._STATUS])
Expand All @@ -280,12 +301,12 @@ def __bytes__(self):
# Returns the new object.
# pylint: disable=unused-argument
@classmethod
def from_bytes(cls, msg_bytes):
def from_bytes(cls, msg_bytes: bytes) -> "MIDIMessage":
"""Creates an object from the byte stream of the wire protocol
representation of the MIDI message."""
return cls()

def __str__(self):
def __str__(self) -> str:
"""Print an instance"""
cls = self.__class__
if slots := getattr(cls, "_message_slots", None):
Expand Down Expand Up @@ -313,7 +334,7 @@ class MIDIUnknownEvent(MIDIMessage):
_message_slots = ["status"]
LENGTH = -1

def __init__(self, status):
def __init__(self, status: int):
self.status = status
super().__init__()

Expand All @@ -333,7 +354,7 @@ class MIDIBadEvent(MIDIMessage):

_message_slots = ["msg_bytes", "exception"]

def __init__(self, msg_bytes, exception):
def __init__(self, msg_bytes: bytearray, exception: Exception):
self.data = bytes(msg_bytes)
self.exception_text = repr(exception)
super().__init__()

0 comments on commit 7d4c2c2

Please sign in to comment.