Skip to content

Commit

Permalink
Support ZodNativeEnum in discriminated unions (#256)
Browse files Browse the repository at this point in the history
  • Loading branch information
samchungy authored Apr 20, 2024
1 parent 3cf4b1b commit c33ef85
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ For example in `z.string().nullable()` will be rendered differently
- `string` `type` mapping by default
- ZodDefault
- ZodDiscriminatedUnion
- `discriminator` mapping when all schemas in the union are [registered](#creating-components). The discriminator must be a `ZodLiteral` string value. Only `ZodLiteral` values wrapped in `ZodBranded`, `ZodReadOnly` and `ZodCatch` are supported.
- `discriminator` mapping when all schemas in the union are [registered](#creating-components). The discriminator must be a `ZodLiteral`, `ZodEnum` or `ZodNativeEnum` with string values. Only values wrapped in `ZodBranded`, `ZodReadOnly` and `ZodCatch` are supported.
- ZodEffects
- `transform` support for request schemas. See [Zod Effects](#zod-effects) for how to enable response schema support
- `pre-process` support. We assume that the input type is the same as the output type. Otherwise pipe and transform can be used instead.
Expand Down
123 changes: 123 additions & 0 deletions src/create/schema/parsers/discriminatedUnion.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,89 @@ describe('createDiscriminatedUnionSchema', () => {
expect(result).toEqual(expected);
});

it('creates a oneOf schema with discriminator mapping when schemas with string nativeEnums', () => {
const expected: Schema = {
type: 'schema',
schema: {
discriminator: {
mapping: {
a: '#/components/schemas/a',
c: '#/components/schemas/a',
b: '#/components/schemas/b',
},
propertyName: 'type',
},
oneOf: [
{
$ref: '#/components/schemas/a',
},
{
$ref: '#/components/schemas/b',
},
],
},
};
enum letters {
a = 'a',
c = 'c',
}

const schema = z.discriminatedUnion('type', [
z
.object({
type: z.nativeEnum(letters),
})
.openapi({ ref: 'a' }),
z
.object({
type: z.literal('b'),
})
.openapi({ ref: 'b' }),
]);

const result = createDiscriminatedUnionSchema(schema, createOutputState());

expect(result).toEqual(expected);
});

it('creates a oneOf schema without discriminator mapping when schemas with mixed nativeEnums', () => {
const expected: Schema = {
type: 'schema',
schema: {
oneOf: [
{
$ref: '#/components/schemas/a',
},
{
$ref: '#/components/schemas/b',
},
],
},
};
enum mixed {
a = 'a',
c = 'c',
d = 1,
}

const schema = z.discriminatedUnion('type', [
z
.object({
type: z.nativeEnum(mixed),
})
.openapi({ ref: 'a' }),
z
.object({
type: z.literal('b'),
})
.openapi({ ref: 'b' }),
]);

const result = createDiscriminatedUnionSchema(schema, createOutputState());

expect(result).toEqual(expected);
});

it('handles a discriminated union with an optional type', () => {
const expected: Schema = {
type: 'schema',
Expand Down Expand Up @@ -281,6 +364,46 @@ describe('createDiscriminatedUnionSchema', () => {
expect(result).toEqual(expected);
});

it('handles a discriminated union with a branded enum type', () => {
const expected: Schema = {
type: 'schema',
schema: {
discriminator: {
mapping: {
a: '#/components/schemas/a',
c: '#/components/schemas/a',
b: '#/components/schemas/b',
},
propertyName: 'type',
},
oneOf: [
{
$ref: '#/components/schemas/a',
},
{
$ref: '#/components/schemas/b',
},
],
},
};
const schema = z.discriminatedUnion('type', [
z
.object({
type: z.enum(['a', 'c']).brand(),
})
.openapi({ ref: 'a' }),
z
.object({
type: z.literal('b'),
})
.openapi({ ref: 'b' }),
]);

const result = createDiscriminatedUnionSchema(schema, createOutputState());

expect(result).toEqual(expected);
});

it('handles a discriminated union with a readonly type', () => {
const expected: Schema = {
type: 'schema',
Expand Down
42 changes: 26 additions & 16 deletions src/create/schema/parsers/discriminatedUnion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
createSchemaObject,
} from '../../schema';

import { createNativeEnumSchema } from './nativeEnum';
import { flattenEffects } from './transform';

export const createDiscriminatedUnionSchema = <
Expand All @@ -33,6 +34,7 @@ export const createDiscriminatedUnionSchema = <
schemaObjects,
options,
zodDiscriminatedUnion.discriminator,
state,
);
return {
type: 'schema',
Expand All @@ -44,26 +46,38 @@ export const createDiscriminatedUnionSchema = <
};
};

const unwrapLiteral = (
const unwrapLiterals = (
zodType: ZodType | ZodTypeAny | undefined,
): string | undefined => {
state: SchemaState,
): string[] | undefined => {
if (isZodType(zodType, 'ZodLiteral')) {
if (typeof zodType._def.value !== 'string') {
return undefined;
}
return zodType._def.value;
return [zodType._def.value];
}

if (isZodType(zodType, 'ZodNativeEnum')) {
const schema = createNativeEnumSchema(zodType, state);
if (schema.type === 'schema' && schema.schema.type === 'string') {
return schema.schema.enum;
}
}

if (isZodType(zodType, 'ZodEnum')) {
return zodType._def.values;
}

if (isZodType(zodType, 'ZodBranded')) {
return unwrapLiteral(zodType._def.type);
return unwrapLiterals(zodType._def.type, state);
}

if (isZodType(zodType, 'ZodReadonly')) {
return unwrapLiteral(zodType._def.innerType);
return unwrapLiterals(zodType._def.innerType, state);
}

if (isZodType(zodType, 'ZodCatch')) {
return unwrapLiteral(zodType._def.innerType);
return unwrapLiterals(zodType._def.innerType, state);
}

return undefined;
Expand All @@ -73,6 +87,7 @@ export const mapDiscriminator = (
schemas: Array<oas31.SchemaObject | oas31.ReferenceObject>,
zodObjects: AnyZodObject[],
discriminator: unknown,
state: SchemaState,
): oas31.SchemaObject['discriminator'] => {
if (typeof discriminator !== 'string') {
return undefined;
Expand All @@ -88,20 +103,15 @@ export const mapDiscriminator = (

const value = (zodObject.shape as ZodRawShape)[discriminator];

if (isZodType(value, 'ZodEnum')) {
for (const enumValue of value._def.values as string[]) {
mapping[enumValue] = componentSchemaRef;
}
continue;
}
const literals = unwrapLiterals(value, state);

const literalValue = unwrapLiteral(value);

if (typeof literalValue !== 'string') {
if (!literals) {
return undefined;
}

mapping[literalValue] = componentSchemaRef;
for (const enumValue of literals) {
mapping[enumValue] = componentSchemaRef;
}
}

return {
Expand Down

0 comments on commit c33ef85

Please sign in to comment.