Skip to content

Commit

Permalink
FEAT: add mlx serializer (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye authored Feb 24, 2025
1 parent 4df7631 commit 81f3a1a
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/xoscar/serialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from . import cuda, exception, numpy, scipy
from . import cuda, exception, mlx, numpy, scipy
from .aio import AioDeserializer, AioSerializer
from .core import Serializer, deserialize, serialize, serialize_with_spawn

del cuda, numpy, scipy, exception
del cuda, numpy, scipy, mlx, exception
33 changes: 33 additions & 0 deletions python/xoscar/serialization/mlx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2022-2025 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List

from ..utils import lazy_import
from .core import Serializer, buffered

mx = lazy_import("mlx.core")


class MLXSerislizer(Serializer):
@buffered
def serial(self, obj: "mx.array", context: dict): # type: ignore
return ({},), [memoryview(obj)], True

def deserial(self, serialized: tuple, context: dict, subs: List[Any]):
return mx.array(subs[0])


if mx is not None:
MLXSerislizer.register(mx.array)
11 changes: 11 additions & 0 deletions python/xoscar/serialization/tests/test_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
cupy = lazy_import("cupy")
cudf = lazy_import("cudf")
pyfury = lazy_import("pyfury")
mx = lazy_import("mlx.core")


class CustomList(list):
Expand Down Expand Up @@ -244,6 +245,16 @@ def test_scipy_sparse():
assert (val != deserial).nnz == 0


@pytest.mark.skipif(mx is None, reason="need mlx to run the test")
def test_mlx():
val = mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=mx.float16)
serial = serialize(val)
deserial = deserialize(*serial)
# buffer should have length 1
assert len(serial[1]) == 1
np.testing.assert_array_equal(np.asarray(val), np.asarray(deserial))


class MockSerializerForErrors(ListSerializer):
serializer_id = 25951
raises = False
Expand Down

0 comments on commit 81f3a1a

Please sign in to comment.