diff --git a/engineio/packet.py b/engineio/packet.py index 5e70511a..79267044 100644 --- a/engineio/packet.py +++ b/engineio/packet.py @@ -6,6 +6,8 @@ (OPEN, CLOSE, PING, PONG, MESSAGE, UPGRADE, NOOP) = (0, 1, 2, 3, 4, 5, 6) packet_names = ['OPEN', 'CLOSE', 'PING', 'PONG', 'MESSAGE', 'UPGRADE', 'NOOP'] +binary_types = (six.binary_type, bytearray) + class Packet(object): """Engine.IO packet.""" @@ -20,7 +22,7 @@ def __init__(self, packet_type=NOOP, data=None, binary=None, self.binary = binary elif isinstance(data, six.text_type): self.binary = False - elif isinstance(data, six.binary_type): + elif isinstance(data, binary_types): self.binary = True else: self.binary = False @@ -47,14 +49,14 @@ def encode(self, b64=False, always_bytes=True): separators=(',', ':')) elif self.data is not None: encoded_packet += str(self.data) - if always_bytes and not isinstance(encoded_packet, six.binary_type): + if always_bytes and not isinstance(encoded_packet, binary_types): encoded_packet = encoded_packet.encode('utf-8') return encoded_packet def decode(self, encoded_packet): """Decode a transmitted package.""" b64 = False - if not isinstance(encoded_packet, six.binary_type): + if not isinstance(encoded_packet, binary_types): encoded_packet = encoded_packet.encode('utf-8') self.packet_type = six.byte2int(encoded_packet[0:1]) if self.packet_type == 98: # 'b' --> binary base64 encoded packet diff --git a/tests/test_packet.py b/tests/test_packet.py index d0541589..e05cfaf4 100644 --- a/tests/test_packet.py +++ b/tests/test_packet.py @@ -37,6 +37,14 @@ def test_encode_binary_packet(self): self.assertTrue(pkt.binary) self.assertEqual(pkt.encode(), b'\x04\x01\x02\x03') + def test_encode_binary_bytearray_packet(self): + pkt = packet.Packet(packet.MESSAGE, data=bytearray(b'\x01\x02\x03'), + binary=True) + self.assertEqual(pkt.packet_type, packet.MESSAGE) + self.assertEqual(pkt.data, b'\x01\x02\x03') + self.assertTrue(pkt.binary) + self.assertEqual(pkt.encode(), b'\x04\x01\x02\x03') + def test_encode_binary_b64_packet(self): pkt = packet.Packet(packet.MESSAGE, data=b'\x01\x02\x03\x04', binary=True) @@ -56,6 +64,10 @@ def test_decode_binary_packet(self): pkt = packet.Packet(encoded_packet=b'\x04\x01\x02\x03') self.assertTrue(pkt.encode(), b'\x04\x01\x02\x03') + def test_decode_binary_bytearray_packet(self): + pkt = packet.Packet(encoded_packet=bytearray(b'\x04\x01\x02\x03')) + self.assertTrue(pkt.encode(), b'\x04\x01\x02\x03') + def test_decode_binary_b64_packet(self): pkt = packet.Packet(encoded_packet=b'b4AAEC') self.assertTrue(pkt.encode(), b'\x04\x01\x02\x03')