Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions server/polar/order/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from pydantic import UUID4, AliasChoices, AliasPath, Field, computed_field
from pydantic.json_schema import SkipJsonSchema

from polar.custom_field.data import CustomFieldDataOutputMixin
from polar.custom_field.data import (
CustomFieldDataInputMixin,
CustomFieldDataOutputMixin,
)
from polar.customer.schemas.customer import CustomerBase
from polar.discount.schemas import DiscountMinimal
from polar.exceptions import ResourceNotFound
Expand Down Expand Up @@ -177,18 +180,20 @@ class Order(CustomFieldDataOutputMixin, MetadataOutputMixin, OrderBase):
items: list[OrderItemSchema] = Field(description="Line items composing the order.")


class OrderUpdateBase(Schema):
class OrderUpdateBase(CustomFieldDataInputMixin, Schema):
billing_name: str | None = Field(
default=None,
description=(
"The name of the customer that should appear on the invoice. "
"Can't be updated after the invoice is generated."
)
),
)
billing_address: Address | None = Field(
default=None,
description=(
"The address of the customer that should appear on the invoice. "
"Can't be updated after the invoice is generated."
)
),
)


Expand Down
23 changes: 17 additions & 6 deletions server/polar/order/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import stripe as stripe_lib
import structlog
from sqlalchemy import select
from sqlalchemy.orm import contains_eager, joinedload
from sqlalchemy.orm import contains_eager, joinedload, selectinload

from polar.account.repository import AccountRepository
from polar.auth.models import AuthSubject
from polar.billing_entry.service import billing_entry as billing_entry_service
from polar.checkout.eventstream import CheckoutEvent, publish_checkout_event
from polar.checkout.repository import CheckoutRepository
from polar.config import settings
from polar.custom_field.data import validate_custom_field_data
from polar.customer.repository import CustomerRepository
from polar.customer_portal.schemas.order import (
CustomerOrderPaymentConfirmation,
Expand Down Expand Up @@ -402,8 +403,9 @@ async def get(
.options(
*repository.get_eager_options(
customer_load=contains_eager(Order.customer),
product_load=joinedload(Order.product).joinedload(
Product.organization
product_load=joinedload(Order.product).options(
joinedload(Product.organization),
selectinload(Product.attached_custom_fields),
),
)
)
Expand Down Expand Up @@ -436,10 +438,19 @@ async def update(
if errors:
raise PolarRequestValidationError(errors)

update_dict = order_update.model_dump(exclude_unset=True)

if "custom_field_data" in update_dict:
# Validate custom field data against the product's attached custom fields
custom_field_data = validate_custom_field_data(
order.product.attached_custom_fields,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will likely raise an eager loading error at runtime (this is a common shortcomings of our tests which usually don't catch that). That's also why @rishi-raj-jain asked for a video showing it working.

For simplicity, you should add that eager load instruction above in the get method:

    async def get(
        self,
        session: AsyncSession,
        auth_subject: AuthSubject[User | Organization],
        id: uuid.UUID,
    ) -> Order | None:
        repository = OrderRepository.from_session(session)
        statement = (
            repository.get_readable_statement(auth_subject)
            .options(
                *repository.get_eager_options(
                    customer_load=contains_eager(Order.customer),
                    product_load=joinedload(Order.product).options(
                        joinedload(Product.organization),
                        selectinload(Product.attached_custom_fields),
                    )
                )
            )
            .where(Order.id == id)
        )
        return await repository.get_one_or_none(statement)

Copy link
Contributor Author

@amankitsingh amankitsingh Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frankie567 are you suggesting this because of the performance, so that update function doesnt lazy load the data?

order_update.custom_field_data,
validate_required=False, # Allow merchants to update even if required fields are missing
)
update_dict["custom_field_data"] = custom_field_data

repository = OrderRepository.from_session(session)
order = await repository.update(
order, update_dict=order_update.model_dump(exclude_unset=True)
)
order = await repository.update(order, update_dict=update_dict)

await self.send_webhook(session, order, WebhookEventType.order_updated)

Expand Down
192 changes: 190 additions & 2 deletions server/tests/order/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
from httpx import AsyncClient

from polar.auth.scope import Scope
from polar.models import Customer, Order, Product, UserOrganization
from polar.enums import SubscriptionRecurringInterval
from polar.models import Customer, Order, Organization, Product, UserOrganization
from polar.models.custom_field import CustomFieldType
from tests.fixtures.auth import AuthSubjectFixture
from tests.fixtures.database import SaveFixture
from tests.fixtures.random_objects import create_order
from tests.fixtures.random_objects import (
create_custom_field,
create_customer,
create_order,
create_product,
)


@pytest_asyncio.fixture
Expand Down Expand Up @@ -166,6 +173,187 @@ async def test_custom_field(
assert json["custom_field_data"] == {"test": None}


@pytest.mark.asyncio
class TestUpdateOrder:
async def test_anonymous(self, client: AsyncClient, orders: list[Order]) -> None:
response = await client.patch(
f"/v1/orders/{orders[0].id}",
json={"custom_field_data": {"test": "updated"}},
)

assert response.status_code == 401

@pytest.mark.auth
async def test_not_existing(self, client: AsyncClient) -> None:
response = await client.patch(
f"/v1/orders/{uuid.uuid4()}",
json={"custom_field_data": {"test": "updated"}},
)

assert response.status_code == 404

@pytest.mark.auth
async def test_user_not_organization_member(
self, client: AsyncClient, orders: list[Order]
) -> None:
response = await client.patch(
f"/v1/orders/{orders[0].id}",
json={"custom_field_data": {"test": "updated"}},
)

assert response.status_code == 404

@pytest.mark.auth(
AuthSubjectFixture(scopes={Scope.web_write}),
AuthSubjectFixture(scopes={Scope.orders_write}),
)
async def test_user_valid(
self,
save_fixture: SaveFixture,
client: AsyncClient,
user_organization: UserOrganization,
organization: Organization,
) -> None:
# Create a product with custom fields
text_field = await create_custom_field(
save_fixture,
type=CustomFieldType.text,
slug="text",
organization=organization,
)
select_field = await create_custom_field(
save_fixture,
type=CustomFieldType.select,
slug="select",
organization=organization,
properties={
"options": [{"value": "a", "label": "A"}, {"value": "b", "label": "B"}],
},
)
product = await create_product(
save_fixture,
organization=organization,
recurring_interval=SubscriptionRecurringInterval.month,
attached_custom_fields=[(text_field, False), (select_field, True)],
)

# Create an order with custom field data
order = await create_order(
save_fixture,
product=product,
customer=await create_customer(save_fixture, organization=organization),
custom_field_data={"text": "original", "select": "a"},
)

response = await client.patch(
f"/v1/orders/{order.id}",
json={"custom_field_data": {"text": "updated", "select": "b"}},
)

assert response.status_code == 200

json = response.json()
assert json["custom_field_data"] == {"text": "updated", "select": "b"}

@pytest.mark.auth(
AuthSubjectFixture(subject="organization", scopes={Scope.orders_write}),
)
async def test_organization(
self, save_fixture: SaveFixture, client: AsyncClient, organization: Organization
) -> None:
# Create a product with custom fields
text_field = await create_custom_field(
save_fixture,
type=CustomFieldType.text,
slug="text",
organization=organization,
)
select_field = await create_custom_field(
save_fixture,
type=CustomFieldType.select,
slug="select",
organization=organization,
properties={
"options": [{"value": "a", "label": "A"}, {"value": "b", "label": "B"}],
},
)
product = await create_product(
save_fixture,
organization=organization,
recurring_interval=SubscriptionRecurringInterval.month,
attached_custom_fields=[(text_field, False), (select_field, True)],
)

# Create an order with custom field data
order = await create_order(
save_fixture,
product=product,
customer=await create_customer(save_fixture, organization=organization),
custom_field_data={"text": "original", "select": "a"},
)

response = await client.patch(
f"/v1/orders/{order.id}",
json={"custom_field_data": {"text": "updated", "select": "b"}},
)

assert response.status_code == 200

json = response.json()
assert json["custom_field_data"] == {"text": "updated", "select": "b"}

@pytest.mark.auth(
AuthSubjectFixture(scopes={Scope.web_write}),
)
async def test_update_existing_custom_field_data(
self,
save_fixture: SaveFixture,
client: AsyncClient,
user_organization: UserOrganization,
organization: Organization,
) -> None:
# Create a product with custom fields
text_field = await create_custom_field(
save_fixture,
type=CustomFieldType.text,
slug="text",
organization=organization,
)
select_field = await create_custom_field(
save_fixture,
type=CustomFieldType.select,
slug="select",
organization=organization,
properties={
"options": [{"value": "a", "label": "A"}, {"value": "b", "label": "B"}],
},
)
product = await create_product(
save_fixture,
organization=organization,
recurring_interval=SubscriptionRecurringInterval.month,
attached_custom_fields=[(text_field, False), (select_field, True)],
)

# Create an order with custom field data
order = await create_order(
save_fixture,
product=product,
customer=await create_customer(save_fixture, organization=organization),
custom_field_data={"text": "original", "select": "a"},
)

response = await client.patch(
f"/v1/orders/{order.id}",
json={"custom_field_data": {"text": "updated", "select": "b"}},
)

assert response.status_code == 200

json = response.json()
assert json["custom_field_data"] == {"text": "updated", "select": "b"}


@pytest.mark.asyncio
class TesGetOrdersStatistics:
async def test_anonymous(self, client: AsyncClient) -> None:
Expand Down
Loading