Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions apps/api/src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { auth } from "@repo/auth"
import { cors } from "hono/cors"
import createApp from "@/lib/helpers/app/create-app"
import configureOpenAPI from "@/lib/helpers/openapi/configure-openapi"
import { globalRateLimit } from "@/middleware/rate-limit"
import index from "@/routes/index.route"
import channelsRouter from "@/routes/v1/channels/index"
import dmsRouter from "@/routes/v1/dms/index"
Expand All @@ -20,6 +21,8 @@ app.use(
})
)

app.use("*", globalRateLimit)

app.on(["POST", "GET"], "/api/auth/*", (c) => {
return auth.handler(c.req.raw)
})
Expand Down
90 changes: 90 additions & 0 deletions apps/api/src/middleware/rate-limit.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import type { Context, Next } from "hono"
import * as HttpStatusCodes from "@/lib/helpers/http/status-codes"
import { getRedisClient } from "@/lib/redis"
import type { AppBindings } from "@/lib/types/app-types"

const WINDOW_SECONDS = 60
const KEY_TTL_SECONDS = 90

interface RateLimitConfig {
/** Requests per window */
max: number
/** Window size in seconds (default 60) */
window?: number
/** Key prefix for Redis */
prefix: string
/** Extract the identifier from the request (default: IP address) */
keyExtractor?: (c: Context<AppBindings>) => string
}

function getIp(c: Context<AppBindings>): string {
return (
c.req.header("x-forwarded-for")?.split(",")[0]?.trim() ||
c.req.header("x-real-ip") ||
"unknown"
)
}

function getWindowNumber(windowSeconds: number): number {
return Math.floor(Date.now() / (windowSeconds * 1000))
}

function getRetryAfterSeconds(windowSeconds: number): number {
const elapsed = Math.floor(Date.now() / 1000) % windowSeconds
return Math.max(1, windowSeconds - elapsed)
}

export function rateLimiter(config: RateLimitConfig) {
const windowSeconds = config.window ?? WINDOW_SECONDS

return async (c: Context<AppBindings>, next: Next) => {
try {
const redis = await getRedisClient()
const identifier = config.keyExtractor ? config.keyExtractor(c) : getIp(c)

const windowNum = getWindowNumber(windowSeconds)
const key = `ratelimit:api:${config.prefix}:${identifier}:${windowNum}`

const count = await redis.incr(key)
if (count === 1) {
await redis.expire(
key,
config.window ? config.window + 30 : KEY_TTL_SECONDS
)
}

c.header("X-RateLimit-Limit", String(config.max))
c.header("X-RateLimit-Remaining", String(Math.max(0, config.max - count)))

if (count > config.max) {
const retryAfter = getRetryAfterSeconds(windowSeconds)
c.header("Retry-After", String(retryAfter))
c.header("X-RateLimit-Remaining", "0")
return c.json(
{
success: false,
message: `Rate limit exceeded. Try again in ${retryAfter} seconds`,
},
HttpStatusCodes.TOO_MANY_REQUESTS
)
}
} catch (err) {
console.error("[rate-limit] Redis unavailable, failing open:", err)
}

await next()
}
}
Comment thread
BuckyMcYolo marked this conversation as resolved.

/** Global rate limit: 100 requests/min per IP */
export const globalRateLimit = rateLimiter({
prefix: "global",
max: 100,
})

/** Stricter rate limit for write operations: 30 requests/min per user */
export const writeRateLimit = rateLimiter({
prefix: "write",
max: 30,
keyExtractor: (c) => c.get("user")?.id ?? getIp(c),
})
229 changes: 227 additions & 2 deletions apps/api/src/routes/v1/channels/handlers.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import { db } from "@repo/db"
import { channel } from "@repo/db/schema"
import { and, asc, eq, inArray } from "drizzle-orm"
import {
channel,
message,
messageMention,
messageReaction,
user,
} from "@repo/db/schema"
import { and, asc, desc, eq, inArray } from "drizzle-orm"
import * as HttpStatusCodes from "@/lib/helpers/http/status-codes"
import { assertGuildPermission } from "@/lib/permissions"
import { fetchMessagePage } from "@/lib/queries/messages"
Expand All @@ -11,7 +17,9 @@ import type {
GetChannelRoute,
ListChannelMessagesRoute,
ListChannelsRoute,
ListPinnedMessagesRoute,
ReorderChannelsRoute,
ToggleMessagePinRoute,
UpdateChannelRoute,
} from "./routes"

Expand Down Expand Up @@ -227,3 +235,220 @@ export const listChannelMessages: AppRouteHandler<
HttpStatusCodes.OK
)
}

export const toggleMessagePin: AppRouteHandler<ToggleMessagePinRoute> = async (
c
) => {
const guild = c.var.guild
const member = c.var.member
const { channelId, messageId } = c.req.valid("param")

assertGuildPermission(member, guild, {
message: ["pin"],
})

// Verify message exists in this channel and guild
const msg = await db
.select({
id: message.id,
pinned: message.pinned,
channelId: message.channelId,
})
.from(message)
.innerJoin(channel, eq(message.channelId, channel.id))
.where(
and(
eq(message.id, messageId),
eq(message.channelId, channelId),
eq(channel.guildId, guild.id)
)
)
.limit(1)
.then((rows) => rows[0])

if (!msg) {
return c.json(
{ success: false, message: "Message not found" },
HttpStatusCodes.NOT_FOUND
)
}

const newPinned = !msg.pinned

await db
.update(message)
.set({ pinned: newPinned })
.where(eq(message.id, messageId))

return c.json(
{ success: true as const, pinned: newPinned },
HttpStatusCodes.OK
)
}

export const listPinnedMessages: AppRouteHandler<
ListPinnedMessagesRoute
> = async (c) => {
const guild = c.var.guild
const currentUser = c.var.user
const { channelId } = c.req.valid("param")

// Verify channel belongs to guild
const ch = await db
.select({ id: channel.id })
.from(channel)
.where(and(eq(channel.id, channelId), eq(channel.guildId, guild.id)))
.limit(1)
.then((rows) => rows[0])

if (!ch) {
return c.json(
{ success: false, message: "Channel not found" },
HttpStatusCodes.NOT_FOUND
)
}

const messages = await db
.select({
id: message.id,
channelId: message.channelId,
content: message.content,
type: message.type,
pinned: message.pinned,
attachments: message.attachments,
embeds: message.embeds,
referencedMessageId: message.referencedMessageId,
editedAt: message.editedAt,
createdAt: message.createdAt,
authorId: message.authorId,
author: {
id: user.id,
name: user.name,
username: user.username,
displayUsername: user.displayUsername,
image: user.image,
},
})
.from(message)
.innerJoin(user, eq(message.authorId, user.id))
.where(and(eq(message.channelId, channelId), eq(message.pinned, true)))
.orderBy(desc(message.createdAt))

const messageIds = messages.map((msg) => msg.id)

const mentionRows =
messageIds.length > 0
? await db
.select({
messageId: messageMention.messageId,
id: user.id,
name: user.name,
username: user.username,
displayUsername: user.displayUsername,
image: user.image,
})
.from(messageMention)
.innerJoin(user, eq(messageMention.mentionedUserId, user.id))
.where(
and(
inArray(messageMention.messageId, messageIds),
eq(messageMention.mentionType, "direct")
)
)
: []

const reactionRows =
messageIds.length > 0
? await db
.select({
messageId: messageReaction.messageId,
emoji: messageReaction.emoji,
userId: messageReaction.userId,
})
.from(messageReaction)
.where(inArray(messageReaction.messageId, messageIds))
: []

const referencedMessageIds = messages
.map((msg) => msg.referencedMessageId)
.filter((id): id is string => id !== null)

const referencedMessageRows =
referencedMessageIds.length > 0
? await db
.select({
id: message.id,
content: message.content,
authorId: user.id,
authorName: user.name,
authorUsername: user.username,
authorDisplayUsername: user.displayUsername,
authorImage: user.image,
})
.from(message)
.innerJoin(user, eq(message.authorId, user.id))
.where(inArray(message.id, referencedMessageIds))
: []

const referencedMessagesById = new Map(
referencedMessageRows.map((row) => [
row.id,
{
id: row.id,
content: row.content,
author: {
id: row.authorId,
name: row.authorName,
username: row.authorUsername,
displayUsername: row.authorDisplayUsername,
image: row.authorImage,
},
},
])
)

const mentionsByMessageId = new Map<string, typeof mentionRows>()
for (const row of mentionRows) {
const existing = mentionsByMessageId.get(row.messageId) ?? []
existing.push(row)
mentionsByMessageId.set(row.messageId, existing)
}

const reactionsByMessageId = new Map<
string,
Map<string, { emoji: string; count: number; reactedByCurrentUser: boolean }>
>()
for (const row of reactionRows) {
const reactionsByEmoji =
reactionsByMessageId.get(row.messageId) ?? new Map()
const existing = reactionsByEmoji.get(row.emoji) ?? {
emoji: row.emoji,
count: 0,
reactedByCurrentUser: false,
}
existing.count += 1
if (row.userId === currentUser.id) {
existing.reactedByCurrentUser = true
}
reactionsByEmoji.set(row.emoji, existing)
reactionsByMessageId.set(row.messageId, reactionsByEmoji)
}

const data = messages.map((msg) => ({
...msg,
embeds: msg.embeds ?? [],
mentions: (mentionsByMessageId.get(msg.id) ?? []).map((m) => ({
id: m.id,
name: m.name,
username: m.username,
displayUsername: m.displayUsername,
image: m.image,
})),
reactions: Array.from(reactionsByMessageId.get(msg.id)?.values() ?? []),
referencedMessage: msg.referencedMessageId
? (referencedMessagesById.get(msg.referencedMessageId) ?? null)
: null,
}))

return c.json({ data }, HttpStatusCodes.OK)
}
2 changes: 2 additions & 0 deletions apps/api/src/routes/v1/channels/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@ const channelsRouter = createRouter()
.openapi(routes.updateChannel, handlers.updateChannel)
.openapi(routes.deleteChannel, handlers.deleteChannel)
.openapi(routes.listChannelMessages, handlers.listChannelMessages)
.openapi(routes.toggleMessagePin, handlers.toggleMessagePin)
.openapi(routes.listPinnedMessages, handlers.listPinnedMessages)

export default channelsRouter
Loading