diff --git a/pretty_midi/pretty_midi.py b/pretty_midi/pretty_midi.py index 94ee4dd..d1b0f3a 100644 --- a/pretty_midi/pretty_midi.py +++ b/pretty_midi/pretty_midi.py @@ -36,14 +36,20 @@ class PrettyMIDI(object): ---------- midi_file : str or file Path or file pointer to a MIDI file. - Default ``None`` which means create an empty class with the supplied - values for resolution and initial tempo. + Default ``None`` would check if ``mido_object`` is populated instead. If both are ``None``, + creates an empty class with the supplied values for resolution and initial tempo. + Additionally, a ValueError is raised if both ``midi_file`` and ``mido_object`` are not ``None``. resolution : int Resolution of the MIDI data, when no file is provided. initial_tempo : float Initial tempo for the MIDI data, when no file is provided. charset : str Charset of the MIDI. + mido_object : mido.MidiFile + Pre-loaded `mido.MidiFile` object. + Default ``None`` would check if ``midi_file`` is populated instead. If both are ``None``, + creates an empty class with the supplied values for resolution and initial tempo. + Additionally, a ValueError is raised if both ``mido_object`` and ``midi_file`` are not ``None``. Attributes ---------- @@ -59,19 +65,30 @@ class PrettyMIDI(object): List of :class:`pretty_midi.Text` objects. """ - def __init__(self, midi_file=None, resolution=220, initial_tempo=120., charset='latin1'): - """Initialize either by populating it with MIDI data from a file or + def __init__(self, midi_file=None, resolution=220, initial_tempo=120., charset='latin1', mido_object=None): + """Initialize either by populating it with MIDI data from a mido.MidiFile object, file or from scratch with no data. """ - if midi_file is not None: - # Load in the MIDI data using the midi module - if isinstance(midi_file, six.string_types) or isinstance(midi_file, pathlib.PurePath): - # If a string or path was given, pass it as the filename - midi_data = mido.MidiFile(filename=midi_file, charset=charset) - else: - # Otherwise, try passing it in as a file pointer - midi_data = mido.MidiFile(file=midi_file, charset=charset) + if mido_object is not None or midi_file is not None: + + if mido_object is not None and midi_file is not None: + raise ValueError("Either the midi_file or the mido_object argument must be provided, but not both.") + + if mido_object is not None: + if isinstance(mido_object, mido.MidiFile): + midi_data = mido_object + else: + raise ValueError("Expected mido_object to be of type mido.MidiFile.") + + if midi_file is not None: + # Load in the MIDI data using the midi module + if isinstance(midi_file, six.string_types) or isinstance(midi_file, pathlib.PurePath): + # If a string or path was given, pass it as the filename + midi_data = mido.MidiFile(filename=midi_file, charset=charset) + else: + # Otherwise, try passing it in as a file pointer + midi_data = mido.MidiFile(file=midi_file, charset=charset) # Convert tick values in midi_data to absolute, a useful thing. for track in midi_data.tracks: @@ -217,7 +234,7 @@ def _load_metadata(self, midi_data): elif event.type == 'text': text_events.append(Text( event.text, self.__tick_to_time[event.time])) - + if lyrics: tracks_with_lyrics.append(lyrics) if text_events: diff --git a/tests/test_pretty_midi.py b/tests/test_pretty_midi.py index 7a7d0b3..bcda9a7 100644 --- a/tests/test_pretty_midi.py +++ b/tests/test_pretty_midi.py @@ -4,6 +4,72 @@ from tempfile import NamedTemporaryFile +def test_pm_object_initialization(): + def make_mido_track(notes_str): + track = mido.MidiTrack() + for line in notes_str.split('\n'): + line = line.strip() + if line: + track.append(mido.Message.from_str(line)) + mido_obj = mido.MidiFile() + mido_obj.tracks.append(track) + return mido_obj + + example_track_1 = """ + note_on channel=0 note=72 velocity=88 time=0 + note_on channel=0 note=72 velocity=0 time=48 + note_on channel=0 note=72 velocity=88 time=0 + note_on channel=0 note=74 velocity=88 time=48 + note_on channel=0 note=72 velocity=0 time=0 + note_on channel=0 note=72 velocity=88 time=48 + note_on channel=0 note=74 velocity=0 time=0 + note_on channel=0 note=72 velocity=0 time=48 + """ + + example_track_2 = """ + note_on channel=0 note=72 velocity=88 time=0 + note_on channel=0 note=72 velocity=0 time=48 + note_on channel=0 note=72 velocity=88 time=0 + note_on channel=0 note=74 velocity=88 time=48 + note_on channel=0 note=72 velocity=0 time=0 + note_on channel=0 note=72 velocity=88 time=48 + note_on channel=0 note=74 velocity=0 time=0 + note_on channel=0 note=72 velocity=0 time=48 + note_on channel=0 note=75 velocity=88 time=0 + note_on channel=0 note=75 velocity=0 time=48 + """ + + # Test-1: Passing pre-loaded mido.MidiFile object + example_mido_obj_1 = make_mido_track(example_track_1) + pm_song = pretty_midi.PrettyMIDI(mido_object=example_mido_obj_1) + assert len(pm_song.instruments[0].notes) == 4 + + # Test-2: Testing value error is raised when non mido.MidiFile object is passed as a mido_object argument + try: + pm_song = pretty_midi.PrettyMIDI(mido_object=mido.MidiTrack()) + raise Exception("Expected ValueError when non mido.MidiFile object is passed as a mido_object argument.") + except ValueError as val_error: + assert val_error.args[0] == "Expected mido_object to be of type mido.MidiFile." + + with NamedTemporaryFile() as file: + + # Test-3: Passing file path while mido_object argument defaults to None. + # This test will ensure <=v0.2.10 compatibility for passing other arguments without keywords + example_mido_obj_2 = make_mido_track(example_track_2) + example_mido_obj_2.save(file=file) + file.seek(0) + pm_song = pretty_midi.PrettyMIDI(file) + assert len(pm_song.instruments[0].notes) == 5 + + # Test-4: Testing value error is raised when both midi_file and mido_object arguments are provided + try: + pm_song = pretty_midi.PrettyMIDI(midi_file=file, mido_object=example_mido_obj_1) + raise Exception("Expected ValueError when both midi_file and mido_object arguments are provided.") + except ValueError as val_error: + assert val_error.args[0] == ("Either the midi_file or the mido_object argument must be provided, " + "but not both.") + + def test_get_beats(): pm = pretty_midi.PrettyMIDI() # Add a note to force get_end_time() to be non-zero