Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aana/api/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import orjson
from fastapi.responses import JSONResponse

from aana.utils.json import orjson_serializer
from aana.utils.json import jsonify


class AanaJSONResponse(JSONResponse):
Expand All @@ -22,4 +22,4 @@ def __init__(self, option: int | None = orjson.OPT_SERIALIZE_NUMPY, **kwargs):

def render(self, content: Any) -> bytes:
"""Override the render method to use orjson.dumps instead of json.dumps."""
return orjson_serializer(content, option=self.option)
return jsonify(content, option=self.option, as_bytes=True)
6 changes: 3 additions & 3 deletions aana/storage/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from aana.exceptions.runtime import EmptyMigrationsException
from aana.utils.core import get_module_dir
from aana.utils.json import orjson_serializer
from aana.utils.json import jsonify


class DbType(str, Enum):
Expand All @@ -30,7 +30,7 @@ def create_postgresql_engine(config):
connection_string = f"postgresql+psycopg://{config['user']}:{config['password']}@{config['host']}:{config['port']}/{config['database']}"
return create_engine(
connection_string,
json_serializer=lambda obj: orjson_serializer(obj).decode(),
json_serializer=lambda obj: jsonify(obj),
json_deserializer=orjson.loads,
)

Expand All @@ -47,7 +47,7 @@ def create_sqlite_engine(config):
connection_string = f"sqlite:///{config['path']}"
return create_engine(
connection_string,
json_serializer=lambda obj: orjson_serializer(obj).decode(),
json_serializer=lambda obj: jsonify(obj),
json_deserializer=orjson.loads,
)

Expand Down
33 changes: 9 additions & 24 deletions aana/utils/json.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import json
from pathlib import Path
from typing import Any

import numpy as np
import orjson
from pydantic import BaseModel
from sqlalchemy import Engine

__all__ = ["jsonify", "orjson_serializer", "json_serializer_default"]
__all__ = ["jsonify", "json_serializer_default"]


def json_serializer_default(obj: object) -> object:
Expand Down Expand Up @@ -38,37 +36,24 @@ def json_serializer_default(obj: object) -> object:
return str(obj)
if isinstance(obj, type):
return str(type)
if isinstance(obj, np.ndarray):
return str(orjson_serializer(obj))
from aana.core.models.media import Media

from aana.core.models.media import Media
if isinstance(obj, Media):
return str(obj)
raise TypeError(type(obj))


def jsonify(data: Any) -> str:
"""Convert data to JSON string.

Args:
data (Any): the data

Returns:
str: the JSON string
"""
return json.dumps(data, default=json_serializer_default, sort_keys=True)
raise TypeError(type(obj))


def orjson_serializer(
content: Any, option: int | None = orjson.OPT_SERIALIZE_NUMPY
) -> bytes:
def jsonify(data: Any, option: int | None = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SORT_KEYS, as_bytes: bool = False) -> str | bytes:
"""Serialize content using orjson.

Args:
content (Any): The content to serialize.
data (Any): The content to serialize.
option (int | None): The option for orjson.dumps.
as_bytes (bool): Return output as bytes instead of string

Returns:
bytes: The serialized content.
bytes | str: The serialized data as desired format.
"""
return orjson.dumps(content, option=option, default=json_serializer_default)
output = orjson.dumps(data, option=option, default=json_serializer_default)
return output if as_bytes else output.decode()