diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index bbeafb07f..81d52dc95 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -83,25 +83,25 @@ function now(): DateTime { /** * Generates a globally unique identifier based on the UUID specs. */ -function uuid(version: Int?): String { +function uuid(version: Int?, format: String?): String { } @@@expressionContext([DefaultValue]) /** * Generates a globally unique identifier based on the CUID spec. */ -function cuid(version: Int?): String { +function cuid(version: Int?, format: String?): String { } @@@expressionContext([DefaultValue]) /** * Generates an identifier based on the nanoid spec. */ -function nanoid(length: Int?): String { +function nanoid(length: Int?, format: String?): String { } @@@expressionContext([DefaultValue]) /** * Generates an identifier based on the ulid spec. */ -function ulid(): String { +function ulid(format: String?): String { } @@@expressionContext([DefaultValue]) /** diff --git a/packages/language/src/validators/function-invocation-validator.ts b/packages/language/src/validators/function-invocation-validator.ts index a2ff34fde..101c95893 100644 --- a/packages/language/src/validators/function-invocation-validator.ts +++ b/packages/language/src/validators/function-invocation-validator.ts @@ -87,6 +87,16 @@ export default class FunctionInvocationValidator implements AstValidator param.name === 'format'); + const formatArg = getLiteral(expr.args[formatParamIdx]?.value); + if (formatArg && !formatArg.includes('%s')) { + accept('error', 'argument must include "%s"', { + node: expr.args[formatParamIdx]!, + }); + } + } + // run checkers for specific functions const checker = invocationCheckers.get(expr.function.$refText); if (checker) { diff --git a/packages/language/test/function-invocation.test.ts b/packages/language/test/function-invocation.test.ts new file mode 100644 index 000000000..1aecf66ec --- /dev/null +++ b/packages/language/test/function-invocation.test.ts @@ -0,0 +1,209 @@ +import { describe, it } from 'vitest'; +import { loadSchema, loadSchemaWithError } from './utils'; + +describe('Function Invocation Tests', () => { + it('id functions should not require format strings', async () => { + await loadSchema( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(uuid()) + } + `, + ); + + await loadSchema( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(uuid(7)) + } + `, + ); + + await loadSchema( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(nanoid()) + } + `, + ); + + await loadSchema( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(nanoid(8)) + } + `, + ); + + await loadSchema( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(ulid()) + } + `, + ); + + await loadSchema( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(cuid()) + } + `, + ); + + await loadSchema( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(cuid(2)) + } + `, + ); + }); + + it('id functions should allow valid format strings', async () => { + await loadSchema( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(uuid(7, '%s_user')) + } + `, + ); + + await loadSchema( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(cuid(2, '%s')) + } + `, + ); + + await loadSchema( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(ulid('user_%s')) + } + `, + ); + + await loadSchema( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(nanoid(8, 'user_%s')) + } + `, + ); + }); + + it('id functions should reject invalid format strings', async () => { + await loadSchemaWithError( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(uuid(7, 'user_%')) + } + `, + 'argument must include', + ); + + await loadSchemaWithError( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(nanoid(8, 'user')) + } + `, + 'argument must include', + ); + + await loadSchemaWithError( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(ulid('user_%')) + } + `, + 'argument must include', + ); + + await loadSchemaWithError( + ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + model User { + id String @id @default(cuid(2, 'user_%')) + } + `, + 'argument must include', + ); + }); +}); diff --git a/packages/orm/src/client/crud/operations/base.ts b/packages/orm/src/client/crud/operations/base.ts index 0ac2fe8c5..8fb69846f 100644 --- a/packages/orm/src/client/crud/operations/base.ts +++ b/packages/orm/src/client/crud/operations/base.ts @@ -860,22 +860,30 @@ export abstract class BaseOperationHandler { private evalGenerator(defaultValue: Expression) { if (ExpressionUtils.isCall(defaultValue)) { return match(defaultValue.function) - .with('cuid', () => createId()) - .with('uuid', () => - defaultValue.args?.[0] && - ExpressionUtils.isLiteral(defaultValue.args?.[0]) && - defaultValue.args[0].value === 7 + .with('cuid', () => this.formatGeneratedValue(createId(), defaultValue.args?.[1])) + .with('uuid', () => { + const version = defaultValue.args?.[0] && ExpressionUtils.isLiteral(defaultValue.args[0]) + ? defaultValue.args[0].value + : undefined; + + const generated = version === 7 ? uuid.v7() - : uuid.v4(), - ) - .with('nanoid', () => - defaultValue.args?.[0] && - ExpressionUtils.isLiteral(defaultValue.args[0]) && - typeof defaultValue.args[0].value === 'number' - ? nanoid(defaultValue.args[0].value) - : nanoid(), - ) - .with('ulid', () => ulid()) + : uuid.v4(); + + return this.formatGeneratedValue(generated, defaultValue.args?.[1]); + }) + .with('nanoid', () => { + const length = defaultValue.args?.[0] && ExpressionUtils.isLiteral(defaultValue.args[0]) + ? defaultValue.args[0].value + : undefined; + + const generated = typeof length === 'number' + ? nanoid(length) + : nanoid(); + + return this.formatGeneratedValue(generated, defaultValue.args?.[1]); + }) + .with('ulid', () => this.formatGeneratedValue(ulid(), defaultValue.args?.[0])) .otherwise(() => undefined); } else if ( ExpressionUtils.isMember(defaultValue) && @@ -893,6 +901,15 @@ export abstract class BaseOperationHandler { } } + private formatGeneratedValue(generated: string, formatExpr?: Expression) { + if (!formatExpr || !ExpressionUtils.isLiteral(formatExpr) || typeof formatExpr.value !== 'string') { + return generated; + } + + // Replace non-escaped %s with the generated value, then unescape \%s to %s + return formatExpr.value.replace(/(? { + it('supports top-level ids', async () => { + const client = await createTestClient(schema); + + const user = await client.user.create({ + data: { + id: 1, + }, + }); + expect(user.uuid).toMatch(/^user_uuid_/); + expect(user.uuid7).toMatch(/^user_uuid7_/); + expect(user.cuid).toMatch(/^user_cuid_/); + expect(user.cuid2).toMatch(/^user_cuid2_/); + expect(user.nanoid).toMatch(/^user_nanoid_/); + expect(user.nanoid8).toMatch(/^user_nanoid8_/); + expect(user.ulid).toMatch(/^user_ulid_/); + }); + + it('supports nested ids', async () => { + const client = await createTestClient(schema); + + const user = await client.user.create({ + data: { + id: 1, + + posts: { + create: { + id: 1, + }, + }, + }, + }); + expect(user.uuid).toMatch(/^user_uuid_/); + expect(user.uuid7).toMatch(/^user_uuid7_/); + expect(user.cuid).toMatch(/^user_cuid_/); + expect(user.cuid2).toMatch(/^user_cuid2_/); + expect(user.nanoid).toMatch(/^user_nanoid_/); + expect(user.nanoid8).toMatch(/^user_nanoid8_/); + expect(user.ulid).toMatch(/^user_ulid_/); + + const post = await client.post.findUniqueOrThrow({ where: { id: 1 } }); + expect(post.uuid).toMatch(/^post_uuid_/); + expect(post.uuid7).toMatch(/^post_uuid7_/); + expect(post.cuid).toMatch(/^post_cuid_/); + expect(post.cuid2).toMatch(/^post_cuid2_/); + expect(post.nanoid).toMatch(/^post_nanoid_/); + expect(post.nanoid8).toMatch(/^post_nanoid8_/); + expect(post.ulid).toMatch(/^post_ulid_/); + }); + + it('supports deeply nested ids', async () => { + const client = await createTestClient(schema); + + const user = await client.user.create({ + data: { + id: 1, + + posts: { + create: { + id: 1, + + comments: { + create: { + id: 1, + }, + }, + }, + }, + }, + }); + expect(user.uuid).toMatch(/^user_uuid_/); + expect(user.uuid7).toMatch(/^user_uuid7_/); + expect(user.cuid).toMatch(/^user_cuid_/); + expect(user.cuid2).toMatch(/^user_cuid2_/); + expect(user.nanoid).toMatch(/^user_nanoid_/); + expect(user.nanoid8).toMatch(/^user_nanoid8_/); + expect(user.ulid).toMatch(/^user_ulid_/); + + const post = await client.post.findUniqueOrThrow({ where: { id: 1 } }); + expect(post.uuid).toMatch(/^post_uuid_/); + expect(post.uuid7).toMatch(/^post_uuid7_/); + expect(post.cuid).toMatch(/^post_cuid_/); + expect(post.cuid2).toMatch(/^post_cuid2_/); + expect(post.nanoid).toMatch(/^post_nanoid_/); + expect(post.nanoid8).toMatch(/^post_nanoid8_/); + expect(post.ulid).toMatch(/^post_ulid_/); + + const comment = await client.comment.findUniqueOrThrow({ where: { id: 1 } }); + expect(comment.uuid).toMatch(/^comment_uuid_/); + expect(comment.uuid7).toMatch(/^comment_uuid7_/); + expect(comment.cuid).toMatch(/^comment_cuid_/); + expect(comment.cuid2).toMatch(/^comment_cuid2_/); + expect(comment.nanoid).toMatch(/^comment_nanoid_/); + expect(comment.nanoid8).toMatch(/^comment_nanoid8_/); + expect(comment.ulid).toMatch(/^comment_ulid_/); + }); + + it('supports escaped placeholders and edge cases', async () => { + const escapedSchema = ` +model EscapedTest { + id Int @id + escaped String @default(uuid(4, "prefix_\\\\%s_suffix")) + consecutive String @default(uuid(4, "%s%s")) + mixedEscaped String @default(uuid(4, "\\\\%s_%s_end")) + startWithPattern String @default(uuid(4, "%s_suffix")) + endWithPattern String @default(uuid(4, "prefix_%s")) +} +`; + const client = await createTestClient(escapedSchema); + + const record = await client.escapedTest.create({ + data: { + id: 1, + }, + }); + + // Escaped \%s should become literal %s in output + expect(record.escaped).toMatch(/^prefix_%s_suffix$/); + + // Consecutive %s%s should both be replaced + expect(record.consecutive).toMatch(/^[0-9a-f-]{36}[0-9a-f-]{36}$/); + + // Mixed: first \%s stays as %s, second %s is replaced + expect(record.mixedEscaped).toMatch(/^%s_[0-9a-f-]{36}_end$/); + + // Pattern at start + expect(record.startWithPattern).toMatch(/^[0-9a-f-]{36}_suffix$/); + + // Pattern at end + expect(record.endWithPattern).toMatch(/^prefix_[0-9a-f-]{36}$/); + }); +});