Skip to content
This repository has been archived by the owner on Mar 13, 2023. It is now read-only.

NAFF 1.10.0 #624

Merged
merged 8 commits into from
Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion naff/api/events/processors/guild_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def _on_raw_guild_create(self, event: "RawGatewayEvent") -> None:

self._guild_event.set()

if self.fetch_members: # noqa
if self.fetch_members and not guild.chunked.is_set(): # noqa
# delays events until chunking has completed
await guild.chunk()

Expand Down
16 changes: 9 additions & 7 deletions naff/api/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TypeVar, TYPE_CHECKING

from naff.api import events
from naff.client.const import logger, MISSING
from naff.client.const import logger, MISSING, __api_version__
from naff.client.utils.input_utils import OverriddenJson
from naff.client.utils.serializer import dict_filter_none
from naff.models.discord.enums import Status
Expand Down Expand Up @@ -48,7 +48,6 @@ class GatewayClient(WebsocketClient):
Multiple `WebsocketClient` instances can be used to implement same-process sharding.

Attributes:
buffer: A buffer to hold incoming data until its complete
sequence: The sequence of this connection
session_id: The session ID of this connection

Expand Down Expand Up @@ -83,7 +82,7 @@ def __init__(self, state: "ConnectionState", shard: tuple[int, int]) -> None:
self._ready = asyncio.Event()
self._close_gateway = asyncio.Event()

# Santity check, it is extremely important that an instance isn't reused.
# Sanity check, it is extremely important that an instance isn't reused.
self._entered = False

async def __aenter__(self: SELF) -> SELF:
Expand Down Expand Up @@ -177,6 +176,7 @@ async def dispatch_opcode(self, data, op: OPCODE) -> None:
match op:

case OPCODE.HEARTBEAT:
logger.debug("Received heartbeat request from gateway")
return await self.send_heartbeat()

case OPCODE.HEARTBEAT_ACK:
Expand All @@ -192,12 +192,12 @@ async def dispatch_opcode(self, data, op: OPCODE) -> None:
return self._acknowledged.set()

case OPCODE.RECONNECT:
logger.info("Gateway requested reconnect. Reconnecting...")
logger.debug("Gateway requested reconnect. Reconnecting...")
return await self.reconnect(resume=True, url=self.ws_resume_url)

case OPCODE.INVALIDATE_SESSION:
logger.warning("Gateway has invalidated session! Reconnecting...")
return await self.reconnect(resume=data, url=self.ws_resume_url if data else None)
return await self.reconnect()

case _:
return logger.debug(f"Unhandled OPCODE: {op} = {OPCODE(op).name}")
Expand All @@ -209,7 +209,9 @@ async def dispatch_event(self, data, seq, event) -> None:
self._trace = data.get("_trace", [])
self.sequence = seq
self.session_id = data["session_id"]
self.ws_resume_url = data["resume_gateway_url"]
self.ws_resume_url = (
f"{data['resume_gateway_url']}?encoding=json&v={__api_version__}&compress=zlib-stream"
)
logger.info(f"Shard {self.shard[0]} has connected to gateway!")
logger.debug(f"Session ID: {self.session_id} Trace: {self._trace}")
# todo: future polls, improve guild caching here. run the debugger. you'll see why
Expand Down Expand Up @@ -287,7 +289,7 @@ async def _resume_connection(self) -> None:
logger.debug(f"{self.shard[0]} is attempting to resume a connection")

async def send_heartbeat(self) -> None:
await self.send_json({"op": OPCODE.HEARTBEAT, "d": self.sequence}, True)
await self.send_json({"op": OPCODE.HEARTBEAT, "d": self.sequence}, bypass=True)
logger.debug(f"❤ Shard {self.shard[0]} is sending a Heartbeat")

async def change_presence(self, activity=None, status: Status = Status.ONLINE, since=None) -> None:
Expand Down
13 changes: 7 additions & 6 deletions naff/api/gateway/websocket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import collections
import random
import time
import zlib
from abc import abstractmethod
Expand Down Expand Up @@ -275,12 +276,12 @@ async def _start_bee_gees(self) -> None:
if self.heartbeat_interval is None:
raise RuntimeError

# try:
# await asyncio.wait_for(self._kill_bee_gees.wait(), timeout=self.heartbeat_interval * random.uniform(0, 0.5))
# except asyncio.TimeoutError:
# pass
# else:
# return
try:
await asyncio.wait_for(self._kill_bee_gees.wait(), timeout=self.heartbeat_interval * random.uniform(0, 0.5))
except asyncio.TimeoutError:
pass
else:
return

logger.debug(f"Sending heartbeat every {self.heartbeat_interval} seconds")
while not self._kill_bee_gees.is_set():
Expand Down
5 changes: 5 additions & 0 deletions naff/client/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def __init__(
self.text = data
super().__init__(f"{self.status}|{self.response.reason}: {f'({self.code}) ' if self.code else ''}{self.text}")

def __str__(self) -> str:
errors = self.search_for_message(self.errors)
out = f"HTTPException: {self.status}|{self.response.reason}: " + "\n".join(errors)
return out

@staticmethod
def search_for_message(errors: dict, lookup: Optional[dict] = None) -> list[str]:
"""
Expand Down
3 changes: 2 additions & 1 deletion naff/client/smart_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def delete_member(self, guild_id: "Snowflake_Type", user_id: "Snowflake_Type") -
guild_id = to_snowflake(guild_id)

if member := self.member_cache.pop((guild_id, user_id), None):
member.guild._member_ids.discard(user_id)
if member.guild:
member.guild._member_ids.discard(user_id)

self.delete_user_guild(user_id, guild_id)

Expand Down
8 changes: 8 additions & 0 deletions naff/models/discord/auto_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ class KeywordPresetTrigger(BaseTrigger):
)


@define()
class MentionSpamTrigger(BaseTrigger):
"""A trigger that checks if content contains more mentions than allowed"""

mention_total_limit: int = field(default=3, repr=True, metadata=docs("The maximum number of mentions allowed"))


@define()
class BlockMessage(BaseAction):
"""blocks the content of a message according to the rule"""
Expand Down Expand Up @@ -320,4 +327,5 @@ def message(self) -> "Optional[Message]":
AutoModTriggerType.KEYWORD: KeywordTrigger,
AutoModTriggerType.HARMFUL_LINK: HarmfulLinkFilter,
AutoModTriggerType.KEYWORD_PRESET: KeywordPresetTrigger,
AutoModTriggerType.MENTION_SPAM: MentionSpamTrigger,
}
7 changes: 4 additions & 3 deletions naff/models/discord/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,20 +865,21 @@ class AuditLogEventType(CursedIntEnum):
GUILD_HOME_FEATURE_ITEM_UPDATE = 172


class AutoModTriggerType(IntEnum):
class AutoModTriggerType(CursedIntEnum):
KEYWORD = 1
HARMFUL_LINK = 2
SPAM = 3
KEYWORD_PRESET = 4
MENTION_SPAM = 5


class AutoModAction(IntEnum):
class AutoModAction(CursedIntEnum):
BLOCK_MESSAGE = 1
ALERT_MESSAGE = 2
TIMEOUT_USER = 3


class AutoModEvent(IntEnum):
class AutoModEvent(CursedIntEnum):
MESSAGE_SEND = 1


Expand Down
51 changes: 28 additions & 23 deletions naff/models/discord/guild.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import time
from asyncio import QueueEmpty
from collections import namedtuple
from functools import cmp_to_key
from typing import List, Optional, Union, Set, Dict, Any, TYPE_CHECKING
Expand Down Expand Up @@ -119,6 +120,26 @@ def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any]
return super()._process_dict(data, client)


class MemberIterator(AsyncIterator):
def __init__(self, guild: "Guild", limit: int = 0) -> None:
super().__init__(limit)
self.guild = guild
self._more = True

async def fetch(self) -> list:
if self._more:
expected = self.get_limit

rcv = await self.guild._client.http.list_members(
self.guild.id, limit=expected, after=self.last["id"] if self.last else MISSING
)
if not rcv:
raise QueueEmpty
self._more = len(rcv) == expected
return rcv
raise QueueEmpty


@define()
class Guild(BaseGuild):
"""Guilds in Discord represent an isolated collection of users and channels, and are often referred to as "servers" in the UI."""
Expand Down Expand Up @@ -501,31 +522,15 @@ async def edit_nickname(self, new_nickname: Absent[str] = MISSING, reason: Absen
async def http_chunk(self) -> None:
"""Populates all members of this guild using the REST API."""
start_time = time.perf_counter()
members = []

# request all guild members
after = MISSING
while True:
if members:
after = members[-1]["user"]["id"]
rcv: list = await self._client.http.list_members(self.id, limit=1000, after=after)
members.extend(rcv)
if len(rcv) < 1000:
# we're done
break

# process all members
s = time.monotonic()
for member in members:

iterator = MemberIterator(self)
async for member in iterator:
self._client.cache.place_member_data(self.id, member)
if (time.monotonic() - s) > 0.05:
# look, i get this *could* be a thread, but because it needs to modify data in the main thread,
# it is still blocking. So by periodically yielding to the event loop, we can avoid blocking, and still
# process this data properly
await asyncio.sleep(0)
s = time.monotonic()

self.chunked.set()
logger.info(f"Cached {len(members)} members for {self.id} in {time.perf_counter() - start_time:.2f} seconds")
logger.info(
f"Cached {iterator.total_retrieved} members for {self.id} in {time.perf_counter() - start_time:.2f} seconds"
)

async def gateway_chunk(self, wait=True, presences=True) -> None:
"""
Expand Down
5 changes: 5 additions & 0 deletions naff/models/misc/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def get_limit(self) -> int:
"""Get how the maximum number of items that should be retrieved."""
return min(self._limit - len(self._retrieved_objects), 100) if self._limit else 100

@property
def total_retrieved(self) -> int:
"""Get the total number of objects this iterator has retrieved."""
return len(self._retrieved_objects)

async def add_object(self, obj) -> None:
"""Add an object to iterator's queue."""
return await self._queue.put(obj)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "naff"
version = "1.9.0"
version = "1.10.0"
description = "Not another freaking fork"
authors = ["LordOfPolls <[email protected]>"]

Expand Down