|
1 | 1 | import logging |
2 | | -from datetime import datetime |
| 2 | +from datetime import date, datetime, time, timedelta |
3 | 3 | from typing import Any |
4 | 4 |
|
5 | | -from django.db.models import QuerySet, Sum |
6 | | -from rest_framework.exceptions import APIException |
| 5 | +from django.db.models import Count, QuerySet, Sum |
| 6 | +from django.utils import timezone |
| 7 | +from rest_framework.exceptions import APIException, ValidationError |
7 | 8 |
|
8 | 9 | from .constants import UsageKeys |
9 | 10 | from .models import Usage |
10 | 11 |
|
| 12 | +# Try to import subscription plugin at module level |
| 13 | +try: |
| 14 | + from pluggable_apps.subscription_v2.subscription_helper import ( |
| 15 | + SubscriptionHelper, |
| 16 | + ) |
| 17 | + |
| 18 | + SUBSCRIPTION_HELPER = SubscriptionHelper |
| 19 | +except ImportError: |
| 20 | + SUBSCRIPTION_HELPER = None |
| 21 | + |
11 | 22 | logger = logging.getLogger(__name__) |
12 | 23 |
|
13 | 24 |
|
@@ -98,3 +109,130 @@ def format_usage_response( |
98 | 109 | "aggregated_data": aggregated_data, |
99 | 110 | "date_range": {"start_date": start_date, "end_date": end_date}, |
100 | 111 | } |
| 112 | + |
| 113 | + @staticmethod |
| 114 | + def _get_subscription_dates( |
| 115 | + organization_id: str, |
| 116 | + ) -> tuple[datetime | None, datetime | None]: |
| 117 | + """Attempt to get trial dates from subscription plugin. |
| 118 | +
|
| 119 | + Assumes subscription plugin returns timezone-aware datetimes or dates |
| 120 | + that can be safely used without normalization (uniform storage). |
| 121 | +
|
| 122 | + Args: |
| 123 | + organization_id: The organization identifier |
| 124 | +
|
| 125 | + Returns: |
| 126 | + Tuple of (start_date, end_date) or (None, None) if not available |
| 127 | + """ |
| 128 | + # Exit early if subscription plugin is not available |
| 129 | + if SUBSCRIPTION_HELPER is None: |
| 130 | + logger.debug("Subscription plugin not available, using fallback dates") |
| 131 | + return None, None |
| 132 | + |
| 133 | + # Plugin is available, attempt to get subscription data |
| 134 | + try: |
| 135 | + org_plans = SUBSCRIPTION_HELPER.get_subscription(organization_id) |
| 136 | + except AttributeError as e: |
| 137 | + logger.warning(f"Subscription plugin missing expected methods: {e}") |
| 138 | + return None, None |
| 139 | + |
| 140 | + if ( |
| 141 | + not org_plans |
| 142 | + or not hasattr(org_plans, "start_date") |
| 143 | + or not hasattr(org_plans, "end_date") |
| 144 | + ): |
| 145 | + return None, None |
| 146 | + |
| 147 | + # If subscription returns dates, convert to datetime with proper times |
| 148 | + start_date = org_plans.start_date |
| 149 | + end_date = org_plans.end_date |
| 150 | + |
| 151 | + # Handle date objects by converting to start/end of day (assuming UTC) |
| 152 | + if isinstance(start_date, date) and not isinstance(start_date, datetime): |
| 153 | + start_date = datetime.combine(start_date, time.min).replace( |
| 154 | + tzinfo=timezone.utc |
| 155 | + ) |
| 156 | + if isinstance(end_date, date) and not isinstance(end_date, datetime): |
| 157 | + end_date = datetime.combine(end_date, time.max).replace(tzinfo=timezone.utc) |
| 158 | + |
| 159 | + logger.info(f"Using subscription dates for org {organization_id}") |
| 160 | + return start_date, end_date |
| 161 | + |
| 162 | + @staticmethod |
| 163 | + def _calculate_usage_metrics( |
| 164 | + organization: Any, trial_start_date: datetime, trial_end_date: datetime |
| 165 | + ) -> tuple[dict, int]: |
| 166 | + """Calculate usage metrics for the trial period. |
| 167 | +
|
| 168 | + Args: |
| 169 | + organization: The organization object |
| 170 | + trial_start_date: Start of trial period |
| 171 | + trial_end_date: End of trial period |
| 172 | +
|
| 173 | + Returns: |
| 174 | + Tuple of (aggregated_data, documents_processed) |
| 175 | + """ |
| 176 | + usage_queryset = Usage.objects.filter( |
| 177 | + organization=organization, |
| 178 | + created_at__gte=trial_start_date, |
| 179 | + created_at__lte=trial_end_date, |
| 180 | + ) |
| 181 | + |
| 182 | + aggregated_data = usage_queryset.aggregate( |
| 183 | + total_cost=Sum("cost_in_dollars"), |
| 184 | + total_tokens=Sum("total_tokens"), |
| 185 | + unique_runs=Count("run_id", distinct=True), |
| 186 | + api_calls=Count("id"), |
| 187 | + ) |
| 188 | + |
| 189 | + documents_processed = ( |
| 190 | + usage_queryset.values("workflow_id", "execution_id").distinct().count() |
| 191 | + ) |
| 192 | + |
| 193 | + return aggregated_data, documents_processed |
| 194 | + |
| 195 | + @staticmethod |
| 196 | + def get_trial_statistics(organization) -> dict[str, Any]: |
| 197 | + """Get comprehensive trial usage statistics for an organization. |
| 198 | +
|
| 199 | + Args: |
| 200 | + organization: The organization object for which to retrieve trial statistics. |
| 201 | + Must have 'organization_id' and 'created_at' attributes. |
| 202 | +
|
| 203 | + Returns: |
| 204 | + dict: A dictionary containing comprehensive trial usage statistics with keys: |
| 205 | + - trial_start_date (str): ISO formatted trial start date |
| 206 | + - trial_end_date (str): ISO formatted trial end date |
| 207 | + - total_cost (float): Total cost in dollars during trial period |
| 208 | + - documents_processed (int): Number of unique document processing operations |
| 209 | + - api_calls (int): Total number of API calls made |
| 210 | + - etl_runs (int): Number of unique ETL pipeline runs |
| 211 | + """ |
| 212 | + trial_start_date, trial_end_date = UsageHelper._get_subscription_dates( |
| 213 | + organization.organization_id |
| 214 | + ) |
| 215 | + # Use fallback dates if needed (assuming uniform UTC storage) |
| 216 | + if not trial_start_date: |
| 217 | + trial_start_date = organization.created_at |
| 218 | + if not trial_end_date: |
| 219 | + # For end date, set to end of day (23:59:59.999999) |
| 220 | + end_of_trial_date = organization.created_at + timedelta(days=14) |
| 221 | + trial_end_date = end_of_trial_date.replace( |
| 222 | + hour=23, minute=59, second=59, microsecond=999999 |
| 223 | + ) |
| 224 | + # Validate trial window - guard against inverted date range |
| 225 | + if trial_end_date < trial_start_date: |
| 226 | + raise ValidationError("trial_end_date must be on or after trial_start_date") |
| 227 | + # Calculate usage metrics |
| 228 | + aggregated_data, documents_processed = UsageHelper._calculate_usage_metrics( |
| 229 | + organization, trial_start_date, trial_end_date |
| 230 | + ) |
| 231 | + return { |
| 232 | + "trial_start_date": trial_start_date.isoformat(), |
| 233 | + "trial_end_date": trial_end_date.isoformat(), |
| 234 | + "total_cost": aggregated_data.get("total_cost", 0) or 0, |
| 235 | + "documents_processed": documents_processed, |
| 236 | + "api_calls": aggregated_data.get("api_calls", 0) or 0, |
| 237 | + "etl_runs": aggregated_data.get("unique_runs", 0) or 0, |
| 238 | + } |
0 commit comments