-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmiddleware.ts
138 lines (117 loc) · 4.28 KB
/
middleware.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import { type NextRequest, NextResponse } from 'next/server'
import { createClient } from '@/lib/supabase-server'
import { SupabaseClient } from '@supabase/supabase-js'
type Subscription = {
user_id: string;
plan: string;
credits_remaining: number;
models_remaining: number;
subscription_start_date: string;
subscription_end_date: string;
is_active: boolean;
}
// Define protected routes and their requirements
const PROTECTED_ROUTES = {
'/create': { requiresAuth: true, requiresSubscription: true },
'/plans': { requiresAuth: true, requiresSubscription: false },
'/api/generate': { requiresAuth: true, requiresSubscription: true },
'/api/train': { requiresAuth: true, requiresSubscription: true },
'/api/models': { requiresAuth: true, requiresSubscription: true },
} as const
export async function middleware(request: NextRequest) {
const pathname = request.nextUrl.pathname
const supabase = await createClient()
// Find matching protected route
const protectedRoute = Object.entries(PROTECTED_ROUTES).find(([route]) =>
pathname === route || pathname.startsWith(`${route}/`)
)
// If not a protected route, allow access
if (!protectedRoute) {
return NextResponse.next()
}
const [, requirements] = protectedRoute
// Get the authenticated user
const { data: { user }, error: userError } = await supabase.auth.getUser()
// Check authentication requirement
if (requirements.requiresAuth && (!user || userError)) {
const redirectUrl = new URL('/auth/login', request.url)
return NextResponse.redirect(redirectUrl)
}
// Check subscription requirement
if (requirements.requiresSubscription && user) {
try {
const { data: subscription } = await supabase
.from('subscriptions')
.select('*')
.eq('user_id', user.id)
.eq('is_active', true)
.single()
// If no active subscription found, redirect to plans page
if (!subscription) {
return NextResponse.redirect(new URL('/plans', request.url))
}
// Validate that subscription is still valid (not expired)
const now = new Date()
const startDate = new Date(subscription.subscription_start_date)
const endDate = new Date(subscription.subscription_end_date)
if (now < startDate || now > endDate) {
// Update subscription to inactive if expired
if (now > endDate) {
await supabase
.from('subscriptions')
.update({ is_active: false })
.eq('user_id', user.id)
}
return NextResponse.redirect(new URL('/plans', request.url))
}
// Check if user has enough credits/models for the operation
if (pathname.startsWith('/api/')) {
const hasEnoughResources = await checkResourceLimits(supabase, subscription, pathname)
if (!hasEnoughResources) {
return new NextResponse(
JSON.stringify({ error: 'Insufficient resources. Please upgrade your plan.' }),
{ status: 403, headers: { 'content-type': 'application/json' } }
)
}
}
} catch {
return NextResponse.redirect(new URL('/plans', request.url))
}
}
return NextResponse.next()
}
async function checkResourceLimits(
supabase: SupabaseClient,
subscription: Subscription,
pathname: string
) {
// For generate endpoint, check credits
if (pathname.startsWith('/api/generate') && subscription.credits_remaining <= 0) {
return false
}
// For train endpoint, check models
if (pathname.startsWith('/api/train')) {
const { count } = await supabase
.from('models')
.select('*', { count: 'exact' })
.eq('user_id', subscription.user_id)
.gte('created_at', subscription.subscription_start_date)
.lt('created_at', subscription.subscription_end_date)
if (count && count >= subscription.models_remaining) {
return false
}
}
return true
}
export const config = {
matcher: [
/*
* Match all request paths except for the ones starting with:
* - _next/static (static files)
* - _next/image (image optimization files)
* - favicon.ico (favicon file)
* Feel free to modify this pattern to include more paths.
*/
'/((?!_next/static|_next/image|favicon.ico|.*\\.(?:svg|png|jpg|jpeg|gif|webp)$).*)',
],
}