Skip to content
8 changes: 4 additions & 4 deletions packages/language/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,25 @@ function now(): DateTime {
/**
* Generates a globally unique identifier based on the UUID specs.
*/
function uuid(version: Int?): String {
function uuid(version: Int?, format: String?): String {
} @@@expressionContext([DefaultValue])

/**
* Generates a globally unique identifier based on the CUID spec.
*/
function cuid(version: Int?): String {
function cuid(version: Int?, format: String?): String {
} @@@expressionContext([DefaultValue])

/**
* Generates an identifier based on the nanoid spec.
*/
function nanoid(length: Int?): String {
function nanoid(length: Int?, format: String?): String {
} @@@expressionContext([DefaultValue])

/**
* Generates an identifier based on the ulid spec.
*/
function ulid(): String {
function ulid(format: String?): String {
} @@@expressionContext([DefaultValue])

/**
Expand Down
10 changes: 10 additions & 0 deletions packages/language/src/validators/function-invocation-validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ export default class FunctionInvocationValidator implements AstValidator<Express
}
}

if (['uuid', 'ulid', 'cuid', 'nanoid'].includes(funcDecl.name)) {
const formatParamIdx = funcDecl.params.findIndex(param => param.name === 'format');
const formatArg = getLiteral<string>(expr.args[formatParamIdx]?.value);
if (formatArg && !formatArg.includes('%s')) {
accept('error', 'argument must include "%s"', {
node: expr.args[formatParamIdx]!,
});
}
}

// run checkers for specific functions
const checker = invocationCheckers.get(expr.function.$refText);
if (checker) {
Expand Down
209 changes: 209 additions & 0 deletions packages/language/test/function-invocation.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import { describe, it } from 'vitest';
import { loadSchema, loadSchemaWithError } from './utils';

describe('Function Invocation Tests', () => {
it('id functions should not require format strings', async () => {
await loadSchema(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(uuid())
}
`,
);

await loadSchema(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(uuid(7))
}
`,
);

await loadSchema(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(nanoid())
}
`,
);

await loadSchema(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(nanoid(8))
}
`,
);

await loadSchema(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(ulid())
}
`,
);

await loadSchema(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(cuid())
}
`,
);

await loadSchema(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(cuid(2))
}
`,
);
});

it('id functions should allow valid format strings', async () => {
await loadSchema(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(uuid(7, '%s_user'))
}
`,
);

await loadSchema(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(cuid(2, '%s'))
}
`,
);

await loadSchema(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(ulid('user_%s'))
}
`,
);

await loadSchema(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(nanoid(8, 'user_%s'))
}
`,
);
});

it('id functions should reject invalid format strings', async () => {
await loadSchemaWithError(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(uuid(7, 'user_%'))
}
`,
'argument must include',
);

await loadSchemaWithError(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(nanoid(8, 'user'))
}
`,
'argument must include',
);

await loadSchemaWithError(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(ulid('user_%'))
}
`,
'argument must include',
);

await loadSchemaWithError(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(cuid(2, 'user_%'))
}
`,
'argument must include',
);
});
});
47 changes: 32 additions & 15 deletions packages/orm/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -860,22 +860,30 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
private evalGenerator(defaultValue: Expression) {
if (ExpressionUtils.isCall(defaultValue)) {
return match(defaultValue.function)
.with('cuid', () => createId())
.with('uuid', () =>
defaultValue.args?.[0] &&
ExpressionUtils.isLiteral(defaultValue.args?.[0]) &&
defaultValue.args[0].value === 7
.with('cuid', () => this.formatGeneratedValue(createId(), defaultValue.args?.[1]))
.with('uuid', () => {
const version = defaultValue.args?.[0] && ExpressionUtils.isLiteral(defaultValue.args[0])
? defaultValue.args[0].value
: undefined;

const generated = version === 7
? uuid.v7()
: uuid.v4(),
)
.with('nanoid', () =>
defaultValue.args?.[0] &&
ExpressionUtils.isLiteral(defaultValue.args[0]) &&
typeof defaultValue.args[0].value === 'number'
? nanoid(defaultValue.args[0].value)
: nanoid(),
)
.with('ulid', () => ulid())
: uuid.v4();

return this.formatGeneratedValue(generated, defaultValue.args?.[1]);
})
.with('nanoid', () => {
const length = defaultValue.args?.[0] && ExpressionUtils.isLiteral(defaultValue.args[0])
? defaultValue.args[0].value
: undefined;

const generated = typeof length === 'number'
? nanoid(length)
: nanoid();

return this.formatGeneratedValue(generated, defaultValue.args?.[1]);
})
.with('ulid', () => this.formatGeneratedValue(ulid(), defaultValue.args?.[0]))
.otherwise(() => undefined);
} else if (
ExpressionUtils.isMember(defaultValue) &&
Expand All @@ -893,6 +901,15 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
}
}

private formatGeneratedValue(generated: string, formatExpr?: Expression) {
if (!formatExpr || !ExpressionUtils.isLiteral(formatExpr) || typeof formatExpr.value !== 'string') {
return generated;
}

// Replace non-escaped %s with the generated value, then unescape \%s to %s
return formatExpr.value.replace(/(?<!\\)%s/g, generated).replace(/\\%s/g, '%s');
}

protected async update(
kysely: AnyKysely,
model: string,
Expand Down
Loading
Loading