Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add @encrypted enhancer #1922

Open
wants to merge 17 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/runtime/src/enhancements/edge/encrypted.ts
13 changes: 12 additions & 1 deletion packages/runtime/src/enhancements/node/create-enhancement.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { withJsonProcessor } from './json-processor';
import { Logger } from './logger';
import { withOmit } from './omit';
import { withPassword } from './password';
import { withEncrypted } from './encrypted';
import { policyProcessIncludeRelationPayload, withPolicy } from './policy';
import type { PolicyDef } from './types';

Expand Down Expand Up @@ -100,6 +101,7 @@ export function createEnhancement<DbClient extends object>(
}

const hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password'));
const hasEncrypted = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@encrypted'));
const hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit'));
const hasDefaultAuth = allFields.some((field) => field.defaultValueProvider);
const hasTypeDefField = allFields.some((field) => field.isTypeDef);
Expand All @@ -120,13 +122,22 @@ export function createEnhancement<DbClient extends object>(
}
}

// password enhancement must be applied prior to policy because it changes then length of the field
// password and encrypted enhancement must be applied prior to policy because it changes then length of the field
// and can break validation rules like `@length`
if (hasPassword && kinds.includes('password')) {
// @password proxy
result = withPassword(result, options);
}

if (hasEncrypted && kinds.includes('encrypted')) {
if (!options.encryption) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we validate the shape of options.encryption? Here or inside the EncryptedHandler constructor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like validating inside of the handler constructor makes better sense, and its together with the rest of the encrypted logic

throw new Error('Encryption options are required for @encrypted enhancement');
}

// @encrypted proxy
result = withEncrypted(result, options);
}

// 'policy' and 'validation' enhancements are both enabled by `withPolicy`
if (kinds.includes('policy') || kinds.includes('validation')) {
result = withPolicy(result, options, context);
Expand Down
167 changes: 167 additions & 0 deletions packages/runtime/src/enhancements/node/encrypted.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/no-unused-vars */

import {
FieldInfo,
NestedWriteVisitor,
enumerate,
getModelFields,
resolveField,
type PrismaWriteActionType,
} from '../../cross';
import { DbClientContract, CustomEncryption, SimpleEncryption } from '../../types';
import { InternalEnhancementOptions } from './create-enhancement';
import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy';
import { QueryUtils } from './query-utils';

/**
* Gets an enhanced Prisma client that supports `@encrypted` attribute.
*
* @private
*/
export function withEncrypted<DbClient extends object = any>(
prisma: DbClient,
options: InternalEnhancementOptions
): DbClient {
return makeProxy(
prisma,
options.modelMeta,
(_prisma, model) => new EncryptedHandler(_prisma as DbClientContract, model, options),
'encrypted'
);
}

class EncryptedHandler extends DefaultPrismaProxyHandler {
private queryUtils: QueryUtils;
private encoder = new TextEncoder();
private decoder = new TextDecoder();

constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) {
super(prisma, model, options);

this.queryUtils = new QueryUtils(prisma, options);

if (!options.encryption) throw new Error('Encryption options must be provided');

if (this.isCustomEncryption(options.encryption!)) {
if (!options.encryption.encrypt || !options.encryption.decrypt)
throw new Error('Custom encryption must provide encrypt and decrypt functions');
} else {
if (!options.encryption.encryptionKey) throw new Error('Encryption key must be provided');
if (options.encryption.encryptionKey.length !== 32) throw new Error('Encryption key must be 32 bytes');
}
}

private async getKey(secret: Uint8Array): Promise<CryptoKey> {
return crypto.subtle.importKey('raw', secret, 'AES-GCM', false, ['encrypt', 'decrypt']);
}

private isCustomEncryption(encryption: CustomEncryption | SimpleEncryption): encryption is CustomEncryption {
return 'encrypt' in encryption && 'decrypt' in encryption;
}

private async encrypt(field: FieldInfo, data: string): Promise<string> {
if (this.isCustomEncryption(this.options.encryption!)) {
return this.options.encryption.encrypt(this.model, field, data);
}

const key = await this.getKey(this.options.encryption!.encryptionKey);
const iv = crypto.getRandomValues(new Uint8Array(12));

const encrypted = await crypto.subtle.encrypt(
{
name: 'AES-GCM',
iv,
},
key,
this.encoder.encode(data)
);

// Combine IV and encrypted data into a single array of bytes
const bytes = [...iv, ...new Uint8Array(encrypted)];

// Convert bytes to base64 string
return btoa(String.fromCharCode(...bytes));
}

private async decrypt(field: FieldInfo, data: string): Promise<string> {
if (this.isCustomEncryption(this.options.encryption!)) {
return this.options.encryption.decrypt(this.model, field, data);
}

const key = await this.getKey(this.options.encryption!.encryptionKey);

// Convert base64 back to bytes
const bytes = Uint8Array.from(atob(data), (c) => c.charCodeAt(0));

// First 12 bytes are IV, rest is encrypted data
const decrypted = await crypto.subtle.decrypt(
{
name: 'AES-GCM',
iv: bytes.slice(0, 12),
},
key,
bytes.slice(12)
);

return this.decoder.decode(decrypted);
}
Comment on lines +87 to +108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error handling in decrypt method.

As suggested in the past review, consider returning the original ciphertext on decryption failure to allow gradual adoption of encryption.

 private async decrypt(field: FieldInfo, data: string): Promise<string> {
     if (this.isCustomEncryption(this.options.encryption!)) {
         return this.options.encryption.decrypt(this.model, field, data);
     }

     const key = await this.getKey(this.options.encryption!.encryptionKey);
+    try {
         // Convert base64 back to bytes
         const bytes = Uint8Array.from(atob(data), (c) => c.charCodeAt(0));

         // First 12 bytes are IV, rest is encrypted data
         const decrypted = await crypto.subtle.decrypt(
             {
                 name: 'AES-GCM',
                 iv: bytes.slice(0, 12),
             },
             key,
             bytes.slice(12)
         );

         return this.decoder.decode(decrypted);
+    } catch (error) {
+        // Return original text if decryption fails
+        // This allows gradual adoption of encryption
+        console.warn(`Failed to decrypt field ${field.name} in model ${this.model}:`, error);
+        return data;
+    }
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
private async decrypt(field: FieldInfo, data: string): Promise<string> {
if (this.isCustomEncryption(this.options.encryption!)) {
return this.options.encryption.decrypt(this.model, field, data);
}
const key = await this.getKey(this.options.encryption!.encryptionKey);
// Convert base64 back to bytes
const bytes = Uint8Array.from(atob(data), (c) => c.charCodeAt(0));
// First 12 bytes are IV, rest is encrypted data
const decrypted = await crypto.subtle.decrypt(
{
name: 'AES-GCM',
iv: bytes.slice(0, 12),
},
key,
bytes.slice(12)
);
return this.decoder.decode(decrypted);
}
private async decrypt(field: FieldInfo, data: string): Promise<string> {
if (this.isCustomEncryption(this.options.encryption!)) {
return this.options.encryption.decrypt(this.model, field, data);
}
const key = await this.getKey(this.options.encryption!.encryptionKey);
try {
// Convert base64 back to bytes
const bytes = Uint8Array.from(atob(data), (c) => c.charCodeAt(0));
// First 12 bytes are IV, rest is encrypted data
const decrypted = await crypto.subtle.decrypt(
{
name: 'AES-GCM',
iv: bytes.slice(0, 12),
},
key,
bytes.slice(12)
);
return this.decoder.decode(decrypted);
} catch (error) {
// Return original text if decryption fails
// This allows gradual adoption of encryption
console.warn(`Failed to decrypt field ${field.name} in model ${this.model}:`, error);
return data;
}
}


// base override
protected async preprocessArgs(action: PrismaProxyActions, args: any) {
const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert'];
if (args && args.data && actionsOfInterest.includes(action)) {
await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args);
}
return args;
}

// base override
protected async processResultEntity<T>(method: PrismaProxyActions, data: T): Promise<T> {
if (!data || typeof data !== 'object') {
return data;
}

for (const value of enumerate(data)) {
await this.doPostProcess(value, this.model);
}

return data;
}

private async doPostProcess(entityData: any, model: string) {
const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData);

for (const field of getModelFields(entityData)) {
const fieldInfo = await resolveField(this.options.modelMeta, realModel, field);

if (!fieldInfo) {
continue;
}

const shouldDecrypt = fieldInfo.attributes?.find((attr) => attr.name === '@encrypted');
if (shouldDecrypt) {
// Don't decrypt null, undefined or empty string values
if (!entityData[field]) return;

entityData[field] = await this.decrypt(fieldInfo, entityData[field]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If decryption fails, should we return the original cipher text? I'm thinking this will allow easier adoption: the @encrypted attribute can be added and deployed and then a background script is run to migrate the existing plain-text data.

}
}
}
Comment on lines +132 to +150
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix early return and add nullish check in doPostProcess.

There are two issues in this method:

  1. The early return on line 146 prevents processing other fields
  2. Missing nullish check for data as suggested in past review
 private async doPostProcess(entityData: any, model: string) {
+    if (entityData == null) {
+        return;
+    }
     const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData);

     for (const field of getModelFields(entityData)) {
         const fieldInfo = await resolveField(this.options.modelMeta, realModel, field);

         if (!fieldInfo) {
             continue;
         }

         const shouldDecrypt = fieldInfo.attributes?.find((attr) => attr.name === '@encrypted');
         if (shouldDecrypt) {
             // Don't decrypt null, undefined or empty string values
-            if (!entityData[field]) return;
+            if (!entityData[field]) continue;

             entityData[field] = await this.decrypt(fieldInfo, entityData[field]);
         }
     }
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
private async doPostProcess(entityData: any, model: string) {
const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData);
for (const field of getModelFields(entityData)) {
const fieldInfo = await resolveField(this.options.modelMeta, realModel, field);
if (!fieldInfo) {
continue;
}
const shouldDecrypt = fieldInfo.attributes?.find((attr) => attr.name === '@encrypted');
if (shouldDecrypt) {
// Don't decrypt null, undefined or empty string values
if (!entityData[field]) return;
entityData[field] = await this.decrypt(fieldInfo, entityData[field]);
}
}
}
private async doPostProcess(entityData: any, model: string) {
if (entityData == null) {
return;
}
const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData);
for (const field of getModelFields(entityData)) {
const fieldInfo = await resolveField(this.options.modelMeta, realModel, field);
if (!fieldInfo) {
continue;
}
const shouldDecrypt = fieldInfo.attributes?.find((attr) => attr.name === '@encrypted');
if (shouldDecrypt) {
// Don't decrypt null, undefined or empty string values
if (!entityData[field]) continue;
entityData[field] = await this.decrypt(fieldInfo, entityData[field]);
}
}
}


private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) {
const visitor = new NestedWriteVisitor(this.options.modelMeta, {
field: async (field, _action, data, context) => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add a nullish check to data here?

// Don't encrypt null, undefined or empty string values
if (!data) return;

const encAttr = field.attributes?.find((attr) => attr.name === '@encrypted');
if (encAttr && field.type === 'String') {
context.parent[field.name] = await this.encrypt(field, data);
}
},
});

await visitor.visit(model, action, args);
}
}
15 changes: 14 additions & 1 deletion packages/runtime/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import type { z } from 'zod';
import { FieldInfo } from './cross';

export type PrismaPromise<T> = Promise<T> & Record<string, (args?: any) => PrismaPromise<any>>;

Expand Down Expand Up @@ -133,6 +134,11 @@ export type EnhancementOptions = {
* The `isolationLevel` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack.
*/
transactionIsolationLevel?: TransactionIsolationLevel;

/**
* The encryption options for using the `encrypted` enhancement.
*/
encryption?: SimpleEncryption | CustomEncryption;
};

/**
Expand All @@ -145,7 +151,7 @@ export type EnhancementContext<User extends AuthUser = AuthUser> = {
/**
* Kinds of enhancements to `PrismaClient`
*/
export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate';
export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate' | 'encrypted';

/**
* Function for transforming errors.
Expand All @@ -166,3 +172,10 @@ export type ZodSchemas = {
*/
input?: Record<string, Record<string, z.ZodSchema>>;
};

export type CustomEncryption = {
encrypt: (model: string, field: FieldInfo, plain: string) => Promise<string>;
decrypt: (model: string, field: FieldInfo, cipher: string) => Promise<string>;
};

export type SimpleEncryption = { encryptionKey: Uint8Array };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add key validation to SimpleEncryption type

Consider adding validation constraints for the encryption key:

  • Define minimum/maximum key length
  • Specify required key format/encoding
  • Add helper methods for key validation

8 changes: 8 additions & 0 deletions packages/schema/src/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,14 @@ attribute @@auth() @@@supportTypeDef
*/
attribute @password(saltLength: Int?, salt: String?) @@@targetField([StringField])


/**
* Indicates that the field is encrypted when storing in the DB and should be decrypted when read
*
* ZenStack uses the Web Crypto API to encrypt and decrypt the field.
*/
attribute @encrypted() @@@targetField([StringField])

/**
* Indicates that the field should be omitted when read from the generated services.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import { FieldInfo } from '@zenstackhq/runtime';
import { loadSchema } from '@zenstackhq/testtools';
import path from 'path';

describe('Encrypted test', () => {
let origDir: string;

beforeAll(async () => {
origDir = path.resolve('.');
});

afterEach(async () => {
process.chdir(origDir);
});

it('Simple encryption test', async () => {
const { enhance } = await loadSchema(`
model User {
id String @id @default(cuid())
encrypted_value String @encrypted()

@@allow('all', true)
}`);

const sudoDb = enhance(undefined, { kinds: [] });
const db = enhance(undefined, {
kinds: ['encrypted'],
encryption: { encryptionKey: 'c558Gq0YQK2QcqtkMF9BGXHCQn4dMF8w' },
});

const create = await db.user.create({
data: {
id: '1',
encrypted_value: 'abc123',
},
});

const read = await db.user.findUnique({
where: {
id: '1',
},
});

const sudoRead = await sudoDb.user.findUnique({
where: {
id: '1',
},
});

expect(create.encrypted_value).toBe('abc123');
expect(read.encrypted_value).toBe('abc123');
expect(sudoRead.encrypted_value).not.toBe('abc123');
});

it('Custom encryption test', async () => {
const { enhance } = await loadSchema(`
model User {
id String @id @default(cuid())
encrypted_value String @encrypted()

@@allow('all', true)
}`);

const sudoDb = enhance(undefined, { kinds: [] });
const db = enhance(undefined, {
kinds: ['encrypted'],
encryption: {
encrypt: async (model: string, field: FieldInfo, data: string) => {
// Add _enc to the end of the input
return data + '_enc';
},
decrypt: async (model: string, field: FieldInfo, cipher: string) => {
// Remove _enc from the end of the input explicitly
if (cipher.endsWith('_enc')) {
return cipher.slice(0, -4); // Remove last 4 characters (_enc)
}

return cipher;
},
},
});

const create = await db.user.create({
data: {
id: '1',
encrypted_value: 'abc123',
},
});

const read = await db.user.findUnique({
where: {
id: '1',
},
});

const sudoRead = await sudoDb.user.findUnique({
where: {
id: '1',
},
});

expect(create.encrypted_value).toBe('abc123');
expect(read.encrypted_value).toBe('abc123');
expect(sudoRead.encrypted_value).toBe('abc123_enc');
});
});
Loading