Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clients/apps/web/src/components/Orders/DownloadInvoice.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ const DownloadInvoice = ({
render={({ field }) => (
<>
<CountryStatePicker
disabled={true}
autoComplete="billing address-level1"
country={country}
value={field.value || ''}
Expand All @@ -357,6 +358,7 @@ const DownloadInvoice = ({
render={({ field }) => (
<>
<CountryPicker
disabled={true}
autoComplete="billing country"
value={field.value || undefined}
onChange={field.onChange}
Expand Down
9 changes: 8 additions & 1 deletion clients/packages/ui/src/components/atoms/CountryPicker.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const CountryPicker = ({
value,
onChange,
autoComplete,
disabled,
className,
itemClassName,
contentClassName,
Expand All @@ -33,13 +34,19 @@ const CountryPicker = ({
value?: string
onChange: (value: string) => void
autoComplete?: string
disabled?: boolean
className?: string
itemClassName?: string
contentClassName?: string
}) => {
const countryMap = getCountryList(allowedCountries as TCountryCode[])
return (
<Select onValueChange={onChange} value={value} autoComplete={autoComplete}>
<Select
onValueChange={onChange}
value={value}
autoComplete={autoComplete}
disabled={disabled}
>
<SelectTrigger className={className}>
<SelectValue
placeholder="Country"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ const CountryStatePicker = ({
autoComplete,
itemClassName,
contentClassName,
disabled,
}: {
className?: string
contentClassName?: string
Expand All @@ -92,6 +93,7 @@ const CountryStatePicker = ({
onChange: (value: string) => void
country?: string
autoComplete?: string
disabled?: boolean
}) => {
if (country === 'US' || country === 'CA') {
const states = country === 'US' ? US_STATES : CA_PROVINCES
Expand All @@ -100,6 +102,7 @@ const CountryStatePicker = ({
onValueChange={onChange}
value={value}
autoComplete={autoComplete}
disabled={disabled}
>
<SelectTrigger className={className}>
<SelectValue
Expand Down Expand Up @@ -133,6 +136,7 @@ const CountryStatePicker = ({
placeholder="State / Province"
value={value}
onChange={(e) => onChange(e.target.value)}
disabled={disabled}
/>
)
}
Expand Down
10 changes: 4 additions & 6 deletions server/polar/order/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,14 @@ class Order(CustomFieldDataOutputMixin, MetadataOutputMixin, OrderBase):

class OrderUpdateBase(Schema):
billing_name: str | None = Field(
description=(
"The name of the customer that should appear on the invoice. "
"Can't be updated after the invoice is generated."
)
None, description="The name of the customer that should appear on the invoice."
)
billing_address: AddressInput | None = Field(
None,
description=(
"The address of the customer that should appear on the invoice. "
"Can't be updated after the invoice is generated."
)
"Country and state fields cannot be updated."
),
)


Expand Down
29 changes: 28 additions & 1 deletion server/polar/order/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
build_system_event,
)
from polar.eventstream.service import publish as eventstream_publish
from polar.exceptions import PolarError
from polar.exceptions import PolarError, PolarRequestValidationError, ValidationError
from polar.file.s3 import S3_SERVICES
from polar.held_balance.service import held_balance as held_balance_service
from polar.integrations.stripe.service import stripe as stripe_service
Expand Down Expand Up @@ -344,6 +344,33 @@ async def update(
order_update: OrderUpdate | CustomerOrderUpdate,
) -> Order:
repository = OrderRepository.from_session(session)

errors: list[ValidationError] = []

billing_address = order_update.billing_address
if billing_address is not None and order.billing_address is not None:
if str(billing_address.country) != str(order.billing_address.country):
errors.append(
{
"loc": ("body", "billing_address", "country"),
"msg": "Country cannot be changed",
"type": "value_error",
"input": billing_address.country,
}
)
if billing_address.state != order.billing_address.state:
errors.append(
{
"loc": ("body", "billing_address", "state"),
"msg": "State cannot be changed",
"type": "value_error",
"input": billing_address.state,
}
)

if errors:
raise PolarRequestValidationError(errors)

order = await repository.update(
order, update_dict=order_update.model_dump(exclude_unset=True)
)
Expand Down
75 changes: 74 additions & 1 deletion server/tests/order/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@
SubscriptionRecurringInterval,
TaxProcessor,
)
from polar.exceptions import PolarRequestValidationError
from polar.held_balance.service import held_balance as held_balance_service
from polar.integrations.stripe.service import StripeService
from polar.kit.address import Address, CountryAlpha2
from polar.kit.address import (
Address,
AddressDict,
AddressInput,
CountryAlpha2,
CountryAlpha2Input,
)
from polar.kit.db.postgres import AsyncSession
from polar.kit.math import polar_round
from polar.kit.pagination import PaginationParams
Expand Down Expand Up @@ -49,6 +56,7 @@
from polar.models.subscription import SubscriptionStatus
from polar.models.transaction import PlatformFeeType, TransactionType
from polar.models.wallet import WalletType
from polar.order.schemas import OrderUpdate
from polar.order.service import (
CardPaymentFailed,
MissingCheckoutCustomer,
Expand Down Expand Up @@ -429,6 +437,71 @@ async def test_metadata_filter(
assert order2 in orders


@pytest.mark.asyncio
class TestUpdate:
@pytest.mark.parametrize(
("set_address", "address_update"),
[
({"country": "US", "state": "CA"}, {"country": "US", "state": "NY"}),
({"country": "US", "state": "CA"}, {"country": "FR", "state": None}),
({"country": "FR", "state": None}, {"country": "US", "state": "CA"}),
],
)
async def test_invalid_country_state_update(
self,
set_address: AddressDict,
address_update: AddressDict,
save_fixture: SaveFixture,
session: AsyncSession,
customer: Customer,
) -> None:
order = await create_order(
save_fixture,
customer=customer,
billing_address=Address.model_validate(set_address),
)

with pytest.raises(PolarRequestValidationError):
await order_service.update(
session,
order,
OrderUpdate(
billing_name=None,
billing_address=AddressInput.model_validate(address_update),
),
)

async def test_valid_billing_address_update(
self, save_fixture: SaveFixture, session: AsyncSession, customer: Customer
) -> None:
order = await create_order(
save_fixture,
customer=customer,
billing_address=Address(country=CountryAlpha2("FR")),
)

updated_order = await order_service.update(
session,
order,
OrderUpdate(
billing_name="New Name",
billing_address=AddressInput(
line1="Rue de la Paix",
city="Paris",
postal_code="75000",
country=CountryAlpha2Input("FR"),
),
),
)
await session.flush()
await session.refresh(updated_order)

assert updated_order.billing_name == "New Name"
assert updated_order.billing_address is not None
assert updated_order.billing_address.country == "FR"
assert updated_order.billing_address.line1 == "Rue de la Paix"


@pytest.mark.asyncio
class TestCreateFromCheckoutOneTime:
async def test_recurring_product(
Expand Down
Loading