Skip to content
This repository has been archived by the owner on Mar 13, 2023. It is now read-only.

Commit

Permalink
fix: Ensure sets can be serialized (#608)
Browse files Browse the repository at this point in the history
  • Loading branch information
silasary committed Aug 19, 2022
1 parent bf323a5 commit b7f495b
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 12 deletions.
8 changes: 4 additions & 4 deletions naff/api/http/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from naff.client.errors import DiscordError, Forbidden, GatewayNotFound, HTTPException, NotFound, LoginError
from naff.client.utils.input_utils import response_decode, OverriddenJson
from naff.client.utils.serializer import dict_filter_missing
from naff.client.utils.serializer import dict_filter
from naff.models import CooldownSystem
from naff.models.discord.file import UPLOADABLE_TYPE
from .route import Route
Expand Down Expand Up @@ -212,9 +212,9 @@ def _process_payload(payload: dict | list[dict], files: Absent[list[UPLOADABLE_T
return None

if isinstance(payload, dict):
payload = dict_filter_missing(payload)
payload = dict_filter(payload)
else:
payload = [dict_filter_missing(x) if isinstance(x, dict) else x for x in payload]
payload = [dict_filter(x) if isinstance(x, dict) else x for x in payload]

if not files:
return payload
Expand Down Expand Up @@ -262,7 +262,7 @@ async def request(
if isinstance(payload, (list, dict)) and not files:
kwargs["headers"]["Content-Type"] = "application/json"
if isinstance(params, dict):
kwargs["params"] = dict_filter_missing(params)
kwargs["params"] = dict_filter(params)

lock = self.get_ratelimit(route)
# this gets a BucketLock for this route.
Expand Down
4 changes: 2 additions & 2 deletions naff/api/http/http_requests/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import discord_typings

from naff.client.const import Absent, MISSING
from naff.client.utils.serializer import dict_filter_missing, dict_filter_none
from naff.client.utils.serializer import dict_filter, dict_filter_none


from ..route import Route
Expand Down Expand Up @@ -661,7 +661,7 @@ async def create_guild(
) -> dict:
return await self.request(
Route("POST", "/guilds"),
payload=dict_filter_missing(
payload=dict_filter(
{
"name": name,
"icon": icon,
Expand Down
14 changes: 10 additions & 4 deletions naff/client/utils/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from naff.client.const import MISSING, T
from naff.models.discord.file import UPLOADABLE_TYPE, File

__all__ = ("no_export_meta", "export_converter", "to_dict", "dict_filter_none", "dict_filter_missing", "to_image_data")
__all__ = ("no_export_meta", "export_converter", "to_dict", "dict_filter_none", "dict_filter", "to_image_data")

no_export_meta = {"no_export": True}

Expand Down Expand Up @@ -95,9 +95,9 @@ def dict_filter_none(data: dict) -> dict:
return {k: v for k, v in data.items() if v is not None}


def dict_filter_missing(data: dict) -> dict:
def dict_filter(data: dict) -> dict:
"""
Filters out all values that are MISSING sentinel.
Filters out all values that are MISSING sentinel and converts all sets to lists.
Args:
data: The dict data to filter.
Expand All @@ -106,7 +106,13 @@ def dict_filter_missing(data: dict) -> dict:
The filtered dict data.
"""
return {k: v for k, v in data.items() if v is not MISSING}
filtered = data.copy()
for k, v in data.items():
if v is MISSING:
filtered.pop(k)
elif isinstance(v, set):
filtered[k] = list(v)
return filtered


def to_image_data(imagefile: Optional[UPLOADABLE_TYPE]) -> Optional[str]:
Expand Down
4 changes: 2 additions & 2 deletions naff/models/discord/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from naff.client.const import MISSING, Absent, T
from naff.client.utils.attr_utils import define, field
from naff.client.utils.attr_converters import optional as optional_c
from naff.client.utils.serializer import dict_filter_missing
from naff.client.utils.serializer import dict_filter
from naff.models.discord.asset import Asset
from naff.models.discord.emoji import PartialEmoji
from naff.models.discord.color import Color
Expand Down Expand Up @@ -190,7 +190,7 @@ async def edit(
if isinstance(color, Color):
color = color.value

payload = dict_filter_missing(
payload = dict_filter(
{"name": name, "permissions": permissions, "color": color, "hoist": hoist, "mentionable": mentionable}
)

Expand Down

0 comments on commit b7f495b

Please sign in to comment.