Skip to content

Commit

Permalink
rename the ratelimiter functions
Browse files Browse the repository at this point in the history
This is to improve the decorator's api better.
  • Loading branch information
shtlrs authored and supakeen committed Mar 11, 2024
1 parent 7025c3a commit db64a8a
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 31 deletions.
8 changes: 5 additions & 3 deletions src/pinnwand/defensive.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
] = {}


def ratelimit(request: HTTPServerRequest, area: str = "global") -> bool:
def should_be_ratelimited(
request: HTTPServerRequest, area: str = "global"
) -> bool:
"""Test if a requesting IP is ratelimited for a certain area. Areas are
different functionalities of the website, for example 'view' or 'input' to
differentiate between creating new pastes (low volume) or high volume
Expand Down Expand Up @@ -55,13 +57,13 @@ def ratelimit(request: HTTPServerRequest, area: str = "global") -> bool:
return False


def ratelimit_endpoint(area: str):
def ratelimit(area: str):
"""A ratelimiting decorator for tornado's request handlers."""

def wrapper(func):
@wraps(func)
def inner(request_handler: RequestHandler, *args, **kwargs):
if ratelimit(request_handler.request, area):
if should_be_ratelimited(request_handler.request, area):
raise error.RatelimitError()
return func(request_handler, *args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/pinnwand/handler/api_curl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def write_error(self, status_code: int, **kwargs: Any) -> None:
else:
super().write_error(status_code, **kwargs)

@defensive.ratelimit_endpoint(area="create")
@defensive.ratelimit(area="create")
def post(self) -> None:

configuration: Configuration = ConfigurationProvider.get_config()
Expand Down
12 changes: 5 additions & 7 deletions src/pinnwand/handler/api_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def post(self) -> None:
class Show(Base):
"""Show a paste on the deprecated API."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, slug: str) -> None: # type: ignore
with manager.DatabaseManager.get_session() as session:
paste = (
Expand Down Expand Up @@ -119,9 +119,8 @@ def check_xsrf_cookie(self) -> None:
async def get(self) -> None:
raise tornado.web.HTTPError(405)

@defensive.ratelimit_endpoint(area="create")
@defensive.ratelimit(area="create")
async def post(self) -> None:

configuration: Configuration = ConfigurationProvider.get_config()

lexer = self.get_body_argument("lexer")
Expand Down Expand Up @@ -175,9 +174,8 @@ def check_xsrf_cookie(self) -> None:
"""No XSRF cookies on the API."""
return

@defensive.ratelimit_endpoint(area="delete")
@defensive.ratelimit(area="delete")
async def post(self) -> None:

with manager.DatabaseManager.get_session() as session:
paste = (
session.query(models.Paste)
Expand Down Expand Up @@ -207,15 +205,15 @@ async def post(self) -> None:
class Lexer(Base):
"""List lexers through the deprecated API."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self) -> None:
self.write(utility.list_languages())


class Expiry(Base):
"""List expiries through the deprecated API."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self) -> None:
configuration: Configuration = ConfigurationProvider.get_config()

Expand Down
12 changes: 4 additions & 8 deletions src/pinnwand/handler/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@ def write_error(self, status_code: int, **kwargs: Any) -> None:


class Lexer(Base):

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self) -> None:
self.write(utility.list_languages())


class Expiry(Base):

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self) -> None:
configuration: Configuration = ConfigurationProvider.get_config()

Expand All @@ -46,9 +44,8 @@ def check_xsrf_cookie(self) -> None:
async def get(self) -> None:
raise tornado.web.HTTPError(405)

@defensive.ratelimit_endpoint(area="create")
@defensive.ratelimit(area="create")
async def post(self) -> None:

try:
data = tornado.escape.json_decode(self.request.body)
except json.decoder.JSONDecodeError:
Expand Down Expand Up @@ -127,8 +124,7 @@ async def post(self) -> None:


class PasteDetail(Base):

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, slug: str) -> None:
with manager.DatabaseManager.get_session() as session:
paste = (
Expand Down
23 changes: 11 additions & 12 deletions src/pinnwand/handler/website.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class Create(Base):
"""The index page shows the new paste page with a list of all available
lexers from Pygments."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, lexers: str = "") -> None:
"""Render the new paste form, optionally have a lexer preselected from
the URL."""
Expand Down Expand Up @@ -112,7 +112,7 @@ async def get(self, lexers: str = "") -> None:
paste=None,
)

@defensive.ratelimit_endpoint(area="create")
@defensive.ratelimit(area="create")
async def post(self) -> None:
"""This is a historical endpoint to create pastes, pastes are marked as
old-web and will get a warning on top of them to remove any access to
Expand Down Expand Up @@ -174,7 +174,7 @@ class CreateAction(Base):
"""The create action is the 'new' way to create pastes and supports multi
file pastes."""

@defensive.ratelimit_endpoint(area="create")
@defensive.ratelimit(area="create")
def post(self) -> None: # type: ignore
"""POST handler for the 'web' side of things."""

Expand Down Expand Up @@ -260,7 +260,7 @@ class Repaste(Base):
"""Repaste is a specific case of the paste page. It only works for pre-
existing pastes and will prefill the textarea and lexer."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, slug: str) -> None: # type: ignore
"""Render the new paste form, optionally have a lexer preselected from
the URL."""
Expand Down Expand Up @@ -293,7 +293,7 @@ async def get(self, slug: str) -> None: # type: ignore
class Show(Base):
"""Show a paste."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, slug: str) -> None: # type: ignore
"""Fetch paste from database by slug and render the paste."""

Expand Down Expand Up @@ -360,7 +360,7 @@ async def get(self, slug: str) -> None: # type: ignore
class FileRaw(Base):
"""Show a file as plaintext."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, file_id: str) -> None: # type: ignore
"""Get a file from the database and show it in the plain."""

Expand Down Expand Up @@ -391,7 +391,7 @@ async def get(self, file_id: str) -> None: # type: ignore
class FileHex(Base):
"""Show a file as hexadecimal."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, file_id: str) -> None: # type: ignore
"""Get a file from the database and show it in hex."""

Expand Down Expand Up @@ -422,7 +422,7 @@ async def get(self, file_id: str) -> None: # type: ignore
class PasteDownload(Base):
"""Download an entire paste."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, paste_id: str) -> None: # type: ignore
"""Get all files from the database and download them as a zipfile."""

Expand Down Expand Up @@ -469,7 +469,7 @@ async def get(self, paste_id: str) -> None: # type: ignore
class FileDownload(Base):
"""Download a file."""

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self, file_id: str) -> None: # type: ignore
"""Get a file from the database and download it in the plain."""

Expand Down Expand Up @@ -511,7 +511,7 @@ async def get(self, file_id: str) -> None: # type: ignore
class Remove(Base):
"""Remove a paste."""

@defensive.ratelimit_endpoint(area="delete")
@defensive.ratelimit(area="delete")
async def get(self, removal: str) -> None: # type: ignore
"""Look up if the user visiting this page has the removal id for a
certain paste. If they do they're authorized to remove the paste."""
Expand Down Expand Up @@ -549,7 +549,7 @@ class RestructuredTextPage(Base):
def initialize(self, file: str) -> None:
self.file = file

@defensive.ratelimit_endpoint(area="read")
@defensive.ratelimit(area="read")
async def get(self) -> None:
try:
with open(path.page / self.file) as f:
Expand All @@ -573,7 +573,6 @@ def initialize(self, path: str) -> None:
self.path = path

async def get(self) -> None:

try:
with open(self.path, "rb") as f:
self.write(f.read())
Expand Down

0 comments on commit db64a8a

Please sign in to comment.