diff --git a/express-zod-api/src/deep-checks.ts b/express-zod-api/src/deep-checks.ts index e79964364..a1654109c 100644 --- a/express-zod-api/src/deep-checks.ts +++ b/express-zod-api/src/deep-checks.ts @@ -1,171 +1,81 @@ -import type { - $ZodArray, - $ZodCatch, - $ZodDefault, - $ZodDiscriminatedUnion, - $ZodInterface, - $ZodIntersection, - $ZodLazy, - $ZodNullable, - $ZodObject, - $ZodOptional, - $ZodPipe, - $ZodReadonly, - $ZodRecord, - $ZodTuple, - $ZodType, - $ZodUnion, -} from "@zod/core"; -import { fail } from "node:assert/strict"; // eslint-disable-line no-restricted-syntax -- acceptable -import { globalRegistry } from "zod"; +import type { $ZodType } from "@zod/core"; +import * as R from "ramda"; +import { globalRegistry, z } from "zod"; import { ezDateInBrand } from "./date-in-schema"; import { ezDateOutBrand } from "./date-out-schema"; -import { ezFileBrand } from "./file-schema"; +import { DeepCheckError } from "./errors"; import { ezFormBrand } from "./form-schema"; import { IOSchema } from "./io-schema"; import { metaSymbol } from "./metadata"; -import { ProprietaryBrand } from "./proprietary-schemas"; -import { ezRawBrand } from "./raw-schema"; -import { - FirstPartyKind, - HandlingRules, - NextHandlerInc, - SchemaHandler, -} from "./schema-walker"; +import { FirstPartyKind } from "./schema-walker"; import { ezUploadBrand } from "./upload-schema"; +import { ezRawBrand } from "./raw-schema"; -type CheckContext = { visited: WeakSet }; -/** @desc Check is a schema handling rule returning boolean */ -type Check = SchemaHandler; - -const onSomeUnion: Check = ( - { _zod }: $ZodUnion | $ZodDiscriminatedUnion, - { next }, -) => _zod.def.options.some(next); - -const onIntersection: Check = ({ _zod }: $ZodIntersection, { next }) => - [_zod.def.left, _zod.def.right].some(next); - -const onWrapped: Check = ( - { - _zod: { def }, - }: $ZodOptional | $ZodNullable | $ZodReadonly | $ZodDefault | $ZodCatch, - { next }, -) => next(def.innerType); - -const ioChecks: HandlingRules = { - object: ({ _zod }: $ZodObject, { next }) => - Object.values(_zod.def.shape).some(next), - interface: (int: $ZodInterface, { next, visited }) => - visited.has(int) - ? false - : visited.add(int) && Object.values(int._zod.def.shape).some(next), - union: onSomeUnion, - intersection: onIntersection, - optional: onWrapped, - nullable: onWrapped, - default: onWrapped, - record: ({ _zod }: $ZodRecord, { next }) => next(_zod.def.valueType), - array: ({ _zod }: $ZodArray, { next }) => next(_zod.def.element), -}; - -interface NestedSchemaLookupProps extends Partial { - condition?: (schema: $ZodType) => boolean; - rules?: HandlingRules< - boolean, - CheckContext, - FirstPartyKind | ProprietaryBrand - >; - maxDepth?: number; - depth?: number; +interface NestedSchemaLookupProps { + io: "input" | "output"; + condition: (zodSchema: $ZodType) => boolean; } -/** @desc The optimized version of the schema walker for boolean checks */ -export const hasNestedSchema = ( +export const findNestedSchema = ( subject: $ZodType, - { - condition, - rules = ioChecks, - depth = 1, - maxDepth = Number.POSITIVE_INFINITY, - visited = new WeakSet(), - }: NestedSchemaLookupProps, -): boolean => { - if (condition?.(subject)) return true; - if (depth >= maxDepth) return false; - const { brand } = globalRegistry.get(subject)?.[metaSymbol] ?? {}; - const handler = - brand && brand in rules - ? rules[brand as keyof typeof rules] - : rules[subject._zod.def.type]; - if (handler) { - return handler(subject, { - visited, - next: (schema) => - hasNestedSchema(schema, { - condition, - rules, - maxDepth, - visited, - depth: depth + 1, - }), - } as CheckContext & NextHandlerInc); - } - return false; -}; - -export const hasUpload = (subject: IOSchema) => - hasNestedSchema(subject, { - condition: (schema) => - globalRegistry.get(schema)?.[metaSymbol]?.brand === ezUploadBrand, - rules: { - ...ioChecks, - [ezFormBrand]: ioChecks.object, + { io, condition }: NestedSchemaLookupProps, +) => + R.tryCatch( + () => { + z.toJSONSchema(subject, { + io, + unrepresentable: "any", + override: ({ zodSchema }) => { + if (condition(zodSchema)) throw new DeepCheckError(zodSchema); // exits early + }, + }); + return undefined; }, - }); + (err: DeepCheckError) => err.cause, + )(); -export const hasRaw = (subject: IOSchema) => - hasNestedSchema(subject, { - condition: (schema) => - globalRegistry.get(schema)?.[metaSymbol]?.brand === ezRawBrand, - maxDepth: 3, +export const findRequestTypeDefiningSchema = (subject: IOSchema) => + findNestedSchema(subject, { + condition: (schema) => { + const { brand } = globalRegistry.get(schema)?.[metaSymbol] || {}; + return ( + typeof brand === "symbol" && + [ezUploadBrand, ezRawBrand, ezFormBrand].includes(brand) + ); + }, + io: "input", }); -export const hasForm = (subject: IOSchema) => - hasNestedSchema(subject, { - condition: (schema) => - globalRegistry.get(schema)?.[metaSymbol]?.brand === ezFormBrand, - maxDepth: 3, - }); +const unsupported: FirstPartyKind[] = [ + "nan", + "symbol", + "map", + "set", + "bigint", + "void", + "promise", + "never", +]; -/** @throws AssertionError with incompatible schema constructor */ -export const assertJsonCompatible = (subject: $ZodType, dir: "in" | "out") => - hasNestedSchema(subject, { - maxDepth: 300, - rules: { - ...ioChecks, - readonly: onWrapped, - catch: onWrapped, - pipe: ({ _zod }: $ZodPipe, { next }) => next(_zod.def[dir]), - lazy: ({ _zod: { def } }: $ZodLazy, { next, visited }) => - visited.has(def.getter) - ? false - : visited.add(def.getter) && next(def.getter()), - tuple: ({ _zod: { def } }: $ZodTuple, { next }) => - [...def.items].concat(def.rest ?? []).some(next), - nan: () => fail("z.nan()"), - symbol: () => fail("z.symbol()"), - map: () => fail("z.map()"), - set: () => fail("z.set()"), - bigint: () => fail("z.bigint()"), - void: () => fail("z.void()"), - promise: () => fail("z.promise()"), - never: () => fail("z.never()"), - date: () => dir === "in" && fail("z.date()"), - [ezDateOutBrand]: () => dir === "in" && fail("ez.dateOut()"), - [ezDateInBrand]: () => dir === "out" && fail("ez.dateIn()"), - [ezRawBrand]: () => dir === "out" && fail("ez.raw()"), - [ezUploadBrand]: () => dir === "out" && fail("ez.upload()"), - [ezFileBrand]: () => false, +export const findJsonIncompatible = ( + subject: $ZodType, + io: "input" | "output", +) => + findNestedSchema(subject, { + io, + condition: (zodSchema) => { + const { brand } = globalRegistry.get(zodSchema)?.[metaSymbol] || {}; + const { type } = zodSchema._zod.def; + if (unsupported.includes(type)) return true; + if (io === "input") { + if (type === "date") return true; + if (brand === ezDateOutBrand) return true; + } + if (io === "output") { + if (brand === ezDateInBrand) return true; + if (brand === ezRawBrand) return true; + if (brand === ezUploadBrand) return true; + } + return false; }, }); diff --git a/express-zod-api/src/diagnostics.ts b/express-zod-api/src/diagnostics.ts index b7092997b..86408dad6 100644 --- a/express-zod-api/src/diagnostics.ts +++ b/express-zod-api/src/diagnostics.ts @@ -1,17 +1,14 @@ -import * as R from "ramda"; import type { $ZodShape } from "@zod/core"; import { z } from "zod"; import { responseVariants } from "./api-response"; import { FlatObject, getRoutePathParams } from "./common-helpers"; import { contentTypes } from "./content-type"; -import { assertJsonCompatible } from "./deep-checks"; +import { findJsonIncompatible } from "./deep-checks"; import { AbstractEndpoint } from "./endpoint"; import { extractObjectSchema } from "./io-schema"; import { ActualLogger } from "./logger-helpers"; export class Diagnostics { - /** @desc (catcher)(...args) => bool | ReturnValue */ - readonly #trier = R.tryCatch(assertJsonCompatible); #verifiedEndpoints = new WeakSet(); #verifiedPaths = new WeakMap< AbstractEndpoint, @@ -35,23 +32,24 @@ export class Diagnostics { } } if (endpoint.requestType === "json") { - this.#trier((reason) => + const reason = findJsonIncompatible(endpoint.inputSchema, "input"); + if (reason) { this.logger.warn( "The final input schema of the endpoint contains an unsupported JSON payload type.", Object.assign(ctx, { reason }), - ), - )(endpoint.inputSchema, "in"); + ); + } } for (const variant of responseVariants) { - const catcher = this.#trier((reason) => - this.logger.warn( - `The final ${variant} response schema of the endpoint contains an unsupported JSON payload type.`, - Object.assign(ctx, { reason }), - ), - ); for (const { mimeTypes, schema } of endpoint.getResponses(variant)) { if (!mimeTypes?.includes(contentTypes.json)) continue; - catcher(schema, "out"); + const reason = findJsonIncompatible(schema, "output"); + if (reason) { + this.logger.warn( + `The final ${variant} response schema of the endpoint contains an unsupported JSON payload type.`, + Object.assign(ctx, { reason }), + ); + } } } this.#verifiedEndpoints.add(endpoint); diff --git a/express-zod-api/src/endpoint.ts b/express-zod-api/src/endpoint.ts index 5527060c7..a62a9c650 100644 --- a/express-zod-api/src/endpoint.ts +++ b/express-zod-api/src/endpoint.ts @@ -1,8 +1,8 @@ import { Request, Response } from "express"; import * as R from "ramda"; -import { z } from "zod"; +import { globalRegistry, z } from "zod"; import { NormalizedResponse, ResponseVariant } from "./api-response"; -import { hasForm, hasRaw, hasUpload } from "./deep-checks"; +import { findRequestTypeDefiningSchema } from "./deep-checks"; import { FlatObject, getActualMethod, @@ -15,16 +15,20 @@ import { OutputValidationError, ResultHandlerError, } from "./errors"; +import { ezFormBrand } from "./form-schema"; import { IOSchema } from "./io-schema"; import { lastResortHandler } from "./last-resort"; import { ActualLogger } from "./logger-helpers"; import { LogicalContainer } from "./logical-container"; +import { metaSymbol } from "./metadata"; import { AuxMethod, Method } from "./method"; import { AbstractMiddleware, ExpressMiddleware } from "./middleware"; import { ContentType } from "./content-type"; +import { ezRawBrand } from "./raw-schema"; import { Routable } from "./routable"; import { AbstractResultHandler } from "./result-handler"; import { Security } from "./security"; +import { ezUploadBrand } from "./upload-schema"; export type Handler = (params: { /** @desc The inputs from the enabled input sources validated against the final input schema (incl. Middlewares) */ @@ -137,13 +141,14 @@ export class Endpoint< /** @internal */ public override get requestType() { - return hasUpload(this.#def.inputSchema) - ? "upload" - : hasRaw(this.#def.inputSchema) - ? "raw" - : hasForm(this.#def.inputSchema) - ? "form" - : "json"; + const found = findRequestTypeDefiningSchema(this.#def.inputSchema); + if (found) { + const { brand } = globalRegistry.get(found)?.[metaSymbol] || {}; + if (brand === ezUploadBrand) return "upload"; + if (brand === ezRawBrand) return "raw"; + if (brand === ezFormBrand) return "form"; + } + return "json"; } /** @internal */ diff --git a/express-zod-api/src/errors.ts b/express-zod-api/src/errors.ts index 39794d62e..d152c98a2 100644 --- a/express-zod-api/src/errors.ts +++ b/express-zod-api/src/errors.ts @@ -1,3 +1,4 @@ +import type { $ZodType } from "@zod/core"; import { z } from "zod"; import { getMessageFromError } from "./common-helpers"; import { OpenAPIContext } from "./documentation-helpers"; @@ -34,6 +35,14 @@ export class IOSchemaError extends Error { public override name = "IOSchemaError"; } +export class DeepCheckError extends IOSchemaError { + public override name = "DeepCheckError"; + + constructor(public override readonly cause: $ZodType) { + super("Found", { cause }); + } +} + /** @desc An error of validating the Endpoint handler's returns against the Endpoint output schema */ export class OutputValidationError extends IOSchemaError { public override name = "OutputValidationError"; diff --git a/express-zod-api/tests/deep-checks.spec.ts b/express-zod-api/tests/deep-checks.spec.ts index 68f27c228..b140d6bc9 100644 --- a/express-zod-api/tests/deep-checks.spec.ts +++ b/express-zod-api/tests/deep-checks.spec.ts @@ -2,17 +2,19 @@ import { UploadedFile } from "express-fileupload"; import { globalRegistry, z } from "zod"; import type { $brand, $ZodType } from "@zod/core"; import { ez } from "../src"; -import { hasNestedSchema } from "../src/deep-checks"; +import { findNestedSchema } from "../src/deep-checks"; import { metaSymbol } from "../src/metadata"; import { ezUploadBrand } from "../src/upload-schema"; describe("Checks", () => { - describe("hasNestedSchema()", () => { + describe("findNestedSchema()", () => { const condition = (subject: $ZodType) => globalRegistry.get(subject)?.[metaSymbol]?.brand === ezUploadBrand; test("should return true for given argument satisfying condition", () => { - expect(hasNestedSchema(ez.upload(), { condition })).toBeTruthy(); + expect( + findNestedSchema(ez.upload(), { condition, io: "input" }), + ).toBeTruthy(); }); test.each([ @@ -26,7 +28,9 @@ describe("Checks", () => { ez.upload().refine(() => true), z.array(ez.upload()), ])("should return true for wrapped needle %#", (subject) => { - expect(hasNestedSchema(subject, { condition })).toBeTruthy(); + expect( + findNestedSchema(subject, { condition, io: "input" }), + ).toBeTruthy(); }); test.each([ @@ -36,10 +40,12 @@ describe("Checks", () => { z.boolean().and(z.literal(true)), z.number().or(z.string()), ])("should return false in other cases %#", (subject) => { - expect(hasNestedSchema(subject, { condition })).toBeFalsy(); + expect( + findNestedSchema(subject, { condition, io: "input" }), + ).toBeUndefined(); }); - test("should finish early", () => { + test("should finish early (from bottom to top)", () => { const subject = z.object({ one: z.object({ two: z.object({ @@ -47,9 +53,10 @@ describe("Checks", () => { }), }), }); - const check = vi.fn((schema) => schema instanceof z.ZodObject); - hasNestedSchema(subject, { + const check = vi.fn((schema) => schema instanceof z.ZodNumber); + findNestedSchema(subject, { condition: check, + io: "input", }); expect(check.mock.calls.length).toBe(1); }); diff --git a/express-zod-api/tests/errors.spec.ts b/express-zod-api/tests/errors.spec.ts index 4e6aa2334..c594ceddf 100644 --- a/express-zod-api/tests/errors.spec.ts +++ b/express-zod-api/tests/errors.spec.ts @@ -6,6 +6,7 @@ import { MissingPeerError, OutputValidationError, ResultHandlerError, + DeepCheckError, } from "../src/errors"; describe("Errors", () => { @@ -59,6 +60,24 @@ describe("Errors", () => { }); }); + describe("DeepCheckError", () => { + const schema = z.any(); + const error = new DeepCheckError(schema); + + test("should be an instance of IOSchemaError and Error", () => { + expect(error).toBeInstanceOf(IOSchemaError); + expect(error).toBeInstanceOf(Error); + }); + + test("should have the name matching its class", () => { + expect(error.name).toBe("DeepCheckError"); + }); + + test("should have the cause matching the schema", () => { + expect(error.cause).toBe(schema); + }); + }); + describe("OutputValidationError", () => { const zodError = new z.ZodError([]); const error = new OutputValidationError(zodError); diff --git a/express-zod-api/tests/routing.spec.ts b/express-zod-api/tests/routing.spec.ts index d4baabe94..d0cb1d8cc 100644 --- a/express-zod-api/tests/routing.spec.ts +++ b/express-zod-api/tests/routing.spec.ts @@ -423,11 +423,11 @@ describe("Routing", () => { expect(logger._getLogs().warn).toEqual([ [ "The final input schema of the endpoint contains an unsupported JSON payload type.", - { method: "get", path: "/path", reason: expect.any(Error) }, + { method: "get", path: "/path", reason: expect.any(z.ZodType) }, ], [ "The final positive response schema of the endpoint contains an unsupported JSON payload type.", - { method: "get", path: "/path", reason: expect.any(Error) }, + { method: "get", path: "/path", reason: expect.any(z.ZodType) }, ], ]); });