diff --git a/clients/apps/web/src/components/Orders/DownloadInvoice.tsx b/clients/apps/web/src/components/Orders/DownloadInvoice.tsx index d6c9d0ab3d..8f8666e6bb 100644 --- a/clients/apps/web/src/components/Orders/DownloadInvoice.tsx +++ b/clients/apps/web/src/components/Orders/DownloadInvoice.tsx @@ -337,6 +337,7 @@ const DownloadInvoice = ({ render={({ field }) => ( <> ( <> void autoComplete?: string + disabled?: boolean className?: string itemClassName?: string contentClassName?: string }) => { const countryMap = getCountryList(allowedCountries as TCountryCode[]) return ( - void country?: string autoComplete?: string + disabled?: boolean }) => { if (country === 'US' || country === 'CA') { const states = country === 'US' ? US_STATES : CA_PROVINCES @@ -100,6 +102,7 @@ const CountryStatePicker = ({ onValueChange={onChange} value={value} autoComplete={autoComplete} + disabled={disabled} > onChange(e.target.value)} + disabled={disabled} /> ) } diff --git a/server/polar/order/schemas.py b/server/polar/order/schemas.py index ceb7eba395..a5e630a03d 100644 --- a/server/polar/order/schemas.py +++ b/server/polar/order/schemas.py @@ -245,16 +245,14 @@ class Order(CustomFieldDataOutputMixin, MetadataOutputMixin, OrderBase): class OrderUpdateBase(Schema): billing_name: str | None = Field( - description=( - "The name of the customer that should appear on the invoice. " - "Can't be updated after the invoice is generated." - ) + None, description="The name of the customer that should appear on the invoice." ) billing_address: AddressInput | None = Field( + None, description=( "The address of the customer that should appear on the invoice. " - "Can't be updated after the invoice is generated." - ) + "Country and state fields cannot be updated." + ), ) diff --git a/server/polar/order/service.py b/server/polar/order/service.py index d01a1d0e77..0ac946a3c6 100644 --- a/server/polar/order/service.py +++ b/server/polar/order/service.py @@ -33,7 +33,7 @@ build_system_event, ) from polar.eventstream.service import publish as eventstream_publish -from polar.exceptions import PolarError +from polar.exceptions import PolarError, PolarRequestValidationError, ValidationError from polar.file.s3 import S3_SERVICES from polar.held_balance.service import held_balance as held_balance_service from polar.integrations.stripe.service import stripe as stripe_service @@ -344,6 +344,33 @@ async def update( order_update: OrderUpdate | CustomerOrderUpdate, ) -> Order: repository = OrderRepository.from_session(session) + + errors: list[ValidationError] = [] + + billing_address = order_update.billing_address + if billing_address is not None and order.billing_address is not None: + if str(billing_address.country) != str(order.billing_address.country): + errors.append( + { + "loc": ("body", "billing_address", "country"), + "msg": "Country cannot be changed", + "type": "value_error", + "input": billing_address.country, + } + ) + if billing_address.state != order.billing_address.state: + errors.append( + { + "loc": ("body", "billing_address", "state"), + "msg": "State cannot be changed", + "type": "value_error", + "input": billing_address.state, + } + ) + + if errors: + raise PolarRequestValidationError(errors) + order = await repository.update( order, update_dict=order_update.model_dump(exclude_unset=True) ) diff --git a/server/tests/order/test_service.py b/server/tests/order/test_service.py index f01f265448..af17161969 100644 --- a/server/tests/order/test_service.py +++ b/server/tests/order/test_service.py @@ -19,9 +19,16 @@ SubscriptionRecurringInterval, TaxProcessor, ) +from polar.exceptions import PolarRequestValidationError from polar.held_balance.service import held_balance as held_balance_service from polar.integrations.stripe.service import StripeService -from polar.kit.address import Address, CountryAlpha2 +from polar.kit.address import ( + Address, + AddressDict, + AddressInput, + CountryAlpha2, + CountryAlpha2Input, +) from polar.kit.db.postgres import AsyncSession from polar.kit.math import polar_round from polar.kit.pagination import PaginationParams @@ -49,6 +56,7 @@ from polar.models.subscription import SubscriptionStatus from polar.models.transaction import PlatformFeeType, TransactionType from polar.models.wallet import WalletType +from polar.order.schemas import OrderUpdate from polar.order.service import ( CardPaymentFailed, MissingCheckoutCustomer, @@ -429,6 +437,71 @@ async def test_metadata_filter( assert order2 in orders +@pytest.mark.asyncio +class TestUpdate: + @pytest.mark.parametrize( + ("set_address", "address_update"), + [ + ({"country": "US", "state": "CA"}, {"country": "US", "state": "NY"}), + ({"country": "US", "state": "CA"}, {"country": "FR", "state": None}), + ({"country": "FR", "state": None}, {"country": "US", "state": "CA"}), + ], + ) + async def test_invalid_country_state_update( + self, + set_address: AddressDict, + address_update: AddressDict, + save_fixture: SaveFixture, + session: AsyncSession, + customer: Customer, + ) -> None: + order = await create_order( + save_fixture, + customer=customer, + billing_address=Address.model_validate(set_address), + ) + + with pytest.raises(PolarRequestValidationError): + await order_service.update( + session, + order, + OrderUpdate( + billing_name=None, + billing_address=AddressInput.model_validate(address_update), + ), + ) + + async def test_valid_billing_address_update( + self, save_fixture: SaveFixture, session: AsyncSession, customer: Customer + ) -> None: + order = await create_order( + save_fixture, + customer=customer, + billing_address=Address(country=CountryAlpha2("FR")), + ) + + updated_order = await order_service.update( + session, + order, + OrderUpdate( + billing_name="New Name", + billing_address=AddressInput( + line1="Rue de la Paix", + city="Paris", + postal_code="75000", + country=CountryAlpha2Input("FR"), + ), + ), + ) + await session.flush() + await session.refresh(updated_order) + + assert updated_order.billing_name == "New Name" + assert updated_order.billing_address is not None + assert updated_order.billing_address.country == "FR" + assert updated_order.billing_address.line1 == "Rue de la Paix" + + @pytest.mark.asyncio class TestCreateFromCheckoutOneTime: async def test_recurring_product(