Skip to content

Commit 6985184

Browse files
Merge branch 'fix/url_safety_net' of https://github.com/Zipstack/unstract into fix/url_safety_net
2 parents 7588621 + 9fbb82d commit 6985184

30 files changed

+3973
-735
lines changed

backend/usage_v2/helper.py

Lines changed: 141 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
11
import logging
2-
from datetime import datetime
2+
from datetime import date, datetime, time, timedelta
33
from typing import Any
44

5-
from django.db.models import QuerySet, Sum
6-
from rest_framework.exceptions import APIException
5+
from django.db.models import Count, QuerySet, Sum
6+
from django.utils import timezone
7+
from rest_framework.exceptions import APIException, ValidationError
78

89
from .constants import UsageKeys
910
from .models import Usage
1011

12+
# Try to import subscription plugin at module level
13+
try:
14+
from pluggable_apps.subscription_v2.subscription_helper import (
15+
SubscriptionHelper,
16+
)
17+
18+
SUBSCRIPTION_HELPER = SubscriptionHelper
19+
except ImportError:
20+
SUBSCRIPTION_HELPER = None
21+
1122
logger = logging.getLogger(__name__)
1223

1324

@@ -98,3 +109,130 @@ def format_usage_response(
98109
"aggregated_data": aggregated_data,
99110
"date_range": {"start_date": start_date, "end_date": end_date},
100111
}
112+
113+
@staticmethod
114+
def _get_subscription_dates(
115+
organization_id: str,
116+
) -> tuple[datetime | None, datetime | None]:
117+
"""Attempt to get trial dates from subscription plugin.
118+
119+
Assumes subscription plugin returns timezone-aware datetimes or dates
120+
that can be safely used without normalization (uniform storage).
121+
122+
Args:
123+
organization_id: The organization identifier
124+
125+
Returns:
126+
Tuple of (start_date, end_date) or (None, None) if not available
127+
"""
128+
# Exit early if subscription plugin is not available
129+
if SUBSCRIPTION_HELPER is None:
130+
logger.debug("Subscription plugin not available, using fallback dates")
131+
return None, None
132+
133+
# Plugin is available, attempt to get subscription data
134+
try:
135+
org_plans = SUBSCRIPTION_HELPER.get_subscription(organization_id)
136+
except AttributeError as e:
137+
logger.warning(f"Subscription plugin missing expected methods: {e}")
138+
return None, None
139+
140+
if (
141+
not org_plans
142+
or not hasattr(org_plans, "start_date")
143+
or not hasattr(org_plans, "end_date")
144+
):
145+
return None, None
146+
147+
# If subscription returns dates, convert to datetime with proper times
148+
start_date = org_plans.start_date
149+
end_date = org_plans.end_date
150+
151+
# Handle date objects by converting to start/end of day (assuming UTC)
152+
if isinstance(start_date, date) and not isinstance(start_date, datetime):
153+
start_date = datetime.combine(start_date, time.min).replace(
154+
tzinfo=timezone.utc
155+
)
156+
if isinstance(end_date, date) and not isinstance(end_date, datetime):
157+
end_date = datetime.combine(end_date, time.max).replace(tzinfo=timezone.utc)
158+
159+
logger.info(f"Using subscription dates for org {organization_id}")
160+
return start_date, end_date
161+
162+
@staticmethod
163+
def _calculate_usage_metrics(
164+
organization: Any, trial_start_date: datetime, trial_end_date: datetime
165+
) -> tuple[dict, int]:
166+
"""Calculate usage metrics for the trial period.
167+
168+
Args:
169+
organization: The organization object
170+
trial_start_date: Start of trial period
171+
trial_end_date: End of trial period
172+
173+
Returns:
174+
Tuple of (aggregated_data, documents_processed)
175+
"""
176+
usage_queryset = Usage.objects.filter(
177+
organization=organization,
178+
created_at__gte=trial_start_date,
179+
created_at__lte=trial_end_date,
180+
)
181+
182+
aggregated_data = usage_queryset.aggregate(
183+
total_cost=Sum("cost_in_dollars"),
184+
total_tokens=Sum("total_tokens"),
185+
unique_runs=Count("run_id", distinct=True),
186+
api_calls=Count("id"),
187+
)
188+
189+
documents_processed = (
190+
usage_queryset.values("workflow_id", "execution_id").distinct().count()
191+
)
192+
193+
return aggregated_data, documents_processed
194+
195+
@staticmethod
196+
def get_trial_statistics(organization) -> dict[str, Any]:
197+
"""Get comprehensive trial usage statistics for an organization.
198+
199+
Args:
200+
organization: The organization object for which to retrieve trial statistics.
201+
Must have 'organization_id' and 'created_at' attributes.
202+
203+
Returns:
204+
dict: A dictionary containing comprehensive trial usage statistics with keys:
205+
- trial_start_date (str): ISO formatted trial start date
206+
- trial_end_date (str): ISO formatted trial end date
207+
- total_cost (float): Total cost in dollars during trial period
208+
- documents_processed (int): Number of unique document processing operations
209+
- api_calls (int): Total number of API calls made
210+
- etl_runs (int): Number of unique ETL pipeline runs
211+
"""
212+
trial_start_date, trial_end_date = UsageHelper._get_subscription_dates(
213+
organization.organization_id
214+
)
215+
# Use fallback dates if needed (assuming uniform UTC storage)
216+
if not trial_start_date:
217+
trial_start_date = organization.created_at
218+
if not trial_end_date:
219+
# For end date, set to end of day (23:59:59.999999)
220+
end_of_trial_date = organization.created_at + timedelta(days=14)
221+
trial_end_date = end_of_trial_date.replace(
222+
hour=23, minute=59, second=59, microsecond=999999
223+
)
224+
# Validate trial window - guard against inverted date range
225+
if trial_end_date < trial_start_date:
226+
raise ValidationError("trial_end_date must be on or after trial_start_date")
227+
# Calculate usage metrics
228+
aggregated_data, documents_processed = UsageHelper._calculate_usage_metrics(
229+
organization, trial_start_date, trial_end_date
230+
)
231+
return {
232+
"trial_start_date": trial_start_date.isoformat(),
233+
"trial_end_date": trial_end_date.isoformat(),
234+
"total_cost": aggregated_data.get("total_cost", 0) or 0,
235+
"documents_processed": documents_processed,
236+
"api_calls": aggregated_data.get("api_calls", 0) or 0,
237+
"etl_runs": aggregated_data.get("unique_runs", 0) or 0,
238+
}

backend/usage_v2/urls.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
get_token_usage = UsageView.as_view({"get": "get_token_usage"})
77
aggregate = UsageView.as_view({"get": UsageView.aggregate.__name__})
8+
trial_statistics = UsageView.as_view({"get": "get_trial_statistics"})
89
usage_list = UsageView.as_view({"get": UsageView.list.__name__})
910
usage_detail = UsageView.as_view(
1011
{
@@ -27,6 +28,11 @@
2728
aggregate,
2829
name="aggregate",
2930
),
31+
path(
32+
"trial-statistics/",
33+
trial_statistics,
34+
name="trial-statistics",
35+
),
3036
path("<str:pk>/", usage_detail, name="usage_detail"),
3137
]
3238
)

backend/usage_v2/views.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from permissions.permission import IsOrganizationMember
66
from rest_framework import status, viewsets
77
from rest_framework.decorators import action
8+
from rest_framework.exceptions import ValidationError
89
from rest_framework.filters import OrderingFilter
910
from rest_framework.permissions import IsAuthenticated
1011
from rest_framework.response import Response
@@ -93,3 +94,33 @@ def get_token_usage(self, request: HttpRequest) -> Response:
9394

9495
# Return the result
9596
return Response(status=status.HTTP_200_OK, data=result)
97+
98+
@action(detail=False, methods=["get"], url_path="trial-statistics")
99+
def get_trial_statistics(self, request: HttpRequest) -> Response:
100+
"""Retrieves comprehensive trial usage statistics for the current organization.
101+
102+
Returns:
103+
Response: A Response object containing trial usage statistics including:
104+
- trial_start_date: ISO formatted trial start date
105+
- trial_end_date: ISO formatted trial end date
106+
- total_cost: Total cost in dollars during trial period
107+
- documents_processed: Number of unique document processing operations
108+
- api_calls: Total number of API calls made
109+
- etl_runs: Number of unique ETL pipeline runs
110+
"""
111+
user_organization = UserContext.get_organization()
112+
113+
# Validate organization context
114+
if not user_organization:
115+
logger.warning("No organization context found for user")
116+
raise ValidationError("No organization context available")
117+
118+
# Get trial statistics from helper
119+
trial_stats = UsageHelper.get_trial_statistics(user_organization)
120+
121+
# Log successful retrieval for audit purposes
122+
logger.info(
123+
f"Trial statistics retrieved for organization {user_organization.organization_id}"
124+
)
125+
126+
return Response(status=status.HTTP_200_OK, data=trial_stats)

backend/workflow_manager/endpoint_v2/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class TableColumns:
2525

2626
class DBConnectionClass:
2727
SNOWFLAKE = "SnowflakeDB"
28+
BIGQUERY = "BigQuery"
2829

2930

3031
class Snowflake:

0 commit comments

Comments
 (0)